1 #include <gtest/gtest.h>
2 #include <torch/csrc/autograd/generated/variable_factories.h>
3 #include <torch/csrc/utils/schema_info.h>
4
5 namespace torch {
6 namespace utils {
7 using c10::SchemaArgType;
8
TEST(FunctionSchemaIsAliasingTest,Basic)9 TEST(FunctionSchemaIsAliasingTest, Basic) {
10 c10::FunctionSchema schema = torch::jit::parseSchema(
11 "aten::test.Tensor(Tensor(a) self, Tensor(b!) other, Tensor more_other) -> (Tensor(a), Tensor(b!))");
12 ASSERT_TRUE(schema.is_aliasing({SchemaArgType::output, 0}));
13 ASSERT_TRUE(schema.is_aliasing({SchemaArgType::output, 1}));
14 ASSERT_TRUE(schema.is_aliasing({SchemaArgType::input, 0}));
15 ASSERT_TRUE(schema.is_aliasing({SchemaArgType::input, 1}));
16 ASSERT_FALSE(schema.is_aliasing({SchemaArgType::input, 2}));
17 }
18
TEST(FunctionSchemaIsAliasingTest,InvalidArgument)19 TEST(FunctionSchemaIsAliasingTest, InvalidArgument) {
20 c10::FunctionSchema schema = torch::jit::parseSchema(
21 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
22 ASSERT_THROW(schema.is_aliasing({SchemaArgType::input, 4}), c10::Error);
23 ASSERT_THROW(schema.is_aliasing({SchemaArgType::output, 4}), c10::Error);
24 }
25
TEST(FunctionSchemaIsMutableTest,Basic)26 TEST(FunctionSchemaIsMutableTest, Basic) {
27 c10::FunctionSchema schema = torch::jit::parseSchema(
28 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
29 ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 0}));
30 ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0}));
31 ASSERT_TRUE(schema.is_mutable("self"));
32 ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1}));
33 ASSERT_FALSE(schema.is_mutable("other"));
34 ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 2}));
35 ASSERT_FALSE(schema.is_mutable("alpha"));
36 }
37
TEST(FunctionSchemaIsMutableTest,InvalidArgument)38 TEST(FunctionSchemaIsMutableTest, InvalidArgument) {
39 c10::FunctionSchema schema = torch::jit::parseSchema(
40 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
41 ASSERT_THROW(schema.is_mutable({SchemaArgType::input, 4}), c10::Error);
42 ASSERT_THROW(schema.is_mutable({SchemaArgType::output, 4}), c10::Error);
43 ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
44 }
45
TEST(SchemaInfoIsMutableTest,Basic)46 TEST(SchemaInfoIsMutableTest, Basic) {
47 SchemaInfo schema(
48 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
49 ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0}));
50 ASSERT_TRUE(schema.is_mutable("self"));
51 ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1}));
52 ASSERT_FALSE(schema.is_mutable("other"));
53 ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 2}));
54 ASSERT_FALSE(schema.is_mutable("alpha"));
55 }
56
TEST(SchemaInfoIsMutableTest,InvalidArgument)57 TEST(SchemaInfoIsMutableTest, InvalidArgument) {
58 SchemaInfo schema(
59 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
60 ASSERT_THROW(schema.is_mutable({SchemaArgType::input, 4}), c10::Error);
61 ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
62 }
63
TEST(SchemaInfoIsMutableTest,AliasingInputs)64 TEST(SchemaInfoIsMutableTest, AliasingInputs) {
65 SchemaInfo schema(
66 "aten::test.Tensor(Tensor(a!) self, Tensor(b) other, *, Scalar alpha=1) -> (Tensor(a!), Tensor(b))");
67 ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0}));
68 ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 0}));
69 ASSERT_TRUE(schema.is_mutable("self"));
70 ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1}));
71 ASSERT_FALSE(schema.is_mutable({SchemaArgType::output, 1}));
72 ASSERT_FALSE(schema.is_mutable("other"));
73 at::Tensor input = at::randn({3, 3});
74 schema.addArgumentValue("self", input);
75 schema.addArgumentValue("other", input);
76 ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 1}));
77 ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 1}));
78 ASSERT_TRUE(schema.is_mutable("other"));
79 }
80
TEST(SchemaInfoIsMutableTest,InstanceNorm)81 TEST(SchemaInfoIsMutableTest, InstanceNorm) {
82 SchemaInfo schema_info(
83 "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor");
84 ASSERT_TRUE(schema_info.is_mutable("running_mean"));
85 ASSERT_TRUE(schema_info.is_mutable("running_var"));
86 schema_info.addArgumentValue("use_input_stats", false);
87 ASSERT_FALSE(schema_info.is_mutable("running_mean"));
88 ASSERT_FALSE(schema_info.is_mutable("running_var"));
89 }
90
TEST(SchemaInfoIsMutableTest,BatchNorm)91 TEST(SchemaInfoIsMutableTest, BatchNorm) {
92 SchemaInfo schema_info(
93 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor");
94 ASSERT_TRUE(schema_info.is_mutable("running_mean"));
95 ASSERT_TRUE(schema_info.is_mutable("running_var"));
96 schema_info.addArgumentValue("training", false);
97 ASSERT_FALSE(schema_info.is_mutable("running_mean"));
98 ASSERT_FALSE(schema_info.is_mutable("running_var"));
99 }
100
TEST(SchemaInfoIsNonDeterministicTest,Basic)101 TEST(SchemaInfoIsNonDeterministicTest, Basic) {
102 SchemaInfo deterministic_schema_info(
103 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
104 SchemaInfo nondeterministic_schema_info(
105 "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor");
106 ASSERT_FALSE(deterministic_schema_info.is_nondeterministic());
107 ASSERT_TRUE(nondeterministic_schema_info.is_nondeterministic());
108 }
109
TEST(SchemaInfoIsNonDeterministicTest,Dropout)110 TEST(SchemaInfoIsNonDeterministicTest, Dropout) {
111 SchemaInfo droupout_schema_info(
112 "aten::dropout(Tensor input, float p, bool train) -> Tensor");
113 ASSERT_TRUE(droupout_schema_info.is_nondeterministic());
114 droupout_schema_info.addArgumentValue("train", false);
115 ASSERT_FALSE(droupout_schema_info.is_nondeterministic());
116 }
117
TEST(FunctionSchemaMayAliasTest,Basic)118 TEST(FunctionSchemaMayAliasTest, Basic) {
119 c10::FunctionSchema schema = torch::jit::parseSchema(
120 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
121 ASSERT_TRUE(
122 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
123 ASSERT_FALSE(
124 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
125 ASSERT_FALSE(
126 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::input, 0}));
127 }
128
TEST(FunctionSchemaMayAliasTest,InvalidArgument)129 TEST(FunctionSchemaMayAliasTest, InvalidArgument) {
130 c10::FunctionSchema schema = torch::jit::parseSchema(
131 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
132 ASSERT_THROW(
133 schema.may_alias({SchemaArgType::input, 15}, {SchemaArgType::output, 0}),
134 c10::Error);
135 ASSERT_THROW(
136 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 15}),
137 c10::Error);
138 }
139
TEST(FunctionSchemaMayAliasTest,Wildcard)140 TEST(FunctionSchemaMayAliasTest, Wildcard) {
141 c10::FunctionSchema schema = torch::jit::parseSchema(
142 "aten::test.Tensor(Tensor(*) self) -> (Tensor(*), Tensor)");
143 ASSERT_TRUE(
144 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
145 ASSERT_FALSE(
146 schema.may_alias({SchemaArgType::output, 1}, {SchemaArgType::input, 0}));
147 }
148
TEST(SchemaInfoMayAliasTest,AliasingInputs)149 TEST(SchemaInfoMayAliasTest, AliasingInputs) {
150 SchemaInfo schema(
151 "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
152 ASSERT_FALSE(
153 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
154 at::Tensor input = at::randn({3, 3});
155 schema.addArgumentValue("self", input);
156 schema.addArgumentValue("other", input);
157 ASSERT_TRUE(
158 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
159 }
160
TEST(SchemaInfoMayAliasTest,AliasingOutputs)161 TEST(SchemaInfoMayAliasTest, AliasingOutputs) {
162 SchemaInfo schema(
163 "aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)");
164 ASSERT_FALSE(
165 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
166 at::Tensor input = at::randn({3, 3});
167 schema.addArgumentValue("min", input);
168 schema.addArgumentValue("max", input);
169 ASSERT_TRUE(
170 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
171 }
172
TEST(SchemaInfoMayAliasTest,AliasingInputOutput)173 TEST(SchemaInfoMayAliasTest, AliasingInputOutput) {
174 SchemaInfo schema(
175 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
176 ASSERT_TRUE(
177 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
178 ASSERT_FALSE(
179 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
180 at::Tensor input = at::randn({3, 3});
181 schema.addArgumentValue("self", input);
182 schema.addArgumentValue("other", input);
183 ASSERT_TRUE(
184 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
185 ASSERT_TRUE(
186 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
187 }
188
TEST(SchemaInfoMayAliasTest,MultipleWildcardInputs)189 TEST(SchemaInfoMayAliasTest, MultipleWildcardInputs) {
190 SchemaInfo schema(
191 "aten::test.Tensor(Tensor(a) a, Tensor(*) b, Tensor(*) c) -> (Tensor(a), Tensor(*))");
192 ASSERT_TRUE(
193 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
194 ASSERT_TRUE(
195 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 1}));
196 ASSERT_TRUE(
197 schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 1}));
198 ASSERT_FALSE(
199 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
200 ASSERT_FALSE(
201 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
202 ASSERT_FALSE(
203 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 1}));
204 ASSERT_FALSE(
205 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
206 at::Tensor input = at::randn({3, 3});
207 schema.addArgumentValue("a", input);
208 schema.addArgumentValue("b", input);
209 ASSERT_TRUE(
210 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
211 ASSERT_TRUE(
212 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 1}));
213 ASSERT_TRUE(
214 schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 1}));
215 ASSERT_TRUE(
216 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
217 ASSERT_TRUE(
218 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
219 ASSERT_TRUE(
220 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 1}));
221 ASSERT_TRUE(
222 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
223 }
224
TEST(SchemaInfoMayAliasTest,MultipleNonWildcardInputs)225 TEST(SchemaInfoMayAliasTest, MultipleNonWildcardInputs) {
226 SchemaInfo schema(
227 "aten::test.Tensor(Tensor(a) a, Tensor(a) b, Tensor(*) c, Tensor(b) d) -> (Tensor(a), Tensor(*))");
228 ASSERT_TRUE(
229 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
230 ASSERT_TRUE(
231 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
232 ASSERT_TRUE(
233 schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::input, 1}));
234 ASSERT_TRUE(
235 schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 0}));
236 }
237
TEST(SchemaInfoMayAliasTest,MultipleNonWildcardOutputs)238 TEST(SchemaInfoMayAliasTest, MultipleNonWildcardOutputs) {
239 SchemaInfo schema(
240 "aten::test.Tensor(Tensor(a) a, Tensor(*) b) -> (Tensor(a), Tensor(a))");
241 ASSERT_TRUE(
242 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
243 ASSERT_TRUE(
244 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
245 ASSERT_TRUE(
246 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
247 }
248
TEST(SchemaInfoMayAliasTest,MismatchingTypes)249 TEST(SchemaInfoMayAliasTest, MismatchingTypes) {
250 SchemaInfo schema("aten::test.Tensor(Tensor(a) a) -> int(a)");
251 ASSERT_FALSE(
252 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
253 }
254
TEST(FunctionSchemaMayContainAliasTest,Basic)255 TEST(FunctionSchemaMayContainAliasTest, Basic) {
256 c10::FunctionSchema schema = torch::jit::parseSchema(
257 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
258 ASSERT_TRUE(schema.may_contain_alias(
259 {SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
260 ASSERT_FALSE(schema.may_contain_alias(
261 {SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
262 ASSERT_FALSE(schema.may_contain_alias(
263 {SchemaArgType::input, 1}, {SchemaArgType::input, 0}));
264 }
265
TEST(FunctionSchemaMayContainAliasTest,Wildcard)266 TEST(FunctionSchemaMayContainAliasTest, Wildcard) {
267 c10::FunctionSchema schema = torch::jit::parseSchema(
268 "aten::test.Tensor(Tensor(*) self) -> (Tensor[], Tensor)");
269 ASSERT_FALSE(
270 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
271 ASSERT_TRUE(schema.may_contain_alias(
272 {SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
273 ASSERT_TRUE(schema.may_contain_alias(
274 {SchemaArgType::output, 0}, {SchemaArgType::input, 0}, false));
275 ASSERT_FALSE(schema.may_contain_alias(
276 {SchemaArgType::input, 0}, {SchemaArgType::output, 0}, false));
277 ASSERT_FALSE(
278 schema.may_alias({SchemaArgType::output, 1}, {SchemaArgType::input, 0}));
279 }
280
TEST(FunctionSchemaMayContainAliasTest,InputAndOutputContainers)281 TEST(FunctionSchemaMayContainAliasTest, InputAndOutputContainers) {
282 c10::FunctionSchema schema =
283 torch::jit::parseSchema("aten::test.Tensor(Tensor[] self) -> Tensor[]");
284 ASSERT_FALSE(
285 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
286 ASSERT_TRUE(schema.may_contain_alias(
287 {SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
288 ASSERT_TRUE(schema.may_contain_alias(
289 {SchemaArgType::output, 0}, {SchemaArgType::input, 0}, false));
290 ASSERT_TRUE(schema.may_contain_alias(
291 {SchemaArgType::input, 0}, {SchemaArgType::output, 0}, false));
292 }
293
TEST(SchemaInfoMayContainAliasTest,ContainAliasInputsEqual)294 TEST(SchemaInfoMayContainAliasTest, ContainAliasInputsEqual) {
295 SchemaInfo schema(
296 "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
297 ASSERT_FALSE(schema.may_contain_alias(
298 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
299 at::Tensor input = at::randn({3, 3});
300 schema.addArgumentValue("self", input);
301 schema.addArgumentValue("other", input);
302 ASSERT_TRUE(schema.may_contain_alias(
303 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
304 ASSERT_TRUE(schema.may_contain_alias(
305 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}, false));
306 ASSERT_TRUE(schema.may_contain_alias(
307 {SchemaArgType::input, 1}, {SchemaArgType::input, 0}, false));
308 }
309
TEST(SchemaInfoMayContainAliasTest,ContainAliasInputsContained)310 TEST(SchemaInfoMayContainAliasTest, ContainAliasInputsContained) {
311 SchemaInfo schema(
312 "aten::test.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor");
313 ASSERT_FALSE(schema.may_contain_alias(
314 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
315 at::Tensor input = at::randn({3, 3});
316 schema.addArgumentValue("self", c10::List<at::Tensor>({input}));
317 schema.addArgumentValue("other", input);
318 ASSERT_TRUE(schema.may_contain_alias(
319 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
320 ASSERT_TRUE(schema.may_contain_alias(
321 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}, false));
322 ASSERT_FALSE(schema.may_contain_alias(
323 {SchemaArgType::input, 1}, {SchemaArgType::input, 0}, false));
324 }
325
TEST(SchemaInfoMayContainAliasTest,ContainAliasOutputs)326 TEST(SchemaInfoMayContainAliasTest, ContainAliasOutputs) {
327 SchemaInfo schema(
328 "aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)");
329 ASSERT_FALSE(schema.may_contain_alias(
330 {SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
331 at::Tensor input = at::randn({3, 3});
332 schema.addArgumentValue("min", input);
333 schema.addArgumentValue("max", input);
334 ASSERT_TRUE(schema.may_contain_alias(
335 {SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
336 }
337
TEST(SchemaInfoMayContainAliasTest,ContainAliasInputOutput)338 TEST(SchemaInfoMayContainAliasTest, ContainAliasInputOutput) {
339 SchemaInfo schema(
340 "aten::test.tensor(Tensor(a) self, Tensor[] other) -> Tensor(a)");
341 ASSERT_FALSE(schema.may_contain_alias(
342 {SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
343 at::Tensor input = at::randn({3, 3});
344 schema.addArgumentValue("other", c10::List<at::Tensor>({input}));
345 schema.addArgumentValue("self", input);
346 ASSERT_TRUE(schema.may_contain_alias(
347 {SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
348 ASSERT_FALSE(schema.may_contain_alias(
349 {SchemaArgType::output, 0}, {SchemaArgType::input, 1}, false));
350 ASSERT_TRUE(schema.may_contain_alias(
351 {SchemaArgType::input, 1}, {SchemaArgType::output, 0}, false));
352 }
353
TEST(SchemaInfoMayContainAliasTest,InputAndOutputContainers)354 TEST(SchemaInfoMayContainAliasTest, InputAndOutputContainers) {
355 SchemaInfo schema(
356 "aten::test.tensor(Tensor self, Tensor[] other) -> Tensor[]");
357 ASSERT_TRUE(schema.may_contain_alias(
358 {SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
359 ASSERT_FALSE(schema.may_contain_alias(
360 {SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
361 ASSERT_FALSE(schema.may_contain_alias(
362 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
363 at::Tensor input = at::randn({3, 3});
364 schema.addArgumentValue("other", c10::List<at::Tensor>({input}));
365 schema.addArgumentValue("self", input);
366 ASSERT_TRUE(schema.may_contain_alias(
367 {SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
368 ASSERT_TRUE(schema.may_contain_alias(
369 {SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
370 ASSERT_TRUE(schema.may_contain_alias(
371 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
372 }
373
TEST(SchemaInfoMayContainAliasTest,Wildcard)374 TEST(SchemaInfoMayContainAliasTest, Wildcard) {
375 SchemaInfo schema(
376 "aten::test.tensor(Tensor a, Tensor[] b, Tensor(*) c) -> Tensor[]");
377 ASSERT_FALSE(schema.may_contain_alias(
378 {SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
379 ASSERT_FALSE(schema.may_contain_alias(
380 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
381 ASSERT_TRUE(schema.may_contain_alias(
382 {SchemaArgType::input, 2}, {SchemaArgType::input, 1}));
383 at::Tensor input = at::randn({3, 3});
384 schema.addArgumentValue("b", c10::List<at::Tensor>({input}));
385 schema.addArgumentValue("a", input);
386 ASSERT_TRUE(schema.may_contain_alias(
387 {SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
388 ASSERT_TRUE(schema.may_contain_alias(
389 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
390 ASSERT_TRUE(schema.may_contain_alias(
391 {SchemaArgType::input, 2}, {SchemaArgType::input, 1}));
392 }
393 } // namespace utils
394 } // namespace torch
395