xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_custom_operators.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/jit/test_utils.h>
4 
5 #include <torch/csrc/jit/ir/alias_analysis.h>
6 #include <torch/csrc/jit/ir/irparser.h>
7 #include <torch/csrc/jit/passes/dead_code_elimination.h>
8 #include <torch/csrc/jit/runtime/custom_operator.h>
9 #include <torch/csrc/jit/runtime/register_ops_utils.h>
10 #include <torch/jit.h>
11 
12 namespace torch {
13 namespace jit {
14 
TEST(CustomOperatorTest,InferredSchema)15 TEST(CustomOperatorTest, InferredSchema) {
16   torch::RegisterOperators reg(
17       "foo::bar", [](double a, at::Tensor b) { return a + b; });
18   auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar"));
19   ASSERT_EQ(ops.size(), 1);
20 
21   auto& op = ops.front();
22   ASSERT_EQ(op->schema().name(), "foo::bar");
23 
24   ASSERT_EQ(op->schema().arguments().size(), 2);
25   ASSERT_EQ(op->schema().arguments()[0].name(), "_0");
26   ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
27   ASSERT_EQ(op->schema().arguments()[1].name(), "_1");
28   ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
29 
30   ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
31 
32   Stack stack;
33   push(stack, 2.0f, at::ones(5));
34   op->getOperation()(stack);
35   at::Tensor output;
36   pop(stack, output);
37 
38   ASSERT_TRUE(output.allclose(at::full(5, 3.0f)));
39 }
40 
TEST(CustomOperatorTest,ExplicitSchema)41 TEST(CustomOperatorTest, ExplicitSchema) {
42   torch::RegisterOperators reg(
43       "foo::bar_with_schema(float a, Tensor b) -> Tensor",
44       [](double a, at::Tensor b) { return a + b; });
45 
46   auto& ops =
47       getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema"));
48   ASSERT_EQ(ops.size(), 1);
49 
50   auto& op = ops.front();
51   ASSERT_EQ(op->schema().name(), "foo::bar_with_schema");
52 
53   ASSERT_EQ(op->schema().arguments().size(), 2);
54   ASSERT_EQ(op->schema().arguments()[0].name(), "a");
55   ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
56   ASSERT_EQ(op->schema().arguments()[1].name(), "b");
57   ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
58 
59   ASSERT_EQ(op->schema().returns().size(), 1);
60   ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
61 
62   Stack stack;
63   push(stack, 2.0f, at::ones(5));
64   op->getOperation()(stack);
65   at::Tensor output;
66   pop(stack, output);
67 
68   ASSERT_TRUE(output.allclose(at::full(5, 3.0f)));
69 }
70 
TEST(CustomOperatorTest,ListParameters)71 TEST(CustomOperatorTest, ListParameters) {
72   // Check that lists work well.
73   torch::RegisterOperators reg(
74       "foo::lists(int[] ints, float[] floats, complex[] complexdoubles, Tensor[] tensors) -> float[]",
75       [](torch::List<int64_t> ints,
76          torch::List<double> floats,
77          torch::List<c10::complex<double>> complexdoubles,
78          torch::List<at::Tensor> tensors) { return floats; });
79 
80   auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists"));
81   ASSERT_EQ(ops.size(), 1);
82 
83   auto& op = ops.front();
84   ASSERT_EQ(op->schema().name(), "foo::lists");
85 
86   ASSERT_EQ(op->schema().arguments().size(), 4);
87   ASSERT_EQ(op->schema().arguments()[0].name(), "ints");
88   ASSERT_TRUE(
89       op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofInts()));
90   ASSERT_EQ(op->schema().arguments()[1].name(), "floats");
91   ASSERT_TRUE(
92       op->schema().arguments()[1].type()->isSubtypeOf(*ListType::ofFloats()));
93   ASSERT_EQ(op->schema().arguments()[2].name(), "complexdoubles");
94   ASSERT_TRUE(op->schema().arguments()[2].type()->isSubtypeOf(
95       *ListType::ofComplexDoubles()));
96   ASSERT_EQ(op->schema().arguments()[3].name(), "tensors");
97   ASSERT_TRUE(
98       op->schema().arguments()[3].type()->isSubtypeOf(*ListType::ofTensors()));
99 
100   ASSERT_EQ(op->schema().returns().size(), 1);
101   ASSERT_TRUE(
102       op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofFloats()));
103 
104   Stack stack;
105   push(stack, c10::List<int64_t>({1, 2}));
106   push(stack, c10::List<double>({1.0, 2.0}));
107   push(
108       stack,
109       c10::List<c10::complex<double>>(
110           {c10::complex<double>(2.4, -5.5), c10::complex<double>(-1.3, 2)}));
111   push(stack, c10::List<at::Tensor>({at::ones(5)}));
112   op->getOperation()(stack);
113   c10::List<double> output;
114   pop(stack, output);
115 
116   ASSERT_EQ(output.size(), 2);
117   ASSERT_EQ(output.get(0), 1.0);
118   ASSERT_EQ(output.get(1), 2.0);
119 }
120 
TEST(CustomOperatorTest,ListParameters2)121 TEST(CustomOperatorTest, ListParameters2) {
122   torch::RegisterOperators reg(
123       "foo::lists2(Tensor[] tensors) -> Tensor[]",
124       [](torch::List<at::Tensor> tensors) { return tensors; });
125 
126   auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
127   ASSERT_EQ(ops.size(), 1);
128 
129   auto& op = ops.front();
130   ASSERT_EQ(op->schema().name(), "foo::lists2");
131 
132   ASSERT_EQ(op->schema().arguments().size(), 1);
133   ASSERT_EQ(op->schema().arguments()[0].name(), "tensors");
134   ASSERT_TRUE(
135       op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofTensors()));
136 
137   ASSERT_EQ(op->schema().returns().size(), 1);
138   ASSERT_TRUE(
139       op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofTensors()));
140 
141   Stack stack;
142   push(stack, c10::List<at::Tensor>({at::ones(5)}));
143   op->getOperation()(stack);
144   c10::List<at::Tensor> output;
145   pop(stack, output);
146 
147   ASSERT_EQ(output.size(), 1);
148   ASSERT_TRUE(output.get(0).allclose(at::ones(5)));
149 }
150 
TEST(CustomOperatorTest,Aliasing)151 TEST(CustomOperatorTest, Aliasing) {
152   torch::RegisterOperators reg(
153       "foo::aliasing", [](at::Tensor a, at::Tensor b) -> at::Tensor {
154         a.add_(b);
155         return a;
156       });
157   getAllOperatorsFor(Symbol::fromQualString("foo::aliasing"));
158 
159   {
160     auto graph = std::make_shared<Graph>();
161     parseIR(
162         R"IR(
163 graph(%x: Tensor, %y: Tensor):
164   %ret : Tensor = foo::aliasing(%x, %y)
165   return (%ret)
166   )IR",
167         graph.get());
168 
169     auto opNode = *graph->block()->nodes().begin();
170 
171     AliasDb aliasDb(graph);
172     for (const auto input : opNode->inputs()) {
173       // The custom op writes to all its inputs
174       ASSERT_TRUE(aliasDb.writesToAlias(opNode, {input}));
175       // The output should be a wildcard and thus alias all inputs
176       ASSERT_TRUE(aliasDb.mayAlias(opNode->output(), input));
177     }
178   }
179   {
180     // DCE should not remove a custom op
181     auto graph = std::make_shared<Graph>();
182     const auto text = R"IR(
183 graph(%x: Tensor, %y: Tensor):
184   # CHECK: foo::aliasing
185   %ret : Tensor = foo::aliasing(%x, %y)
186   return (%x)
187   )IR";
188     parseIR(text, graph.get());
189     EliminateDeadCode(graph);
190 
191     testing::FileCheck().run(text, *graph);
192   }
193 }
194 
195 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
196 static constexpr char op_list[] = "foofoo::bar.template;foo::another";
197 #define TORCH_SELECTIVE_NAME_IN_SCHEMA(l, n)                                   \
198   torch::detail::SelectiveStr<c10::impl::op_allowlist_contains_name_in_schema( \
199       l, n)>(n)
200 
TEST(TestCustomOperator,OperatorGeneratorUndeclared)201 TEST(TestCustomOperator, OperatorGeneratorUndeclared) {
202   // Try to register an op name that does not exist in op_list.
203   // Expected: the op name is not registered.
204   torch::jit::RegisterOperators reg({OperatorGenerator(
205       TORCH_SELECTIVE_NAME_IN_SCHEMA(
206           op_list, "foofoo::not_exist(float a, Tensor b) -> Tensor"),
207       [](Stack& stack) {
208         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
209         double a;
210         at::Tensor b;
211         pop(stack, a, b);
212         push(stack, a + b);
213       },
214       aliasAnalysisFromSchema())});
215 
216   auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist"));
217   ASSERT_EQ(ops.size(), 0);
218 }
219 
TEST(TestCustomOperator,OperatorGeneratorBasic)220 TEST(TestCustomOperator, OperatorGeneratorBasic) {
221   // The operator should be successfully registered since its name is in the
222   // whitelist.
223   torch::jit::RegisterOperators reg({OperatorGenerator(
224       TORCH_SELECTIVE_NAME_IN_SCHEMA(
225           op_list, "foofoo::bar.template(float a, Tensor b) -> Tensor"),
226       [](Stack& stack) {
227         // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
228         double a;
229         at::Tensor b;
230         pop(stack, a, b);
231         push(stack, a + b);
232       },
233       aliasAnalysisFromSchema())});
234 
235   auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar"));
236   ASSERT_EQ(ops.size(), 1);
237 
238   auto& op = ops.front();
239   ASSERT_EQ(op->schema().name(), "foofoo::bar");
240 
241   ASSERT_EQ(op->schema().arguments().size(), 2);
242   ASSERT_EQ(op->schema().arguments()[0].name(), "a");
243   ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
244   ASSERT_EQ(op->schema().arguments()[1].name(), "b");
245   ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
246 
247   ASSERT_EQ(op->schema().returns().size(), 1);
248   ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
249 
250   Stack stack;
251   push(stack, 2.0f, at::ones(5));
252   op->getOperation()(stack);
253   at::Tensor output;
254   pop(stack, output);
255 
256   ASSERT_TRUE(output.allclose(at::full(5, 3.0f)));
257 }
258 
259 } // namespace jit
260 } // namespace torch
261