xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/maia_extension.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/extension.h>
2 #include <torch/library.h>
3 
4 using namespace at;
5 
6 static int test_int;
7 
get_tensor(caffe2::TypeMeta dtype,IntArrayRef size)8 Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
9   auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
10       Storage(
11           Storage::use_byte_size_t(),
12           0,
13           at::DataPtr(nullptr, Device(DeviceType::MAIA, 0)),
14           nullptr,
15           false),
16       DispatchKey::MAIA,
17       dtype);
18   // This is a hack to workaround the shape checks in _convolution.
19   tensor_impl->set_sizes_contiguous(size);
20   return Tensor(std::move(tensor_impl));
21 }
22 
empty_override(IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<c10::MemoryFormat> optional_memory_format)23 Tensor empty_override(IntArrayRef size, std::optional<ScalarType> dtype, std::optional<Layout> layout, std::optional<Device> device,
24                       std::optional<bool> pin_memory, std::optional<c10::MemoryFormat> optional_memory_format) {
25   test_int = 0;
26   return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size);
27 }
28 
add_out_override(const Tensor & a,const Tensor & b,const Scalar & c,Tensor & out)29 Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) {
30   test_int = 1;
31   return out;
32 }
33 
fake_convolution(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups)34 Tensor fake_convolution(
35     const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias,
36     IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
37     bool transposed, IntArrayRef output_padding, int64_t groups) {
38   test_int = 2;
39   // Only the first 2 dimension of output shape is correct.
40   return get_tensor(input.dtype(), {input.size(0), weight.size(0), input.size(2), input.size(3)});
41 }
42 
fake_convolution_backward(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,std::array<bool,3> output_mask)43 std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
44         const Tensor & grad_output, const Tensor & input, const Tensor & weight,
45         IntArrayRef stride, IntArrayRef padding,
46         IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
47         int64_t groups, std::array<bool,3> output_mask) {
48     test_int = 3;
49     return std::tuple<Tensor, Tensor, Tensor>(
50             get_tensor(input.dtype(), input.sizes()),
51             get_tensor(weight.dtype(), weight.sizes()),
52             get_tensor(input.dtype(), {}));
53 }
54 
TORCH_LIBRARY_IMPL(aten,MAIA,m)55 TORCH_LIBRARY_IMPL(aten, MAIA, m) {
56   m.impl("empty.memory_format",                empty_override);
57   m.impl("add.out",                            add_out_override);
58   m.impl("convolution_overrideable",           fake_convolution);
59   m.impl("convolution_backward_overrideable",  fake_convolution_backward);
60 }
61 
62 // TODO: Extend this to exercise multi-device setting.  In that case,
63 // we need to add a thread local variable to track the current device.
64 struct MAIAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
65   static constexpr DeviceType static_type = DeviceType::MAIA;
MAIAGuardImplMAIAGuardImpl66   MAIAGuardImpl() {}
MAIAGuardImplMAIAGuardImpl67   MAIAGuardImpl(DeviceType t) {
68     AT_ASSERT(t == DeviceType::MAIA);
69   }
typeMAIAGuardImpl70   DeviceType type() const override {
71     return DeviceType::MAIA;
72   }
exchangeDeviceMAIAGuardImpl73   Device exchangeDevice(Device d) const override {
74     AT_ASSERT(d.type() == DeviceType::MAIA);
75     AT_ASSERT(d.index() == 0);
76     return d;
77   }
getDeviceMAIAGuardImpl78   Device getDevice() const override {
79     return Device(DeviceType::MAIA, 0);
80   }
setDeviceMAIAGuardImpl81   void setDevice(Device d) const override {
82     AT_ASSERT(d.type() == DeviceType::MAIA);
83     AT_ASSERT(d.index() == 0);
84   }
uncheckedSetDeviceMAIAGuardImpl85   void uncheckedSetDevice(Device d) const noexcept override {
86   }
getStreamMAIAGuardImpl87   Stream getStream(Device d) const noexcept override {
88     return Stream(Stream::DEFAULT, Device(DeviceType::MAIA, 0));
89   }
exchangeStreamMAIAGuardImpl90   Stream exchangeStream(Stream s) const noexcept override {
91     return Stream(Stream::DEFAULT, Device(DeviceType::MAIA, 0));
92   }
deviceCountMAIAGuardImpl93   DeviceIndex deviceCount() const noexcept override {
94     return 1;
95   }
96 
97   // Event-related functions
recordMAIAGuardImpl98   void record(void** event,
99     const Stream& stream,
100     const DeviceIndex device_index,
101     const EventFlag flag) const override {
102     TORCH_CHECK(false, "MAIA backend doesn't support events.");
103   }
blockMAIAGuardImpl104   void block(
105     void* event,
106     const Stream& stream) const override {
107     TORCH_CHECK(false, "MAIA backend doesn't support events.");
108   }
queryEventMAIAGuardImpl109   bool queryEvent(void* event) const override {
110     TORCH_CHECK(false, "MAIA backend doesn't support events.");
111   }
destroyEventMAIAGuardImpl112   void destroyEvent(
113     void* event,
114     const DeviceIndex device_index) const noexcept override { }
115 };
116 
117 constexpr DeviceType MAIAGuardImpl::static_type;
118 C10_REGISTER_GUARD_IMPL(MAIA, MAIAGuardImpl);
119 
get_test_int()120 int get_test_int() {
121   return test_int;
122 }
123 
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)124 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
125   m.def("get_test_int", &get_test_int);
126 }
127