xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/extension_backend_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/NativeFunctions.h>
5 #include <torch/library.h>
6 
7 #include <torch/csrc/jit/runtime/operator.h>
8 
9 // NB. These tests use the MAIA dispatch key to test backend dispatching
10 // machinery, but these tests are not specific to MAIA at all. The MAIA
11 // backend is fully out-of-tree, so it's safe to use this key for
12 // in-tree tests.
13 
14 using namespace at;
15 
16 static int test_int;
17 
empty_override(SymIntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<MemoryFormat> optional_memory_format)18 Tensor empty_override(SymIntArrayRef size, std::optional<ScalarType> dtype, std::optional<Layout> layout,
19                       std::optional<Device> device, std::optional<bool> pin_memory, std::optional<MemoryFormat> optional_memory_format) {
20   test_int = 1;
21   auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
22       Storage(
23           Storage::use_byte_size_t(),
24           0,
25           at::DataPtr(nullptr, Device(DeviceType::MAIA, 1)),
26           nullptr,
27           false),
28       DispatchKey::MAIA,
29       caffe2::TypeMeta::Make<float>());
30   return Tensor(std::move(tensor_impl));
31 }
32 
add_override(const Tensor & a,const Tensor & b,const Scalar & c)33 Tensor add_override(const Tensor & a, const Tensor & b , const Scalar& c) {
34   auto out = empty({5, 5}, at::kMAIA);  // Don't return self as-is
35   test_int = 2;
36   return out;
37 }
38 
empty_strided_override(IntArrayRef size,IntArrayRef stride,std::optional<c10::ScalarType> dtype,std::optional<c10::Layout> layout,std::optional<c10::Device> device,std::optional<bool> pin_memory)39 Tensor empty_strided_override(
40   IntArrayRef size,
41   IntArrayRef stride,
42   std::optional<c10::ScalarType> dtype,
43   std::optional<c10::Layout> layout,
44   std::optional<c10::Device> device,
45   std::optional<bool> pin_memory) {
46 
47   return empty_override(fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, std::nullopt);
48 }
49 
TORCH_LIBRARY_IMPL(aten,MAIA,m)50 TORCH_LIBRARY_IMPL(aten, MAIA, m) {
51   m.impl("aten::empty.memory_format",  empty_override);
52   m.impl("aten::empty_strided",        empty_strided_override);
53   m.impl("aten::add.Tensor",           add_override);
54 }
55 
TEST(BackendExtensionTest,TestRegisterOp)56 TEST(BackendExtensionTest, TestRegisterOp) {
57   Tensor a = empty({5, 5}, at::kMAIA);
58   ASSERT_EQ(a.device().type(), at::kMAIA);
59   ASSERT_EQ(a.device().index(), 1);
60   ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>());
61   ASSERT_EQ(test_int, 1);
62 
63   Tensor b = empty_like(a, at::kMAIA);
64   ASSERT_EQ(b.device().type(), at::kMAIA);
65   ASSERT_EQ(b.device().index(), 1);
66   ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make<float>());
67 
68   add(a, b);
69   ASSERT_EQ(test_int, 2);
70 
71   // Ensure that non-MAIA operator still works
72   Tensor d = empty({5, 5}, at::kCPU);
73   ASSERT_EQ(d.device().type(), at::kCPU);
74 }
75