xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_te_fuser_pass.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/tensorexpr/test_base.h>
4 #include <torch/csrc/jit/codegen/fuser/interface.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/ir/irparser.h>
7 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
8 #include <torch/csrc/jit/runtime/interpreter.h>
9 #include <torch/csrc/jit/testing/file_check.h>
10 #include <sstream>
11 
12 namespace torch {
13 namespace jit {
14 
15 using namespace torch::jit::tensorexpr;
16 
17 struct WithCPUFuser {
WithCPUFusertorch::jit::WithCPUFuser18   WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
19     overrideCanFuseOnCPU(val);
20   }
21 
~WithCPUFusertorch::jit::WithCPUFuser22   ~WithCPUFuser() {
23     overrideCanFuseOnCPU(cpuFuserEnabled);
24   }
25 
26   bool cpuFuserEnabled;
27 };
28 
TEST(TEFuserPass,FuserPass_1)29 TEST(TEFuserPass, FuserPass_1) {
30   WithCPUFuser cf;
31   const auto graph_string = R"IR(
32     graph(%0 : Float(128, strides=[1], device=cpu),
33           %1 : Float(128, strides=[1], device=cpu)):
34       %12 : int = prim::Constant[value=1]()
35       %2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
36       %2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1)
37       %3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12)
38       %4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1)
39       %5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12)
40       return (%5))IR";
41   auto g = std::make_shared<Graph>();
42   torch::jit::parseIR(graph_string, g.get());
43 
44   g->lint();
45   FuseTensorExprs(g);
46 
47   // We should not be able to fuse across the in-place operation here.
48   testing::FileCheck()
49       .check("prim::TensorExprGroup_")
50       ->check("aten::add_")
51       ->check("prim::TensorExprGroup_")
52       ->run(*g);
53 }
54 
TEST(TEFuserPass,FuserPass_2)55 TEST(TEFuserPass, FuserPass_2) {
56   WithCPUFuser cf;
57   const auto graph_string = R"IR(
58     graph(%0 : Float(128, strides=[1], device=cpu),
59           %1 : Float(128, strides=[1], device=cpu)):
60       %12 : int = prim::Constant[value=1]()
61       %a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
62       %b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12)
63       %c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12)
64       %d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a)
65       return (%d))IR";
66   auto g = std::make_shared<Graph>();
67   torch::jit::parseIR(graph_string, g.get());
68 
69   g->lint();
70   FuseTensorExprs(g);
71 
72   // We should not be able to fuse across the in-place operation here.
73   testing::FileCheck()
74       .check("aten::add_")
75       ->check("prim::TensorExprGroup_0")
76       ->run(*g);
77 }
78 
TEST(TEFuserPass,FuserPass_3)79 TEST(TEFuserPass, FuserPass_3) {
80   WithCPUFuser cf;
81   const auto graph_string = R"IR(
82     graph(%x : Float(128, strides=[1], device=cpu),
83           %y : Float(128, strides=[1], device=cpu)):
84       %r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y)
85       return (%r))IR";
86   {
87     auto g = std::make_shared<Graph>();
88     torch::jit::parseIR(graph_string, g.get());
89 
90     g->lint();
91     FuseTensorExprs(g, /* min_group_size= */ 2);
92 
93     // We should not create a fusion group since its size would be too small
94     testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
95   }
96   {
97     auto g = std::make_shared<Graph>();
98     torch::jit::parseIR(graph_string, g.get());
99 
100     g->lint();
101     FuseTensorExprs(g, /* min_group_size= */ 1);
102 
103     // We should create a fusion group since its size is above the threshold
104     testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
105   }
106 }
107 
TEST(TEFuserPass,FuserPass_0DimInput)108 TEST(TEFuserPass, FuserPass_0DimInput) {
109   WithCPUFuser cf;
110   const auto graph_string = R"IR(
111     graph(%x : Float(device=cpu),
112           %y : Float(device=cpu)):
113       %one : int = prim::Constant[value=1]()
114       %a : Float(device=cpu) = aten::mul(%x, %y)
115       %b : Float(device=cpu) = aten::add(%x, %a, %one)
116       return (%b))IR";
117   auto g = std::make_shared<Graph>();
118   torch::jit::parseIR(graph_string, g.get());
119 
120   g->lint();
121   FuseTensorExprs(g);
122 
123   // We should fuse 0-dim tensors too
124   testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
125 }
126 
TEST(TEFuserPass,FuserPass_UnfusibleDevice)127 TEST(TEFuserPass, FuserPass_UnfusibleDevice) {
128   WithCPUFuser cf(false);
129   const auto graph_string = R"IR(
130     graph(%x : Float(10, strides=[1], device=cpu),
131           %y : Float(10, strides=[1], device=cpu)):
132       %a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
133       return (%a))IR";
134   auto g = std::make_shared<Graph>();
135   torch::jit::parseIR(graph_string, g.get());
136 
137   g->lint();
138   FuseTensorExprs(g, /* min_group_size= */ 1);
139 
140   // Test that we're not starting fusion groups from nodes with unfusible device
141   testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
142 }
143 
TEST(TEFuserPass,FuserPass_UnknownShapes)144 TEST(TEFuserPass, FuserPass_UnknownShapes) {
145   WithCPUFuser cf;
146   const auto graph_string = R"IR(
147     graph(%x : Tensor,
148           %y : Tensor):
149       %a : Tensor = aten::mul(%x, %y)
150       %b : Tensor = aten::mul(%x, %a)
151       return (%b))IR";
152   auto g = std::make_shared<Graph>();
153   torch::jit::parseIR(graph_string, g.get());
154 
155   g->lint();
156   FuseTensorExprs(g);
157 
158   // Test that we're not generating fusion groups when shapes are not known
159   testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
160 }
161 
TEST(TEFuserPass,FuserPass_Multidevice)162 TEST(TEFuserPass, FuserPass_Multidevice) {
163   {
164     WithCPUFuser cf;
165     const auto graph_string = R"IR(
166     graph(%x : Float(10, strides=[1], device=cpu),
167           %y : Float(20, strides=[1], device=cpu),
168           %z : Float(30, strides=[1], device=cpu)):
169       %dim : int = prim::Constant[value=0]()
170       %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
171       %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
172       return (%cat))IR";
173     auto g = std::make_shared<Graph>();
174     torch::jit::parseIR(graph_string, g.get());
175 
176     g->lint();
177     FuseTensorExprs(g, /* min_group_size= */ 1);
178 
179     // We should be able to fuse this
180     testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
181   }
182   {
183     WithCPUFuser cf;
184     const auto graph_string = R"IR(
185     graph(%x : Float(10, strides=[1], device=cpu),
186           %y : Float(20, strides=[1], device=cuda:0),
187           %z : Float(30, strides=[1], device=cpu)):
188       %dim : int = prim::Constant[value=0]()
189       %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
190       %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
191       return (%cat))IR";
192     auto g = std::make_shared<Graph>();
193     torch::jit::parseIR(graph_string, g.get());
194 
195     g->lint();
196     FuseTensorExprs(g, /* min_group_size= */ 1);
197 
198     // We should not fuse this aten::cat since its inputs are from different
199     // devices
200     testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
201   }
202   {
203     WithCPUFuser cf;
204     const auto graph_string = R"IR(
205     graph(%x : Float(10, strides=[1], device=cpu),
206           %y : Float(20, strides=[1], device=cpu),
207           %z : Float(10, strides=[1], device=cuda:0)):
208       %dim : int = prim::Constant[value=0]()
209       %xy_list : Tensor[] = prim::ListConstruct(%x, %y)
210       %xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
211       %r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z)
212       return (%r))IR";
213     auto g = std::make_shared<Graph>();
214     torch::jit::parseIR(graph_string, g.get());
215 
216     g->lint();
217     FuseTensorExprs(g, /* min_group_size= */ 2);
218 
219     // Test that we check device before merging one node (cat) into another
220     // (mul)
221     testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
222   }
223   {
224     WithCPUFuser cf;
225     const auto graph_string = R"IR(
226     graph(%x : Float(10, strides=[1], device=cpu),
227           %y : Float(20, strides=[1], device=cpu),
228           %z : Float(10, strides=[1], device=cuda:0)):
229       %z2 : Tensor = aten::mul(%z, %z)
230       %dim : int = prim::Constant[value=0]()
231       %xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2)
232       %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
233       return (%cat))IR";
234     auto g = std::make_shared<Graph>();
235     torch::jit::parseIR(graph_string, g.get());
236 
237     g->lint();
238     FuseTensorExprs(g, /* min_group_size= */ 2);
239 
240     // Test that we check device before merging one node (mul) into another
241     // (cat)
242     testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
243   }
244   {
245     WithCPUFuser cf;
246     const auto graph_string = R"IR(
247     graph(%x : Float(10, strides=[1], device=cpu),
248           %y : Float(20, strides=[1], device=cuda:0)):
249       %r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
250       return (%r))IR";
251     auto g = std::make_shared<Graph>();
252     torch::jit::parseIR(graph_string, g.get());
253 
254     g->lint();
255     FuseTensorExprs(g, /* min_group_size= */ 1);
256 
257     // We should not fuse this graph since its inputs are from different devices
258     testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
259   }
260   {
261     WithCPUFuser cf;
262     const auto graph_string = R"IR(
263     graph(%x : Float(10, strides=[1], device=cuda:0),
264           %y : Float(20, strides=[1], device=cuda:1),
265           %z : Float(20, strides=[1], device=cpu)):
266       %x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x)
267       %y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y)
268       %z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z)
269       return (%x2, %y2, %z2))IR";
270     auto g = std::make_shared<Graph>();
271     torch::jit::parseIR(graph_string, g.get());
272 
273     g->lint();
274     FuseTensorExprs(g, /* min_group_size= */ 2);
275 
276     // We should not fuse these two computations since they use different
277     // devices
278     testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
279   }
280 }
281 
TEST(TEFuserPass,FuserPass_MergeGroups)282 TEST(TEFuserPass, FuserPass_MergeGroups) {
283   WithCPUFuser cf;
284   const auto graph_string = R"IR(
285     graph(%a : Float(128, strides=[1], device=cpu),
286           %b : Float(128, strides=[1], device=cpu)):
287       %x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a)
288       %y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b)
289       return (%x, %y))IR";
290   auto g = std::make_shared<Graph>();
291   torch::jit::parseIR(graph_string, g.get());
292 
293   g->lint();
294   FuseTensorExprs(g, /* min_group_size= */ 1);
295 
296   // The %x and %y computations are completely independent and yet we should put
297   // them into a single fusion group rather than having two separate ones.
298   testing::FileCheck()
299       .check("= prim::TensorExprGroup_")
300       ->check_not("= prim::TensorExprGroup_")
301       ->run(*g);
302 }
303 
TEST(TEFuserPass,FuserPass_IgnoreUnknownShapeAtStart)304 TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) {
305   WithCPUFuser cf;
306   const auto graph_string = R"IR(
307     graph(%x : Bool(8, strides=[1], device=cpu),
308           %y : Bool(8, strides=[1], device=cpu)):
309       %a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y)
310       %b : Tensor = aten::__or__(%a, %y)
311       return (%b)
312     )IR";
313   auto g = std::make_shared<Graph>();
314   torch::jit::parseIR(graph_string, g.get());
315   g->lint();
316   FuseTensorExprs(g, /* min_group_size= */ 2);
317   testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
318 }
319 
TEST(TEFuserPass,FuserPass_Where)320 TEST(TEFuserPass, FuserPass_Where) {
321   WithCPUFuser cf;
322   const auto graph_string = R"IR(
323     graph(%x : Float(8, strides=[1], device=cpu),
324           %y : Float(8, strides=[1], device=cpu),
325           %z : Float(8, strides=[1], device=cpu)):
326       %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
327       %b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z)
328       return (%b)
329     )IR";
330   auto g = std::make_shared<Graph>();
331   torch::jit::parseIR(graph_string, g.get());
332   g->lint();
333   FuseTensorExprs(g, /* min_group_size= */ 2);
334   testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
335 }
336 
TEST(TEFuserPass,FuserPass_WhereList)337 TEST(TEFuserPass, FuserPass_WhereList) {
338   WithCPUFuser cf;
339   const auto graph_string = R"IR(
340     graph(%x : Float(8, strides=[1], device=cpu),
341           %y : Float(8, strides=[1], device=cpu),
342           %z : Float(8, strides=[1], device=cpu)):
343       %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
344       %b : Tensor[] = aten::where(%cond)
345       return (%b)
346     )IR";
347   auto g = std::make_shared<Graph>();
348   torch::jit::parseIR(graph_string, g.get());
349   g->lint();
350   FuseTensorExprs(g, /* min_group_size= */ 2);
351   testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
352 }
353 
TEST(TEFuserPass,DynamicShapeFusion)354 TEST(TEFuserPass, DynamicShapeFusion) {
355   WithCPUFuser cf;
356   const auto graph_string = R"IR(
357     graph(%0 : Float(10, 5, strides=[5, 1], device=cpu),
358           %1 : Float(10, 5, strides=[5, 1], device=cpu)):
359       %2 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%0, %1)
360       %3 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%2, %1)
361       return (%3))IR";
362   auto g = std::make_shared<Graph>();
363   torch::jit::parseIR(graph_string, g.get());
364 
365   g->lint();
366   FuseTensorExprs(
367       g,
368       /* min_group_size = */ 2,
369       /* add_composed_op = */ true,
370       /* fuse_to_dynamic_shapes = */ true);
371   Code code(g, "");
372 
373   testing::FileCheck()
374       .check("prim::TensorExprDynamicGroup_")
375       ->check("prim::TensorExprDynamicGuard")
376       ->check("prim::TensorExprGroup_")
377       ->run(*g);
378 
379   auto run_and_compare = [&](const std::vector<at::Tensor>& inputs) {
380     TORCH_INTERNAL_ASSERT(inputs.size() == 2);
381 
382     auto ref = at::mul(at::mul(inputs[0], inputs[1]), inputs[1]);
383 
384     InterpreterState interp(code);
385     Stack stack(inputs.begin(), inputs.end());
386     interp.run(stack);
387     at::Tensor out = pop(stack).toTensor();
388     ASSERT_TRUE(at::allclose(out, ref));
389   };
390 
391   std::vector<at::Tensor> inputs = {at::rand({10, 5}), at::rand({10, 5})};
392   run_and_compare(inputs);
393 
394   std::vector<at::Tensor> inputs2 = {at::rand({20, 5}), at::rand({20, 5})};
395   run_and_compare(inputs2);
396 
397   std::vector<at::Tensor> inputs3 = {at::rand({25, 60}), at::rand({25, 60})};
398   run_and_compare(inputs3);
399 }
400 
401 } // namespace jit
402 } // namespace torch
403