xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_stack_opt.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Functions.h>
2 #include <gtest/gtest.h>
3 
4 #include <test/cpp/jit/test_utils.h>
5 #include <torch/csrc/jit/ir/irparser.h>
6 #include <torch/csrc/jit/passes/variadic_ops.h>
7 #include <torch/csrc/jit/runtime/interpreter.h>
8 #include <torch/csrc/jit/testing/file_check.h>
9 
10 namespace torch {
11 namespace jit {
12 
TEST(StackOptTest,UseVariadicStack)13 TEST(StackOptTest, UseVariadicStack) {
14   auto graph = std::make_shared<Graph>();
15 
16   const std::string input =
17       R"IR(
18         graph(%0: Float(56, 56, 56),
19               %1: Float(56, 56, 56),
20               %2: Float(56, 56, 56),
21               %3: Float(56, 56, 56),
22               %4: Float(56, 56, 56),
23               %5: Float(56, 56, 56)):
24           %10 : int = prim::Constant[value=0]()
25           %input : Tensor[] = prim::ListConstruct(%0, %1, %2, %3, %4, %5)
26           %stack : Float(5, 56, 56, 56) = aten::stack(%input, %10)
27           return (%stack)
28       )IR";
29   parseIR(input, graph.get());
30   std::vector<at::Tensor> inputs = {
31       at::rand({56, 56, 56}, at::kCPU),
32       at::rand({56, 56, 56}, at::kCPU),
33       at::rand({56, 56, 56}, at::kCPU),
34       at::rand({56, 56, 56}, at::kCPU),
35       at::rand({56, 56, 56}, at::kCPU),
36       at::rand({56, 56, 56}, at::kCPU)};
37   auto orig_outputs = runGraph(graph, inputs);
38 
39   ASSERT_TRUE(UseVariadicStack(graph));
40   graph->lint();
41   auto opt_outputs = runGraph(graph, inputs);
42 
43   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
44 
45   // After replacing `aten::stack` with `prim::VarStack` we should have the
46   // following graph:
47   //
48   //  graph(%0 : ...,
49   //        %1 : ...):
50   //    %zero : int = prim:Constant[value=0]()
51   //    %varstack : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %5, %zero)
52   //    return (%varstack)
53   testing::FileCheck()
54       .check_count("= prim::VarStack(", 1, /*exactly*/ true)
55       ->check_count("= aten::stack(", 0, /*exactly*/ true)
56       ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
57       ->run(*graph);
58 }
59 
TEST(StackOptTest,UseVariadicStackReplaceMultiple)60 TEST(StackOptTest, UseVariadicStackReplaceMultiple) {
61   auto graph = std::make_shared<Graph>();
62 
63   const std::string input =
64       R"IR(
65         graph(%0: Float(56, 56, 56),
66               %1: Float(56, 56, 56),
67               %2: Float(56, 56, 56),
68               %3: Float(56, 56, 56)):
69           %10 : int = prim::Constant[value=0]()
70           %input1 : Tensor[] = prim::ListConstruct(%0, %1)
71           %stack1 : Float(4, 56, 56, 56) = aten::stack(%input1, %10)
72           %input2 : Tensor[] = prim::ListConstruct(%2, %3)
73           %stack2 : Float(4, 56, 56, 56) = aten::stack(%input2, %10)
74           return (%stack1, %stack2)
75       )IR";
76   parseIR(input, graph.get());
77   std::vector<at::Tensor> inputs = {
78       at::rand({56, 56, 56}, at::kCPU),
79       at::rand({56, 56, 56}, at::kCPU),
80       at::rand({56, 56, 56}, at::kCPU),
81       at::rand({56, 56, 56}, at::kCPU)};
82   auto orig_outputs = runGraph(graph, inputs);
83 
84   ASSERT_TRUE(UseVariadicStack(graph));
85   graph->lint();
86   auto opt_outputs = runGraph(graph, inputs);
87 
88   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
89 
90   // After full stack optimization we should have the following graph:
91   //
92   //  graph(%0 : ...,
93   //        %1 : ...,
94   //        %2 : ...,
95   //        %3 : ....):
96   //    %zero : int = prim:Constant[value=0]()
97   //    %varcat1 : Tensor = prim::VarStack(%0, %1, %zero)
98   //    %varcat2 : Tensor = prim::VarStack(%2, %3, %zero)
99   //    return (%varcat1, %varcat2)
100   testing::FileCheck()
101       .check_count("= prim::VarStack(", 2, /*exactly*/ true)
102       ->check_count("= aten::stack(", 0, /*exactly*/ true)
103       ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
104       ->run(*graph);
105 }
106 
TEST(StackOptTest,UseVariadicStackWithMultipleListUses)107 TEST(StackOptTest, UseVariadicStackWithMultipleListUses) {
108   auto graph = std::make_shared<Graph>();
109 
110   const std::string input =
111       R"IR(
112         graph(%0: Float(56, 56, 56),
113               %1: Float(56, 56, 56)):
114           %2 : int = prim::Constant[value=0]()
115           %input : Tensor[] = prim::ListConstruct(%0, %1)
116           %stack : Float(2, 56, 56, 56) = aten::stack(%input, %2)
117           return (%stack, %input)
118       )IR";
119   parseIR(input, graph.get());
120   std::vector<at::Tensor> inputs = {
121       at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU)};
122   auto orig_outputs = runGraph(graph, inputs);
123 
124   ASSERT_TRUE(UseVariadicStack(graph));
125   graph->lint();
126   auto opt_outputs = runGraph(graph, inputs);
127 
128   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
129 
130   // After replacing `aten::stack` with `prim::VarStack` we should have the
131   // following graph:
132   //
133   //  graph(%0 : ...,
134   //        %1 : ...):
135   //    %zero : int = prim:Constant[value=0]()
136   //    %input : Tensor[] = prim::ListConstruct(%0, %1)
137   //    %varcat : Tensor = prim::VarStack(%0, %1, %zero)
138   //    return (%varcat, %input)
139   testing::FileCheck()
140       .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
141       ->check_count("= prim::VarStack(", 1, /*exactly*/ true)
142       ->check_count("= aten::stack(", 0, /*exactly*/ true)
143       ->run(*graph);
144 }
145 
TEST(StackOptTest,UseVariadicStackWithListMutationAfterCat)146 TEST(StackOptTest, UseVariadicStackWithListMutationAfterCat) {
147   auto graph = std::make_shared<Graph>();
148 
149   const std::string input =
150       R"IR(
151         graph(%0: Float(56, 56, 56),
152               %1: Float(56, 56, 56),
153               %2: Float(56, 56, 56)):
154           %10 : int = prim::Constant[value=0]()
155           %input : Tensor[] = prim::ListConstruct(%0, %1)
156           %stack : Float(3, 56, 56, 56) = aten::stack(%input, %10)
157           %11 : Tensor = aten::append(%input, %2)
158           return (%stack, %input)
159       )IR";
160   parseIR(input, graph.get());
161   std::vector<at::Tensor> inputs = {
162       at::rand({56, 56, 56}, at::kCPU),
163       at::rand({56, 56, 56}, at::kCPU),
164       at::rand({56, 56, 56}, at::kCPU)};
165   auto orig_outputs = runGraph(graph, inputs);
166 
167   ASSERT_TRUE(UseVariadicStack(graph));
168   graph->lint();
169   auto opt_outputs = runGraph(graph, inputs);
170   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
171 
172   // The input list to `aten::stack` is mutated only after `aten::stack` op. So,
173   // it should have been replaced with `prim::VarStack`. The transformed graph
174   // should look like the following:
175   //
176   //  graph(%0 : ...,
177   //        %1 : ...,
178   //        %2 : ...):
179   //    %3 : int = prim:Constant[value=0]()
180   //    %4 : Tensor[] = prim::ListConstruct(%0, %1)
181   //    %7 : Tensor = prim::VarStack(%0, %1, %3)
182   //    %6 : Tensor = aten::append(%4, %2)
183   //    return (%7, %4)
184   testing::FileCheck()
185       .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
186       ->check_count("= prim::VarStack(", 1, /*exactly*/ true)
187       ->check_count("= aten::stack(", 0, /*exactly*/ true)
188       ->run(*graph);
189 }
190 
TEST(StackOptTest,UseVariadicStackWithListMutationBeforeCat)191 TEST(StackOptTest, UseVariadicStackWithListMutationBeforeCat) {
192   auto graph = std::make_shared<Graph>();
193 
194   const std::string input =
195       R"IR(
196         graph(%0: Float(56, 56, 56),
197               %1: Float(56, 56, 56),
198               %2: Float(56, 56, 56)):
199           %10 : int = prim::Constant[value=0]()
200           %input : Tensor[] = prim::ListConstruct(%0, %1)
201           %11 : Tensor = aten::append(%input, %2)
202           %stack : Float(3, 56, 56, 56) = aten::stack(%input, %10)
203           return (%stack)
204       )IR";
205   parseIR(input, graph.get());
206   std::vector<at::Tensor> inputs = {
207       at::rand({56, 56, 56}, at::kCPU),
208       at::rand({56, 56, 56}, at::kCPU),
209       at::rand({56, 56, 56}, at::kCPU)};
210   auto orig_outputs = runGraph(graph, inputs);
211 
212   {
213     ASSERT_FALSE(UseVariadicStack(graph));
214     graph->lint();
215     auto opt_outputs = runGraph(graph, inputs);
216     ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
217 
218     // No transformation should have happened since the `prim::ListConstruct` is
219     // mutated before `aten::stack`.
220     testing::FileCheck()
221         .check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
222         ->check_count("= aten::stack(", 1, /*exactly*/ true)
223         ->check_count("= prim::VarStack(", 0, /*exactly*/ true)
224         ->run(*graph);
225   }
226 
227   {
228     ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph));
229     graph->lint();
230     auto opt_outputs = runGraph(graph, inputs);
231     ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
232 
233     // The mutation of the list must be removed and the `aten::stack` op must
234     // be replaced with the `prim::VarStack` op in the graph. The transformed
235     // graph should look like the following:
236     //
237     //  graph(%0 : ...,
238     //        %1 : ...,
239     //        %2 : ...):
240     //    %3 : int = prim:Constant[value=0]()
241     //    %7 : Tensor = prim::VarStack(%0, %1, %2, %3)
242     //    return (%7)
243     testing::FileCheck()
244         .check_count("= prim::VarStack(", 1, /*exactly*/ true)
245         ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
246         ->check_count("= aten::stack(", 0, /*exactly*/ true)
247         ->run(*graph);
248   }
249 }
250 
TEST(StackOptTest,UseVariadicStackWithMultipleListMutations)251 TEST(StackOptTest, UseVariadicStackWithMultipleListMutations) {
252   auto graph = std::make_shared<Graph>();
253 
254   const std::string input =
255       R"IR(
256         graph(%0: Float(56, 56, 56),
257               %1: Float(56, 56, 56),
258               %2: Float(56, 56, 56),
259               %3: Float(56, 56, 56),
260               %4: Float(56, 56, 56)):
261           %10 : int = prim::Constant[value=0]()
262           %input : Tensor[] = prim::ListConstruct(%0, %1)
263           %stack.1 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
264           %11 : Tensor = aten::append(%input, %2)
265           %stack.2 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
266           %12 : Tensor = aten::append(%input, %3)
267           %stack.3 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
268           %13 : Tensor = aten::append(%input, %4)
269           %stack.4 : Float(5, 56, 56, 56) = aten::stack(%input, %10)
270           return (%stack.1, %stack.2, %stack.3, %stack.4)
271       )IR";
272   parseIR(input, graph.get());
273   std::vector<at::Tensor> inputs = {
274       at::rand({56, 56, 56}, at::kCPU),
275       at::rand({56, 56, 56}, at::kCPU),
276       at::rand({56, 56, 56}, at::kCPU),
277       at::rand({56, 56, 56}, at::kCPU),
278       at::rand({56, 56, 56}, at::kCPU)};
279   auto orig_outputs = runGraph(graph, inputs);
280 
281   ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph));
282   graph->lint();
283   auto opt_outputs = runGraph(graph, inputs);
284   ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs));
285 
286   // All the mutations of the list must be removed and the `aten::stack` ops
287   // must be replaced with `prim::VarStack` ops in the graph. The transformed
288   // graph should look like the following:
289   //
290   //  graph(%0 : ...,
291   //        %1 : ...,
292   //        %2 : ...,
293   //        %3 : ...,
294   //        %4 : ...):
295   //    %10 : int = prim:Constant[value=0]()
296   //    %5 : Tensor = prim::VarStack(%0, %1, %10)
297   //    %6 : Tensor = prim::VarStack(%0, %1, %2, %10)
298   //    %7 : Tensor = prim::VarStack(%0, %1, %2, %3, %10)
299   //    %8 : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %10)
300   //    return (%5, %6, %7, %8)
301   testing::FileCheck()
302       .check_count("= prim::VarStack(", 4, /*exactly*/ true)
303       ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
304       ->check_count("= aten::stack(", 0, /*exactly*/ true)
305       ->run(*graph);
306 }
307 
308 } // namespace jit
309 } // namespace torch
310