1 #include <ATen/Functions.h>
2 #include <gtest/gtest.h>
3
4 #include <test/cpp/jit/test_utils.h>
5 #include <torch/csrc/jit/ir/irparser.h>
6 #include <torch/csrc/jit/passes/variadic_ops.h>
7 #include <torch/csrc/jit/runtime/interpreter.h>
8 #include <torch/csrc/jit/testing/file_check.h>
9
10 namespace torch {
11 namespace jit {
12
TEST(StackOptTest,UseVariadicStack)13 TEST(StackOptTest, UseVariadicStack) {
14 auto graph = std::make_shared<Graph>();
15
16 const std::string input =
17 R"IR(
18 graph(%0: Float(56, 56, 56),
19 %1: Float(56, 56, 56),
20 %2: Float(56, 56, 56),
21 %3: Float(56, 56, 56),
22 %4: Float(56, 56, 56),
23 %5: Float(56, 56, 56)):
24 %10 : int = prim::Constant[value=0]()
25 %input : Tensor[] = prim::ListConstruct(%0, %1, %2, %3, %4, %5)
26 %stack : Float(5, 56, 56, 56) = aten::stack(%input, %10)
27 return (%stack)
28 )IR";
29 parseIR(input, graph.get());
30 std::vector<at::Tensor> inputs = {
31 at::rand({56, 56, 56}, at::kCPU),
32 at::rand({56, 56, 56}, at::kCPU),
33 at::rand({56, 56, 56}, at::kCPU),
34 at::rand({56, 56, 56}, at::kCPU),
35 at::rand({56, 56, 56}, at::kCPU),
36 at::rand({56, 56, 56}, at::kCPU)};
37 auto orig_outputs = runGraph(graph, inputs);
38
39 ASSERT_TRUE(UseVariadicStack(graph));
40 graph->lint();
41 auto opt_outputs = runGraph(graph, inputs);
42
43 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
44
45 // After replacing `aten::stack` with `prim::VarStack` we should have the
46 // following graph:
47 //
48 // graph(%0 : ...,
49 // %1 : ...):
50 // %zero : int = prim:Constant[value=0]()
51 // %varstack : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %5, %zero)
52 // return (%varstack)
53 testing::FileCheck()
54 .check_count("= prim::VarStack(", 1, /*exactly*/ true)
55 ->check_count("= aten::stack(", 0, /*exactly*/ true)
56 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
57 ->run(*graph);
58 }
59
TEST(StackOptTest,UseVariadicStackReplaceMultiple)60 TEST(StackOptTest, UseVariadicStackReplaceMultiple) {
61 auto graph = std::make_shared<Graph>();
62
63 const std::string input =
64 R"IR(
65 graph(%0: Float(56, 56, 56),
66 %1: Float(56, 56, 56),
67 %2: Float(56, 56, 56),
68 %3: Float(56, 56, 56)):
69 %10 : int = prim::Constant[value=0]()
70 %input1 : Tensor[] = prim::ListConstruct(%0, %1)
71 %stack1 : Float(4, 56, 56, 56) = aten::stack(%input1, %10)
72 %input2 : Tensor[] = prim::ListConstruct(%2, %3)
73 %stack2 : Float(4, 56, 56, 56) = aten::stack(%input2, %10)
74 return (%stack1, %stack2)
75 )IR";
76 parseIR(input, graph.get());
77 std::vector<at::Tensor> inputs = {
78 at::rand({56, 56, 56}, at::kCPU),
79 at::rand({56, 56, 56}, at::kCPU),
80 at::rand({56, 56, 56}, at::kCPU),
81 at::rand({56, 56, 56}, at::kCPU)};
82 auto orig_outputs = runGraph(graph, inputs);
83
84 ASSERT_TRUE(UseVariadicStack(graph));
85 graph->lint();
86 auto opt_outputs = runGraph(graph, inputs);
87
88 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
89
90 // After full stack optimization we should have the following graph:
91 //
92 // graph(%0 : ...,
93 // %1 : ...,
94 // %2 : ...,
95 // %3 : ....):
96 // %zero : int = prim:Constant[value=0]()
97 // %varcat1 : Tensor = prim::VarStack(%0, %1, %zero)
98 // %varcat2 : Tensor = prim::VarStack(%2, %3, %zero)
99 // return (%varcat1, %varcat2)
100 testing::FileCheck()
101 .check_count("= prim::VarStack(", 2, /*exactly*/ true)
102 ->check_count("= aten::stack(", 0, /*exactly*/ true)
103 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
104 ->run(*graph);
105 }
106
TEST(StackOptTest,UseVariadicStackWithMultipleListUses)107 TEST(StackOptTest, UseVariadicStackWithMultipleListUses) {
108 auto graph = std::make_shared<Graph>();
109
110 const std::string input =
111 R"IR(
112 graph(%0: Float(56, 56, 56),
113 %1: Float(56, 56, 56)):
114 %2 : int = prim::Constant[value=0]()
115 %input : Tensor[] = prim::ListConstruct(%0, %1)
116 %stack : Float(2, 56, 56, 56) = aten::stack(%input, %2)
117 return (%stack, %input)
118 )IR";
119 parseIR(input, graph.get());
120 std::vector<at::Tensor> inputs = {
121 at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU)};
122 auto orig_outputs = runGraph(graph, inputs);
123
124 ASSERT_TRUE(UseVariadicStack(graph));
125 graph->lint();
126 auto opt_outputs = runGraph(graph, inputs);
127
128 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
129
130 // After replacing `aten::stack` with `prim::VarStack` we should have the
131 // following graph:
132 //
133 // graph(%0 : ...,
134 // %1 : ...):
135 // %zero : int = prim:Constant[value=0]()
136 // %input : Tensor[] = prim::ListConstruct(%0, %1)
137 // %varcat : Tensor = prim::VarStack(%0, %1, %zero)
138 // return (%varcat, %input)
139 testing::FileCheck()
140 .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
141 ->check_count("= prim::VarStack(", 1, /*exactly*/ true)
142 ->check_count("= aten::stack(", 0, /*exactly*/ true)
143 ->run(*graph);
144 }
145
TEST(StackOptTest,UseVariadicStackWithListMutationAfterCat)146 TEST(StackOptTest, UseVariadicStackWithListMutationAfterCat) {
147 auto graph = std::make_shared<Graph>();
148
149 const std::string input =
150 R"IR(
151 graph(%0: Float(56, 56, 56),
152 %1: Float(56, 56, 56),
153 %2: Float(56, 56, 56)):
154 %10 : int = prim::Constant[value=0]()
155 %input : Tensor[] = prim::ListConstruct(%0, %1)
156 %stack : Float(3, 56, 56, 56) = aten::stack(%input, %10)
157 %11 : Tensor = aten::append(%input, %2)
158 return (%stack, %input)
159 )IR";
160 parseIR(input, graph.get());
161 std::vector<at::Tensor> inputs = {
162 at::rand({56, 56, 56}, at::kCPU),
163 at::rand({56, 56, 56}, at::kCPU),
164 at::rand({56, 56, 56}, at::kCPU)};
165 auto orig_outputs = runGraph(graph, inputs);
166
167 ASSERT_TRUE(UseVariadicStack(graph));
168 graph->lint();
169 auto opt_outputs = runGraph(graph, inputs);
170 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
171
172 // The input list to `aten::stack` is mutated only after `aten::stack` op. So,
173 // it should have been replaced with `prim::VarStack`. The transformed graph
174 // should look like the following:
175 //
176 // graph(%0 : ...,
177 // %1 : ...,
178 // %2 : ...):
179 // %3 : int = prim:Constant[value=0]()
180 // %4 : Tensor[] = prim::ListConstruct(%0, %1)
181 // %7 : Tensor = prim::VarStack(%0, %1, %3)
182 // %6 : Tensor = aten::append(%4, %2)
183 // return (%7, %4)
184 testing::FileCheck()
185 .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
186 ->check_count("= prim::VarStack(", 1, /*exactly*/ true)
187 ->check_count("= aten::stack(", 0, /*exactly*/ true)
188 ->run(*graph);
189 }
190
TEST(StackOptTest,UseVariadicStackWithListMutationBeforeCat)191 TEST(StackOptTest, UseVariadicStackWithListMutationBeforeCat) {
192 auto graph = std::make_shared<Graph>();
193
194 const std::string input =
195 R"IR(
196 graph(%0: Float(56, 56, 56),
197 %1: Float(56, 56, 56),
198 %2: Float(56, 56, 56)):
199 %10 : int = prim::Constant[value=0]()
200 %input : Tensor[] = prim::ListConstruct(%0, %1)
201 %11 : Tensor = aten::append(%input, %2)
202 %stack : Float(3, 56, 56, 56) = aten::stack(%input, %10)
203 return (%stack)
204 )IR";
205 parseIR(input, graph.get());
206 std::vector<at::Tensor> inputs = {
207 at::rand({56, 56, 56}, at::kCPU),
208 at::rand({56, 56, 56}, at::kCPU),
209 at::rand({56, 56, 56}, at::kCPU)};
210 auto orig_outputs = runGraph(graph, inputs);
211
212 {
213 ASSERT_FALSE(UseVariadicStack(graph));
214 graph->lint();
215 auto opt_outputs = runGraph(graph, inputs);
216 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
217
218 // No transformation should have happened since the `prim::ListConstruct` is
219 // mutated before `aten::stack`.
220 testing::FileCheck()
221 .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
222 ->check_count("= aten::stack(", 1, /*exactly*/ true)
223 ->check_count("= prim::VarStack(", 0, /*exactly*/ true)
224 ->run(*graph);
225 }
226
227 {
228 ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph));
229 graph->lint();
230 auto opt_outputs = runGraph(graph, inputs);
231 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
232
233 // The mutation of the list must be removed and the `aten::stack` op must
234 // be replaced with the `prim::VarStack` op in the graph. The transformed
235 // graph should look like the following:
236 //
237 // graph(%0 : ...,
238 // %1 : ...,
239 // %2 : ...):
240 // %3 : int = prim:Constant[value=0]()
241 // %7 : Tensor = prim::VarStack(%0, %1, %2, %3)
242 // return (%7)
243 testing::FileCheck()
244 .check_count("= prim::VarStack(", 1, /*exactly*/ true)
245 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
246 ->check_count("= aten::stack(", 0, /*exactly*/ true)
247 ->run(*graph);
248 }
249 }
250
TEST(StackOptTest,UseVariadicStackWithMultipleListMutations)251 TEST(StackOptTest, UseVariadicStackWithMultipleListMutations) {
252 auto graph = std::make_shared<Graph>();
253
254 const std::string input =
255 R"IR(
256 graph(%0: Float(56, 56, 56),
257 %1: Float(56, 56, 56),
258 %2: Float(56, 56, 56),
259 %3: Float(56, 56, 56),
260 %4: Float(56, 56, 56)):
261 %10 : int = prim::Constant[value=0]()
262 %input : Tensor[] = prim::ListConstruct(%0, %1)
263 %stack.1 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
264 %11 : Tensor = aten::append(%input, %2)
265 %stack.2 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
266 %12 : Tensor = aten::append(%input, %3)
267 %stack.3 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
268 %13 : Tensor = aten::append(%input, %4)
269 %stack.4 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
270 return (%stack.1, %stack.2, %stack.3, %stack.4)
271 )IR";
272 parseIR(input, graph.get());
273 std::vector<at::Tensor> inputs = {
274 at::rand({56, 56, 56}, at::kCPU),
275 at::rand({56, 56, 56}, at::kCPU),
276 at::rand({56, 56, 56}, at::kCPU),
277 at::rand({56, 56, 56}, at::kCPU),
278 at::rand({56, 56, 56}, at::kCPU)};
279 auto orig_outputs = runGraph(graph, inputs);
280
281 ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph));
282 graph->lint();
283 auto opt_outputs = runGraph(graph, inputs);
284 ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
285
286 // All the mutations of the list must be removed and the `aten::stack` ops
287 // must be replaced with `prim::VarStack` ops in the graph. The transformed
288 // graph should look like the following:
289 //
290 // graph(%0 : ...,
291 // %1 : ...,
292 // %2 : ...,
293 // %3 : ...,
294 // %4 : ...):
295 // %10 : int = prim:Constant[value=0]()
296 // %5 : Tensor = prim::VarStack(%0, %1, %10)
297 // %6 : Tensor = prim::VarStack(%0, %1, %2, %10)
298 // %7 : Tensor = prim::VarStack(%0, %1, %2, %3, %10)
299 // %8 : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %10)
300 // return (%5, %6, %7, %8)
301 testing::FileCheck()
302 .check_count("= prim::VarStack(", 4, /*exactly*/ true)
303 ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
304 ->check_count("= aten::stack(", 0, /*exactly*/ true)
305 ->run(*graph);
306 }
307
308 } // namespace jit
309 } // namespace torch
310