1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/NativeFunctions.h>
5 #include <torch/library.h>
6
7 #include <torch/csrc/jit/runtime/operator.h>
8
9 // NB. These tests use the MAIA dispatch key to test backend dispatching
10 // machinery, but these tests are not specific to MAIA at all. The MAIA
11 // backend is fully out-of-tree, so it's safe to use this key for
12 // in-tree tests.
13
14 using namespace at;
15
16 static int test_int;
17
empty_override(SymIntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<MemoryFormat> optional_memory_format)18 Tensor empty_override(SymIntArrayRef size, std::optional<ScalarType> dtype, std::optional<Layout> layout,
19 std::optional<Device> device, std::optional<bool> pin_memory, std::optional<MemoryFormat> optional_memory_format) {
20 test_int = 1;
21 auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
22 Storage(
23 Storage::use_byte_size_t(),
24 0,
25 at::DataPtr(nullptr, Device(DeviceType::MAIA, 1)),
26 nullptr,
27 false),
28 DispatchKey::MAIA,
29 caffe2::TypeMeta::Make<float>());
30 return Tensor(std::move(tensor_impl));
31 }
32
add_override(const Tensor & a,const Tensor & b,const Scalar & c)33 Tensor add_override(const Tensor & a, const Tensor & b , const Scalar& c) {
34 auto out = empty({5, 5}, at::kMAIA); // Don't return self as-is
35 test_int = 2;
36 return out;
37 }
38
empty_strided_override(IntArrayRef size,IntArrayRef stride,std::optional<c10::ScalarType> dtype,std::optional<c10::Layout> layout,std::optional<c10::Device> device,std::optional<bool> pin_memory)39 Tensor empty_strided_override(
40 IntArrayRef size,
41 IntArrayRef stride,
42 std::optional<c10::ScalarType> dtype,
43 std::optional<c10::Layout> layout,
44 std::optional<c10::Device> device,
45 std::optional<bool> pin_memory) {
46
47 return empty_override(fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, std::nullopt);
48 }
49
TORCH_LIBRARY_IMPL(aten,MAIA,m)50 TORCH_LIBRARY_IMPL(aten, MAIA, m) {
51 m.impl("aten::empty.memory_format", empty_override);
52 m.impl("aten::empty_strided", empty_strided_override);
53 m.impl("aten::add.Tensor", add_override);
54 }
55
TEST(BackendExtensionTest,TestRegisterOp)56 TEST(BackendExtensionTest, TestRegisterOp) {
57 Tensor a = empty({5, 5}, at::kMAIA);
58 ASSERT_EQ(a.device().type(), at::kMAIA);
59 ASSERT_EQ(a.device().index(), 1);
60 ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>());
61 ASSERT_EQ(test_int, 1);
62
63 Tensor b = empty_like(a, at::kMAIA);
64 ASSERT_EQ(b.device().type(), at::kMAIA);
65 ASSERT_EQ(b.device().index(), 1);
66 ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make<float>());
67
68 add(a, b);
69 ASSERT_EQ(test_int, 2);
70
71 // Ensure that non-MAIA operator still works
72 Tensor d = empty({5, 5}, at::kCPU);
73 ASSERT_EQ(d.device().type(), at::kCPU);
74 }
75