1 #include <gtest/gtest.h>
2
3 #include <ATen/Functions.h>
4 #include <test/cpp/jit/test_utils.h>
5 #include <torch/csrc/jit/ir/irparser.h>
6 #include <torch/csrc/jit/passes/concat_opt.h>
7 #include <torch/csrc/jit/passes/variadic_ops.h>
8 #include <torch/csrc/jit/runtime/interpreter.h>
9 #include <torch/csrc/jit/testing/file_check.h>
10
11 namespace torch {
12 namespace jit {
13
TEST(ConcatOptTest,SimpleCommonInputsEliminationPrefix)14 TEST(ConcatOptTest, SimpleCommonInputsEliminationPrefix) {
15 auto graph = std::make_shared<Graph>();
16
17 const std::string input =
18 R"IR(
19 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
20 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
21 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
22 %5 : int = prim::Constant[value=0]()
23 %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5)
24 %concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5)
25 %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3)
26 return (%res)
27 )IR";
28 parseIR(input, graph.get());
29 std::vector<at::Tensor> inputs = {
30 at::rand({64, 56, 56}, at::kCPU),
31 at::rand({32, 56, 56}, at::kCPU),
32 at::rand({32, 56, 56}, at::kCPU)};
33 auto orig_outputs = runGraph(graph, inputs);
34
35 ASSERT_TRUE(EliminateConcatCommonInputs(graph));
36 graph->lint();
37 auto opt_outputs = runGraph(graph, inputs);
38 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
39
40 // Graph after EliminateConcatCommonInputs:
41 // graph(%0 : ...,
42 // %1 : ...,
43 // %2 : ...):
44 // %3 : int = prim::Constant[value=0]()
45 // %4 : Tensor = prim::VarConcat(%0, %1, %3)
46 // %7 : Tensor = prim::VarConcat(%4, %2, %3) // UPDATED
47 // %8 : Tensor[] = prim::ListConstruct(%4, %7)
48 // return (%8)
49
50 testing::FileCheck()
51 .check_count("= prim::VarConcat(%0, %1, %3)", 1, /*exactly*/ true)
52 ->check_count("= prim::VarConcat(%4, %2, %3)", 1, /*exactly*/ true)
53 ->check_count("= prim::ListConstruct(%4, %7)", 1, /*exactly*/ true)
54 ->check_count("= aten::cat(", 0, /*exactly*/ true)
55 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
56 ->run(*graph);
57 }
58
TEST(ConcatOptTest,SimpleCommonInputsEliminationSuffix)59 TEST(ConcatOptTest, SimpleCommonInputsEliminationSuffix) {
60 auto graph = std::make_shared<Graph>();
61
62 const std::string input =
63 R"IR(
64 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
65 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
66 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
67 %5 : int = prim::Constant[value=0]()
68 %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%1, %2, %5)
69 %concat.3 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5)
70 %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3)
71 return (%res)
72 )IR";
73 parseIR(input, graph.get());
74 std::vector<at::Tensor> inputs = {
75 at::rand({64, 56, 56}, at::kCPU),
76 at::rand({32, 56, 56}, at::kCPU),
77 at::rand({32, 56, 56}, at::kCPU)};
78 auto orig_outputs = runGraph(graph, inputs);
79
80 ASSERT_TRUE(EliminateConcatCommonInputs(graph));
81 graph->lint();
82 auto opt_outputs = runGraph(graph, inputs);
83 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
84
85 // Graph after EliminateConcatCommonInputs:
86 // graph(%0 : ...,
87 // %1 : ...,
88 // %2 : ...):
89 // %3 : int = prim::Constant[value=0]()
90 // %4 : Tensor = prim::VarConcat(%1, %2, %3)
91 // %7 : Tensor = prim::VarConcat(%0, %4, %3) // UPDATED
92 // %8 : Tensor[] = prim::ListConstruct(%4, %7)
93 // return (%8)
94
95 testing::FileCheck()
96 .check_count("= prim::VarConcat(%1, %2, %3)", 1, /*exactly*/ true)
97 ->check_count("= prim::VarConcat(%0, %4, %3)", 1, /*exactly*/ true)
98 ->check_count("= prim::ListConstruct(%4, %7)", 1, /*exactly*/ true)
99 ->check_count("= aten::cat(", 0, /*exactly*/ true)
100 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
101 ->run(*graph);
102 }
103
TEST(ConcatOptTest,CommonInputsEliminationWithDifferentOrderInputs)104 TEST(ConcatOptTest, CommonInputsEliminationWithDifferentOrderInputs) {
105 auto graph = std::make_shared<Graph>();
106
107 const std::string input =
108 R"IR(
109 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
110 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
111 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
112 %5 : int = prim::Constant[value=0]()
113
114 #CHECK: prim::VarConcat
115 %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5)
116
117 #CHECK: prim::VarConcat
118 %concat.2 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%1, %0, %2, %5)
119
120 #CHECK: prim::ListConstruct
121 %res : Tensor[] = prim::ListConstruct(%concat.1, %concat.2)
122 return (%res)
123 )IR";
124 parseIR(input, graph.get());
125 std::vector<at::Tensor> inputs = {
126 at::rand({64, 56, 56}, at::kCPU),
127 at::rand({32, 56, 56}, at::kCPU),
128 at::rand({32, 56, 56}, at::kCPU)};
129 auto orig_outputs = runGraph(graph, inputs);
130
131 ASSERT_FALSE(EliminateConcatCommonInputs(graph));
132 graph->lint();
133 auto opt_outputs = runGraph(graph, inputs);
134
135 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
136
137 // No optimizations should have happened in this case since the inputs
138 // to the `cat` are in different order.
139 testing::FileCheck().run(input, *graph);
140 }
141
TEST(ConcatOptTest,MoreCommonInputsElimination)142 TEST(ConcatOptTest, MoreCommonInputsElimination) {
143 auto graph = std::make_shared<Graph>();
144
145 const std::string input =
146 R"IR(
147 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
148 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
149 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
150 %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
151 %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
152 %5 : int = prim::Constant[value=0]()
153 %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %5)
154 %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %5)
155 %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %3, %5)
156 %concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = prim::VarConcat(%0, %1, %2, %3, %4, %5)
157 %res : Tensor[] = prim::ListConstruct(%concat.1, %concat.2, %concat.3, %concat.4)
158 return (%res)
159 )IR";
160 parseIR(input, graph.get());
161 std::vector<at::Tensor> inputs = {
162 at::rand({64, 56, 56}, at::kCPU),
163 at::rand({32, 56, 56}, at::kCPU),
164 at::rand({32, 56, 56}, at::kCPU),
165 at::rand({32, 56, 56}, at::kCPU),
166 at::rand({32, 56, 56}, at::kCPU)};
167 auto orig_outputs = runGraph(graph, inputs);
168
169 ASSERT_TRUE(EliminateConcatCommonInputs(graph));
170 graph->lint();
171 auto opt_outputs = runGraph(graph, inputs);
172 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
173
174 testing::FileCheck()
175 .check_count("= prim::VarConcat(%0, %1, %5)", 1, /*exactly*/ true)
176 ->check_count("= prim::VarConcat(%6, %2, %5)", 1, /*exactly*/ true)
177 ->check_count("= prim::VarConcat(%11, %3, %5)", 1, /*exactly*/ true)
178 ->check_count("= prim::VarConcat(%12, %4, %5)", 1, /*exactly*/ true)
179 ->check_count("= aten::cat(", 0, /*exactly*/ true)
180 ->run(*graph);
181 }
182
TEST(ConcatOptTest,ExpandConcat)183 TEST(ConcatOptTest, ExpandConcat) {
184 auto graph = std::make_shared<Graph>();
185
186 const std::string input =
187 R"IR(
188 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
189 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
190 %2 : int = prim::Constant[value=0]()
191 %3 : float = prim::Constant[value=0.5]()
192 %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3)
193 %5 : Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%1, %3)
194 %input : Tensor[] = prim::ListConstruct(%4, %5)
195 %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %2)
196 return (%concat)
197 )IR";
198 parseIR(input, graph.get());
199 std::vector<at::Tensor> inputs = {
200 at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)};
201 auto orig_outputs = runGraph(graph, inputs);
202
203 ExpandConcatAndEliminateRedundancy(graph);
204 graph->lint();
205 auto opt_outputs = runGraph(graph, inputs);
206
207 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
208
209 // After full concat optimization we should have the following graph:
210 //
211 // graph(%0 : ...,
212 // %1 : ...):
213 // ...
214 // %4 : Tensor = aten::clamp_max(...)
215 // %5 : Tensor = aten::clamp_max(...)
216 // %13 : int[] = prim::ListConstruct(...)
217 // %14 : Tensor = aten::empty(%13, ...) // concat buffer
218 // %17 : Tensor = aten::slice(%14, ...) // slice for %4
219 // %18 : Tensor = aten::copy_(%17, %4)
220 // %20 : Tensor = aten::slice(%14, ...) // slice for %5
221 // %21 : Tensor = aten::copy_(%20, %5)
222 // return (%14)
223 testing::FileCheck()
224 .check_count("= aten::cat(", 0, /*exactly*/ true)
225 ->check_count("= aten::clamp_max(", 2, /*exactly*/ true)
226 ->check_count("= aten::empty(", 1, /*exactly*/ true)
227 ->check_count("= aten::slice(", 1, /*exactly*/ true)
228 ->check_count("= aten::copy_(", 1, /*exactly*/ true)
229 ->check_count("= aten::slice(", 1, /*exactly*/ true)
230 ->check_count("= aten::copy_(", 1, /*exactly*/ true)
231 ->check_count("= aten::cat(", 0, /*exactly*/ true)
232 ->run(*graph);
233 }
234
TEST(ConcatOptTest,ConcatWithoutResultShape)235 TEST(ConcatOptTest, ConcatWithoutResultShape) {
236 auto graph = std::make_shared<Graph>();
237
238 const std::string input =
239 R"IR(
240 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
241 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
242 %2 : int = prim::Constant[value=0]()
243 %3 : float = prim::Constant[value=0.5]()
244 # CHECK: clamp_max
245 %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3)
246 # CHECK: clamp_max
247 %5 : Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%1, %3)
248 # CHECK: prim::ListConstruct
249 %6 : Tensor[] = prim::ListConstruct(%4, %5)
250 # CHECK: aten::cat
251 %7 : Tensor = aten::cat(%6, %2)
252 return (%7)
253 )IR";
254 parseIR(input, graph.get());
255 std::vector<at::Tensor> inputs = {
256 at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)};
257 auto orig_outputs = runGraph(graph, inputs);
258
259 ExpandConcatAndEliminateRedundancy(graph);
260 graph->lint();
261 auto opt_outputs = runGraph(graph, inputs);
262
263 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
264
265 // No optimizations should have happened in this case since the output
266 // shape of `aten::cat` is not known.
267 testing::FileCheck().run(input, *graph);
268 }
269
TEST(ConcatOptTest,ConcatWithoutInputShape)270 TEST(ConcatOptTest, ConcatWithoutInputShape) {
271 auto graph = std::make_shared<Graph>();
272
273 const std::string input =
274 R"IR(
275 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
276 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
277 %2 : int = prim::Constant[value=0]()
278 %3 : float = prim::Constant[value=0.5]()
279 # CHECK: clamp_max
280 %4 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::clamp_max(%0, %3)
281 # CHECK: clamp_max
282 %5 : Tensor = aten::clamp_max(%1, %3)
283 # CHECK: prim::ListConstruct
284 %6 : Tensor[] = prim::ListConstruct(%4, %5)
285 # CHECK: aten::cat
286 %7 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%6, %2)
287 return (%7)
288 )IR";
289 parseIR(input, graph.get());
290 std::vector<at::Tensor> inputs = {
291 at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)};
292 auto orig_outputs = runGraph(graph, inputs);
293
294 ExpandConcatAndEliminateRedundancy(graph);
295 graph->lint();
296 auto opt_outputs = runGraph(graph, inputs);
297
298 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
299
300 // No optimizations should have happened in this case since the shape of %5,
301 // which is an input to `aten::cat`, is not known.
302 testing::FileCheck().run(input, *graph);
303 }
304
TEST(ConcatOptTest,UseVariadicCat)305 TEST(ConcatOptTest, UseVariadicCat) {
306 auto graph = std::make_shared<Graph>();
307
308 const std::string input =
309 R"IR(
310 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
311 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
312 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
313 %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
314 %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
315 %5: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
316 %10 : int = prim::Constant[value=0]()
317 %input : Tensor[] = prim::ListConstruct(%0, %1, %2, %3, %4, %5)
318 %concat : Float(224, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
319 return (%concat)
320 )IR";
321 parseIR(input, graph.get());
322 std::vector<at::Tensor> inputs = {
323 at::rand({64, 56, 56}, at::kCPU),
324 at::rand({32, 56, 56}, at::kCPU),
325 at::rand({32, 56, 56}, at::kCPU),
326 at::rand({32, 56, 56}, at::kCPU),
327 at::rand({32, 56, 56}, at::kCPU),
328 at::rand({32, 56, 56}, at::kCPU)};
329 auto orig_outputs = runGraph(graph, inputs);
330
331 ASSERT_TRUE(UseVariadicCat(graph));
332 graph->lint();
333 auto opt_outputs = runGraph(graph, inputs);
334
335 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
336
337 // After replacing `aten::cat` with `prim::VarConcat` we should have the
338 // following graph:
339 //
340 // graph(%0 : ...,
341 // %1 : ...):
342 // %zero : int = prim:Constant[value=0]()
343 // %varcat : Tensor = prim::VarConcat(%0, %1, %2, %3, %4, %5, %zero)
344 // return (%varcat)
345 testing::FileCheck()
346 .check_count("= prim::VarConcat(", 1, /*exactly*/ true)
347 ->check_count("= aten::cat(", 0, /*exactly*/ true)
348 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
349 ->run(*graph);
350 }
351
TEST(OptimizeConcatTest,UseVariadicCatReplaceMultiple)352 TEST(OptimizeConcatTest, UseVariadicCatReplaceMultiple) {
353 auto graph = std::make_shared<Graph>();
354
355 const std::string input =
356 R"IR(
357 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
358 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
359 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
360 %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
361 %10 : int = prim::Constant[value=0]()
362 %input1 : Tensor[] = prim::ListConstruct(%0, %1)
363 %concat1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input1, %10)
364 %input2 : Tensor[] = prim::ListConstruct(%2, %3)
365 %concat2 : Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input2, %10)
366 return (%concat1, %concat2)
367 )IR";
368 parseIR(input, graph.get());
369 std::vector<at::Tensor> inputs = {
370 at::rand({64, 56, 56}, at::kCPU),
371 at::rand({32, 56, 56}, at::kCPU),
372 at::rand({32, 56, 56}, at::kCPU),
373 at::rand({32, 56, 56}, at::kCPU)};
374 auto orig_outputs = runGraph(graph, inputs);
375
376 ASSERT_TRUE(UseVariadicCat(graph));
377 graph->lint();
378 auto opt_outputs = runGraph(graph, inputs);
379
380 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
381
382 // After full concat optimization we should have the following graph:
383 //
384 // graph(%0 : ...,
385 // %1 : ...,
386 // %2 : ...,
387 // %3 : ....):
388 // %zero : int = prim:Constant[value=0]()
389 // %varcat1 : Tensor = prim::VarConcat(%0, %1, %zero)
390 // %varcat2 : Tensor = prim::VarConcat(%2, %3, %zero)
391 // return (%varcat1, %varcat2)
392 testing::FileCheck()
393 .check_count("= prim::VarConcat(", 2, /*exactly*/ true)
394 ->check_count("= aten::cat(", 0, /*exactly*/ true)
395 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
396 ->run(*graph);
397 }
398
TEST(ConcatOptTest,UseVariadicCatWithMultipleListUses)399 TEST(ConcatOptTest, UseVariadicCatWithMultipleListUses) {
400 auto graph = std::make_shared<Graph>();
401
402 const std::string input =
403 R"IR(
404 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
405 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
406 %2 : int = prim::Constant[value=0]()
407 %input : Tensor[] = prim::ListConstruct(%0, %1)
408 %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %2)
409 return (%concat, %input)
410 )IR";
411 parseIR(input, graph.get());
412 std::vector<at::Tensor> inputs = {
413 at::rand({64, 56, 56}, at::kCPU), at::rand({32, 56, 56}, at::kCPU)};
414 auto orig_outputs = runGraph(graph, inputs);
415
416 ASSERT_TRUE(UseVariadicCat(graph));
417 graph->lint();
418 auto opt_outputs = runGraph(graph, inputs);
419
420 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
421
422 // After replacing `aten::cat` with `prim::VarConcat` we should have the
423 // following graph:
424 //
425 // graph(%0 : ...,
426 // %1 : ...):
427 // %zero : int = prim:Constant[value=0]()
428 // %input : Tensor[] = prim::ListConstruct(%0, %1)
429 // %varcat : Tensor = prim::VarConcat(%0, %1, %zero)
430 // return (%varcat, %input)
431 testing::FileCheck()
432 .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
433 ->check_count("= prim::VarConcat(", 1, /*exactly*/ true)
434 ->check_count("= aten::cat(", 0, /*exactly*/ true)
435 ->run(*graph);
436 }
437
TEST(ConcatOptTest,UseVariadicCatWithListMutationAfterCat)438 TEST(ConcatOptTest, UseVariadicCatWithListMutationAfterCat) {
439 auto graph = std::make_shared<Graph>();
440
441 const std::string input =
442 R"IR(
443 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
444 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
445 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
446 %10 : int = prim::Constant[value=0]()
447 %input : Tensor[] = prim::ListConstruct(%0, %1)
448 %concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
449 %11 : Tensor = aten::append(%input, %2)
450 return (%concat, %input)
451 )IR";
452 parseIR(input, graph.get());
453 std::vector<at::Tensor> inputs = {
454 at::rand({64, 56, 56}, at::kCPU),
455 at::rand({32, 56, 56}, at::kCPU),
456 at::rand({32, 56, 56}, at::kCPU)};
457 auto orig_outputs = runGraph(graph, inputs);
458
459 ASSERT_TRUE(UseVariadicCat(graph));
460 graph->lint();
461 auto opt_outputs = runGraph(graph, inputs);
462 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
463
464 // The input list to `aten::cat` is mutated only after `aten::cat` op. So,
465 // it should have been replaced with `prim::VarConcat`. The transformed graph
466 // should look like the following:
467 //
468 // graph(%0 : ...,
469 // %1 : ...,
470 // %2 : ...):
471 // %3 : int = prim:Constant[value=0]()
472 // %4 : Tensor[] = prim::ListConstruct(%0, %1)
473 // %7 : Tensor = prim::VarConcat(%0, %1, %3)
474 // %6 : Tensor = aten::append(%4, %2)
475 // return (%7, %4)
476 testing::FileCheck()
477 .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
478 ->check_count("= prim::VarConcat(", 1, /*exactly*/ true)
479 ->check_count("= aten::cat(", 0, /*exactly*/ true)
480 ->run(*graph);
481 }
482
TEST(ConcatOptTest,UseVariadicCatWithListMutationBeforeCat)483 TEST(ConcatOptTest, UseVariadicCatWithListMutationBeforeCat) {
484 auto graph = std::make_shared<Graph>();
485
486 const std::string input =
487 R"IR(
488 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
489 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
490 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
491 %10 : int = prim::Constant[value=0]()
492 %input : Tensor[] = prim::ListConstruct(%0, %1)
493 %11 : Tensor = aten::append(%input, %2)
494 %concat : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
495 return (%concat)
496 )IR";
497 parseIR(input, graph.get());
498 std::vector<at::Tensor> inputs = {
499 at::rand({64, 56, 56}, at::kCPU),
500 at::rand({32, 56, 56}, at::kCPU),
501 at::rand({32, 56, 56}, at::kCPU)};
502 auto orig_outputs = runGraph(graph, inputs);
503
504 {
505 ASSERT_FALSE(UseVariadicCat(graph));
506 graph->lint();
507 auto opt_outputs = runGraph(graph, inputs);
508 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
509
510 // No transformation should have happened since the `prim::ListConstruct` is
511 // mutated before `aten::cat`.
512 testing::FileCheck()
513 .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
514 ->check_count("= aten::cat(", 1, /*exactly*/ true)
515 ->check_count("= prim::VarConcat(", 0, /*exactly*/ true)
516 ->run(*graph);
517 }
518
519 {
520 ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
521 graph->lint();
522 auto opt_outputs = runGraph(graph, inputs);
523 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
524
525 // The mutation of the list must be removed and the `aten::cat` op must
526 // be replaced with the `prim::VarConcat` op in the graph. The transformed
527 // graph should look like the following:
528 //
529 // graph(%0 : ...,
530 // %1 : ...,
531 // %2 : ...):
532 // %3 : int = prim:Constant[value=0]()
533 // %7 : Tensor = prim::VarConcat(%0, %1, %2, %3)
534 // return (%7)
535 testing::FileCheck()
536 .check_count("= prim::VarConcat(", 1, /*exactly*/ true)
537 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
538 ->check_count("= aten::cat(", 0, /*exactly*/ true)
539 ->run(*graph);
540 }
541 }
542
TEST(ConcatOptTest,UseVariadicCatWithMultipleListMutations)543 TEST(ConcatOptTest, UseVariadicCatWithMultipleListMutations) {
544 auto graph = std::make_shared<Graph>();
545
546 const std::string input =
547 R"IR(
548 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
549 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
550 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
551 %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
552 %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
553 %10 : int = prim::Constant[value=0]()
554 %input : Tensor[] = prim::ListConstruct(%0, %1)
555 %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
556 %11 : Tensor = aten::append(%input, %2)
557 %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
558 %12 : Tensor = aten::append(%input, %3)
559 %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
560 %13 : Tensor = aten::append(%input, %4)
561 %concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
562 return (%concat.1, %concat.2, %concat.3, %concat.4)
563 )IR";
564 parseIR(input, graph.get());
565 std::vector<at::Tensor> inputs = {
566 at::rand({64, 56, 56}, at::kCPU),
567 at::rand({32, 56, 56}, at::kCPU),
568 at::rand({32, 56, 56}, at::kCPU),
569 at::rand({32, 56, 56}, at::kCPU),
570 at::rand({32, 56, 56}, at::kCPU)};
571 auto orig_outputs = runGraph(graph, inputs);
572
573 ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
574 graph->lint();
575 auto opt_outputs = runGraph(graph, inputs);
576 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
577
578 // All the mutations of the list must be removed and the `aten::cat` ops must
579 // be replaced with `prim::VarConcat` ops in the graph. The transformed graph
580 // should look like the following:
581 //
582 // graph(%0 : ...,
583 // %1 : ...,
584 // %2 : ...,
585 // %3 : ...,
586 // %4 : ...):
587 // %10 : int = prim:Constant[value=0]()
588 // %5 : Tensor = prim::VarConcat(%0, %1, %10)
589 // %6 : Tensor = prim::VarConcat(%0, %1, %2, %10)
590 // %7 : Tensor = prim::VarConcat(%0, %1, %2, %3, %10)
591 // %8 : Tensor = prim::VarConcat(%0, %1, %2, %3, %4, %10)
592 // return (%5, %6, %7, %8)
593 testing::FileCheck()
594 .check_count("= prim::VarConcat(", 4, /*exactly*/ true)
595 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
596 ->check_count("= aten::cat(", 0, /*exactly*/ true)
597 ->run(*graph);
598 }
599
TEST(ConcatOptTest,RemoveListMutationUseVariadicCatAndCommonInputsElimination)600 TEST(
601 ConcatOptTest,
602 RemoveListMutationUseVariadicCatAndCommonInputsElimination) {
603 auto graph = std::make_shared<Graph>();
604
605 const std::string input =
606 R"IR(
607 graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
608 %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
609 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
610 %5 : int = prim::Constant[value=0]()
611
612 %features.2 : Tensor[] = prim::ListConstruct(%0, %1)
613 %6 : Tensor [] = aten::append(%features.2, %2)
614 %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%features.2, %5)
615
616 %7 : Tensor [] = aten::append(%features.2, %0)
617 %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%features.2, %5)
618
619 %res : Tensor[] = prim::ListConstruct(%concat.2, %concat.3)
620 return (%res)
621 )IR";
622 parseIR(input, graph.get());
623 std::vector<at::Tensor> inputs = {
624 at::rand({64, 56, 56}, at::kCPU),
625 at::rand({32, 56, 56}, at::kCPU),
626 at::rand({32, 56, 56}, at::kCPU)};
627 auto orig_outputs = runGraph(graph, inputs);
628
629 ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
630 ASSERT_TRUE(EliminateConcatCommonInputs(graph));
631 graph->lint();
632 auto opt_outputs = runGraph(graph, inputs);
633 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
634
635 // After performing:
636 // * Remove list mutation
637 // * Use variadic cat
638 // * Eliminate common inputs
639 // we should have the following graph:
640 //
641 // graph(%0 : ...,
642 // %1 : ...,
643 // %2 : ...):
644 // %3 : int = prim::Constant[value=0]()
645 // %10 : Tensor = prim::VarConcat(%0, %1, %2, %3)
646 // %12 : Tensor = prim::VarConcat(%10, %0, %3) // UPDATED
647 // %8 : Tensor[] = prim::ListConstruct(%10, %12)
648 // return (%8)
649 testing::FileCheck()
650 .check_count("= prim::VarConcat(%0, %1, %2, %3)", 1, /*exactly*/ true)
651 ->check_count("= prim::VarConcat(%10, %0, %3)", 1, /*exactly*/ true)
652 ->check_count("= prim::ListConstruct(%10, %12)", 1, /*exactly*/ true)
653 ->check_count("= aten::cat(", 0, /*exactly*/ true)
654 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
655 ->run(*graph);
656 }
657
TEST(ConcatOpt,CombineConcatsSimpleCase)658 TEST(ConcatOpt, CombineConcatsSimpleCase) {
659 auto graph = std::make_shared<Graph>();
660 const std::string input =
661 R"IR(
662 graph(%0: Tensor):
663 %dim : int = prim::Constant[value=0]()
664 %input.1 : Tensor[] = prim::ListConstruct(%0, %0)
665 %concat.1 : Tensor = aten::cat(%input.1, %dim)
666 %input.2 : Tensor[] = prim::ListConstruct(%concat.1, %0)
667 %concat.2 : Tensor = aten::cat(%input.2, %dim)
668 return (%concat.2)
669 )IR";
670 parseIR(input, graph.get());
671 std::vector<at::Tensor> inputs = {at::rand({1})};
672 auto orig_outputs = runGraph(graph, inputs);
673
674 ASSERT_TRUE(CombineConcats(graph));
675 graph->lint();
676 auto opt_outputs = runGraph(graph, inputs);
677 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
678
679 // After performing CombineConcats:
680 // graph(%0 : Tensor):
681 // %dim : int = prim::Constant[value=0]()
682 // %input : Tensor[] = prim::ListConstruct(%0, %0, %0)
683 // %concat : Tensor = aten::cat(%input, %dim)
684 // return (%concat)
685 testing::FileCheck()
686 .check_count("prim::ListConstruct", 1, /*exactly*/ true)
687 ->check_count("aten::cat", 1, /*exactly*/ true)
688 ->run(*graph);
689 }
690
TEST(ConcatOpt,CombineConcatsLongChain)691 TEST(ConcatOpt, CombineConcatsLongChain) {
692 auto graph = std::make_shared<Graph>();
693 const std::string input =
694 R"IR(
695 graph(%0: Tensor, %1 : Tensor):
696 %dim : int = prim::Constant[value=0]()
697 %input.1 : Tensor[] = prim::ListConstruct(%0, %0)
698 %concat.1 : Tensor = aten::cat(%input.1, %dim)
699 %input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1)
700 %concat.2 : Tensor = aten::cat(%input.2, %dim)
701 %input.3 : Tensor[] = prim::ListConstruct(%0, %concat.2, %0)
702 %concat.3 : Tensor = aten::cat(%input.3, %dim)
703 return (%concat.3)
704 )IR";
705 parseIR(input, graph.get());
706 std::vector<at::Tensor> inputs = {at::rand({1}), at::randn({1})};
707 auto orig_outputs = runGraph(graph, inputs);
708
709 ASSERT_TRUE(CombineConcats(graph));
710 graph->lint();
711 auto opt_outputs = runGraph(graph, inputs);
712 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
713
714 // After performing CombineConcats:
715 // graph(%0 : Tensor):
716 // %dim : int = prim::Constant[value=0]()
717 // %input : Tensor[] = prim::ListConstruct(%0, %1, %0, %0, %1, %0)
718 // %concat : Tensor = aten::cat(%input, %dim)
719 // return (%concat)
720 testing::FileCheck()
721 .check_count("prim::ListConstruct", 1, /*exactly*/ true)
722 ->check_count("aten::cat", 1, /*exactly*/ true)
723 ->run(*graph);
724 }
725
TEST(ConcatOpt,CombineConcatsMutation)726 TEST(ConcatOpt, CombineConcatsMutation) {
727 auto graph = std::make_shared<Graph>();
728 const std::string input =
729 R"IR(
730 graph(%0: Tensor, %1 : Tensor):
731 %dim : int = prim::Constant[value=0]()
732 %input.1 : Tensor[] = prim::ListConstruct(%0, %0)
733 %concat.1 : Tensor = aten::cat(%input.1, %dim)
734 %input.2 : Tensor[] = prim::ListConstruct(%1, %concat.1, %1)
735 %input.3 : Tensor[] = aten::append(%input.2, %0)
736 %concat.2 : Tensor = aten::cat(%input.2, %dim)
737 return (%concat.2)
738 )IR";
739 parseIR(input, graph.get());
740 std::vector<at::Tensor> inputs = {at::rand({1}), at::randn({1})};
741 // No modifications due to aten::append
742 ASSERT_FALSE(CombineConcats(graph));
743 }
744
745 } // namespace jit
746 } // namespace torch
747