1*da0073e9SAndroid Build Coastguard Worker #include <torch/extension.h> 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker struct Doubler { DoublerDoubler4*da0073e9SAndroid Build Coastguard Worker Doubler(int A, int B) { 5*da0073e9SAndroid Build Coastguard Worker tensor_ = 6*da0073e9SAndroid Build Coastguard Worker torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true)); 7*da0073e9SAndroid Build Coastguard Worker } forwardDoubler8*da0073e9SAndroid Build Coastguard Worker torch::Tensor forward() { 9*da0073e9SAndroid Build Coastguard Worker return tensor_ * 2; 10*da0073e9SAndroid Build Coastguard Worker } getDoubler11*da0073e9SAndroid Build Coastguard Worker torch::Tensor get() const { 12*da0073e9SAndroid Build Coastguard Worker return tensor_; 13*da0073e9SAndroid Build Coastguard Worker } 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker private: 16*da0073e9SAndroid Build Coastguard Worker torch::Tensor tensor_; 17*da0073e9SAndroid Build Coastguard Worker }; 18