xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_graph_opt.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/tensorexpr/test_base.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/ir/irparser.h>
6 #include <torch/csrc/jit/passes/lower_tuples.h>
7 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
8 #include <torch/csrc/jit/tensorexpr/kernel.h>
9 #include <torch/csrc/jit/testing/file_check.h>
10 #include <torch/torch.h>
11 
12 #include <limits>
13 
14 namespace torch {
15 namespace jit {
16 
17 using namespace torch::jit::tensorexpr;
18 
19 class GraphOpt : public ::testing::Test {
20  public:
SetUp()21   void SetUp() override {
22     old_cat_wo_conditionals_ = getCatWoConditionals();
23     getCatWoConditionals() = true;
24   }
25 
TearDown()26   void TearDown() override {
27     getCatWoConditionals() = old_cat_wo_conditionals_;
28   }
29 
30  private:
31   bool old_cat_wo_conditionals_;
32 };
33 
TEST_F(GraphOpt,OptimizeCat)34 TEST_F(GraphOpt, OptimizeCat) {
35 #ifdef TORCH_ENABLE_LLVM
36   const auto graph_string = R"IR(
37     graph(%x : Float(10, strides=[1], device=cpu),
38           %y : Float(20, strides=[1], device=cpu),
39           %z : Float(30, strides=[1], device=cpu)):
40       %dim : int = prim::Constant[value=0]()
41       %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
42       %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
43       %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
44       return (%5))IR";
45   auto g = std::make_shared<Graph>();
46   torch::jit::parseIR(graph_string, g.get());
47   g->lint();
48 
49   TensorExprKernel kernel(g);
50 
51   // The `aten::log` op must be moved to the inputs of `aten::cat`.
52   testing::FileCheck()
53       .check("aten::log")
54       ->check("aten::log")
55       ->check("aten::log")
56       ->check("aten::cat")
57       ->check_not("aten::log")
58       ->run(*kernel.graph());
59 
60   auto x = at::rand({10}, at::kFloat);
61   auto y = at::rand({20}, at::kFloat);
62   auto z = at::rand({30}, at::kFloat);
63   auto ref = at::log(at::cat({x, y, z}, 0));
64 
65   std::vector<at::Tensor> inputs = {x, y, z};
66   std::vector<IValue> stack = fmap<IValue>(inputs);
67   kernel.run(stack);
68   auto out = stack[0].toTensor();
69   ASSERT_EQ(out.sizes(), ref.sizes());
70   ASSERT_EQ(out.dtype(), ref.dtype());
71   ASSERT_TRUE(at::allclose(out, ref));
72 #endif
73 }
74 
TEST_F(GraphOpt,OptimizeCat2)75 TEST_F(GraphOpt, OptimizeCat2) {
76 #ifdef TORCH_ENABLE_LLVM
77   const auto graph_string = R"IR(
78     graph(%x : Float(10, strides=[1], device=cpu),
79           %y : Float(20, strides=[1], device=cpu),
80           %z : Float(30, strides=[1], device=cpu)):
81       %dim : int = prim::Constant[value=0]()
82       %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
83       %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
84       %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
85       %6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5)
86       return (%6))IR";
87   auto g = std::make_shared<Graph>();
88   torch::jit::parseIR(graph_string, g.get());
89   g->lint();
90 
91   TensorExprKernel kernel(g);
92 
93   // The `aten::log` and `aten::tanh` ops must be moved to the inputs of
94   // `aten::cat`.
95   testing::FileCheck()
96       .check("aten::log")
97       ->check("aten::log")
98       ->check("aten::log")
99       ->check("aten::tanh")
100       ->check("aten::tanh")
101       ->check("aten::tanh")
102       ->check("aten::cat")
103       ->check_not("aten::log")
104       ->check_not("aten::tanh")
105       ->run(*kernel.graph());
106 
107   auto x = at::rand({10}, at::kFloat);
108   auto y = at::rand({20}, at::kFloat);
109   auto z = at::rand({30}, at::kFloat);
110   auto ref = at::tanh(at::log(at::cat({x, y, z}, 0)));
111 
112   std::vector<at::Tensor> inputs = {x, y, z};
113   std::vector<IValue> stack = fmap<IValue>(inputs);
114   kernel.run(stack);
115   auto out = stack[0].toTensor();
116   ASSERT_EQ(out.sizes(), ref.sizes());
117   ASSERT_EQ(out.dtype(), ref.dtype());
118   ASSERT_TRUE(at::allclose(out, ref));
119 #endif
120 }
121 
TEST_F(GraphOpt,OptimizeCat3)122 TEST_F(GraphOpt, OptimizeCat3) {
123 #ifdef TORCH_ENABLE_LLVM
124   const auto graph_string = R"IR(
125     graph(%a : Float(60, strides=[1], device=cpu),
126           %x : Float(10, strides=[1], device=cpu),
127           %y : Float(20, strides=[1], device=cpu),
128           %z : Float(30, strides=[1], device=cpu)):
129       %dim : int = prim::Constant[value=0]()
130       %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
131       %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
132       %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
133       %6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5)
134       return (%6))IR";
135   auto g = std::make_shared<Graph>();
136   torch::jit::parseIR(graph_string, g.get());
137   g->lint();
138 
139   TensorExprKernel kernel(g);
140 
141   // The `aten::tanh` op must be moved to the inputs of `aten::cat`.
142   // But the `aten::mul` op must not be moved since it is not a single-tensor
143   // op (it has 2 tensor inputs).
144   testing::FileCheck()
145       .check("aten::tanh")
146       ->check("aten::tanh")
147       ->check("aten::tanh")
148       ->check("aten::cat")
149       ->check("aten::mul")
150       ->check_not("aten::tanh")
151       ->run(*kernel.graph());
152 
153   auto a = at::rand({60}, at::kFloat);
154   auto x = at::rand({10}, at::kFloat);
155   auto y = at::rand({20}, at::kFloat);
156   auto z = at::rand({30}, at::kFloat);
157   auto ref = at::tanh(at::cat({x, y, z}, 0)) * a;
158 
159   std::vector<at::Tensor> inputs = {a, x, y, z};
160   std::vector<IValue> stack = fmap<IValue>(inputs);
161   kernel.run(stack);
162   auto out = stack[0].toTensor();
163   ASSERT_EQ(out.sizes(), ref.sizes());
164   ASSERT_EQ(out.dtype(), ref.dtype());
165   ASSERT_TRUE(at::allclose(out, ref));
166 #endif
167 }
168 
TEST_F(GraphOpt,OptimizeCatWithTypePromotionInUser)169 TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) {
170 #ifdef TORCH_ENABLE_LLVM
171   const auto graph_string = R"IR(
172     graph(%x : Int(10, strides=[1], device=cpu),
173           %y : Int(20, strides=[1], device=cpu),
174           %z : Int(30, strides=[1], device=cpu)):
175       %dim : int = prim::Constant[value=0]()
176       %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
177       %cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
178       %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
179       return (%5))IR";
180   auto g = std::make_shared<Graph>();
181   torch::jit::parseIR(graph_string, g.get());
182   g->lint();
183 
184   TensorExprKernel kernel(g);
185 
186   // The `aten::tanh` op must be moved to the inputs of `aten::cat`.
187   // The scalar type of the inputs to `cat` should now be `Float` since they
188   // are the result of `tanh` which does the type promotion.
189   testing::FileCheck()
190       .check("aten::tanh")
191       ->check("aten::tanh")
192       ->check("aten::tanh")
193       ->check("aten::cat")
194       ->check_not("aten::tanh")
195       ->run(*kernel.graph());
196 
197   auto x = at::randint(std::numeric_limits<int>::max(), {10}, at::kInt);
198   auto y = at::randint(std::numeric_limits<int>::max(), {20}, at::kInt);
199   auto z = at::randint(std::numeric_limits<int>::max(), {30}, at::kInt);
200   auto ref = at::tanh(at::cat({x, y, z}, 0));
201 
202   std::vector<at::Tensor> inputs = {x, y, z};
203   std::vector<IValue> stack = fmap<IValue>(inputs);
204   kernel.run(stack);
205   auto out = stack[0].toTensor();
206   ASSERT_EQ(out.sizes(), ref.sizes());
207   ASSERT_EQ(out.dtype(), ref.dtype());
208   ASSERT_TRUE(at::allclose(out, ref));
209 #endif
210 }
211 
TEST_F(GraphOpt,OptimizeCatWithTypePromotionInCat)212 TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) {
213 #ifdef TORCH_ENABLE_LLVM
214   const auto graph_string = R"IR(
215     graph(%x : Float(10, strides=[1], device=cpu),
216           %y : Float(20, strides=[1], device=cpu),
217           %z : Double(30, strides=[1], device=cpu)):
218       %dim : int = prim::Constant[value=0]()
219       %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
220       %cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
221       %5 : Double(60, strides=[1], device=cpu) = aten::log(%cat)
222       return (%5))IR";
223   auto g = std::make_shared<Graph>();
224   torch::jit::parseIR(graph_string, g.get());
225   g->lint();
226 
227   TensorExprKernel kernel(g);
228 
229   // No transformation should have happened because the `aten::cat` op performs
230   // type promotion. This case is currently not handled.
231   testing::FileCheck()
232       .check("aten::cat")
233       ->check("aten::log")
234       ->check_not("aten::cat")
235       ->check_not("aten::log")
236       ->run(*kernel.graph());
237 #endif
238 }
239 
TEST_F(GraphOpt,OptimizeCatNoSingleTensorElementwiseOp)240 TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) {
241 #ifdef TORCH_ENABLE_LLVM
242   const auto graph_string = R"IR(
243     graph(%0 : Float(60, strides=[1], device=cpu),
244           %x : Float(10, strides=[1], device=cpu),
245           %y : Float(20, strides=[1], device=cpu),
246           %z : Float(30, strides=[1], device=cpu)):
247       %dim : int = prim::Constant[value=0]()
248       %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
249       %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
250       %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
251       return (%5))IR";
252   auto g = std::make_shared<Graph>();
253   torch::jit::parseIR(graph_string, g.get());
254   g->lint();
255 
256   TensorExprKernel kernel(g);
257 
258   // No transformation is expected since the consumers of cat are not
259   // single-tensor element-wise ops.
260   testing::FileCheck()
261       .check("aten::cat")
262       ->check("aten::mul")
263       ->check_not("aten::cat")
264       ->check_not("aten::mul")
265       ->run(*kernel.graph());
266 #endif
267 }
268 
TEST_F(GraphOpt,OptimizeCatNoSingleTensorElementwiseOp2)269 TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) {
270 #ifdef TORCH_ENABLE_LLVM
271   const auto graph_string = R"IR(
272     graph(%0 : Float(60, strides=[1], device=cpu),
273           %1 : Float(60, strides=[1], device=cpu),
274           %x : Float(10, strides=[1], device=cpu),
275           %y : Float(20, strides=[1], device=cpu),
276           %z : Float(30, strides=[1], device=cpu)):
277       %one : int = prim::Constant[value=1]()
278       %dim : int = prim::Constant[value=0]()
279       %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
280       %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
281       %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
282       %6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one)
283       return (%6))IR";
284   auto g = std::make_shared<Graph>();
285   torch::jit::parseIR(graph_string, g.get());
286   g->lint();
287 
288   TensorExprKernel kernel(g);
289 
290   // No transformation is expected since the consumers of cat are not
291   // single-tensor element-wise ops.
292   testing::FileCheck()
293       .check("aten::cat")
294       ->check("aten::mul")
295       ->check("aten::add")
296       ->check_not("aten::cat")
297       ->check_not("aten::mul")
298       ->check_not("aten::add")
299       ->run(*kernel.graph());
300 #endif
301 }
302 
TEST_F(GraphOpt,AOTGraphPrepPasses)303 TEST_F(GraphOpt, AOTGraphPrepPasses) {
304   const auto graph_string = R"IR(
305     graph(%x, %y, %z, %i : int):
306       %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
307       return (%xyz_list, %i))IR";
308   auto g = std::make_shared<Graph>();
309   torch::jit::parseIR(graph_string, g.get());
310 
311   removeGraphOutput(g, 1);
312   replaceListOutputWithTuple(g);
313   LowerAllTuples(g);
314 
315   testing::FileCheck().check("return (%x, %y, %z)")->run(*g);
316 }
317 
318 } // namespace jit
319 } // namespace torch
320