xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/doubler.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/extension.h>
2 
3 struct Doubler {
DoublerDoubler4   Doubler(int A, int B) {
5     tensor_ =
6         torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
7   }
forwardDoubler8   torch::Tensor forward() {
9     return tensor_ * 2;
10   }
getDoubler11   torch::Tensor get() const {
12     return tensor_;
13   }
14 
15  private:
16   torch::Tensor tensor_;
17 };
18