xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_schema_info.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <torch/csrc/autograd/generated/variable_factories.h>
3 #include <torch/csrc/utils/schema_info.h>
4 
5 namespace torch {
6 namespace utils {
7 using c10::SchemaArgType;
8 
TEST(FunctionSchemaIsAliasingTest,Basic)9 TEST(FunctionSchemaIsAliasingTest, Basic) {
10   c10::FunctionSchema schema = torch::jit::parseSchema(
11       "aten::test.Tensor(Tensor(a) self, Tensor(b!) other, Tensor more_other) -> (Tensor(a), Tensor(b!))");
12   ASSERT_TRUE(schema.is_aliasing({SchemaArgType::output, 0}));
13   ASSERT_TRUE(schema.is_aliasing({SchemaArgType::output, 1}));
14   ASSERT_TRUE(schema.is_aliasing({SchemaArgType::input, 0}));
15   ASSERT_TRUE(schema.is_aliasing({SchemaArgType::input, 1}));
16   ASSERT_FALSE(schema.is_aliasing({SchemaArgType::input, 2}));
17 }
18 
TEST(FunctionSchemaIsAliasingTest,InvalidArgument)19 TEST(FunctionSchemaIsAliasingTest, InvalidArgument) {
20   c10::FunctionSchema schema = torch::jit::parseSchema(
21       "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
22   ASSERT_THROW(schema.is_aliasing({SchemaArgType::input, 4}), c10::Error);
23   ASSERT_THROW(schema.is_aliasing({SchemaArgType::output, 4}), c10::Error);
24 }
25 
TEST(FunctionSchemaIsMutableTest,Basic)26 TEST(FunctionSchemaIsMutableTest, Basic) {
27   c10::FunctionSchema schema = torch::jit::parseSchema(
28       "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
29   ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 0}));
30   ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0}));
31   ASSERT_TRUE(schema.is_mutable("self"));
32   ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1}));
33   ASSERT_FALSE(schema.is_mutable("other"));
34   ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 2}));
35   ASSERT_FALSE(schema.is_mutable("alpha"));
36 }
37 
TEST(FunctionSchemaIsMutableTest,InvalidArgument)38 TEST(FunctionSchemaIsMutableTest, InvalidArgument) {
39   c10::FunctionSchema schema = torch::jit::parseSchema(
40       "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
41   ASSERT_THROW(schema.is_mutable({SchemaArgType::input, 4}), c10::Error);
42   ASSERT_THROW(schema.is_mutable({SchemaArgType::output, 4}), c10::Error);
43   ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
44 }
45 
TEST(SchemaInfoIsMutableTest,Basic)46 TEST(SchemaInfoIsMutableTest, Basic) {
47   SchemaInfo schema(
48       "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
49   ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0}));
50   ASSERT_TRUE(schema.is_mutable("self"));
51   ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1}));
52   ASSERT_FALSE(schema.is_mutable("other"));
53   ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 2}));
54   ASSERT_FALSE(schema.is_mutable("alpha"));
55 }
56 
TEST(SchemaInfoIsMutableTest,InvalidArgument)57 TEST(SchemaInfoIsMutableTest, InvalidArgument) {
58   SchemaInfo schema(
59       "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
60   ASSERT_THROW(schema.is_mutable({SchemaArgType::input, 4}), c10::Error);
61   ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
62 }
63 
TEST(SchemaInfoIsMutableTest,AliasingInputs)64 TEST(SchemaInfoIsMutableTest, AliasingInputs) {
65   SchemaInfo schema(
66       "aten::test.Tensor(Tensor(a!) self, Tensor(b) other, *, Scalar alpha=1) -> (Tensor(a!), Tensor(b))");
67   ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0}));
68   ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 0}));
69   ASSERT_TRUE(schema.is_mutable("self"));
70   ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1}));
71   ASSERT_FALSE(schema.is_mutable({SchemaArgType::output, 1}));
72   ASSERT_FALSE(schema.is_mutable("other"));
73   at::Tensor input = at::randn({3, 3});
74   schema.addArgumentValue("self", input);
75   schema.addArgumentValue("other", input);
76   ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 1}));
77   ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 1}));
78   ASSERT_TRUE(schema.is_mutable("other"));
79 }
80 
TEST(SchemaInfoIsMutableTest,InstanceNorm)81 TEST(SchemaInfoIsMutableTest, InstanceNorm) {
82   SchemaInfo schema_info(
83       "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor");
84   ASSERT_TRUE(schema_info.is_mutable("running_mean"));
85   ASSERT_TRUE(schema_info.is_mutable("running_var"));
86   schema_info.addArgumentValue("use_input_stats", false);
87   ASSERT_FALSE(schema_info.is_mutable("running_mean"));
88   ASSERT_FALSE(schema_info.is_mutable("running_var"));
89 }
90 
TEST(SchemaInfoIsMutableTest,BatchNorm)91 TEST(SchemaInfoIsMutableTest, BatchNorm) {
92   SchemaInfo schema_info(
93       "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor");
94   ASSERT_TRUE(schema_info.is_mutable("running_mean"));
95   ASSERT_TRUE(schema_info.is_mutable("running_var"));
96   schema_info.addArgumentValue("training", false);
97   ASSERT_FALSE(schema_info.is_mutable("running_mean"));
98   ASSERT_FALSE(schema_info.is_mutable("running_var"));
99 }
100 
TEST(SchemaInfoIsNonDeterministicTest,Basic)101 TEST(SchemaInfoIsNonDeterministicTest, Basic) {
102   SchemaInfo deterministic_schema_info(
103       "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
104   SchemaInfo nondeterministic_schema_info(
105       "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor");
106   ASSERT_FALSE(deterministic_schema_info.is_nondeterministic());
107   ASSERT_TRUE(nondeterministic_schema_info.is_nondeterministic());
108 }
109 
TEST(SchemaInfoIsNonDeterministicTest,Dropout)110 TEST(SchemaInfoIsNonDeterministicTest, Dropout) {
111   SchemaInfo droupout_schema_info(
112       "aten::dropout(Tensor input, float p, bool train) -> Tensor");
113   ASSERT_TRUE(droupout_schema_info.is_nondeterministic());
114   droupout_schema_info.addArgumentValue("train", false);
115   ASSERT_FALSE(droupout_schema_info.is_nondeterministic());
116 }
117 
TEST(FunctionSchemaMayAliasTest,Basic)118 TEST(FunctionSchemaMayAliasTest, Basic) {
119   c10::FunctionSchema schema = torch::jit::parseSchema(
120       "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
121   ASSERT_TRUE(
122       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
123   ASSERT_FALSE(
124       schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
125   ASSERT_FALSE(
126       schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::input, 0}));
127 }
128 
TEST(FunctionSchemaMayAliasTest,InvalidArgument)129 TEST(FunctionSchemaMayAliasTest, InvalidArgument) {
130   c10::FunctionSchema schema = torch::jit::parseSchema(
131       "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
132   ASSERT_THROW(
133       schema.may_alias({SchemaArgType::input, 15}, {SchemaArgType::output, 0}),
134       c10::Error);
135   ASSERT_THROW(
136       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 15}),
137       c10::Error);
138 }
139 
TEST(FunctionSchemaMayAliasTest,Wildcard)140 TEST(FunctionSchemaMayAliasTest, Wildcard) {
141   c10::FunctionSchema schema = torch::jit::parseSchema(
142       "aten::test.Tensor(Tensor(*) self) -> (Tensor(*), Tensor)");
143   ASSERT_TRUE(
144       schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
145   ASSERT_FALSE(
146       schema.may_alias({SchemaArgType::output, 1}, {SchemaArgType::input, 0}));
147 }
148 
TEST(SchemaInfoMayAliasTest,AliasingInputs)149 TEST(SchemaInfoMayAliasTest, AliasingInputs) {
150   SchemaInfo schema(
151       "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
152   ASSERT_FALSE(
153       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
154   at::Tensor input = at::randn({3, 3});
155   schema.addArgumentValue("self", input);
156   schema.addArgumentValue("other", input);
157   ASSERT_TRUE(
158       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
159 }
160 
TEST(SchemaInfoMayAliasTest,AliasingOutputs)161 TEST(SchemaInfoMayAliasTest, AliasingOutputs) {
162   SchemaInfo schema(
163       "aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)");
164   ASSERT_FALSE(
165       schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
166   at::Tensor input = at::randn({3, 3});
167   schema.addArgumentValue("min", input);
168   schema.addArgumentValue("max", input);
169   ASSERT_TRUE(
170       schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
171 }
172 
TEST(SchemaInfoMayAliasTest,AliasingInputOutput)173 TEST(SchemaInfoMayAliasTest, AliasingInputOutput) {
174   SchemaInfo schema(
175       "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
176   ASSERT_TRUE(
177       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
178   ASSERT_FALSE(
179       schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
180   at::Tensor input = at::randn({3, 3});
181   schema.addArgumentValue("self", input);
182   schema.addArgumentValue("other", input);
183   ASSERT_TRUE(
184       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
185   ASSERT_TRUE(
186       schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
187 }
188 
TEST(SchemaInfoMayAliasTest,MultipleWildcardInputs)189 TEST(SchemaInfoMayAliasTest, MultipleWildcardInputs) {
190   SchemaInfo schema(
191       "aten::test.Tensor(Tensor(a) a, Tensor(*) b, Tensor(*) c) -> (Tensor(a), Tensor(*))");
192   ASSERT_TRUE(
193       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
194   ASSERT_TRUE(
195       schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 1}));
196   ASSERT_TRUE(
197       schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 1}));
198   ASSERT_FALSE(
199       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
200   ASSERT_FALSE(
201       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
202   ASSERT_FALSE(
203       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 1}));
204   ASSERT_FALSE(
205       schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
206   at::Tensor input = at::randn({3, 3});
207   schema.addArgumentValue("a", input);
208   schema.addArgumentValue("b", input);
209   ASSERT_TRUE(
210       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
211   ASSERT_TRUE(
212       schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 1}));
213   ASSERT_TRUE(
214       schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 1}));
215   ASSERT_TRUE(
216       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
217   ASSERT_TRUE(
218       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
219   ASSERT_TRUE(
220       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 1}));
221   ASSERT_TRUE(
222       schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
223 }
224 
TEST(SchemaInfoMayAliasTest,MultipleNonWildcardInputs)225 TEST(SchemaInfoMayAliasTest, MultipleNonWildcardInputs) {
226   SchemaInfo schema(
227       "aten::test.Tensor(Tensor(a) a, Tensor(a) b, Tensor(*) c, Tensor(b) d) -> (Tensor(a), Tensor(*))");
228   ASSERT_TRUE(
229       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
230   ASSERT_TRUE(
231       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
232   ASSERT_TRUE(
233       schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::input, 1}));
234   ASSERT_TRUE(
235       schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 0}));
236 }
237 
TEST(SchemaInfoMayAliasTest,MultipleNonWildcardOutputs)238 TEST(SchemaInfoMayAliasTest, MultipleNonWildcardOutputs) {
239   SchemaInfo schema(
240       "aten::test.Tensor(Tensor(a) a, Tensor(*) b) -> (Tensor(a), Tensor(a))");
241   ASSERT_TRUE(
242       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
243   ASSERT_TRUE(
244       schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
245   ASSERT_TRUE(
246       schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
247 }
248 
TEST(SchemaInfoMayAliasTest,MismatchingTypes)249 TEST(SchemaInfoMayAliasTest, MismatchingTypes) {
250   SchemaInfo schema("aten::test.Tensor(Tensor(a) a) -> int(a)");
251   ASSERT_FALSE(
252       schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
253 }
254 
TEST(FunctionSchemaMayContainAliasTest,Basic)255 TEST(FunctionSchemaMayContainAliasTest, Basic) {
256   c10::FunctionSchema schema = torch::jit::parseSchema(
257       "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
258   ASSERT_TRUE(schema.may_contain_alias(
259       {SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
260   ASSERT_FALSE(schema.may_contain_alias(
261       {SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
262   ASSERT_FALSE(schema.may_contain_alias(
263       {SchemaArgType::input, 1}, {SchemaArgType::input, 0}));
264 }
265 
TEST(FunctionSchemaMayContainAliasTest,Wildcard)266 TEST(FunctionSchemaMayContainAliasTest, Wildcard) {
267   c10::FunctionSchema schema = torch::jit::parseSchema(
268       "aten::test.Tensor(Tensor(*) self) -> (Tensor[], Tensor)");
269   ASSERT_FALSE(
270       schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
271   ASSERT_TRUE(schema.may_contain_alias(
272       {SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
273   ASSERT_TRUE(schema.may_contain_alias(
274       {SchemaArgType::output, 0}, {SchemaArgType::input, 0}, false));
275   ASSERT_FALSE(schema.may_contain_alias(
276       {SchemaArgType::input, 0}, {SchemaArgType::output, 0}, false));
277   ASSERT_FALSE(
278       schema.may_alias({SchemaArgType::output, 1}, {SchemaArgType::input, 0}));
279 }
280 
TEST(FunctionSchemaMayContainAliasTest,InputAndOutputContainers)281 TEST(FunctionSchemaMayContainAliasTest, InputAndOutputContainers) {
282   c10::FunctionSchema schema =
283       torch::jit::parseSchema("aten::test.Tensor(Tensor[] self) -> Tensor[]");
284   ASSERT_FALSE(
285       schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
286   ASSERT_TRUE(schema.may_contain_alias(
287       {SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
288   ASSERT_TRUE(schema.may_contain_alias(
289       {SchemaArgType::output, 0}, {SchemaArgType::input, 0}, false));
290   ASSERT_TRUE(schema.may_contain_alias(
291       {SchemaArgType::input, 0}, {SchemaArgType::output, 0}, false));
292 }
293 
TEST(SchemaInfoMayContainAliasTest,ContainAliasInputsEqual)294 TEST(SchemaInfoMayContainAliasTest, ContainAliasInputsEqual) {
295   SchemaInfo schema(
296       "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
297   ASSERT_FALSE(schema.may_contain_alias(
298       {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
299   at::Tensor input = at::randn({3, 3});
300   schema.addArgumentValue("self", input);
301   schema.addArgumentValue("other", input);
302   ASSERT_TRUE(schema.may_contain_alias(
303       {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
304   ASSERT_TRUE(schema.may_contain_alias(
305       {SchemaArgType::input, 0}, {SchemaArgType::input, 1}, false));
306   ASSERT_TRUE(schema.may_contain_alias(
307       {SchemaArgType::input, 1}, {SchemaArgType::input, 0}, false));
308 }
309 
TEST(SchemaInfoMayContainAliasTest,ContainAliasInputsContained)310 TEST(SchemaInfoMayContainAliasTest, ContainAliasInputsContained) {
311   SchemaInfo schema(
312       "aten::test.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor");
313   ASSERT_FALSE(schema.may_contain_alias(
314       {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
315   at::Tensor input = at::randn({3, 3});
316   schema.addArgumentValue("self", c10::List<at::Tensor>({input}));
317   schema.addArgumentValue("other", input);
318   ASSERT_TRUE(schema.may_contain_alias(
319       {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
320   ASSERT_TRUE(schema.may_contain_alias(
321       {SchemaArgType::input, 0}, {SchemaArgType::input, 1}, false));
322   ASSERT_FALSE(schema.may_contain_alias(
323       {SchemaArgType::input, 1}, {SchemaArgType::input, 0}, false));
324 }
325 
TEST(SchemaInfoMayContainAliasTest,ContainAliasOutputs)326 TEST(SchemaInfoMayContainAliasTest, ContainAliasOutputs) {
327   SchemaInfo schema(
328       "aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)");
329   ASSERT_FALSE(schema.may_contain_alias(
330       {SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
331   at::Tensor input = at::randn({3, 3});
332   schema.addArgumentValue("min", input);
333   schema.addArgumentValue("max", input);
334   ASSERT_TRUE(schema.may_contain_alias(
335       {SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
336 }
337 
TEST(SchemaInfoMayContainAliasTest,ContainAliasInputOutput)338 TEST(SchemaInfoMayContainAliasTest, ContainAliasInputOutput) {
339   SchemaInfo schema(
340       "aten::test.tensor(Tensor(a) self, Tensor[] other) -> Tensor(a)");
341   ASSERT_FALSE(schema.may_contain_alias(
342       {SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
343   at::Tensor input = at::randn({3, 3});
344   schema.addArgumentValue("other", c10::List<at::Tensor>({input}));
345   schema.addArgumentValue("self", input);
346   ASSERT_TRUE(schema.may_contain_alias(
347       {SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
348   ASSERT_FALSE(schema.may_contain_alias(
349       {SchemaArgType::output, 0}, {SchemaArgType::input, 1}, false));
350   ASSERT_TRUE(schema.may_contain_alias(
351       {SchemaArgType::input, 1}, {SchemaArgType::output, 0}, false));
352 }
353 
TEST(SchemaInfoMayContainAliasTest,InputAndOutputContainers)354 TEST(SchemaInfoMayContainAliasTest, InputAndOutputContainers) {
355   SchemaInfo schema(
356       "aten::test.tensor(Tensor self, Tensor[] other) -> Tensor[]");
357   ASSERT_TRUE(schema.may_contain_alias(
358       {SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
359   ASSERT_FALSE(schema.may_contain_alias(
360       {SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
361   ASSERT_FALSE(schema.may_contain_alias(
362       {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
363   at::Tensor input = at::randn({3, 3});
364   schema.addArgumentValue("other", c10::List<at::Tensor>({input}));
365   schema.addArgumentValue("self", input);
366   ASSERT_TRUE(schema.may_contain_alias(
367       {SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
368   ASSERT_TRUE(schema.may_contain_alias(
369       {SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
370   ASSERT_TRUE(schema.may_contain_alias(
371       {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
372 }
373 
TEST(SchemaInfoMayContainAliasTest,Wildcard)374 TEST(SchemaInfoMayContainAliasTest, Wildcard) {
375   SchemaInfo schema(
376       "aten::test.tensor(Tensor a, Tensor[] b, Tensor(*) c) -> Tensor[]");
377   ASSERT_FALSE(schema.may_contain_alias(
378       {SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
379   ASSERT_FALSE(schema.may_contain_alias(
380       {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
381   ASSERT_TRUE(schema.may_contain_alias(
382       {SchemaArgType::input, 2}, {SchemaArgType::input, 1}));
383   at::Tensor input = at::randn({3, 3});
384   schema.addArgumentValue("b", c10::List<at::Tensor>({input}));
385   schema.addArgumentValue("a", input);
386   ASSERT_TRUE(schema.may_contain_alias(
387       {SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
388   ASSERT_TRUE(schema.may_contain_alias(
389       {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
390   ASSERT_TRUE(schema.may_contain_alias(
391       {SchemaArgType::input, 2}, {SchemaArgType::input, 1}));
392 }
393 } // namespace utils
394 } // namespace torch
395