1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/core/interned_strings.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/Optional.h>
7 #include <test/cpp/jit/test_utils.h>
8 #include <torch/csrc/jit/ir/ir.h>
9 #include <torch/csrc/jit/ir/ir_views.h>
10 #include <torch/csrc/jit/ir/irparser.h>
11 #include <torch/csrc/jit/passes/constant_propagation.h>
12 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
13 #include <torch/csrc/jit/passes/symbolic_shape_cache.h>
14 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
15 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
16 #include <torch/csrc/jit/runtime/graph_iterator.h>
17 #include <torch/csrc/jit/runtime/interpreter.h>
18 #include <torch/csrc/jit/testing/file_check.h>
19 #include <torch/cuda.h>
20 #include <unordered_map>
21
22 namespace torch {
23 namespace jit {
24
25 namespace {
26
findNode(std::shared_ptr<Graph> & g,Symbol k)27 Node* findNode(std::shared_ptr<Graph>& g, Symbol k) {
28 DepthFirstGraphNodeIterator graph_it(g);
29 for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
30 if (node->kind() == k) {
31 return node;
32 }
33 }
34 TORCH_INTERNAL_ASSERT(false, "Couldn't find node");
35 }
36 } // namespace
37
TEST(ShapeAnalysisTest,DynamicShapesFusion)38 TEST(ShapeAnalysisTest, DynamicShapesFusion) {
39 // Test Generalizing shapes to symbolic dimensions, guarding those symbolic
40 // dimensions and passing in runtime computed symbolic dimensions via inlined
41 // shape functions
42 std::shared_ptr<Graph> subgraph = std::make_shared<Graph>();
43 const auto graph_string = R"IR(
44 graph(%x.1 : Tensor, %y.1 : Tensor, %z: Tensor):
45 %11 : int = prim::Constant[value=0]()
46 %3 : Tensor = aten::tanh(%x.1)
47 %out1.1 : Tensor = aten::erf(%3)
48 %out2.1 : Tensor = aten::relu(%y.1)
49 %10 : Tensor[] = prim::ListConstruct(%out1.1, %out2.1)
50 %25 : Tensor = aten::cat(%10, %11)
51 %28 : Tensor = aten::hardswish(%25)
52 %29 : Tensor = aten::mul(%28, %z)
53 return (%28))IR";
54 torch::jit::parseIR(graph_string, subgraph.get());
55
56 /*
57 set up fused TensorExprGroup
58 */
59
60 std::shared_ptr<Graph> g = std::make_shared<Graph>();
61 auto x_inp = g->addInput("x_inp");
62 auto y_inp = g->addInput("y_inp");
63 auto z_inp = g->addInput("z_inp");
64 auto x_type = TensorType::create(at::rand({10, 5}));
65 auto y_type = TensorType::create(at::rand({4, 5}));
66 auto z_type = TensorType::create(at::rand({1, 1}));
67 x_inp->setType(x_type);
68 y_inp->setType(y_type);
69 z_inp->setType(z_type);
70 subgraph->inputs().at(0)->setType(x_type);
71 subgraph->inputs().at(1)->setType(y_type);
72 subgraph->inputs().at(2)->setType(z_type);
73 subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5})));
74 auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
75 subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5})));
76 output->node()->addInput(x_inp);
77 output->node()->addInput(y_inp);
78 output->node()->addInput(z_inp);
79 output->node()->g_(attr::Subgraph, subgraph);
80
81 auto success = GenerateGuard(output->node());
82 TORCH_INTERNAL_ASSERT(success);
83 testing::FileCheck()
84 .check("TensorExprDynamicGuard")
85 ->check_next("prim::If")
86 ->check("aten::add")
87 ->check("TensorExprGroup")
88 ->check_same("symbolic_shape_inputs")
89 ->check("block1")
90 ->check("aten::cat")
91 ->run(*g);
92
93 // clang-format off
94 /* Graph Should Look Something like: (note: strides not yet handled)
95 graph(%x_inp : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu),
96 %y_inp : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu),
97 %z_inp : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)):
98 %4 : bool = prim::TensorExprDynamicGuard[types=[Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)]](%x_inp, %y_inp, %z_inp)
99 %5 : Tensor = prim::If(%4)
100 block0():
101 %15 : int[] = aten::size(%x_inp)
102 %16 : int[] = aten::size(%y_inp)
103 %17 : int = prim::Constant[value=1]()
104 %18 : int = prim::Constant[value=0]()
105 %elem.3 : int = aten::__getitem__(%15, %18) # <string>:40:10
106 %elem.5 : int = aten::__getitem__(%15, %17) # <string>:40:10
107 %elem.11 : int = aten::__getitem__(%16, %18) # <string>:40:10
108 %cat_dim_size.48 : int = aten::add(%elem.3, %elem.11) # <string>:321:29
109 %3 : Tensor = prim::TensorExprGroup_0[symbolic_shape_inputs=[-5, -4, -3, -2]](%x_inp, %y_inp, %z_inp, %cat_dim_size.48, %elem.11, %elem.5, %elem.3)
110 -> (%3)
111 block1():
112 // FallbackGraph is inlined
113 %14 : Tensor = prim::FallbackGraph_1(%x_inp, %y_inp, %z_inp)
114 -> (%14)
115 return ()
116 with prim::TensorExprGroup_0 = graph(%x.1 : Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu),
117 %y.1 : Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu),
118 %z : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
119 %SS_5 : int,
120 %SS_4 : int,
121 %SS_3 : int,
122 %SS_2 : int):
123 %3 : int = prim::Constant[value=0]()
124 %4 : Tensor(SS(-2), SS(-3)) = aten::tanh(%x.1)
125 %5 : Tensor(SS(-2), SS(-3)) = aten::erf(%4)
126 %6 : Tensor(SS(-4), SS(-3)) = aten::relu(%y.1)
127 %7 : Tensor[] = prim::ListConstruct(%5, %6)
128 %8 : Tensor(SS(-5), SS(-3)) = aten::cat(%7, %3)
129 %9 : Tensor(SS(-5), SS(-3)) = aten::hardswish(%8)
130 %10 : Tensor(SS(-5), SS(-3)) = aten::mul(%9, %z)
131 return (%9)
132 */
133 // clang-format on
134
135 DepthFirstGraphNodeIterator graph_it(g);
136 Node* te_group = findNode(g, prim::TensorExprGroup);
137
138 /*
139 Test that input to the kernel - (10, 5), (4, 5), (1, 1) - are correctly
140 generalized to sym dimensions, and that the output - (10 + 4, 5)
141 correctly preserves non-catted dim as sym shape and catted dim as new sym
142 shape
143 */
144
145 auto tensorexpr_graph = te_group->g(attr::Subgraph);
146 auto inp1 = tensorexpr_graph->inputs().at(0)->type()->expect<TensorType>();
147 auto inp2 = tensorexpr_graph->inputs().at(1)->type()->expect<TensorType>();
148 auto inp3 = tensorexpr_graph->inputs().at(2)->type()->expect<TensorType>();
149 auto out = tensorexpr_graph->outputs().at(0)->type()->expect<TensorType>();
150
151 // 1 dims are preserved
152 auto inp3_sizes = inp3->sizes().concrete_sizes();
153 TORCH_INTERNAL_ASSERT(inp3_sizes);
154 TORCH_INTERNAL_ASSERT(
155 inp3_sizes->size() == 2 && inp3_sizes->at(0) == 1 &&
156 inp3_sizes->at(1) == 1);
157
158 // 5 made into sym shape
159 ASSERT_EQ(
160 inp1->symbolic_sizes()[1].value(), inp2->symbolic_sizes()[1].value());
161 ASSERT_EQ(
162 out->symbolic_sizes()[1].value(), inp2->symbolic_sizes()[1].value());
163
164 // 4, 10, 14 are different sym shapes
165 ASSERT_NE(
166 inp1->symbolic_sizes()[0].value(), inp2->symbolic_sizes()[0].value());
167 ASSERT_NE(
168 out->symbolic_sizes()[0].value(), inp1->symbolic_sizes()[0].value());
169 ASSERT_NE(
170 out->symbolic_sizes()[0].value(), inp2->symbolic_sizes()[0].value());
171
172 /*
173 Test guard behaves correctly at runtime and symbolic shapes are computed
174 correctly. As we don't have TE Kernel support for dynamic shapes we're
175 going to return all of the computed runtime symbolic dimensions as outputs
176 of the graph on guard success, and return None on guard failure
177 */
178
179 // Setting up guard to return sym shapes on guard success and None on failure
180 Node* if_node = findNode(g, prim::If);
181 IfView if_v(if_node);
182 if_node->eraseOutput(0);
183 if_v.thenBlock()->eraseOutput(0);
184 if_v.elseBlock()->eraseOutput(0);
185 WithInsertPoint guard(if_node);
186 auto none_val = g->insertConstant(IValue());
187
188 auto sym_shapes = te_group->is(Symbol::attr("symbolic_shape_inputs"));
189 auto offset = te_group->inputs().size() - sym_shapes.size();
190 for (size_t i = 0; i < sym_shapes.size(); ++i) {
191 if_v.thenBlock()->insertOutput(i, te_group->inputs().at(offset + i));
192 if_v.elseBlock()->insertOutput(i, none_val);
193 if_node->insertOutput(i)->setType(OptionalType::create(IntType::get()));
194 }
195
196 auto new_outputs = g->createTuple(if_node->outputs())->insertAfter(if_node);
197
198 g->registerOutput(new_outputs->output());
199 te_group->destroy();
200 findNode(g, prim::FallbackGraph)->destroy();
201
202 // Testing bad inputs
203
204 auto first_inp = at::rand({2, 5});
205 std::vector<std::vector<at::Tensor>> second_inps = {
206 {at::rand({3, 4}), at::rand({1, 1})}, // sym shape mismatch
207 {at::rand({5, 2}).transpose(0, 1), at::rand({1, 1})}, // discontiguous
208 {at::zeros({2, 5}).to(at::ScalarType::Int),
209 at::rand({1, 1})}, // wrong dtype
210 {at::rand({2, 5, 1}), at::rand({1, 1})}, // wrong # dims
211 {at::rand({2, 5}).requires_grad_(true),
212 at::rand({1, 1})}, // requires grad
213 {at::rand({2, 5}), at::rand({1, 12})}, // concrete dim mismatch (1)
214 };
215 if (torch::cuda::is_available()) {
216 second_inps.push_back({at::rand({2, 5}).cuda(), at::rand({1, 1})});
217 }
218 for (const auto& last_inps : second_inps) {
219 // todo - reusing interpreter across iters gave error
220 Code code(g, "");
221 InterpreterState interp(code);
222 auto stack = createStack({at::rand({2, 5}), last_inps[0], last_inps[1]});
223 interp.run(stack);
224 TORCH_INTERNAL_ASSERT(pop(stack).toTuple()->elements().at(0).isNone());
225 }
226
227 // Test good inputs
228 Code code(g, "");
229 InterpreterState interp(code);
230 std::vector<at::Tensor> inps = {
231 at::rand({2, 5}), at::rand({4, 5}), at::rand({1, 1})};
232 Stack stack(inps.begin(), inps.end());
233 interp.run(stack);
234 auto tuple = pop(stack).toTuple();
235 TORCH_INTERNAL_ASSERT(tuple->elements().at(0).isInt());
236
237 // Testing that the sym shape calculation was correct
238 for (size_t i = 0; i < sym_shapes.size(); ++i) {
239 auto sym_shape = sym_shapes[i];
240 auto computed_value = tuple->elements().at(i).toInt();
241 if (sym_shape == inp1->symbolic_sizes().at(0).value()) {
242 ASSERT_EQ(computed_value, 2);
243 } else if (sym_shape == inp1->symbolic_sizes().at(1).value()) {
244 ASSERT_EQ(computed_value, 5);
245 } else if (sym_shape == inp2->symbolic_sizes().at(0).value()) {
246 ASSERT_EQ(computed_value, 4);
247 } else if (sym_shape == out->symbolic_sizes().at(0).value()) {
248 ASSERT_EQ(computed_value, 6);
249 } else {
250 TORCH_INTERNAL_ASSERT(false);
251 }
252 }
253 }
254
TEST(ShapeAnalysisTest,MovingConstantOutOfFusionGroups)255 TEST(ShapeAnalysisTest, MovingConstantOutOfFusionGroups) {
256 std::shared_ptr<Graph> subgraph = std::make_shared<Graph>();
257 const auto graph_string = R"IR(
258 graph(%x.1 : Tensor):
259 %none : NoneType = prim::Constant()
260 %size1 : int = prim::Constant[value=1]()
261 %size10 : int = prim::Constant[value=10]()
262 %sizes : int[] = prim::ListConstruct(%size10, %size1)
263 %device : Device = prim::Constant[value="cpu"]()
264 %10 : Tensor = aten::ones(%sizes, %none, %none, %device, %none)
265 %3 : Tensor = aten::tanh(%x.1)
266 %29 : Tensor = aten::mul(%3, %10)
267 return (%29))IR";
268 torch::jit::parseIR(graph_string, subgraph.get());
269 ConstantPropagation(subgraph);
270
271 std::shared_ptr<Graph> g = std::make_shared<Graph>();
272 auto x_inp = g->addInput("x_inp");
273 auto x_type = TensorType::create(at::rand({10, 5}));
274 x_inp->setType(x_type);
275 subgraph->inputs().at(0)->setType(x_type);
276 subgraph->outputs().at(0)->setType(x_type);
277 auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
278 output->node()->addInput(x_inp);
279 output->node()->g_(attr::Subgraph, subgraph);
280
281 auto success = GenerateGuard(output->node());
282 TORCH_INTERNAL_ASSERT(success);
283
284 // Check that the constants have been moved out of the fused graph.
285 // This should result in not have any conditionals other than the one
286 // checking the result of TensorExprDynamicGuard.
287 testing::FileCheck()
288 .check("TensorExprDynamicGuard")
289 ->check_next("prim::If")
290 ->check_not("prim::If") // no other IFs due to constants.
291 ->check("TensorExprGroup")
292 ->check("block1")
293 ->check("FallbackGraph")
294 ->run(*g);
295 }
296
297 namespace {
298
299 std::optional<int64_t> sym_dim = std::nullopt;
300
301 // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
assertShapeEqual(c10::SymbolicShape & a,c10::SymbolicShape & e)302 void assertShapeEqual(c10::SymbolicShape& a, c10::SymbolicShape& e) {
303 auto a_canonical = CanonicalizedSymbolicShape(a);
304 auto e_canonical = CanonicalizedSymbolicShape(e);
305 EXPECT_EQ(a_canonical, e_canonical);
306 }
307
assertShapeEqual(std::optional<std::vector<c10::SymbolicShape>> & actual,std::vector<std::optional<int64_t>> expected)308 void assertShapeEqual(
309 std::optional<std::vector<c10::SymbolicShape>>& actual,
310 std::vector<std::optional<int64_t>> expected) {
311 ASSERT_TRUE(actual.has_value());
312 ASSERT_EQ(actual->size(), 1);
313
314 auto symb_expected = c10::SymbolicShape(expected);
315 assertShapeEqual(actual->at(0), symb_expected);
316 }
317
getSchema(const char * name)318 const FunctionSchema* getSchema(const char* name) {
319 return &(getOperatorForLiteral(name)->schema());
320 }
321 } // namespace
322
TEST(ShapeAnalysisTest,SymbolicShapeAPI)323 TEST(ShapeAnalysisTest, SymbolicShapeAPI) {
324 // Figure out how to fetch a function schema
325
326 // Ask someone else how to create a function schema / operator in C++
327 auto schema = getSchema(
328 "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
329
330 c10::IValue const_size_1 = std::vector<int64_t>{64, 56, 56};
331 c10::IValue const_size_2 = std::vector<int64_t>{1, 56, 56};
332
333 // Check vector initializer list syntax
334 c10::SymbolicShape ss_concrete =
335 std::vector<std::optional<int64_t>>{1, 56, 56};
336 c10::SymbolicShape ss1 = std::vector<std::optional<int64_t>>{sym_dim, 56, 56};
337 c10::SymbolicShape ss2 =
338 std::vector<std::optional<int64_t>>{64, sym_dim, sym_dim};
339 c10::SymbolicShape ss3 =
340 std::vector<std::optional<int64_t>>{sym_dim, sym_dim, sym_dim, sym_dim};
341
342 auto res = calculateSymbolicShapesOnOp(
343 schema, std::vector<SSAInput>{const_size_1, const_size_1});
344 assertShapeEqual(res, {64, 56, 56});
345
346 res = calculateSymbolicShapesOnOp(
347 schema, std::vector<SSAInput>{const_size_1, const_size_2});
348 assertShapeEqual(res, {64, 56, 56});
349
350 res = calculateSymbolicShapesOnOp(
351 schema, std::vector<SSAInput>{const_size_1, ss1});
352 assertShapeEqual(res, {64, 56, 56});
353
354 res = calculateSymbolicShapesOnOp(
355 schema, std::vector<SSAInput>{const_size_2, ss1});
356 assertShapeEqual(res, {sym_dim, 56, 56});
357
358 res = calculateSymbolicShapesOnOp(
359 schema, std::vector<SSAInput>{ss_concrete, ss2});
360 assertShapeEqual(res, {64, 56, 56});
361
362 res = calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{ss2, ss3});
363 assertShapeEqual(res, {sym_dim, 64, sym_dim, sym_dim});
364 }
365
TEST(ShapeAnalysisTest,BoundedSymbolicShapes)366 TEST(ShapeAnalysisTest, BoundedSymbolicShapes) {
367 auto schema = getSchema("aten::nonzero(Tensor self) -> (Tensor)");
368
369 // Test that we generate symbolic shapes for the output of a nonzero op
370 c10::IValue const_size_1 = std::vector<int64_t>{5, 10};
371 auto res =
372 calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{const_size_1});
373 assertShapeEqual(res, {sym_dim, 2});
374
375 // Test that nonzero can also create concrete shapes
376 c10::IValue const_size_2 = std::vector<int64_t>({1, 0});
377 res =
378 calculateSymbolicShapesOnOp(schema, std::vector<SSAInput>{const_size_2});
379 assertShapeEqual(res, {0, 2});
380 }
381
TEST(ShapeAnalysisTest,SymbolicShapeCaching)382 TEST(ShapeAnalysisTest, SymbolicShapeCaching) {
383 clear_shape_cache();
384 auto schema = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
385
386 c10::IValue const_size_1 = std::vector<int64_t>{64, 56};
387 c10::IValue const_size_2 = std::vector<int64_t>{64, 56};
388 c10::IValue const_size_3 = std::vector<int64_t>{64, 20};
389
390 c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
391 c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
392 c10::SymbolicShape ss3 = c10::SymbolicShape({sym_dim, sym_dim});
393
394 auto res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1});
395 assertShapeEqual(res, {sym_dim, 56});
396 auto res1_val = res->at(0);
397
398 // The exact same arguments should return the exact same result
399 res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_1});
400 auto res2_val = res->at(0);
401 EXPECT_EQ(res1_val, res2_val);
402 EXPECT_EQ(get_shape_cache_size(), 1);
403
404 // Same shape but different symbols should return same shape
405 // but different symbolic indices
406 res = calculateSymbolicShapesOnOp(schema, {ss2, const_size_2});
407 auto res3_val = res->at(0);
408
409 assertShapeEqual(res3_val, res2_val);
410 EXPECT_NE(res3_val, res2_val);
411 EXPECT_EQ(get_shape_cache_size(), 1);
412
413 // Different concrete shape should be cached separately
414 res = calculateSymbolicShapesOnOp(schema, {ss1, const_size_3});
415 assertShapeEqual(res, {sym_dim, 20});
416 EXPECT_EQ(get_shape_cache_size(), 2);
417
418 res = calculateSymbolicShapesOnOp(schema, {ss3, const_size_3});
419 assertShapeEqual(res, {sym_dim, 20});
420 EXPECT_EQ(get_shape_cache_size(), 3);
421
422 res = calculateSymbolicShapesOnOp(schema, {ss3, ss3});
423 assertShapeEqual(res, {sym_dim, sym_dim});
424 EXPECT_EQ(get_shape_cache_size(), 4);
425 }
426
TEST(ShapeAnalysisTest,ShapeCacheMultipleFns)427 TEST(ShapeAnalysisTest, ShapeCacheMultipleFns) {
428 clear_shape_cache();
429
430 auto squeeze_op =
431 getSchema("aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)");
432 auto mul_tensor =
433 getSchema("aten::mul.Tensor(Tensor self, Tensor other) -> Tensor");
434 auto mul_scalar =
435 getSchema("aten::mul.Scalar(Tensor self, Scalar other) -> Tensor");
436 auto div_tensor =
437 getSchema("aten::div.Tensor(Tensor self, Tensor other) -> Tensor");
438 auto matmul = getSchema("aten::mm(Tensor self, Tensor mat2) -> Tensor");
439
440 c10::IValue const_int = 1;
441
442 c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
443
444 auto res = calculateSymbolicShapesOnOp(squeeze_op, {ss1, const_int});
445 assertShapeEqual(res, {sym_dim, 64});
446
447 // Show that cache can handle multiple functions
448 res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int});
449 assertShapeEqual(res, {sym_dim, 64});
450 EXPECT_EQ(get_shape_cache_size(), 2);
451
452 res = calculateSymbolicShapesOnOp(mul_tensor, {ss1, ss1});
453 assertShapeEqual(res, {sym_dim, 64});
454 EXPECT_EQ(get_shape_cache_size(), 3);
455
456 // Even when the expected outcome is the same, should not collide
457 res = calculateSymbolicShapesOnOp(div_tensor, {ss1, ss1});
458 assertShapeEqual(res, {sym_dim, 64});
459 EXPECT_EQ(get_shape_cache_size(), 4);
460
461 // Don't lose cached objects
462 res = calculateSymbolicShapesOnOp(mul_scalar, {ss1, const_int});
463 assertShapeEqual(res, {sym_dim, 64});
464 EXPECT_EQ(get_shape_cache_size(), 4);
465
466 res = calculateSymbolicShapesOnOp(matmul, {ss1, ss1});
467 // SSA can infer that sym_dim is 64 as both tensors
468 // use the same sym_dim
469 assertShapeEqual(res, {64, 64});
470 EXPECT_EQ(get_shape_cache_size(), 5);
471 }
472
TEST(ShapeAnalysisTest,TestShapeMultipleReturns)473 TEST(ShapeAnalysisTest, TestShapeMultipleReturns) {
474 clear_shape_cache();
475
476 auto max_dim_op = getSchema(
477 "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)");
478 c10::IValue const_int = 1;
479 c10::IValue false_ival = false;
480
481 c10::SymbolicShape ss1 = c10::SymbolicShape({sym_dim, 64});
482 c10::SymbolicShape ss2 = c10::SymbolicShape({sym_dim, 64});
483
484 auto res =
485 calculateSymbolicShapesOnOp(max_dim_op, {ss1, const_int, false_ival});
486 c10::SymbolicShape expected_res =
487 c10::SymbolicShape(std::vector<std::optional<int64_t>>{sym_dim});
488 assertShapeEqual(res->at(0), expected_res);
489 // res0 and res1 should share the same symbolic symbol
490 EXPECT_EQ(res->at(0), res->at(1));
491
492 // Also test that the shape cache also returns consistent result shapes
493 res = calculateSymbolicShapesOnOp(max_dim_op, {ss2, const_int, false_ival});
494 assertShapeEqual(res->at(0), expected_res);
495 EXPECT_EQ(res->at(0), res->at(1));
496 EXPECT_EQ(get_shape_cache_size(), 1);
497 }
498 } // namespace jit
499 } // namespace torch
500