xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_shape_analysis.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/core/interned_strings.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/Optional.h>
7 #include <test/cpp/jit/test_utils.h>
8 #include <torch/csrc/jit/ir/ir.h>
9 #include <torch/csrc/jit/ir/ir_views.h>
10 #include <torch/csrc/jit/ir/irparser.h>
11 #include <torch/csrc/jit/passes/constant_propagation.h>
12 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
13 #include <torch/csrc/jit/passes/symbolic_shape_cache.h>
14 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
15 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
16 #include <torch/csrc/jit/runtime/graph_iterator.h>
17 #include <torch/csrc/jit/runtime/interpreter.h>
18 #include <torch/csrc/jit/testing/file_check.h>
19 #include <torch/cuda.h>
20 #include <unordered_map>
21 
22 namespace torch {
23 namespace jit {
24 
25 namespace {
26 
findNode(std::shared_ptr<Graph> & g,Symbol k)27 Node* findNode(std::shared_ptr<Graph>& g, Symbol k) {
28   DepthFirstGraphNodeIterator graph_it(g);
29   for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
30     if (node->kind() == k) {
31       return node;
32     }
33   }
34   TORCH_INTERNAL_ASSERT(false, "Couldn't find node");
35 }
36 } // namespace
37 
TEST(ShapeAnalysisTest,DynamicShapesFusion)38 TEST(ShapeAnalysisTest, DynamicShapesFusion) {
39   // Test Generalizing shapes to symbolic dimensions, guarding those symbolic
40   // dimensions and passing in runtime computed symbolic dimensions via inlined
41   // shape functions
42   std::shared_ptr<Graph> subgraph = std::make_shared<Graph>();
43   const auto graph_string = R"IR(
44       graph(%x.1 : Tensor, %y.1 : Tensor, %z: Tensor):
45         %11 : int = prim::Constant[value=0]()
46         %3 : Tensor = aten::tanh(%x.1)
47         %out1.1 : Tensor = aten::erf(%3)
48         %out2.1 : Tensor = aten::relu(%y.1)
49         %10 : Tensor[] = prim::ListConstruct(%out1.1, %out2.1)
50         %25 : Tensor = aten::cat(%10, %11)
51         %28 : Tensor = aten::hardswish(%25)
52         %29 : Tensor = aten::mul(%28, %z)
53         return (%28))IR";
54   torch::jit::parseIR(graph_string, subgraph.get());
55 
56   /*
57   set up fused TensorExprGroup
58   */
59 
60   std::shared_ptr<Graph> g = std::make_shared<Graph>();
61   auto x_inp = g->addInput("x_inp");
62   auto y_inp = g->addInput("y_inp");
63   auto z_inp = g->addInput("z_inp");
64   auto x_type = TensorType::create(at::rand({10, 5}));
65   auto y_type = TensorType::create(at::rand({4, 5}));
66   auto z_type = TensorType::create(at::rand({1, 1}));
67   x_inp->setType(x_type);
68   y_inp->setType(y_type);
69   z_inp->setType(z_type);
70   subgraph->inputs().at(0)->setType(x_type);
71   subgraph->inputs().at(1)->setType(y_type);
72   subgraph->inputs().at(2)->setType(z_type);
73   subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5})));
74   auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
75   subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5})));
76   output->node()->addInput(x_inp);
77   output->node()->addInput(y_inp);
78   output->node()->addInput(z_inp);
79   output->node()->g_(attr::Subgraph, subgraph);
80 
81   auto success = GenerateGuard(output->node());
82   TORCH_INTERNAL_ASSERT(success);
83   testing::FileCheck()
84       .check("TensorExprDynamicGuard")
85       ->check_next("prim::If")
86       ->check("aten::add")
87       ->check("TensorExprGroup")
88       ->check_same("symbolic_shape_inputs")
89       ->check("block1")
90       ->check("aten::cat")
91       ->run(*g);
92 
93   // clang-format off
94   /* Graph Should Look Something like: (note: strides not yet handled)
95   graph(%x_inp : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu),
96       %y_inp : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu),
97       %z_inp : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)):
98   %4 : bool = prim::TensorExprDynamicGuard[types=[Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)]](%x_inp, %y_inp, %z_inp)
99   %5 : Tensor = prim::If(%4)
100     block0():
101       %15 : int[] = aten::size(%x_inp)
102       %16 : int[] = aten::size(%y_inp)
103       %17 : int = prim::Constant[value=1]()
104       %18 : int = prim::Constant[value=0]()
105       %elem.3 : int = aten::__getitem__(%15, %18) # <string>:40:10
106       %elem.5 : int = aten::__getitem__(%15, %17) # <string>:40:10
107       %elem.11 : int = aten::__getitem__(%16, %18) # <string>:40:10
108       %cat_dim_size.48 : int = aten::add(%elem.3, %elem.11) # <string>:321:29
109       %3 : Tensor = prim::TensorExprGroup_0[symbolic_shape_inputs=[-5, -4, -3, -2]](%x_inp, %y_inp, %z_inp, %cat_dim_size.48, %elem.11, %elem.5, %elem.3)
110       -> (%3)
111     block1():
112       // FallbackGraph is inlined
113       %14 : Tensor = prim::FallbackGraph_1(%x_inp, %y_inp, %z_inp)
114       -> (%14)
115   return ()
116   with prim::TensorExprGroup_0 = graph(%x.1 : Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu),
117         %y.1 : Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu),
118         %z : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
119         %SS_5 : int,
120         %SS_4 : int,
121         %SS_3 : int,
122         %SS_2 : int):
123     %3 : int = prim::Constant[value=0]()
124     %4 : Tensor(SS(-2), SS(-3)) = aten::tanh(%x.1)
125     %5 : Tensor(SS(-2), SS(-3)) = aten::erf(%4)
126     %6 : Tensor(SS(-4), SS(-3)) = aten::relu(%y.1)
127     %7 : Tensor[] = prim::ListConstruct(%5, %6)
128     %8 : Tensor(SS(-5), SS(-3)) = aten::cat(%7, %3)
129     %9 : Tensor(SS(-5), SS(-3)) = aten::hardswish(%8)
130     %10 : Tensor(SS(-5), SS(-3)) = aten::mul(%9, %z)
131     return (%9)
132   */
133   // clang-format on
134 
135   DepthFirstGraphNodeIterator graph_it(g);
136   Node* te_group = findNode(g, prim::TensorExprGroup);
137 
138   /*
139   Test that input to the kernel - (10, 5), (4, 5), (1, 1) - are correctly
140   generalized to sym dimensions, and that the output - (10 + 4, 5)
141   correctly preserves non-catted dim as sym shape and catted dim as new sym
142   shape
143   */
144 
145   auto tensorexpr_graph = te_group->g(attr::Subgraph);
146   auto inp1 = tensorexpr_graph->inputs().at(0)->type()->expect<TensorType>();
147   auto inp2 = tensorexpr_graph->inputs().at(1)->type()->expect<TensorType>();
148   auto inp3 = tensorexpr_graph->inputs().at(2)->type()->expect<TensorType>();
149   auto out = tensorexpr_graph->outputs().at(0)->type()->expect<TensorType>();
150 
151   // 1 dims are preserved
152   auto inp3_sizes = inp3->sizes().concrete_sizes();
153   TORCH_INTERNAL_ASSERT(inp3_sizes);
154   TORCH_INTERNAL_ASSERT(
155       inp3_sizes->size() == 2 && inp3_sizes->at(0) == 1 &&
156       inp3_sizes->at(1) == 1);
157 
158   // 5 made into sym shape
159   ASSERT_EQ(
160       inp1->symbolic_sizes()[1].value(), inp2->symbolic_sizes()[1].value());
161   ASSERT_EQ(
162       out->symbolic_sizes()[1].value(), inp2->symbolic_sizes()[1].value());
163 
164   // 4, 10, 14 are different sym shapes
165   ASSERT_NE(
166       inp1->symbolic_sizes()[0].value(), inp2->symbolic_sizes()[0].value());
167   ASSERT_NE(
168       out->symbolic_sizes()[0].value(), inp1->symbolic_sizes()[0].value());
169   ASSERT_NE(
170       out->symbolic_sizes()[0].value(), inp2->symbolic_sizes()[0].value());
171 
172   /*
173     Test guard behaves correctly at runtime and symbolic shapes are computed
174     correctly. As we don't have TE Kernel support for dynamic shapes we're
175     going to return all of the computed runtime symbolic dimensions as outputs
176     of the graph on guard success, and return None on guard failure
177   */
178 
179   // Setting up guard to return sym shapes on guard success and None on failure
180   Node* if_node = findNode(g, prim::If);
181   IfView if_v(if_node);
182   if_node->eraseOutput(0);
183   if_v.thenBlock()->eraseOutput(0);
184   if_v.elseBlock()->eraseOutput(0);
185   WithInsertPoint guard(if_node);
186   auto none_val = g->insertConstant(IValue());
187 
188   auto sym_shapes = te_group->is(Symbol::attr("symbolic_shape_inputs"));
189   auto offset = te_group->inputs().size() - sym_shapes.size();
190   for (size_t i = 0; i < sym_shapes.size(); ++i) {
191     if_v.thenBlock()->insertOutput(i, te_group->inputs().at(offset + i));
192     if_v.elseBlock()->insertOutput(i, none_val);
193     if_node->insertOutput(i)->setType(OptionalType::create(IntType::get()));
194   }
195 
196   auto new_outputs = g->createTuple(if_node->outputs())->insertAfter(if_node);
197 
198   g->registerOutput(new_outputs->output());
199   te_group->destroy();
200   findNode(g, prim::FallbackGraph)->destroy();
201 
202   // Testing bad inputs
203 
204   auto first_inp = at::rand({2, 5});
205   std::vector<std::vector<at::Tensor>> second_inps = {
206       {at::rand({3, 4}), at::rand({1, 1})}, // sym shape mismatch
207       {at::rand({5, 2}).transpose(0, 1), at::rand({1, 1})}, // discontiguous
208       {at::zeros({2, 5}).to(at::ScalarType::Int),
209        at::rand({1, 1})}, // wrong dtype
210       {at::rand({2, 5, 1}), at::rand({1, 1})}, // wrong # dims
211       {at::rand({2, 5}).requires_grad_(true),
212        at::rand({1, 1})}, // requires grad
213       {at::rand({2, 5}), at::rand({1, 12})}, // concrete dim mismatch (1)
214   };
215   if (torch::cuda::is_available()) {
216     second_inps.push_back({at::rand({2, 5}).cuda(), at::rand({1, 1})});
217   }
218   for (const auto& last_inps : second_inps) {
219     // todo - reusing interpreter across iters gave error
220     Code code(g, "");
221     InterpreterState interp(code);
222     auto stack = createStack({at::rand({2, 5}), last_inps[0], last_inps[1]});
223     interp.run(stack);
224     TORCH_INTERNAL_ASSERT(pop(stack).toTuple()->elements().at(0).isNone());
225   }
226 
227   // Test good inputs
228   Code code(g, "");
229   InterpreterState interp(code);
230   std::vector<at::Tensor> inps = {
231       at::rand({2, 5}), at::rand({4, 5}), at::rand({1, 1})};
232   Stack stack(inps.begin(), inps.end());
233   interp.run(stack);
234   auto tuple = pop(stack).toTuple();
235   TORCH_INTERNAL_ASSERT(tuple->elements().at(0).isInt());
236 
237   // Testing that the sym shape calculation was correct
238   for (size_t i = 0; i < sym_shapes.size(); ++i) {
239     auto sym_shape = sym_shapes[i];
240     auto computed_value = tuple->elements().at(i).toInt();
241     if (sym_shape == inp1->symbolic_sizes().at(0).value()) {
242       ASSERT_EQ(computed_value, 2);
243     } else if (sym_shape == inp1->symbolic_sizes().at(1).value()) {
244       ASSERT_EQ(computed_value, 5);
245     } else if (sym_shape == inp2->symbolic_sizes().at(0).value()) {
246       ASSERT_EQ(computed_value, 4);
247     } else if (sym_shape == out->symbolic_sizes().at(0).value()) {
248       ASSERT_EQ(computed_value, 6);
249     } else {
250       TORCH_INTERNAL_ASSERT(false);
251     }
252   }
253 }
254 
TEST(ShapeAnalysisTest,MovingConstantOutOfFusionGroups)255 TEST(ShapeAnalysisTest, MovingConstantOutOfFusionGroups) {
256   std::shared_ptr<Graph> subgraph = std::make_shared<Graph>();
257   const auto graph_string = R"IR(
258       graph(%x.1 : Tensor):
259         %none : NoneType = prim::Constant()
260         %size1 : int = prim::Constant[value=1]()
261         %size10 : int = prim::Constant[value=10]()
262         %sizes : int[] = prim::ListConstruct(%size10, %size1)
263         %device : Device = prim::Constant[value="cpu"]()
264         %10 : Tensor = aten::ones(%sizes, %none, %none, %device, %none)
265         %3 : Tensor = aten::tanh(%x.1)
266         %29 : Tensor = aten::mul(%3, %10)
267         return (%29))IR";
268   torch::jit::parseIR(graph_string, subgraph.get());
269   ConstantPropagation(subgraph);
270 
271   std::shared_ptr<Graph> g = std::make_shared<Graph>();
272   auto x_inp = g->addInput("x_inp");
273   auto x_type = TensorType::create(at::rand({10, 5}));
274   x_inp->setType(x_type);
275   subgraph->inputs().at(0)->setType(x_type);
276   subgraph->outputs().at(0)->setType(x_type);
277   auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
278   output->node()->addInput(x_inp);
279   output->node()->g_(attr::Subgraph, subgraph);
280 
281   auto success = GenerateGuard(output->node());
282   TORCH_INTERNAL_ASSERT(success);
283 
284   // Check that the constants have been moved out of the fused graph.
285   // This should result in not have any conditionals other than the one
286   // checking the result of TensorExprDynamicGuard.
287   testing::FileCheck()
288       .check("TensorExprDynamicGuard")
289       ->check_next("prim::If")
290       ->check_not("prim::If") // no other IFs due to constants.
291       ->check("TensorExprGroup")
292       ->check("block1")
293       ->check("FallbackGraph")
294       ->run(*g);
295 }
296 
297 namespace {
298 
299 std::optional<int64_t> sym_dim = std::nullopt;
300 
301 // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
assertShapeEqual(c10::SymbolicShape & a,c10::SymbolicShape & e)302 void assertShapeEqual(c10::SymbolicShape& a, c10::SymbolicShape& e) {
303   auto a_canonical = CanonicalizedSymbolicShape(a);
304   auto e_canonical = CanonicalizedSymbolicShape(e);
305   EXPECT_EQ(a_canonical, e_canonical);
306 }
307 
assertShapeEqual(std::optional<std::vector<c10::SymbolicShape>> & actual,std::vector<std::optional<int64_t>> expected)308 void assertShapeEqual(
309     std::optional<std::vector<c10::SymbolicShape>>& actual,
310     std::vector<std::optional<int64_t>> expected) {
311   ASSERT_TRUE(actual.has_value());
312   ASSERT_EQ(actual->size(), 1);
313 
314   auto symb_expected = c10::SymbolicShape(expected);
315   assertShapeEqual(actual->at(0), symb_expected);
316 }
317 
getSchema(const char * name)318 const FunctionSchema* getSchema(const char* name) {
319   return &(getOperatorForLiteral(name)->schema());
320 }
321 } // namespace
322 
TEST(ShapeAnalysisTest,SymbolicShapeAPI)323 TEST(ShapeAnalysisTest, SymbolicShapeAPI) {
324   // Figure out how to fetch a function schema
325 
326   // Ask someone else how to create a function schema / operator in C++
327   auto schema = getSchema(
328       "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
329 
330   c10::IValue const_size_1 = std::vector<int64_t>{64, 56, 56};
331   c10::IValue const_size_2 = std::vector<int64_t>{1, 56, 56};
332 
333   // Check vector initializer list syntax
334   c10::SymbolicShape ss_concrete =
335       std::vector<std::optional<int64_t>>{1, 56, 56};
336   c10::SymbolicShape ss1 = std::vector<std::optional<int64_t>>{sym_dim, 56, 56};
337   c10::SymbolicShape ss2 =
338       std::vector<std::optional<int64_t>>{64, sym_dim, sym_dim};
339   c10::SymbolicShape ss3 =
340       std::vector<std::optional<int64_t>>{sym_dim, sym_dim, sym_dim, sym_dim};
341 
342   auto res = calculateSymbolicShapesOnOp(
343       schema, std::vector<SSAInput>{const_size_1, const_size_1});
344   assertShapeEqual(res, {64, 56, 56});
345 
346   res = calculateSymbolicShapesOnOp(
347       schema, std::vector<SSAInput>{const_size_1, const_size_2});
348   assertShapeEqual(res, {64, 56, 56});
349 
350   res = calculateSymbolicShapesOnOp(
351       schema, std::vector<SSAInput>{const_size_1, ss1});
352   assertShapeEqual(res, {64, 56, 56});
353 
354   res = calculateSymbolicShapesOnOp(
355       schema, std::vector<SSAInput>{const_size_2, ss1});
356   assertShapeEqual(res, {sym_dim, 56, 56});
357 
358   res = calculateSymbolicShapesOnOp(
359       schema, std::vector<SSAInput>{ss_concrete, ss2});
360   assertShapeEqual(res, {64, 56, 56});
361 
362   res = calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{ss2, ss3});
363   assertShapeEqual(res, {sym_dim, 64, sym_dim, sym_dim});
364 }
365 
TEST(ShapeAnalysisTest,BoundedSymbolicShapes)366 TEST(ShapeAnalysisTest, BoundedSymbolicShapes) {
367   auto schema = getSchema("aten::nonzero(Tensor self) -> (Tensor)");
368 
369   // Test that we generate symbolic shapes for the output of a nonzero op
370   c10::IValue const_size_1 = std::vector<int64_t>{5, 10};
371   auto res =
372       calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{const_size_1});
373   assertShapeEqual(res, {sym_dim, 2});
374 
375   // Test that nonzero can also create concrete shapes
376   c10::IValue const_size_2 = std::vector<int64_t>({1, 0});
377   res =
378       calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{const_size_2});
379   assertShapeEqual(res, {0, 2});
380 }
381 
TEST(ShapeAnalysisTest,SymbolicShapeCaching)382 TEST(ShapeAnalysisTest, SymbolicShapeCaching) {
383   clear_shape_cache();
384   auto schema = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
385 
386   c10::IValue const_size_1 = std::vector<int64_t>{64, 56};
387   c10::IValue const_size_2 = std::vector<int64_t>{64, 56};
388   c10::IValue const_size_3 = std::vector<int64_t>{64, 20};
389 
390   c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
391   c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
392   c10::SymbolicShape ss3 = c10::SymbolicShape({sym_dim, sym_dim});
393 
394   auto res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1});
395   assertShapeEqual(res, {sym_dim, 56});
396   auto res1_val = res->at(0);
397 
398   // The exact same arguments should return the exact same result
399   res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1});
400   auto res2_val = res->at(0);
401   EXPECT_EQ(res1_val, res2_val);
402   EXPECT_EQ(get_shape_cache_size(), 1);
403 
404   // Same shape but different symbols should return same shape
405   // but different symbolic indices
406   res = calculateSymbolicShapesOnOp(schema, {ss2, const_size_2});
407   auto res3_val = res->at(0);
408 
409   assertShapeEqual(res3_val, res2_val);
410   EXPECT_NE(res3_val, res2_val);
411   EXPECT_EQ(get_shape_cache_size(), 1);
412 
413   // Different concrete shape should be cached separately
414   res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_3});
415   assertShapeEqual(res, {sym_dim, 20});
416   EXPECT_EQ(get_shape_cache_size(), 2);
417 
418   res = calculateSymbolicShapesOnOp(schema, {ss3, const_size_3});
419   assertShapeEqual(res, {sym_dim, 20});
420   EXPECT_EQ(get_shape_cache_size(), 3);
421 
422   res = calculateSymbolicShapesOnOp(schema, {ss3, ss3});
423   assertShapeEqual(res, {sym_dim, sym_dim});
424   EXPECT_EQ(get_shape_cache_size(), 4);
425 }
426 
TEST(ShapeAnalysisTest,ShapeCacheMultipleFns)427 TEST(ShapeAnalysisTest, ShapeCacheMultipleFns) {
428   clear_shape_cache();
429 
430   auto squeeze_op =
431       getSchema("aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)");
432   auto mul_tensor =
433       getSchema("aten::mul.Tensor(Tensor self, Tensor other) -> Tensor");
434   auto mul_scalar =
435       getSchema("aten::mul.Scalar(Tensor self, Scalar other) -> Tensor");
436   auto div_tensor =
437       getSchema("aten::div.Tensor(Tensor self, Tensor other) -> Tensor");
438   auto matmul = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
439 
440   c10::IValue const_int = 1;
441 
442   c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
443 
444   auto res = calculateSymbolicShapesOnOp(squeeze_op, {ss1, const_int});
445   assertShapeEqual(res, {sym_dim, 64});
446 
447   // Show that cache can handle multiple functions
448   res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int});
449   assertShapeEqual(res, {sym_dim, 64});
450   EXPECT_EQ(get_shape_cache_size(), 2);
451 
452   res = calculateSymbolicShapesOnOp(mul_tensor, {ss1, ss1});
453   assertShapeEqual(res, {sym_dim, 64});
454   EXPECT_EQ(get_shape_cache_size(), 3);
455 
456   // Even when the expected outcome is the same, should not collide
457   res = calculateSymbolicShapesOnOp(div_tensor, {ss1, ss1});
458   assertShapeEqual(res, {sym_dim, 64});
459   EXPECT_EQ(get_shape_cache_size(), 4);
460 
461   // Don't lose cached objects
462   res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int});
463   assertShapeEqual(res, {sym_dim, 64});
464   EXPECT_EQ(get_shape_cache_size(), 4);
465 
466   res = calculateSymbolicShapesOnOp(matmul, {ss1, ss1});
467   // SSA can infer that sym_dim is 64 as both tensors
468   // use the same sym_dim
469   assertShapeEqual(res, {64, 64});
470   EXPECT_EQ(get_shape_cache_size(), 5);
471 }
472 
TEST(ShapeAnalysisTest,TestShapeMultipleReturns)473 TEST(ShapeAnalysisTest, TestShapeMultipleReturns) {
474   clear_shape_cache();
475 
476   auto max_dim_op = getSchema(
477       "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)");
478   c10::IValue const_int = 1;
479   c10::IValue false_ival = false;
480 
481   c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
482   c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
483 
484   auto res =
485       calculateSymbolicShapesOnOp(max_dim_op, {ss1, const_int, false_ival});
486   c10::SymbolicShape expected_res =
487       c10::SymbolicShape(std::vector<std::optional<int64_t>>{sym_dim});
488   assertShapeEqual(res->at(0), expected_res);
489   // res0 and res1 should share the same symbolic symbol
490   EXPECT_EQ(res->at(0), res->at(1));
491 
492   // Also test that the shape cache also returns consistent result shapes
493   res = calculateSymbolicShapesOnOp(max_dim_op, {ss2, const_int, false_ival});
494   assertShapeEqual(res->at(0), expected_res);
495   EXPECT_EQ(res->at(0), res->at(1));
496   EXPECT_EQ(get_shape_cache_size(), 1);
497 }
498 } // namespace jit
499 } // namespace torch
500