xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_concat_opt.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/Functions.h>
4 #include <test/cpp/jit/test_utils.h>
5 #include <torch/csrc/jit/ir/irparser.h>
6 #include <torch/csrc/jit/passes/concat_opt.h>
7 #include <torch/csrc/jit/passes/variadic_ops.h>
8 #include <torch/csrc/jit/runtime/interpreter.h>
9 #include <torch/csrc/jit/testing/file_check.h>
10 
11 namespace torch {
12 namespace jit {
13 
TEST(ConcatOptTest,SimpleCommonInputsEliminationPrefix)14 TEST(ConcatOptTest, SimpleCommonInputsEliminationPrefix) {
15   auto graph = std::make_shared<Graph>();
16 
17   const std::string input =
18       R"IR(
19         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
20               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
21               %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
22           %5 : int = prim::Constant[value=0]()
23           %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5)
24           %concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5)
25           %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3)
26           return (%res)
27       )IR";
28   parseIR(input, graph.get());
29   std::vector<at::Tensor> inputs = {
30       at::rand({64, 56, 56}, at::kCPU),
31       at::rand({32, 56, 56}, at::kCPU),
32       at::rand({32, 56, 56}, at::kCPU)};
33   auto orig_outputs = runGraph(graph, inputs);
34 
35   ASSERT_TRUE(EliminateConcatCommonInputs(graph));
36   graph->lint();
37   auto opt_outputs = runGraph(graph, inputs);
38   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
39 
40   // Graph after EliminateConcatCommonInputs:
41   //  graph(%0 : ...,
42   //        %1 : ...,
43   //        %2 : ...):
44   //    %3 : int = prim::Constant[value=0]()
45   //    %4 : Tensor = prim::VarConcat(%0, %1, %3)
46   //    %7 : Tensor = prim::VarConcat(%4, %2, %3) // UPDATED
47   //    %8 : Tensor[] = prim::ListConstruct(%4, %7)
48   //    return (%8)
49 
50   testing::FileCheck()
51       .check_count("= prim::VarConcat(%0, %1, %3)", 1, /*exactly*/ true)
52       ->check_count("= prim::VarConcat(%4, %2, %3)", 1, /*exactly*/ true)
53       ->check_count("= prim::ListConstruct(%4, %7)", 1, /*exactly*/ true)
54       ->check_count("= aten::cat(", 0, /*exactly*/ true)
55       ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
56       ->run(*graph);
57 }
58 
TEST(ConcatOptTest,SimpleCommonInputsEliminationSuffix)59 TEST(ConcatOptTest, SimpleCommonInputsEliminationSuffix) {
60   auto graph = std::make_shared<Graph>();
61 
62   const std::string input =
63       R"IR(
64         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
65               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
66               %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
67           %5 : int = prim::Constant[value=0]()
68           %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%1, %2, %5)
69           %concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5)
70           %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3)
71           return (%res)
72       )IR";
73   parseIR(input, graph.get());
74   std::vector<at::Tensor> inputs = {
75       at::rand({64, 56, 56}, at::kCPU),
76       at::rand({32, 56, 56}, at::kCPU),
77       at::rand({32, 56, 56}, at::kCPU)};
78   auto orig_outputs = runGraph(graph, inputs);
79 
80   ASSERT_TRUE(EliminateConcatCommonInputs(graph));
81   graph->lint();
82   auto opt_outputs = runGraph(graph, inputs);
83   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
84 
85   // Graph after EliminateConcatCommonInputs:
86   //  graph(%0 : ...,
87   //        %1 : ...,
88   //        %2 : ...):
89   //    %3 : int = prim::Constant[value=0]()
90   //    %4 : Tensor = prim::VarConcat(%1, %2, %3)
91   //    %7 : Tensor = prim::VarConcat(%0, %4, %3) // UPDATED
92   //    %8 : Tensor[] = prim::ListConstruct(%4, %7)
93   //    return (%8)
94 
95   testing::FileCheck()
96       .check_count("= prim::VarConcat(%1, %2, %3)", 1, /*exactly*/ true)
97       ->check_count("= prim::VarConcat(%0, %4, %3)", 1, /*exactly*/ true)
98       ->check_count("= prim::ListConstruct(%4, %7)", 1, /*exactly*/ true)
99       ->check_count("= aten::cat(", 0, /*exactly*/ true)
100       ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
101       ->run(*graph);
102 }
103 
TEST(ConcatOptTest,CommonInputsEliminationWithDifferentOrderInputs)104 TEST(ConcatOptTest, CommonInputsEliminationWithDifferentOrderInputs) {
105   auto graph = std::make_shared<Graph>();
106 
107   const std::string input =
108       R"IR(
109         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
110               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
111               %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
112           %5 : int = prim::Constant[value=0]()
113 
114           #CHECK: prim::VarConcat
115           %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5)
116 
117           #CHECK: prim::VarConcat
118           %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%1, %0, %2, %5)
119 
120           #CHECK: prim::ListConstruct
121           %res : Tensor[] = prim::ListConstruct(%concat.1, %concat.2)
122           return (%res)
123       )IR";
124   parseIR(input, graph.get());
125   std::vector<at::Tensor> inputs = {
126       at::rand({64, 56, 56}, at::kCPU),
127       at::rand({32, 56, 56}, at::kCPU),
128       at::rand({32, 56, 56}, at::kCPU)};
129   auto orig_outputs = runGraph(graph, inputs);
130 
131   ASSERT_FALSE(EliminateConcatCommonInputs(graph));
132   graph->lint();
133   auto opt_outputs = runGraph(graph, inputs);
134 
135   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
136 
137   // No optimizations should have happened in this case since the inputs
138   // to the `cat` are in different order.
139   testing::FileCheck().run(input, *graph);
140 }
141 
TEST(ConcatOptTest,MoreCommonInputsElimination)142 TEST(ConcatOptTest, MoreCommonInputsElimination) {
143   auto graph = std::make_shared<Graph>();
144 
145   const std::string input =
146       R"IR(
147         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
148               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
149               %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
150               %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
151               %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
152           %5 : int = prim::Constant[value=0]()
153           %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5)
154           %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5)
155           %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %3, %5)
156           %concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %3, %4, %5)
157           %res : Tensor[] = prim::ListConstruct(%concat.1, %concat.2, %concat.3, %concat.4)
158           return (%res)
159       )IR";
160   parseIR(input, graph.get());
161   std::vector<at::Tensor> inputs = {
162       at::rand({64, 56, 56}, at::kCPU),
163       at::rand({32, 56, 56}, at::kCPU),
164       at::rand({32, 56, 56}, at::kCPU),
165       at::rand({32, 56, 56}, at::kCPU),
166       at::rand({32, 56, 56}, at::kCPU)};
167   auto orig_outputs = runGraph(graph, inputs);
168 
169   ASSERT_TRUE(EliminateConcatCommonInputs(graph));
170   graph->lint();
171   auto opt_outputs = runGraph(graph, inputs);
172   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
173 
174   testing::FileCheck()
175       .check_count("= prim::VarConcat(%0, %1, %5)", 1, /*exactly*/ true)
176       ->check_count("= prim::VarConcat(%6, %2, %5)", 1, /*exactly*/ true)
177       ->check_count("= prim::VarConcat(%11, %3, %5)", 1, /*exactly*/ true)
178       ->check_count("= prim::VarConcat(%12, %4, %5)", 1, /*exactly*/ true)
179       ->check_count("= aten::cat(", 0, /*exactly*/ true)
180       ->run(*graph);
181 }
182 
TEST(ConcatOptTest,ExpandConcat)183 TEST(ConcatOptTest, ExpandConcat) {
184   auto graph = std::make_shared<Graph>();
185 
186   const std::string input =
187       R"IR(
188         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
189               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
190           %2 : int = prim::Constant[value=0]()
191           %3 : float = prim::Constant[value=0.5]()
192           %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3)
193           %5 : Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%1, %3)
194           %input : Tensor[] = prim::ListConstruct(%4, %5)
195           %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %2)
196           return (%concat)
197       )IR";
198   parseIR(input, graph.get());
199   std::vector<at::Tensor> inputs = {
200       at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)};
201   auto orig_outputs = runGraph(graph, inputs);
202 
203   ExpandConcatAndEliminateRedundancy(graph);
204   graph->lint();
205   auto opt_outputs = runGraph(graph, inputs);
206 
207   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
208 
209   // After full concat optimization we should have the following graph:
210   //
211   //  graph(%0 : ...,
212   //        %1 : ...):
213   //    ...
214   //    %4 : Tensor = aten::clamp_max(...)
215   //    %5 : Tensor = aten::clamp_max(...)
216   //    %13 : int[] = prim::ListConstruct(...)
217   //    %14 : Tensor = aten::empty(%13, ...)    // concat buffer
218   //    %17 : Tensor = aten::slice(%14, ...)    // slice for %4
219   //    %18 : Tensor = aten::copy_(%17, %4)
220   //    %20 : Tensor = aten::slice(%14, ...)    // slice for %5
221   //    %21 : Tensor = aten::copy_(%20, %5)
222   //    return (%14)
223   testing::FileCheck()
224       .check_count("= aten::cat(", 0, /*exactly*/ true)
225       ->check_count("= aten::clamp_max(", 2, /*exactly*/ true)
226       ->check_count("= aten::empty(", 1, /*exactly*/ true)
227       ->check_count("= aten::slice(", 1, /*exactly*/ true)
228       ->check_count("= aten::copy_(", 1, /*exactly*/ true)
229       ->check_count("= aten::slice(", 1, /*exactly*/ true)
230       ->check_count("= aten::copy_(", 1, /*exactly*/ true)
231       ->check_count("= aten::cat(", 0, /*exactly*/ true)
232       ->run(*graph);
233 }
234 
TEST(ConcatOptTest,ConcatWithoutResultShape)235 TEST(ConcatOptTest, ConcatWithoutResultShape) {
236   auto graph = std::make_shared<Graph>();
237 
238   const std::string input =
239       R"IR(
240         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
241               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
242           %2 : int = prim::Constant[value=0]()
243           %3 : float = prim::Constant[value=0.5]()
244           # CHECK: clamp_max
245           %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3)
246           # CHECK: clamp_max
247           %5 : Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%1, %3)
248           # CHECK: prim::ListConstruct
249           %6 : Tensor[] = prim::ListConstruct(%4, %5)
250           # CHECK: aten::cat
251           %7 : Tensor = aten::cat(%6, %2)
252           return (%7)
253       )IR";
254   parseIR(input, graph.get());
255   std::vector<at::Tensor> inputs = {
256       at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)};
257   auto orig_outputs = runGraph(graph, inputs);
258 
259   ExpandConcatAndEliminateRedundancy(graph);
260   graph->lint();
261   auto opt_outputs = runGraph(graph, inputs);
262 
263   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
264 
265   // No optimizations should have happened in this case since the output
266   // shape of `aten::cat` is not known.
267   testing::FileCheck().run(input, *graph);
268 }
269 
TEST(ConcatOptTest,ConcatWithoutInputShape)270 TEST(ConcatOptTest, ConcatWithoutInputShape) {
271   auto graph = std::make_shared<Graph>();
272 
273   const std::string input =
274       R"IR(
275         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
276               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
277           %2 : int = prim::Constant[value=0]()
278           %3 : float = prim::Constant[value=0.5]()
279           # CHECK: clamp_max
280           %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3)
281           # CHECK: clamp_max
282           %5 : Tensor = aten::clamp_max(%1, %3)
283           # CHECK: prim::ListConstruct
284           %6 : Tensor[] = prim::ListConstruct(%4, %5)
285           # CHECK: aten::cat
286           %7 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%6, %2)
287           return (%7)
288       )IR";
289   parseIR(input, graph.get());
290   std::vector<at::Tensor> inputs = {
291       at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)};
292   auto orig_outputs = runGraph(graph, inputs);
293 
294   ExpandConcatAndEliminateRedundancy(graph);
295   graph->lint();
296   auto opt_outputs = runGraph(graph, inputs);
297 
298   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
299 
300   // No optimizations should have happened in this case since the shape of %5,
301   // which is an input to `aten::cat`, is not known.
302   testing::FileCheck().run(input, *graph);
303 }
304 
TEST(ConcatOptTest,UseVariadicCat)305 TEST(ConcatOptTest, UseVariadicCat) {
306   auto graph = std::make_shared<Graph>();
307 
308   const std::string input =
309       R"IR(
310         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
311               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
312               %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
313               %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
314               %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
315               %5: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
316           %10 : int = prim::Constant[value=0]()
317           %input : Tensor[] = prim::ListConstruct(%0, %1, %2, %3, %4, %5)
318           %concat : Float(224, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
319           return (%concat)
320       )IR";
321   parseIR(input, graph.get());
322   std::vector<at::Tensor> inputs = {
323       at::rand({64, 56, 56}, at::kCPU),
324       at::rand({32, 56, 56}, at::kCPU),
325       at::rand({32, 56, 56}, at::kCPU),
326       at::rand({32, 56, 56}, at::kCPU),
327       at::rand({32, 56, 56}, at::kCPU),
328       at::rand({32, 56, 56}, at::kCPU)};
329   auto orig_outputs = runGraph(graph, inputs);
330 
331   ASSERT_TRUE(UseVariadicCat(graph));
332   graph->lint();
333   auto opt_outputs = runGraph(graph, inputs);
334 
335   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
336 
337   // After replacing `aten::cat` with `prim::VarConcat` we should have the
338   // following graph:
339   //
340   //  graph(%0 : ...,
341   //        %1 : ...):
342   //    %zero : int = prim:Constant[value=0]()
343   //    %varcat : Tensor = prim::VarConcat(%0, %1, %2, %3, %4, %5, %zero)
344   //    return (%varcat)
345   testing::FileCheck()
346       .check_count("= prim::VarConcat(", 1, /*exactly*/ true)
347       ->check_count("= aten::cat(", 0, /*exactly*/ true)
348       ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
349       ->run(*graph);
350 }
351 
TEST(OptimizeConcatTest,UseVariadicCatReplaceMultiple)352 TEST(OptimizeConcatTest, UseVariadicCatReplaceMultiple) {
353   auto graph = std::make_shared<Graph>();
354 
355   const std::string input =
356       R"IR(
357         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
358               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
359               %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
360               %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
361           %10 : int = prim::Constant[value=0]()
362           %input1 : Tensor[] = prim::ListConstruct(%0, %1)
363           %concat1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input1, %10)
364           %input2 : Tensor[] = prim::ListConstruct(%2, %3)
365           %concat2 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input2, %10)
366           return (%concat1, %concat2)
367       )IR";
368   parseIR(input, graph.get());
369   std::vector<at::Tensor> inputs = {
370       at::rand({64, 56, 56}, at::kCPU),
371       at::rand({32, 56, 56}, at::kCPU),
372       at::rand({32, 56, 56}, at::kCPU),
373       at::rand({32, 56, 56}, at::kCPU)};
374   auto orig_outputs = runGraph(graph, inputs);
375 
376   ASSERT_TRUE(UseVariadicCat(graph));
377   graph->lint();
378   auto opt_outputs = runGraph(graph, inputs);
379 
380   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
381 
382   // After full concat optimization we should have the following graph:
383   //
384   //  graph(%0 : ...,
385   //        %1 : ...,
386   //        %2 : ...,
387   //        %3 : ....):
388   //    %zero : int = prim:Constant[value=0]()
389   //    %varcat1 : Tensor = prim::VarConcat(%0, %1, %zero)
390   //    %varcat2 : Tensor = prim::VarConcat(%2, %3, %zero)
391   //    return (%varcat1, %varcat2)
392   testing::FileCheck()
393       .check_count("= prim::VarConcat(", 2, /*exactly*/ true)
394       ->check_count("= aten::cat(", 0, /*exactly*/ true)
395       ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
396       ->run(*graph);
397 }
398 
TEST(ConcatOptTest,UseVariadicCatWithMultipleListUses)399 TEST(ConcatOptTest, UseVariadicCatWithMultipleListUses) {
400   auto graph = std::make_shared<Graph>();
401 
402   const std::string input =
403       R"IR(
404         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
405               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
406           %2 : int = prim::Constant[value=0]()
407           %input : Tensor[] = prim::ListConstruct(%0, %1)
408           %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %2)
409           return (%concat, %input)
410       )IR";
411   parseIR(input, graph.get());
412   std::vector<at::Tensor> inputs = {
413       at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)};
414   auto orig_outputs = runGraph(graph, inputs);
415 
416   ASSERT_TRUE(UseVariadicCat(graph));
417   graph->lint();
418   auto opt_outputs = runGraph(graph, inputs);
419 
420   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
421 
422   // After replacing `aten::cat` with `prim::VarConcat` we should have the
423   // following graph:
424   //
425   //  graph(%0 : ...,
426   //        %1 : ...):
427   //    %zero : int = prim:Constant[value=0]()
428   //    %input : Tensor[] = prim::ListConstruct(%0, %1)
429   //    %varcat : Tensor = prim::VarConcat(%0, %1, %zero)
430   //    return (%varcat, %input)
431   testing::FileCheck()
432       .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
433       ->check_count("= prim::VarConcat(", 1, /*exactly*/ true)
434       ->check_count("= aten::cat(", 0, /*exactly*/ true)
435       ->run(*graph);
436 }
437 
TEST(ConcatOptTest,UseVariadicCatWithListMutationAfterCat)438 TEST(ConcatOptTest, UseVariadicCatWithListMutationAfterCat) {
439   auto graph = std::make_shared<Graph>();
440 
441   const std::string input =
442       R"IR(
443         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
444               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
445               %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
446           %10 : int = prim::Constant[value=0]()
447           %input : Tensor[] = prim::ListConstruct(%0, %1)
448           %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
449           %11 : Tensor = aten::append(%input, %2)
450           return (%concat, %input)
451       )IR";
452   parseIR(input, graph.get());
453   std::vector<at::Tensor> inputs = {
454       at::rand({64, 56, 56}, at::kCPU),
455       at::rand({32, 56, 56}, at::kCPU),
456       at::rand({32, 56, 56}, at::kCPU)};
457   auto orig_outputs = runGraph(graph, inputs);
458 
459   ASSERT_TRUE(UseVariadicCat(graph));
460   graph->lint();
461   auto opt_outputs = runGraph(graph, inputs);
462   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
463 
464   // The input list to `aten::cat` is mutated only after `aten::cat` op. So,
465   // it should have been replaced with `prim::VarConcat`. The transformed graph
466   // should look like the following:
467   //
468   //  graph(%0 : ...,
469   //        %1 : ...,
470   //        %2 : ...):
471   //    %3 : int = prim:Constant[value=0]()
472   //    %4 : Tensor[] = prim::ListConstruct(%0, %1)
473   //    %7 : Tensor = prim::VarConcat(%0, %1, %3)
474   //    %6 : Tensor = aten::append(%4, %2)
475   //    return (%7, %4)
476   testing::FileCheck()
477       .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
478       ->check_count("= prim::VarConcat(", 1, /*exactly*/ true)
479       ->check_count("= aten::cat(", 0, /*exactly*/ true)
480       ->run(*graph);
481 }
482 
TEST(ConcatOptTest,UseVariadicCatWithListMutationBeforeCat)483 TEST(ConcatOptTest, UseVariadicCatWithListMutationBeforeCat) {
484   auto graph = std::make_shared<Graph>();
485 
486   const std::string input =
487       R"IR(
488         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
489               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
490               %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
491           %10 : int = prim::Constant[value=0]()
492           %input : Tensor[] = prim::ListConstruct(%0, %1)
493           %11 : Tensor = aten::append(%input, %2)
494           %concat : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
495           return (%concat)
496       )IR";
497   parseIR(input, graph.get());
498   std::vector<at::Tensor> inputs = {
499       at::rand({64, 56, 56}, at::kCPU),
500       at::rand({32, 56, 56}, at::kCPU),
501       at::rand({32, 56, 56}, at::kCPU)};
502   auto orig_outputs = runGraph(graph, inputs);
503 
504   {
505     ASSERT_FALSE(UseVariadicCat(graph));
506     graph->lint();
507     auto opt_outputs = runGraph(graph, inputs);
508     ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
509 
510     // No transformation should have happened since the `prim::ListConstruct` is
511     // mutated before `aten::cat`.
512     testing::FileCheck()
513         .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
514         ->check_count("= aten::cat(", 1, /*exactly*/ true)
515         ->check_count("= prim::VarConcat(", 0, /*exactly*/ true)
516         ->run(*graph);
517   }
518 
519   {
520     ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
521     graph->lint();
522     auto opt_outputs = runGraph(graph, inputs);
523     ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
524 
525     // The mutation of the list must be removed and the `aten::cat` op must
526     // be replaced with the `prim::VarConcat` op in the graph. The transformed
527     // graph should look like the following:
528     //
529     //  graph(%0 : ...,
530     //        %1 : ...,
531     //        %2 : ...):
532     //    %3 : int = prim:Constant[value=0]()
533     //    %7 : Tensor = prim::VarConcat(%0, %1, %2, %3)
534     //    return (%7)
535     testing::FileCheck()
536         .check_count("= prim::VarConcat(", 1, /*exactly*/ true)
537         ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
538         ->check_count("= aten::cat(", 0, /*exactly*/ true)
539         ->run(*graph);
540   }
541 }
542 
TEST(ConcatOptTest,UseVariadicCatWithMultipleListMutations)543 TEST(ConcatOptTest, UseVariadicCatWithMultipleListMutations) {
544   auto graph = std::make_shared<Graph>();
545 
546   const std::string input =
547       R"IR(
548         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
549               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
550               %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
551               %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
552               %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
553           %10 : int = prim::Constant[value=0]()
554           %input : Tensor[] = prim::ListConstruct(%0, %1)
555           %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
556           %11 : Tensor = aten::append(%input, %2)
557           %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
558           %12 : Tensor = aten::append(%input, %3)
559           %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
560           %13 : Tensor = aten::append(%input, %4)
561           %concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
562           return (%concat.1, %concat.2, %concat.3, %concat.4)
563       )IR";
564   parseIR(input, graph.get());
565   std::vector<at::Tensor> inputs = {
566       at::rand({64, 56, 56}, at::kCPU),
567       at::rand({32, 56, 56}, at::kCPU),
568       at::rand({32, 56, 56}, at::kCPU),
569       at::rand({32, 56, 56}, at::kCPU),
570       at::rand({32, 56, 56}, at::kCPU)};
571   auto orig_outputs = runGraph(graph, inputs);
572 
573   ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
574   graph->lint();
575   auto opt_outputs = runGraph(graph, inputs);
576   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
577 
578   // All the mutations of the list must be removed and the `aten::cat` ops must
579   // be replaced with `prim::VarConcat` ops in the graph. The transformed graph
580   // should look like the following:
581   //
582   //  graph(%0 : ...,
583   //        %1 : ...,
584   //        %2 : ...,
585   //        %3 : ...,
586   //        %4 : ...):
587   //    %10 : int = prim:Constant[value=0]()
588   //    %5 : Tensor = prim::VarConcat(%0, %1, %10)
589   //    %6 : Tensor = prim::VarConcat(%0, %1, %2, %10)
590   //    %7 : Tensor = prim::VarConcat(%0, %1, %2, %3, %10)
591   //    %8 : Tensor = prim::VarConcat(%0, %1, %2, %3, %4, %10)
592   //    return (%5, %6, %7, %8)
593   testing::FileCheck()
594       .check_count("= prim::VarConcat(", 4, /*exactly*/ true)
595       ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
596       ->check_count("= aten::cat(", 0, /*exactly*/ true)
597       ->run(*graph);
598 }
599 
TEST(ConcatOptTest,RemoveListMutationUseVariadicCatAndCommonInputsElimination)600 TEST(
601     ConcatOptTest,
602     RemoveListMutationUseVariadicCatAndCommonInputsElimination) {
603   auto graph = std::make_shared<Graph>();
604 
605   const std::string input =
606       R"IR(
607         graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
608               %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
609               %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
610           %5 : int = prim::Constant[value=0]()
611 
612           %features.2 : Tensor[] = prim::ListConstruct(%0, %1)
613           %6 : Tensor [] = aten::append(%features.2, %2)
614           %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%features.2, %5)
615 
616           %7 : Tensor [] = aten::append(%features.2, %0)
617           %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%features.2, %5)
618 
619           %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3)
620           return (%res)
621       )IR";
622   parseIR(input, graph.get());
623   std::vector<at::Tensor> inputs = {
624       at::rand({64, 56, 56}, at::kCPU),
625       at::rand({32, 56, 56}, at::kCPU),
626       at::rand({32, 56, 56}, at::kCPU)};
627   auto orig_outputs = runGraph(graph, inputs);
628 
629   ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
630   ASSERT_TRUE(EliminateConcatCommonInputs(graph));
631   graph->lint();
632   auto opt_outputs = runGraph(graph, inputs);
633   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
634 
635   // After performing:
636   //     * Remove list mutation
637   //     * Use variadic cat
638   //     * Eliminate common inputs
639   // we should have the following graph:
640   //
641   //  graph(%0 : ...,
642   //        %1 : ...,
643   //        %2 : ...):
644   //    %3 : int = prim::Constant[value=0]()
645   //    %10 : Tensor = prim::VarConcat(%0, %1, %2, %3)
646   //    %12 : Tensor = prim::VarConcat(%10, %0, %3) // UPDATED
647   //    %8 : Tensor[] = prim::ListConstruct(%10, %12)
648   //    return (%8)
649   testing::FileCheck()
650       .check_count("= prim::VarConcat(%0, %1, %2, %3)", 1, /*exactly*/ true)
651       ->check_count("= prim::VarConcat(%10, %0, %3)", 1, /*exactly*/ true)
652       ->check_count("= prim::ListConstruct(%10, %12)", 1, /*exactly*/ true)
653       ->check_count("= aten::cat(", 0, /*exactly*/ true)
654       ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
655       ->run(*graph);
656 }
657 
TEST(ConcatOpt,CombineConcatsSimpleCase)658 TEST(ConcatOpt, CombineConcatsSimpleCase) {
659   auto graph = std::make_shared<Graph>();
660   const std::string input =
661       R"IR(
662         graph(%0: Tensor):
663           %dim : int = prim::Constant[value=0]()
664           %input.1 : Tensor[] = prim::ListConstruct(%0, %0)
665           %concat.1 : Tensor = aten::cat(%input.1, %dim)
666           %input.2 : Tensor[] = prim::ListConstruct(%concat.1, %0)
667           %concat.2 : Tensor = aten::cat(%input.2, %dim)
668           return (%concat.2)
669       )IR";
670   parseIR(input, graph.get());
671   std::vector<at::Tensor> inputs = {at::rand({1})};
672   auto orig_outputs = runGraph(graph, inputs);
673 
674   ASSERT_TRUE(CombineConcats(graph));
675   graph->lint();
676   auto opt_outputs = runGraph(graph, inputs);
677   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
678 
679   // After performing CombineConcats:
680   //  graph(%0 : Tensor):
681   //    %dim : int = prim::Constant[value=0]()
682   //    %input : Tensor[] = prim::ListConstruct(%0, %0, %0)
683   //    %concat : Tensor = aten::cat(%input, %dim)
684   //    return (%concat)
685   testing::FileCheck()
686       .check_count("prim::ListConstruct", 1, /*exactly*/ true)
687       ->check_count("aten::cat", 1, /*exactly*/ true)
688       ->run(*graph);
689 }
690 
TEST(ConcatOpt,CombineConcatsLongChain)691 TEST(ConcatOpt, CombineConcatsLongChain) {
692   auto graph = std::make_shared<Graph>();
693   const std::string input =
694       R"IR(
695         graph(%0: Tensor, %1 : Tensor):
696           %dim : int = prim::Constant[value=0]()
697           %input.1 : Tensor[] = prim::ListConstruct(%0, %0)
698           %concat.1 : Tensor = aten::cat(%input.1, %dim)
699           %input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1)
700           %concat.2 : Tensor = aten::cat(%input.2, %dim)
701           %input.3 : Tensor[] = prim::ListConstruct(%0, %concat.2, %0)
702           %concat.3 : Tensor = aten::cat(%input.3, %dim)
703           return (%concat.3)
704       )IR";
705   parseIR(input, graph.get());
706   std::vector<at::Tensor> inputs = {at::rand({1}), at::randn({1})};
707   auto orig_outputs = runGraph(graph, inputs);
708 
709   ASSERT_TRUE(CombineConcats(graph));
710   graph->lint();
711   auto opt_outputs = runGraph(graph, inputs);
712   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
713 
714   // After performing CombineConcats:
715   //  graph(%0 : Tensor):
716   //    %dim : int = prim::Constant[value=0]()
717   //    %input : Tensor[] = prim::ListConstruct(%0, %1, %0, %0, %1, %0)
718   //    %concat : Tensor = aten::cat(%input, %dim)
719   //    return (%concat)
720   testing::FileCheck()
721       .check_count("prim::ListConstruct", 1, /*exactly*/ true)
722       ->check_count("aten::cat", 1, /*exactly*/ true)
723       ->run(*graph);
724 }
725 
TEST(ConcatOpt,CombineConcatsMutation)726 TEST(ConcatOpt, CombineConcatsMutation) {
727   auto graph = std::make_shared<Graph>();
728   const std::string input =
729       R"IR(
730         graph(%0: Tensor, %1 : Tensor):
731           %dim : int = prim::Constant[value=0]()
732           %input.1 : Tensor[] = prim::ListConstruct(%0, %0)
733           %concat.1 : Tensor = aten::cat(%input.1, %dim)
734           %input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1)
735           %input.3 : Tensor[] = aten::append(%input.2, %0)
736           %concat.2 : Tensor = aten::cat(%input.2, %dim)
737           return (%concat.2)
738       )IR";
739   parseIR(input, graph.get());
740   std::vector<at::Tensor> inputs = {at::rand({1}), at::randn({1})};
741   // No modifications due to aten::append
742   ASSERT_FALSE(CombineConcats(graph));
743 }
744 
745 } // namespace jit
746 } // namespace torch
747