xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/doubler.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <torch/extension.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker struct Doubler {
DoublerDoubler4*da0073e9SAndroid Build Coastguard Worker   Doubler(int A, int B) {
5*da0073e9SAndroid Build Coastguard Worker     tensor_ =
6*da0073e9SAndroid Build Coastguard Worker         torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
7*da0073e9SAndroid Build Coastguard Worker   }
forwardDoubler8*da0073e9SAndroid Build Coastguard Worker   torch::Tensor forward() {
9*da0073e9SAndroid Build Coastguard Worker     return tensor_ * 2;
10*da0073e9SAndroid Build Coastguard Worker   }
getDoubler11*da0073e9SAndroid Build Coastguard Worker   torch::Tensor get() const {
12*da0073e9SAndroid Build Coastguard Worker     return tensor_;
13*da0073e9SAndroid Build Coastguard Worker   }
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker  private:
16*da0073e9SAndroid Build Coastguard Worker   torch::Tensor tensor_;
17*da0073e9SAndroid Build Coastguard Worker };
18