xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_kernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/code_template.h>
4 #include <c10/util/irange.h>
5 #include <test/cpp/tensorexpr/test_base.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/ir/irparser.h>
8 #include <torch/csrc/jit/passes/constant_propagation.h>
9 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
10 #include <torch/csrc/jit/tensorexpr/kernel.h>
11 #include <torch/csrc/jit/tensorexpr/loopnest.h>
12 #include <torch/csrc/jit/tensorexpr/tensor.h>
13 #include <torch/csrc/jit/testing/file_check.h>
14 #include <torch/torch.h>
15 #include <cmath>
16 #include <sstream>
17 #include <stdexcept>
18 
19 namespace torch {
20 namespace jit {
21 
22 using namespace torch::indexing;
23 using namespace torch::jit::tensorexpr;
24 
25 class Kernel : public ::testing::Test {
26  public:
SetUp()27   void SetUp() override {
28     getTEMustUseLLVMOnCPU() = false;
29   }
30 };
31 
TEST_F(Kernel,ParallelExternalCallBuf)32 TEST_F(Kernel, ParallelExternalCallBuf) {
33   const auto graph_string = R"IR(
34     graph(%0 : Float(1000, 5000, strides=[5000, 1], device=cpu),
35           %1 : Float(1000, 5000, strides=[5000, 1], device=cpu),
36           %2 : Float(5000, 1000, strides=[5000, 1], device=cpu)):
37       %3 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::mul(%0, %1)
38       %4 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::matmul(%3, %2)
39       return (%4))IR";
40   auto graph = std::make_shared<Graph>();
41   torch::jit::parseIR(graph_string, &*graph);
42   const std::string& verification_pattern =
43       R"IR(
44 # CHECK: for (int64_t i = 0ll; i < 5000ll; i++)  /* parallel */{)IR";
45 
46 #ifdef TORCH_ENABLE_LLVM
47   TensorExprKernel k(graph);
48   StmtPtr s = k.getCodeGenStmt();
49   std::ostringstream oss;
50   oss << *s;
51   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
52 #endif
53 }
54 
TEST_F(Kernel,InliningIntermediates)55 TEST_F(Kernel, InliningIntermediates) {
56   // here, each mul has only one use, so it should be completely inlined
57   {
58     const auto graph_string = R"IR(
59         graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
60               %1 : Float(5, 3, strides=[3, 1], device=cpu)):
61           %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
62           %one : int = prim::Constant[value=1]()
63           %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
64           %5: Float(5, 3, strides=[3, 1]) = aten::add(%4, %1, %one)
65           return (%5))IR";
66     auto graph = std::make_shared<Graph>();
67     parseIR(graph_string, &*graph);
68     TensorExprKernel k(graph);
69     auto stmt = k.getCodeGenStmt();
70     std::ostringstream oss;
71     oss << *stmt;
72     torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str());
73   }
74   {
75     const auto graph_template = R"IR(
76         graph(%0 : Float(5, 3, strides=[3, 1], device=${device}),
77               %1 : Float(5, 3, strides=[3, 1], device=${device})):
78           %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
79           %one : int = prim::Constant[value=1]()
80           %3 : Float(5, 3, strides=[3, 1]) = aten::sub(%0, %2, %one)
81           %4 : Float(5, 3, strides=[3, 1]) = aten::add(%3, %0, %one)
82           %5 : Float(5, 3, strides=[3, 1]) = aten::div(%3, %0)
83           return (%4, %5))IR";
84     for (bool use_cuda : {false, true}) {
85       if (!torch::cuda::is_available() && use_cuda) {
86         continue;
87       }
88 
89       at::jit::TemplateEnv env;
90       env.s("device", use_cuda ? "cuda:0" : "cpu");
91       const auto graph_string = format(graph_template, env);
92       auto graph = std::make_shared<Graph>();
93       parseIR(graph_string, &*graph);
94       TensorExprKernel k(graph);
95       auto stmt = k.getCodeGenStmt();
96       std::ostringstream oss;
97       oss << *stmt;
98       // aten_mul only has one use, inlined completely
99       torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str());
100 
101       // aten_sub should be removed by the CUDA backend by metavar rewriting
102       // and by the CPU backend by horizontal fusion.
103       torch::jit::testing::FileCheck().check_not("aten_sub")->run(oss.str());
104     }
105   }
106 }
107 
TEST_F(Kernel,PreAllocIntermediateBufs)108 TEST_F(Kernel, PreAllocIntermediateBufs) {
109   const auto graph_string = R"IR(
110 graph(%a.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu),
111       %b.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu)):
112   %2 : int = prim::Constant[value=1]()
113   %c.2 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::matmul(%a.1, %b.1) # test_matmul.py:12:12
114   %3 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %c.2, %2) # test_matmul.py:13:15
115   return (%3))IR";
116   auto graph = std::make_shared<Graph>();
117   parseIR(graph_string, &*graph);
118 
119   auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
120   auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
121   auto o = at::zeros({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
122   auto ref = at::matmul(a, b) + a;
123   TensorExprKernel k(graph, {}, {}, true);
124 
125   std::vector<at::Tensor> inputs = {a, b};
126   auto stmt = k.getCodeGenStmt();
127 
128   std::ostringstream oss;
129   oss << *stmt;
130 
131   // Check whether the intermediate buffer has been added to constants
132   auto constants = k.getConstantDescriptors();
133   ASSERT_EQ(constants.size(), 1);
134 
135   // Check the IR we produced
136   torch::jit::testing::FileCheck().check_not("Alloc")->run(oss.str());
137   torch::jit::testing::FileCheck().check_not("Free")->run(oss.str());
138 
139   // Check correctness
140   std::vector<IValue> stack = fmap<IValue>(inputs);
141   k.run(stack);
142   o = stack[0].toTensor();
143   ASSERT_TRUE(at::allclose(o, ref));
144 }
145 
TEST_F(Kernel,_1)146 TEST_F(Kernel, _1) {
147   const auto graph_string = R"IR(
148       graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
149             %1 : Float(5, 3, strides=[3, 1], device=cpu)):
150         %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
151         %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
152         return (%3))IR";
153   auto graph = std::make_shared<Graph>();
154   parseIR(graph_string, &*graph);
155 
156   auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
157   auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
158   auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
159   auto ref = a * (a * b);
160   TensorExprKernel k(graph);
161   std::vector<at::Tensor> inputs = {a, b};
162   StmtPtr s = k.getCodeGenStmt();
163 
164   std::ostringstream oss;
165   oss << *s;
166 
167   // Check the IR we produced
168   const std::string& verification_pattern =
169       R"IR(
170 # CHECK: for
171 # CHECK-NEXT: for
172 # CHECK-NOT: for)IR";
173   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
174 
175   std::vector<IValue> stack = fmap<IValue>(inputs);
176   k.run(stack);
177   o = stack[0].toTensor();
178   for (size_t i = 0; i < 5 * 3; i++) {
179     TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
180   }
181 }
182 
TEST_F(Kernel,_2)183 TEST_F(Kernel, _2) {
184   const auto graph_string = R"IR(
185       graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
186             %1 : Float(5, 3, strides=[1, 5], device=cpu)):
187         %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
188         %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
189         return (%3))IR";
190   auto graph = std::make_shared<Graph>();
191   parseIR(graph_string, &*graph);
192 
193   auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
194   auto b =
195       at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
196   auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
197   auto ref = a * (a * b);
198   TensorExprKernel k(graph);
199   std::vector<at::Tensor> inputs = {a, b};
200   StmtPtr s = k.getCodeGenStmt();
201 
202   std::ostringstream oss;
203   oss << *s;
204 
205   // Check the IR we produced
206   const std::string& verification_pattern =
207       R"IR(
208 # CHECK: for
209 # CHECK-NEXT: for
210 # CHECK-NOT: for)IR";
211   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
212 
213   std::vector<IValue> stack = fmap<IValue>(inputs);
214   k.run(stack);
215   o = stack[0].toTensor();
216   for (size_t i = 0; i < 5 * 3; i++) {
217     TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
218   }
219 }
220 
TEST_F(Kernel,_3)221 TEST_F(Kernel, _3) {
222   const auto graph_string = R"IR(
223       graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
224             %1 : Float(5, 3, strides=[12, 2], device=cpu)):
225         %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
226         %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
227         return (%3))IR";
228   auto graph = std::make_shared<Graph>();
229   parseIR(graph_string, &*graph);
230 
231   auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
232   auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat))
233                .index({Slice(None, None, 2), Slice(None, None, 2)});
234   auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
235   auto ref = a * (a * b);
236   TensorExprKernel k(graph);
237   std::vector<at::Tensor> inputs = {a, b};
238   StmtPtr s = k.getCodeGenStmt();
239 
240   std::ostringstream oss;
241   oss << *s;
242 
243   // Check the IR we produced
244   const std::string& verification_pattern =
245       R"IR(
246 # CHECK: for
247 # CHECK-NEXT: for
248 # CHECK-NOT: for)IR";
249   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
250 
251   std::vector<IValue> stack = fmap<IValue>(inputs);
252   k.run(stack);
253   o = stack[0].toTensor();
254   for (size_t i = 0; i < 5 * 3; i++) {
255     TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
256   }
257 }
258 
TEST_F(Kernel,Huge)259 TEST_F(Kernel, Huge) {
260   const auto graph_string = R"IR(
261       graph(%x.1 : Float(4000000000, strides=[1], requires_grad=0, device=cpu)):
262         %1 : int = prim::Constant[value=0]()
263         %2 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::unsqueeze(%x.1, %1)
264         %3 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::relu(%2)
265         return (%3))IR";
266   auto graph = std::make_shared<Graph>();
267   parseIR(graph_string, &*graph);
268   TensorExprKernel k(graph);
269   std::ostringstream oss;
270   oss << *k.getCodeGenStmt();
271   // The 4000000000 iterations loop will be split into 500000000 x 8 and the
272   // outer loop will be parallel. If LLVM is not present, it will not be split,
273   // and to cover both of these cases we're looking for 00000000ll; in the
274   // output.
275   const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR";
276   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
277 }
278 
TEST_F(Kernel,ParallelStrided)279 TEST_F(Kernel, ParallelStrided) {
280   const auto graph_string = R"IR(
281       graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu),
282             %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)):
283         %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1)
284         %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2)
285         return (%3))IR";
286   auto graph = std::make_shared<Graph>();
287   parseIR(graph_string, &*graph);
288 
289   auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat));
290   auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat))
291                .index(
292                    {Slice(None, None, 2),
293                     Slice(None, None, 2),
294                     Slice(None, None, 2)});
295   auto ref = a * (a * b);
296   auto o = at::zeros_like(ref);
297   TensorExprKernel k(graph);
298   std::vector<at::Tensor> inputs = {a, b};
299   std::vector<IValue> stack = fmap<IValue>(inputs);
300   k.run(stack);
301   o = stack[0].toTensor();
302   for (size_t i = 0; i < 5 * 3; i++) {
303     TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
304   }
305 }
306 
TEST_F(Kernel,DISABLED_Shape_Inference)307 TEST_F(Kernel, DISABLED_Shape_Inference) {
308   // disabled: doesn't do stride propagation, and isn't being used currently
309 
310   // Test TensorExpr shape inference capabilities: it should only require shapes
311   // for the inputs
312   {
313     const auto graph_string = R"IR(
314       graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
315             %1 : Float(5, 3, strides=[12, 2], device=cpu)):
316         %2 : Tensor = aten::mul(%0, %1)
317         %3 : Tensor = aten::mul(%0, %2)
318         return (%3))IR";
319     auto graph = std::make_shared<Graph>();
320     parseIR(graph_string, &*graph);
321 
322     auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
323     auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat))
324                  .index({Slice(None, None, 2), Slice(None, None, 2)});
325     auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
326     auto ref = a * (a * b);
327     TensorExprKernel k(graph);
328     std::vector<at::Tensor> inputs = {a, b};
329     StmtPtr s = k.getCodeGenStmt();
330 
331     std::ostringstream oss;
332     oss << *s;
333 
334     // Check the IR we produced
335     const std::string& verification_pattern =
336         R"IR(
337 # CHECK: for
338 # CHECK-NEXT: for
339 # CHECK-NOT: for)IR";
340     torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
341 
342     std::vector<IValue> stack = fmap<IValue>(inputs);
343     k.run(stack);
344     o = stack[0].toTensor();
345     for (size_t i = 0; i < 5 * 3; i++) {
346       TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
347     }
348   }
349   {
350     const auto graph_string = R"IR(
351       graph(%0 : Float(8, 8, strides=[8, 1], device=cpu),
352             %1 : Float(8, 8, strides=[8, 1], device=cpu)):
353         %2 : Tensor = aten::mul(%0, %1)
354         %3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2)
355         %r : Tensor = aten::mul(%3, %4)
356         return (%r))IR";
357     auto graph = std::make_shared<Graph>();
358     parseIR(graph_string, &*graph);
359 
360     auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
361     auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
362     auto o = at::zeros({8, 4}, TensorOptions(kCPU).dtype(at::kFloat));
363     auto t = torch::chunk(a * b, 2, 1);
364     auto ref = t[0] * t[1];
365     TensorExprKernel k(graph);
366     std::vector<at::Tensor> inputs = {a, b};
367     StmtPtr s = k.getCodeGenStmt();
368 
369     std::ostringstream oss;
370     oss << *s;
371 
372     // Check the IR we produced
373     const std::string& verification_pattern =
374         R"IR(
375 # CHECK: for)IR";
376     torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
377 
378     std::vector<IValue> stack = fmap<IValue>(inputs);
379     k.run(stack);
380     o = stack[0].toTensor();
381     TORCH_CHECK_EQ(o.sizes()[0], 8);
382     TORCH_CHECK_EQ(o.sizes()[1], 4);
383     for (size_t i = 0; i < 8 * 4; i++) {
384       TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
385     }
386   }
387   {
388     // Test that shape inference handles aten::unsqueeze
389 
390     const auto graph_string = R"IR(
391       graph(%a : Float(4, 2, strides=[2, 1], device=cpu),
392             %b : Float(4, 3, 2, strides=[6, 2, 1], device=cpu),
393             %c : Float(3, 2, 2, strides=[4, 2, 1], device=cpu)):
394         %one : int = prim::Constant[value=1]()
395         %minus_one : int = prim::Constant[value=-1]()
396         %three : int = prim::Constant[value=3]()
397         %minus_four : int = prim::Constant[value=-4]()
398         %a1 : Tensor = aten::unsqueeze(%a, %one)        # new size: [4,1,2]
399         %a2 : Tensor = aten::unsqueeze(%a1, %minus_one) # new size: [4,1,2,1]
400         %b1 : Tensor = aten::unsqueeze(%b, %three)      # new size: [4,3,2,1]
401         %c1 : Tensor = aten::unsqueeze(%c, %minus_four) # new size: [1,3,2,2]
402         %ab : Tensor = aten::mul(%a2, %b1)         # expected size: [4,3,2,1]
403         %abc : Tensor = aten::mul(%ab, %c1)        # expected size: [4,3,2,2]
404         return (%abc))IR";
405     auto graph = std::make_shared<Graph>();
406     parseIR(graph_string, &*graph);
407 
408     auto a = at::rand({4, 2}, TensorOptions(kCPU).dtype(at::kFloat));
409     auto b = at::rand({4, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
410     auto c = at::rand({3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
411     auto o = at::zeros({4, 3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
412     auto ref = at::unsqueeze(at::unsqueeze(a, 1), -1) * at::unsqueeze(b, 3) *
413         at::unsqueeze(c, -4);
414 
415     TensorExprKernel k(graph);
416     std::vector<at::Tensor> inputs = {a, b, c};
417     StmtPtr s = k.getCodeGenStmt();
418 
419     std::ostringstream oss;
420     oss << *s;
421 
422     // Check the IR we produced
423     const std::string& verification_pattern =
424         R"IR(
425 # CHECK: for
426 # CHECK-NEXT: for
427 # CHECK-NEXT: for
428 # CHECK-NEXT: for
429 # CHECK-NEXT: aten_mul)IR";
430     torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
431 
432     std::vector<IValue> stack = fmap<IValue>(inputs);
433     k.run(stack);
434     o = stack[0].toTensor();
435 
436     // Check sizes
437     TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
438     size_t num_el = 1;
439     for (const auto idx : c10::irange(ref.sizes().size())) {
440       TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
441       num_el *= ref.sizes()[idx];
442     }
443 
444     // Check the contents
445     for (const auto i : c10::irange(num_el)) {
446       TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
447     }
448   }
449   {
450     // Test that shape inference handles aten::cat
451 
452     const auto graph_string = R"IR(
453       graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
454             %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu),
455             %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)):
456         %dim : int = prim::Constant[value=1]()
457         %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
458         %r : Tensor = aten::cat(%inputs, %dim)               # new size: [5,19,2]
459         return (%r))IR";
460     auto graph = std::make_shared<Graph>();
461     parseIR(graph_string, &*graph);
462 
463     auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
464     auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat));
465     auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat));
466     auto o = at::zeros({5, 19, 2}, TensorOptions(kCPU).dtype(at::kFloat));
467     auto ref = at::cat({a, b, c}, 1);
468 
469     TensorExprKernel k(graph);
470     std::vector<at::Tensor> inputs = {a, b, c};
471     StmtPtr s = k.getCodeGenStmt();
472 
473     std::ostringstream oss;
474     oss << *s;
475 
476     // Check the IR we produced
477     const std::string& verification_pattern =
478         R"IR(
479 # CHECK: for
480 # CHECK-NEXT: for
481 # CHECK-NEXT: for
482 # CHECK-NEXT: aten_cat)IR";
483     torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
484 
485     std::vector<IValue> stack = fmap<IValue>(inputs);
486     k.run(stack);
487     o = stack[0].toTensor();
488 
489     // Check sizes
490     TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
491     size_t num_el = 1;
492     for (const auto idx : c10::irange(ref.sizes().size())) {
493       TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
494       num_el *= ref.sizes()[idx];
495     }
496 
497     // Check the contents
498     for (const auto i : c10::irange(num_el)) {
499       TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
500     }
501   }
502   {
503     // Test that we throw an error when input list for aten::cat is empty
504 
505     const auto graph_string = R"IR(
506       graph():
507         %dim : int = prim::Constant[value=1]()
508         %inputs : Tensor[] = prim::ListConstruct()
509         %r : Tensor = aten::cat(%inputs, %dim)
510         return (%r))IR";
511     auto graph = std::make_shared<Graph>();
512     parseIR(graph_string, &*graph);
513     auto compile = [&]() {
514       TensorExprKernel k(graph);
515       k.getCodeGenStmt();
516     };
517     ASSERT_THROWS_WITH(compile(), "Empty input list is passed to aten::cat");
518   }
519   {
520     // Test that we throw an error when 'dim' passed to aten::cat is invalid
521 
522     const auto ir_dim_99 = R"IR(
523       graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
524             %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)):
525         %dim : int = prim::Constant[value=99]()
526         %inputs : Tensor[] = prim::ListConstruct(%a, %b)
527         %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim)
528         return (%r))IR";
529     const auto ir_dim_minus_6 = R"IR(
530       graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
531             %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)):
532         %dim : int = prim::Constant[value=-6]()
533         %inputs : Tensor[] = prim::ListConstruct(%a, %b)
534         %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim)
535         return (%r))IR";
536 
537     auto compile = [](const std::string& graph_string) {
538       auto graph = std::make_shared<Graph>();
539       parseIR(graph_string, &*graph);
540       TensorExprKernel k(graph);
541       k.getCodeGenStmt();
542     };
543     ASSERT_THROWS_WITH(compile(ir_dim_99), "Invalid index");
544     ASSERT_THROWS_WITH(compile(ir_dim_minus_6), "Invalid index");
545   }
546 }
547 
TEST_F(Kernel,CatInputTypesPromotion)548 TEST_F(Kernel, CatInputTypesPromotion) {
549   {
550     // Test that we properly promote input types for aten::cat
551 
552     const auto graph_string = R"IR(
553       graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
554             %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu),
555             %c : Double(5, 9, 2, strides=[18, 2, 1], device=cpu)):
556         %dim : int = prim::Constant[value=1]()
557         %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
558         %r : Double(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim)
559         return (%r))IR";
560     auto graph = std::make_shared<Graph>();
561     parseIR(graph_string, &*graph);
562 
563     auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
564     auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat));
565     auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kDouble));
566     auto ref = at::cat({a, b, c}, 1);
567 
568     TensorExprKernel k(graph);
569     std::vector<at::Tensor> inputs = {a, b, c};
570     StmtPtr s = k.getCodeGenStmt();
571 
572     std::ostringstream oss;
573     oss << *s;
574 
575     // Check the IR we produced
576     const std::string& verification_pattern =
577         R"IR(
578 # CHECK: for
579 # CHECK-NEXT: for
580 # CHECK-NEXT: for
581 # CHECK-NEXT: aten_cat)IR";
582     torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
583 
584     std::vector<IValue> stack = fmap<IValue>(inputs);
585     k.run(stack);
586     auto o = stack[0].toTensor();
587 
588     // Check sizes
589     TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
590     TORCH_CHECK_EQ(o.dtype(), ref.dtype());
591     size_t num_el = 1;
592     for (const auto idx : c10::irange(ref.sizes().size())) {
593       TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
594       num_el *= ref.sizes()[idx];
595     }
596 
597     // Check the contents
598     for (const auto i : c10::irange(num_el)) {
599       TORCH_CHECK_EQ(((double*)o.data_ptr())[i], ((double*)ref.data_ptr())[i]);
600     }
601   }
602 }
603 
TEST_F(Kernel,ToDType)604 TEST_F(Kernel, ToDType) {
605 #ifdef TORCH_ENABLE_LLVM
606   const auto graph_string = R"IR(
607       graph(%x.1 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
608         %1 : NoneType = prim::Constant()
609         %2 : bool = prim::Constant[value=0]()
610         %3 : int = prim::Constant[value=6]()
611         %4 : int = prim::Constant[value=15]()
612         %5 : int = prim::Constant[value=5]()
613         %6 : bool = prim::Constant[value=1]()
614         %y.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::sigmoid(%x.1)
615         %z.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_reduced_precision(%y.3, %6, %6, %5, %4)
616         %h.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_full_precision(%z.3, %6, %6)
617         %i.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%h.3, %3, %2, %2, %1)
618         %j.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%i.3, %4, %2, %2, %1)
619         %k.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%j.3, %3, %2, %2, %1)
620         return (%k.3))IR";
621 
622   auto graph = std::make_shared<Graph>();
623   parseIR(graph_string, &*graph);
624   TensorExprKernel k(graph);
625   StmtPtr s = k.getCodeGenStmt();
626   std::ostringstream oss;
627   oss << *s;
628 
629   const std::string& verification_pattern =
630       R"IR(
631 # CHECK: for
632 # CHECK-NEXT: for
633 # CHECK-NEXT: aten_to
634 # CHECK-NEXT: }
635 # CHECK-NEXT: })IR";
636   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
637 
638   auto a = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kBFloat16));
639   auto ref =
640       at::_to_copy(at::sigmoid(a), TensorOptions(kCPU).dtype(at::kFloat));
641 
642   std::vector<at::Tensor> inputs = {a};
643   std::vector<IValue> stack = fmap<IValue>(inputs);
644   k.run(stack);
645   auto o = stack[0].toTensor();
646   ASSERT_EQ(o.sizes(), ref.sizes());
647   ASSERT_EQ(o.dtype(), ref.dtype());
648   ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3));
649 #endif
650 }
651 
TEST_F(Kernel,CatAndInlineWithAConstantDim)652 TEST_F(Kernel, CatAndInlineWithAConstantDim) {
653   const auto graph_string = R"IR(
654       graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu),
655             %1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)):
656         %2 : bool = prim::Constant[value=0]()
657         %3 : int = prim::Constant[value=1]()
658         %4 : Tensor[] = prim::ListConstruct(%0, %1)
659         %5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3)
660         %6 : Tensor[] = prim::ListConstruct(%5)
661         %7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3)
662         %8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2)
663         return (%8, %7))IR";
664 
665   auto graph = std::make_shared<Graph>();
666   parseIR(graph_string, &*graph);
667   TensorExprKernel k(graph);
668 
669   auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
670   auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
671   auto ref = at::_cast_Float(at::cat({a, b}, 1), 0);
672 
673   std::vector<at::Tensor> inputs = {a, b};
674   std::vector<IValue> stack = fmap<IValue>(inputs);
675   k.run(stack);
676   auto o = stack[0].toTensor();
677   ASSERT_EQ(o.sizes(), ref.sizes());
678   ASSERT_EQ(o.dtype(), ref.dtype());
679   ASSERT_TRUE(at::allclose(o, ref));
680 }
681 
TEST_F(Kernel,CatWithEmptyInputs)682 TEST_F(Kernel, CatWithEmptyInputs) {
683   bool curr_cat_wo_conditionals = getCatWoConditionals();
684   for (auto cat_wo_conditionals : {true, false}) {
685     getCatWoConditionals() = cat_wo_conditionals;
686     const auto graph_string = R"IR(
687         graph(%0 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu),
688               %1 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu)):
689           %3 : int = prim::Constant[value=0]()
690           %6 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%0)
691           %7 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%1)
692           %10 : Tensor[] = prim::ListConstruct(%6, %7)
693           %11 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::cat(%10, %3)
694           return (%11))IR";
695 
696     auto graph = std::make_shared<Graph>();
697     parseIR(graph_string, &*graph);
698     TensorExprKernel k(graph);
699 
700     auto a = at::rand({0, 64}, TensorOptions(kCPU).dtype(at::kFloat));
701     auto b = at::rand({10, 64}, TensorOptions(kCPU).dtype(at::kFloat));
702     auto ref = at::cat({at::tanh(a), at::tanh(b)}, 0);
703 
704     std::vector<at::Tensor> inputs = {a, b};
705     std::vector<IValue> stack = fmap<IValue>(inputs);
706     k.run(stack);
707     auto o = stack[0].toTensor();
708     ASSERT_EQ(o.sizes(), ref.sizes());
709     ASSERT_EQ(o.dtype(), ref.dtype());
710     ASSERT_TRUE(at::allclose(o, ref));
711   }
712   getCatWoConditionals() = curr_cat_wo_conditionals;
713 }
714 
TEST_F(Kernel,CatWoConditionals)715 TEST_F(Kernel, CatWoConditionals) {
716   bool old_cat_wo_conditionals = getCatWoConditionals();
717   getCatWoConditionals() = true;
718   const auto graph_string = R"IR(
719       graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
720             %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu),
721             %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)):
722         %dim : int = prim::Constant[value=1]()
723         %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
724         %r : Float(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim)
725         return (%r))IR";
726 
727   auto graph = std::make_shared<Graph>();
728   parseIR(graph_string, &*graph);
729 
730   TensorExprKernel k(graph);
731   StmtPtr s = k.getCodeGenStmt();
732   std::ostringstream oss;
733   oss << *s;
734 
735   const std::string& verification_pattern =
736       R"IR(
737 # CHECK: for
738 # CHECK: for
739 # CHECK: for
740 # CHECK: aten_cat
741 # CHECK: for
742 # CHECK: for
743 # CHECK: aten_cat
744 # CHECK: for
745 # CHECK: for
746 # CHECK: aten_cat)IR";
747   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
748 
749   auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
750   auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat));
751   auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat));
752   auto ref = at::cat({a, b, c}, 1);
753 
754   std::vector<at::Tensor> inputs = {a, b, c};
755   std::vector<IValue> stack = fmap<IValue>(inputs);
756   k.run(stack);
757   auto o = stack[0].toTensor();
758 
759   // Check sizes
760   TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
761   TORCH_CHECK_EQ(o.dtype(), ref.dtype());
762   size_t num_el = 1;
763   for (const auto idx : c10::irange(ref.sizes().size())) {
764     TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
765     num_el *= ref.sizes()[idx];
766   }
767 
768   // Check the contents
769   for (const auto i : c10::irange(num_el)) {
770     TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
771   }
772   getCatWoConditionals() = old_cat_wo_conditionals;
773 }
774 
TEST_F(Kernel,OptimizeConditionals)775 TEST_F(Kernel, OptimizeConditionals) {
776   bool old_cat_wo_conditionals = getCatWoConditionals();
777   bool old_opt_conditionals = getOptConditionals();
778   getCatWoConditionals() = false;
779   getOptConditionals() = true;
780   const auto graph_string = R"IR(
781       graph(%a : Float(5, 3, strides=[3, 1], device=cpu),
782             %b : Float(5, 7, strides=[7, 1], device=cpu),
783             %c : Float(5, 9, strides=[9, 1], device=cpu)):
784         %dim : int = prim::Constant[value=1]()
785         %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
786         %r : Float(5, 19, strides=[19, 1]) = aten::cat(%inputs, %dim)
787         %t : Float(5, 19, strides=[19, 1]) = aten::relu(%r)
788         return (%t))IR";
789 
790   auto graph = std::make_shared<Graph>();
791   parseIR(graph_string, &*graph);
792 
793   TensorExprKernel k(graph);
794   StmtPtr s = k.getCodeGenStmt();
795   std::ostringstream oss;
796   oss << *s;
797 
798   const std::string& verification_pattern =
799       R"IR(
800 # CHECK: for
801 # CHECK-NEXT: for
802 # CHECK-NEXT: aten_relu
803 # CHECK: for
804 # CHECK-NEXT: aten_relu
805 # CHECK: for
806 # CHECK-NEXT: aten_relu
807 # CHECK-NOT: Allocate
808 # CHECK-NOT: Free)IR";
809   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
810 
811   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
812   auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
813   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
814   auto b = at::rand({5, 7}, TensorOptions(kCPU).dtype(at::kFloat));
815   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
816   auto c = at::rand({5, 9}, TensorOptions(kCPU).dtype(at::kFloat));
817   auto ref = at::relu(at::cat({a, b, c}, 1));
818 
819   std::vector<at::Tensor> inputs = {a, b, c};
820   std::vector<IValue> stack = fmap<IValue>(inputs);
821   k.run(stack);
822   auto o = stack[0].toTensor();
823 
824   // Check sizes
825   TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
826   TORCH_CHECK_EQ(o.dtype(), ref.dtype());
827   size_t num_el = 1;
828   for (const auto idx : c10::irange(ref.sizes().size())) {
829     TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
830     num_el *= ref.sizes()[idx];
831   }
832 
833   // Check the contents
834   for (const auto i : c10::irange(num_el)) {
835     TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
836   }
837   getOptConditionals() = old_opt_conditionals;
838   getCatWoConditionals() = old_cat_wo_conditionals;
839 }
840 
841 namespace {
842 
dtypeConstant(ScalarType scalar_type)843 std::string dtypeConstant(ScalarType scalar_type) {
844   if (scalar_type == ScalarType::Undefined) {
845     return "None = prim::Constant()";
846   } else {
847     at::jit::TemplateEnv env_dtype;
848     env_dtype.d("scalar_type", static_cast<int>(scalar_type));
849     return format("int = prim::Constant[value=${scalar_type}]()", env_dtype);
850   }
851 }
852 
iotaTensor(IntArrayRef sizes,const at::TensorOptions & options)853 at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) {
854   int64_t numel = std::accumulate(
855       sizes.begin(),
856       sizes.end(),
857       1,
858       // NOLINTNEXTLINE(modernize-use-transparent-functors)
859       std::multiplies<int64_t>());
860   std::vector<float> values(numel);
861   std::iota(values.begin(), values.end(), 0);
862   auto a = at::tensor(values, options);
863   return a.reshape(sizes);
864 }
865 
866 } // namespace
867 
TEST_F(Kernel,SumAllAxes)868 TEST_F(Kernel, SumAllAxes) {
869   // Test lowering of sum on all axes.
870   const auto graph_template = R"IR(
871       graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
872         %1 : ${dtype}
873         %2 : ${out_dtype}(requires_grad=0, device=cpu) = aten::sum(%0, %1)
874         return (%2))IR";
875   auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
876 
877   for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) {
878     at::jit::TemplateEnv env;
879     env.s("dtype", dtypeConstant(scalar_type));
880     if (scalar_type == ScalarType::Undefined) {
881       env.s("out_dtype", "Float");
882     } else {
883       env.s("out_dtype", "Double");
884     }
885     const auto graph_string = format(graph_template, env);
886 
887     auto graph = std::make_shared<Graph>();
888     parseIR(graph_string, &*graph);
889 
890     auto o = at::empty({}, TensorOptions(kCPU));
891     std::optional<c10::ScalarType> dtype;
892     if (scalar_type != ScalarType::Undefined) {
893       dtype = static_cast<c10::ScalarType>(scalar_type);
894     }
895     auto ref = a.sum(/*dtype=*/dtype);
896     TensorExprKernel k(graph);
897     std::vector<at::Tensor> inputs = {a};
898     StmtPtr s = k.getCodeGenStmt();
899 
900     std::ostringstream oss;
901     oss << *s;
902 
903     // Check the IR we produced
904     const std::string& verification_pattern =
905         R"IR(
906 # CHECK: for
907 # CHECK-NEXT: for)IR";
908     torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
909 
910     std::vector<IValue> stack = fmap<IValue>(inputs);
911     k.run(stack);
912     o = stack[0].toTensor();
913     ASSERT_EQ(o.sizes(), ref.sizes());
914     ASSERT_EQ(o.dtype(), ref.dtype());
915     ASSERT_TRUE(at::allclose(o, ref));
916   }
917 }
918 
li_to_str(at::ArrayRef<int64_t> li)919 std::string li_to_str(at::ArrayRef<int64_t> li) {
920   std::stringstream out;
921   bool first = true;
922   for (auto elem : li) {
923     if (!first) {
924       out << ", ";
925     }
926     out << elem;
927     first = false;
928   }
929   return out.str();
930 }
931 
TEST_F(Kernel,SumOneAxis)932 TEST_F(Kernel, SumOneAxis) {
933   // Test lowering of sum on one axis.
934   const auto graph_template = R"IR(
935       graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
936         %1 : int[] = prim::Constant[value=[${dim}]]()
937         %2 : bool = prim::Constant[value=${keepdim}]()
938         %3 : ${dtype}
939         %4 : ${out_dtype}(${size}, strides=[${strides}], device=cpu) = aten::sum(%0, %1, %2, %3)
940         return (%4))IR";
941   auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
942 
943   for (int dim = -a.dim(); dim < a.dim(); ++dim) {
944     for (bool keepdim : {false, true}) {
945       for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) {
946         at::jit::TemplateEnv env;
947         env.d("dim", dim);
948         env.d("keepdim", keepdim);
949         env.s("dtype", dtypeConstant(scalar_type));
950         std::optional<c10::ScalarType> dtype;
951         if (scalar_type != ScalarType::Undefined) {
952           dtype = static_cast<c10::ScalarType>(scalar_type);
953         }
954         auto ref = a.sum({dim}, /*keepdim=*/keepdim, /*dtype=*/dtype);
955         if (scalar_type == ScalarType::Undefined) {
956           env.s("out_dtype", "Float");
957         } else {
958           env.s("out_dtype", "Double");
959         }
960         env.s("size", li_to_str(ref.sizes()));
961         env.s("strides", li_to_str(ref.strides()));
962         const auto graph_string = format(graph_template, env);
963         auto graph = std::make_shared<Graph>();
964         parseIR(graph_string, &*graph);
965 
966         auto o = at::empty({}, TensorOptions(kCPU));
967         TensorExprKernel k(graph);
968         std::vector<at::Tensor> inputs = {a};
969         StmtPtr s = k.getCodeGenStmt();
970 
971         std::ostringstream oss;
972         oss << *s;
973 
974         // Check the IR we produced
975         const std::string& verification_pattern =
976             R"IR(
977 # CHECK: for (int64_t
978 # CHECK-NEXT: sum
979 # CHECK-NEXT: for (int64_t
980 # CHECK-NEXT:   sum)IR";
981         torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
982 
983         std::vector<IValue> stack = fmap<IValue>(inputs);
984         k.run(stack);
985         o = stack[0].toTensor();
986         ASSERT_EQ(o.sizes(), ref.sizes());
987         ASSERT_EQ(o.dtype(), ref.dtype());
988         ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3));
989       }
990     }
991   }
992 }
993 
TEST_F(Kernel,SumMultipleAxes)994 TEST_F(Kernel, SumMultipleAxes) {
995   // Test lowering of sum on multiple axes.
996   const auto graph_template = R"IR(
997       graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], requires_grad=0, device=cpu)):
998         %1 : int = prim::Constant[value=${dim1}]()
999         %2 : int = prim::Constant[value=${dim2}]()
1000         %3 : int[] = prim::ListConstruct(%1, %2)
1001         %4 : bool = prim::Constant[value=${keepdim}]()
1002         %5 : ${dtype}
1003         %6 : Float(${size}, strides=[${strides}], requires_grad=0, device=cpu) = aten::sum(%0, %3, %4, %5)
1004         return (%6))IR";
1005   auto a = iotaTensor({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1006 
1007   // Only iterate over positive values of axes to keep the running time
1008   // reasonable, since the number of pairs is quadratic.
1009   for (const auto dim1 : c10::irange(a.dim())) {
1010     for (int dim2 = dim1 + 1; dim2 < a.dim(); ++dim2) {
1011       for (bool keepdim : {false, true}) {
1012         at::jit::TemplateEnv env;
1013         env.d("dim1", dim1);
1014         env.d("dim2", dim2);
1015         env.d("keepdim", keepdim);
1016         env.s("dtype", dtypeConstant(ScalarType::Undefined));
1017         auto o = at::empty({}, TensorOptions(kCPU));
1018         auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim);
1019 
1020         env.s("size", li_to_str(ref.sizes()));
1021         env.s("strides", li_to_str(ref.strides()));
1022 
1023         const auto graph_string = format(graph_template, env);
1024 
1025         auto graph = std::make_shared<Graph>();
1026         parseIR(graph_string, &*graph);
1027 
1028         TensorExprKernel k(graph);
1029         std::vector<at::Tensor> inputs = {a};
1030         StmtPtr s = k.getCodeGenStmt();
1031 
1032         std::ostringstream oss;
1033         oss << *s;
1034 
1035         // Check the IR we produced
1036         const std::string& verification_pattern =
1037             R"IR(
1038 # CHECK: for (int64_t
1039 # CHECK: for (int64_t
1040 # CHECK: for (int64_t
1041 # CHECK: for (int64_t
1042 # CHECK: sum)IR";
1043         torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1044 
1045         std::vector<IValue> stack = fmap<IValue>(inputs);
1046         k.run(stack);
1047         o = stack[0].toTensor();
1048         ASSERT_EQ(o.sizes(), ref.sizes());
1049         ASSERT_EQ(o.dtype(), ref.dtype());
1050         ASSERT_TRUE(at::allclose(o, ref));
1051       }
1052     }
1053   }
1054 }
1055 
1056 // This test and the following ones testing Softmax only tests with dim set
1057 // to one of the valid input dimensions. It does not test with dim=None
1058 // because that is supposed to be deprecated.
TEST_F(Kernel,Softmax2D)1059 TEST_F(Kernel, Softmax2D) {
1060   const auto graph_template = R"IR(
1061       graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
1062         %1 : int = prim::Constant[value=${dim}]()
1063         %dt_float : int = prim::Constant[value=7]()
1064         %dt_none : NoneType = prim::Constant()
1065         %4 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %${dt})
1066         return (%4))IR";
1067 
1068   auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1069 
1070   const std::string& verification_template =
1071       R"IR(
1072         # CHECK: for (int i${other_dim} = 0; i${other_dim} < ${other_dim_size}
1073         # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size}
1074         # CHECK-NEXT: aten_softmax_max
1075         # CHECK: for (int i${other_dim}_1 = 0; i${other_dim}_1 < ${other_dim_size}
1076         # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size}
1077         # CHECK-NEXT: aten_softmax_sum
1078         # CHECK: for (int i0_2 = 0; i0_2 < 5
1079         # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3
1080         # CHECK-NEXT: aten_softmax)IR";
1081 
1082   for (bool empty_dtype : {false, true}) {
1083     for (auto log_softmax : {false, true}) {
1084       for (const auto softmax_dim : c10::irange(a.dim())) {
1085         auto softmax_dim_size = a.sizes()[softmax_dim];
1086         auto other_dim = (softmax_dim + 1) % a.dim();
1087         auto ref =
1088             log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
1089         at::jit::TemplateEnv env;
1090         env.d("dim", softmax_dim);
1091         env.s("op", log_softmax ? "log_softmax" : "softmax");
1092         env.s("size", li_to_str(ref.sizes()));
1093         env.s("strides", li_to_str(ref.strides()));
1094         env.s("dt", empty_dtype ? "dt_none" : "dt_float");
1095 
1096         const auto graph_string = format(graph_template, env);
1097 
1098         auto graph = std::make_shared<Graph>();
1099         parseIR(graph_string, &*graph);
1100 
1101         TensorExprKernel k(graph);
1102         std::vector<at::Tensor> inputs = {a};
1103         StmtPtr s = k.getCodeGenStmt();
1104 
1105         std::ostringstream oss;
1106         oss << *s;
1107 
1108         at::jit::TemplateEnv ver_env;
1109         ver_env.d("other_dim", other_dim);
1110         ver_env.d("other_dim_size", a.sizes()[other_dim]);
1111         ver_env.d("softmax_dim", softmax_dim);
1112         ver_env.d("softmax_dim_size", softmax_dim_size);
1113         const auto verification_pattern =
1114             format(verification_template, ver_env);
1115 
1116         // verification sting temporarily disabled until
1117         // inlining of exp() is benchmarked and determined
1118         // torch::jit::testing::FileCheck().run(verification_pattern,
1119         // oss.str());
1120 
1121         std::vector<IValue> stack = fmap<IValue>(inputs);
1122         k.run(stack);
1123         auto output = stack[0].toTensor();
1124         ASSERT_EQ(output.sizes(), ref.sizes());
1125         ASSERT_TRUE(at::allclose(output, ref));
1126       }
1127     }
1128   }
1129 }
1130 
TEST_F(Kernel,Softmax3D)1131 TEST_F(Kernel, Softmax3D) {
1132   const auto graph_template = R"IR(
1133       graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)):
1134         %1 : int = prim::Constant[value=${dim}]()
1135         %2 : int = prim::Constant[value=7]()
1136         %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2)
1137         return (%3))IR";
1138 
1139   auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat));
1140 
1141   const std::string& verification_template =
1142       R"IR(
1143         # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size}
1144         # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size}
1145         # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size}
1146         # CHECK-NEXT: aten_softmax_max
1147         # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size}
1148         # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size}
1149         # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size}
1150         # CHECK-NEXT: aten_softmax_sum
1151         # CHECK: for (int i0_2 = 0; i0_2 < 3
1152         # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 4
1153         # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5
1154         # CHECK-NEXT: aten_softmax)IR";
1155 
1156   for (auto log_softmax : {false, true}) {
1157     for (const auto softmax_dim : c10::irange(a.dim())) {
1158       auto softmax_dim_size = a.sizes()[softmax_dim];
1159       std::vector<int> other_dims;
1160       for (const auto i : c10::irange(a.dim())) {
1161         if (i != softmax_dim) {
1162           other_dims.push_back(i);
1163         }
1164       }
1165       auto ref =
1166           log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
1167 
1168       at::jit::TemplateEnv env;
1169       env.d("dim", softmax_dim);
1170       env.s("op", log_softmax ? "log_softmax" : "softmax");
1171       env.s("size", li_to_str(ref.sizes()));
1172       env.s("strides", li_to_str(ref.strides()));
1173 
1174       const auto graph_string = format(graph_template, env);
1175 
1176       auto graph = std::make_shared<Graph>();
1177       parseIR(graph_string, &*graph);
1178 
1179       TensorExprKernel k(graph);
1180       std::vector<at::Tensor> inputs = {a};
1181       StmtPtr s = k.getCodeGenStmt();
1182 
1183       std::ostringstream oss;
1184       oss << *s;
1185 
1186       at::jit::TemplateEnv ver_env;
1187       ver_env.d("dim1", other_dims[0]);
1188       ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
1189       ver_env.d("dim2", other_dims[1]);
1190       ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
1191       ver_env.d("softmax_dim", softmax_dim);
1192       ver_env.d("softmax_dim_size", softmax_dim_size);
1193       const auto verification_pattern = format(verification_template, ver_env);
1194 
1195       // verification sting temporarily disabled until
1196       // inlining of exp() is benchmarked and determined
1197       // torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1198 
1199       std::vector<IValue> stack = fmap<IValue>(inputs);
1200       k.run(stack);
1201       auto output = stack[0].toTensor();
1202 
1203       ASSERT_EQ(output.sizes(), ref.sizes());
1204       ASSERT_TRUE(at::allclose(output, ref));
1205     }
1206   }
1207 }
1208 
TEST_F(Kernel,Softmax4D)1209 TEST_F(Kernel, Softmax4D) {
1210   const auto graph_template = R"IR(
1211       graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)):
1212         %1 : int = prim::Constant[value=${dim}]()
1213         %2 : int = prim::Constant[value=7]()
1214         %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2)
1215         return (%3))IR";
1216 
1217   auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1218 
1219   const std::string& verification_template =
1220       R"IR(
1221         # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size}
1222         # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size}
1223         # CHECK-NEXT: for (int i${dim3} = 0; i${dim3} < ${dim3_size}
1224         # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size}
1225         # CHECK-NEXT: aten_softmax_max
1226         # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size}
1227         # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size}
1228         # CHECK-NEXT: for (int i${dim3}_1 = 0; i${dim3}_1 < ${dim3_size}
1229         # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size}
1230         # CHECK-NEXT: aten_softmax_sum
1231         # CHECK: for (int i0_2 = 0; i0_2 < 2
1232         # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3
1233         # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 2
1234         # CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3
1235         # CHECK-NEXT: aten_softmax)IR";
1236 
1237   for (auto log_softmax : {false, true}) {
1238     for (const auto softmax_dim : c10::irange(a.dim())) {
1239       auto softmax_dim_size = a.sizes()[softmax_dim];
1240       std::vector<int> other_dims;
1241       for (const auto i : c10::irange(a.dim())) {
1242         if (i != softmax_dim) {
1243           other_dims.push_back(i);
1244         }
1245       }
1246       auto ref =
1247           log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
1248 
1249       at::jit::TemplateEnv env;
1250       env.d("dim", softmax_dim);
1251       env.s("op", log_softmax ? "log_softmax" : "softmax");
1252       env.s("size", li_to_str(ref.sizes()));
1253       env.s("strides", li_to_str(ref.strides()));
1254 
1255       const auto graph_string = format(graph_template, env);
1256 
1257       auto graph = std::make_shared<Graph>();
1258       parseIR(graph_string, &*graph);
1259 
1260       TensorExprKernel k(graph);
1261       std::vector<at::Tensor> inputs = {a};
1262       StmtPtr s = k.getCodeGenStmt();
1263 
1264       std::ostringstream oss;
1265       oss << *s;
1266 
1267       at::jit::TemplateEnv ver_env;
1268       ver_env.d("dim1", other_dims[0]);
1269       ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
1270       ver_env.d("dim2", other_dims[1]);
1271       ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
1272       ver_env.d("dim3", other_dims[2]);
1273       ver_env.d("dim3_size", a.sizes()[other_dims[2]]);
1274       ver_env.d("softmax_dim", softmax_dim);
1275       ver_env.d("softmax_dim_size", softmax_dim_size);
1276       const auto verification_pattern = format(verification_template, ver_env);
1277 
1278       // verification sting temporarily disabled until
1279       // inlining of exp() is benchmarked and determined
1280       // torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1281 
1282       std::vector<IValue> stack = fmap<IValue>(inputs);
1283       k.run(stack);
1284       auto output = stack[0].toTensor();
1285       ASSERT_EQ(output.sizes(), ref.sizes());
1286       ASSERT_TRUE(at::allclose(output, ref));
1287     }
1288   }
1289 }
1290 
TEST_F(Kernel,SignTest)1291 TEST_F(Kernel, SignTest) {
1292   const auto graph_template = R"IR(
1293       graph(%0 : ${dtype}(${size}, strides=[1], device=cpu)):
1294         %2 : ${dtype}(${size}, strides=[1]) = aten::sign(%0)
1295         return (%2))IR";
1296 
1297   auto run_test = [](const std::string& graph_string, const at::Tensor& input) {
1298     auto graph = std::make_shared<Graph>();
1299     parseIR(graph_string, &*graph);
1300 
1301     TensorExprKernel k(graph);
1302     StmtPtr s = k.getCodeGenStmt();
1303 
1304     std::vector<at::Tensor> inputs = {input};
1305     std::vector<IValue> stack = fmap<IValue>(inputs);
1306     k.run(stack);
1307     auto o = stack[0].toTensor();
1308     auto ref = at::sign(input);
1309     ASSERT_TRUE(at::allclose(o, ref));
1310   };
1311   auto common_options = at::TensorOptions()
1312                             .layout(at::kStrided)
1313                             .device(at::kCPU)
1314                             .requires_grad(false);
1315   int default_input_size = 100;
1316   for (auto scalar_type : {ScalarType::Float, ScalarType::Double}) {
1317     at::Tensor corner_case_inputs;
1318     at::jit::TemplateEnv env;
1319     auto options = common_options;
1320     switch (scalar_type) {
1321       case ScalarType::Float: {
1322         env.s("dtype", "Float");
1323         options = options.dtype(at::kFloat);
1324         std::vector<float> input_float = {
1325             0.0f,
1326             -0.0f,
1327             std::numeric_limits<float>::infinity(),
1328             -std::numeric_limits<float>::infinity(),
1329             std::nanf("1"),
1330             -std::nanf("1")};
1331         corner_case_inputs = at::from_blob(
1332             input_float.data(),
1333             {static_cast<long>(input_float.size())},
1334             options);
1335         auto rand_input = at::rand({default_input_size}, options);
1336         auto input = at::cat({rand_input, corner_case_inputs});
1337         env.d("size", at::numel(input));
1338         const auto graph_string = format(graph_template, env);
1339         run_test(graph_string, input);
1340         break;
1341       }
1342       case ScalarType::Double: {
1343         env.s("dtype", "Double");
1344         options = options.dtype(at::kDouble);
1345         std::vector<double> input_double = {
1346             0.0,
1347             -0.0,
1348             std::numeric_limits<double>::infinity(),
1349             -std::numeric_limits<double>::infinity(),
1350             std::nan("1"),
1351             -std::nan("1")};
1352         corner_case_inputs = at::from_blob(
1353             input_double.data(),
1354             {static_cast<long>(input_double.size())},
1355             options);
1356         auto rand_input = at::rand({default_input_size}, options);
1357         auto input = at::cat({rand_input, corner_case_inputs});
1358         env.d("size", at::numel(input));
1359         const auto graph_string = format(graph_template, env);
1360         run_test(graph_string, input);
1361         break;
1362       }
1363       default:
1364         throw unsupported_dtype();
1365     }
1366   }
1367 }
1368 
TEST_F(Kernel,InlineProducerIntoReduction)1369 TEST_F(Kernel, InlineProducerIntoReduction) {
1370   // Inline producer (mul) into reduction (sum).
1371   const auto graph_string = R"IR(
1372       graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1373             %1 : Float(5, 3, strides=[3, 1], device=cpu)):
1374         %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1)
1375         %3 : int = prim::Constant[value=7]()
1376         %4 : Double(device=cpu) = aten::sum(%2, %3)
1377         return (%4))IR";
1378   auto graph = std::make_shared<Graph>();
1379   parseIR(graph_string, &*graph);
1380 
1381   TensorExprKernel k(graph);
1382   StmtPtr s = k.getCodeGenStmt();
1383   std::ostringstream oss;
1384   oss << *s;
1385 
1386   // Check the IR we produced.
1387   // We should have only one loop in the end.
1388   const std::string& verification_pattern =
1389       R"IR(
1390         # CHECK: for (int64_t i_1 = 0ll; i_1 < 5
1391         # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3
1392         # CHECK-NEXT:   sum
1393         # CHECK-NOT: for)IR";
1394   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1395 
1396   auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1397   auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1398   std::vector<at::Tensor> inputs = {a, b};
1399   std::vector<IValue> stack = fmap<IValue>(inputs);
1400   k.run(stack);
1401   auto o = stack[0].toTensor();
1402   auto ref = (a * b).sum(at::kDouble);
1403   ASSERT_TRUE(at::allclose(o, ref));
1404 }
1405 
TEST_F(Kernel,InlineReductionIntoConsumer)1406 TEST_F(Kernel, InlineReductionIntoConsumer) {
1407   // Inline producer (mul %2) into reduction (sum %4) but DO NOT
1408   // inline the reduction into consumer (mul %4).
1409   const auto graph_string = R"IR(
1410       graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1411             %1 : Float(5, 3, strides=[3, 1], device=cpu)):
1412         %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1413         %3 : int = prim::Constant[value=6]()
1414         %4 : Float(device=cpu) = aten::sum(%2, %3)
1415         %5 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%2, %4)
1416         return (%5))IR";
1417   auto graph = std::make_shared<Graph>();
1418   parseIR(graph_string, &*graph);
1419 
1420   TensorExprKernel k(graph);
1421   StmtPtr s = k.getCodeGenStmt();
1422   std::ostringstream oss;
1423   oss << *s;
1424 
1425   // Check the IR we produced.
1426   // We should have two loops in the end.
1427   const std::string& verification_pattern =
1428       R"IR(
1429         # CHECK: for (int64_t i_1 = 0ll; i_1 < 5
1430         # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3
1431         # CHECK-NEXT:   sum
1432         # CHECK: for (int64_t i_2 = 0ll; i_2 < 5
1433         # CHECK-NEXT: for (int64_t j_2 = 0ll; j_2 < 3
1434         # CHECK-NEXT:   aten_mul
1435         # CHECK-NOT: for)IR";
1436   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1437 
1438   auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1439   auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1440   std::vector<at::Tensor> inputs = {a, b};
1441   std::vector<IValue> stack = fmap<IValue>(inputs);
1442   k.run(stack);
1443   auto o = stack[0].toTensor();
1444   auto ref = (a * b).sum(at::kFloat) * (a * b);
1445   ASSERT_TRUE(at::allclose(o, ref));
1446 }
1447 
TEST_F(Kernel,SanitizeNames_CUDA)1448 TEST_F(Kernel, SanitizeNames_CUDA) {
1449   const auto graph_string = R"IR(
1450       graph(%0 : Float(5, 3, strides=[3, 1], device=cuda:0),
1451             %1 : Float(5, 3, strides=[3, 1], device=cuda:0)):
1452         %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1453         %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
1454         return (%4))IR";
1455   auto graph = std::make_shared<Graph>();
1456   parseIR(graph_string, &*graph);
1457   graph->inputs().at(0)->setDebugName("aten::add:");
1458   graph->inputs().at(1)->setDebugName("aten::add_");
1459   TensorExprKernel k(graph);
1460   auto a = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat));
1461   auto b = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat));
1462   auto ref = a * (a * b);
1463   std::vector<at::Tensor> inputs = {a, b};
1464   std::vector<IValue> stack = fmap<IValue>(inputs);
1465   k.run(stack);
1466   auto o = stack[0].toTensor();
1467   ASSERT_TRUE(at::allclose(o, ref));
1468 }
1469 
TEST_F(Kernel,SanitizeConstants_CUDA)1470 TEST_F(Kernel, SanitizeConstants_CUDA) {
1471   const auto graph_string = R"IR(
1472         graph(%x : Float(16, 16, strides=[16, 1], device=cuda:0)):
1473           %none : NoneType = prim::Constant()
1474           %size : int = prim::Constant[value=16]()
1475           %sizes : int[] = prim::ListConstruct(%size, %size)
1476           %30 : Device = prim::Constant[value="cuda"]()
1477           %y : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::ones(%sizes, %none, %none, %30, %none)
1478           %z : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::mul(%x, %y)
1479           return (%z))IR";
1480   auto graph = std::make_shared<Graph>();
1481   parseIR(graph_string, &*graph);
1482   // IRParser doesn't support tensor constants, so we insert a call to
1483   // aten::ones and then const-prop it
1484   ConstantPropagation(graph);
1485 
1486   // We set the name of the constant to include special characters that are
1487   // not allowed. This should be fixed by the sanitizer in TensorExprKernel.
1488   graph->nodes().front()->output()->setDebugName("illegal.name");
1489 
1490   // Check if we have a constant node with illegal name in the graph.
1491   auto const_node = graph->nodes().front();
1492   ASSERT_EQ(const_node->kind(), prim::Constant);
1493   ASSERT_NE(const_node->output()->debugName().find('.'), std::string::npos);
1494 
1495   TensorExprKernel k(graph);
1496 
1497   auto x = at::rand({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat));
1498   std::vector<at::Tensor> inputs = {x};
1499   std::vector<IValue> stack = fmap<IValue>(inputs);
1500   k.run(stack);
1501   auto o = stack[0].toTensor();
1502   auto y = at::ones({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat));
1503   auto ref = x * y;
1504   ASSERT_TRUE(at::allclose(o, ref));
1505 }
1506 
TEST_F(Kernel,ConstantTensors)1507 TEST_F(Kernel, ConstantTensors) {
1508   const auto graph_string = R"IR(
1509         graph(%x : Float(16, 16, strides=[16, 1], device=cpu)):
1510           %none : NoneType = prim::Constant()
1511           %size : int = prim::Constant[value=16]()
1512           %sizes : int[] = prim::ListConstruct(%size, %size)
1513           %y : Float(16, 16, strides=[16, 1], device=cpu) = aten::ones(%sizes, %none, %none, %none, %none)
1514           %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y)
1515           return (%z))IR";
1516   auto graph = std::make_shared<Graph>();
1517   parseIR(graph_string, &*graph);
1518   // IRParser doesn't support tensor constants, so we insert a call to
1519   // aten::ones and then const-prop it
1520   ConstantPropagation(graph);
1521 
1522   TensorExprKernel k(graph);
1523 
1524   auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1525   std::vector<at::Tensor> inputs = {x};
1526   std::vector<IValue> stack = fmap<IValue>(inputs);
1527   k.run(stack);
1528   auto o = stack[0].toTensor();
1529   auto y = at::ones({16, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1530   auto ref = x * y;
1531   ASSERT_TRUE(at::allclose(o, ref));
1532 }
1533 
TEST_F(Kernel,ConstantTensorsNonContiguous)1534 TEST_F(Kernel, ConstantTensorsNonContiguous) {
1535   const auto graph_string = R"IR(
1536         graph(%x : Float(16, 16, strides=[16, 1], device=cpu)):
1537           %none : NoneType = prim::Constant()
1538           %dtype : int = prim::Constant[value=6]()
1539           %c0 : int = prim::Constant[value=0]()
1540           %c256 : int = prim::Constant[value=256]()
1541           %c16 : int = prim::Constant[value=16]()
1542           %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none)
1543           %sizes : int[] = prim::ListConstruct(%c16, %c16)
1544           %y_t : Tensor = aten::view(%y_flat, %sizes)
1545           %y : Tensor = aten::t(%y_t)
1546           %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y)
1547           return (%z))IR";
1548   auto graph = std::make_shared<Graph>();
1549   parseIR(graph_string, &*graph);
1550   // IRParser doesn't support tensor constants, so we generate several aten
1551   // calls to produce non-contiguous constant tensor and then const-prop it
1552   ConstantPropagation(graph);
1553 
1554   TensorExprKernel k(graph);
1555 
1556   auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1557   std::vector<at::Tensor> inputs = {x};
1558   std::vector<IValue> stack = fmap<IValue>(inputs);
1559   k.run(stack);
1560   auto o = stack[0].toTensor();
1561   auto y = at::arange(0, 256, TensorOptions(kCPU).dtype(at::kFloat))
1562                .view({16, 16})
1563                .t();
1564   auto ref = x * y;
1565   ASSERT_TRUE(at::allclose(o, ref));
1566 }
1567 
TEST_F(Kernel,RunFast)1568 TEST_F(Kernel, RunFast) {
1569 #ifdef TORCH_ENABLE_LLVM
1570   // TODO: Implement call_raw in IREval and remove the ifdef
1571 
1572   const auto graph_string = R"IR(
1573       graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1574             %1 : Float(5, 3, strides=[1, 5], device=cpu)):
1575         %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1576         %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
1577         return (%3))IR";
1578   auto graph = std::make_shared<Graph>();
1579   parseIR(graph_string, &*graph);
1580 
1581   auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1582   auto b =
1583       at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
1584   auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1585   auto ref = a * (a * b);
1586   TensorExprKernel k(graph);
1587 
1588   k.runFast({a.data_ptr(), b.data_ptr()}, {o.data_ptr()});
1589   for (size_t i = 0; i < 5 * 3; i++) {
1590     TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1591   }
1592 #endif
1593 }
1594 
TEST_F(Kernel,RunWithAllocatedOutputs)1595 TEST_F(Kernel, RunWithAllocatedOutputs) {
1596 #ifdef TORCH_ENABLE_LLVM
1597   const auto graph_string = R"IR(
1598       graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1599             %1 : Float(5, 3, strides=[1, 5], device=cpu)):
1600         %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1601         %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
1602         return (%3))IR";
1603   auto graph = std::make_shared<Graph>();
1604   parseIR(graph_string, &*graph);
1605 
1606   auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1607   auto b =
1608       at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
1609   auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1610   auto ref = a * (a * b);
1611   TensorExprKernel k(graph);
1612 
1613   std::vector<at::Tensor> args = {o, a, b};
1614   std::vector<IValue> stack = fmap<IValue>(args);
1615   k.runWithAllocatedOutputs(stack);
1616   for (size_t i = 0; i < 5 * 3; i++) {
1617     TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1618   }
1619 #endif
1620 }
1621 
TEST_F(Kernel,CodegenInspection)1622 TEST_F(Kernel, CodegenInspection) {
1623 #ifdef TORCH_ENABLE_LLVM
1624   const auto graph_string = R"IR(
1625         graph(%x : Float(16, 16, strides=[16, 1], device=cpu)):
1626           %none : NoneType = prim::Constant()
1627           %dtype : int = prim::Constant[value=6]()
1628           %c0 : int = prim::Constant[value=0]()
1629           %c256 : int = prim::Constant[value=256]()
1630           %c16 : int = prim::Constant[value=16]()
1631           %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none)
1632           %sizes : int[] = prim::ListConstruct(%c16, %c16)
1633           %y_t : Tensor = aten::view(%y_flat, %sizes)
1634           %y : Tensor = aten::t(%y_t)
1635           %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y)
1636           return (%z))IR";
1637   auto graph = std::make_shared<Graph>();
1638   parseIR(graph_string, &*graph);
1639   // IRParser doesn't support tensor constants, so we generate several aten
1640   // calls to produce non-contiguous constant tensor and then const-prop it
1641   ConstantPropagation(graph);
1642 
1643   TensorExprKernel k(graph);
1644 
1645   // Check that we could retrieve generated assembly
1646   auto asm_str = k.getCodeText("asm");
1647   const std::string& asm_verification_pattern =
1648       R"ASM(
1649         # CHECK: .text
1650         # CHECK: retq)ASM";
1651   torch::jit::testing::FileCheck().run(asm_verification_pattern, asm_str);
1652 
1653   // Check that we could retrieve info about codegen parameters
1654   auto constants = k.getConstantDescriptors();
1655   auto buf_args = k.getBufferArgs();
1656   // Expected buf args: [input0, output0, constant0]
1657   ASSERT_EQ(buf_args.size(), 3);
1658   ASSERT_EQ(constants.size(), 1);
1659   ASSERT_TRUE(
1660       !buf_args[0].isVar() && !buf_args[1].isVar() && !buf_args[2].isVar());
1661 #endif
1662 }
1663 
lowerNanToNum(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)1664 Tensor lowerNanToNum(
1665     const std::vector<ArgValue>& inputs,
1666     const std::vector<ExprHandle>& outputShape,
1667     const std::vector<ExprHandle>& outputStrides,
1668     const std::optional<ScalarType>& outputType,
1669     at::Device device) {
1670   auto input_buf = std::get<BufHandle>(inputs[0]);
1671   auto e = Compute(
1672       "custom_nan_to_num",
1673       outputShape,
1674       outputStrides,
1675       [&](const std::vector<VarHandle>& axes) {
1676         std::vector<ExprHandle> indices(axes.begin(), axes.end());
1677         auto load = input_buf.load(indices);
1678         return IfThenElse::make(Cast::make(kBool, isnan(load)), 0.0f, load);
1679       });
1680   return e;
1681 }
1682 
TEST_F(Kernel,CustomLowering)1683 TEST_F(Kernel, CustomLowering) {
1684   const auto graph_string = R"IR(
1685       graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
1686           %none : NoneType = prim::Constant()
1687           %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none)
1688           return (%y)
1689 )IR";
1690   auto graph = std::make_shared<Graph>();
1691   parseIR(graph_string, &*graph);
1692 
1693   std::unordered_map<c10::Symbol, NNCLoweringFunction> lowerings = {
1694       {aten::nan_to_num, lowerNanToNum}};
1695   TensorExprKernel k(graph, lowerings);
1696 
1697   auto stmt = k.getCodeGenStmt();
1698   std::ostringstream oss;
1699   oss << *stmt;
1700 
1701   // Check that our custom lowering is actually used
1702   torch::jit::testing::FileCheck().check("custom_nan_to_num")->run(oss.str());
1703   torch::jit::testing::FileCheck().check("isnan")->run(oss.str());
1704 }
1705 
TEST_F(Kernel,Vectorize)1706 TEST_F(Kernel, Vectorize) {
1707 #ifdef TORCH_ENABLE_LLVM
1708   const auto graph_string = R"IR(
1709       graph(%0 : Float(100, 16, strides=[16, 1], device=cpu),
1710             %1 : Float(100, 16, strides=[16, 1], device=cpu)):
1711         %2 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %1)
1712         %3 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %2)
1713         return (%3))IR";
1714   auto graph = std::make_shared<Graph>();
1715   parseIR(graph_string, &*graph);
1716 
1717   auto a = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1718   auto b = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1719   auto o = at::zeros({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1720   auto ref = a * (a * b);
1721   TensorExprKernel k(graph);
1722   std::vector<at::Tensor> inputs = {a, b};
1723   StmtPtr s = k.getCodeGenStmt();
1724 
1725   std::ostringstream oss;
1726   oss << *s;
1727 
1728   // Check the IR we produced
1729   const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR";
1730   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1731 
1732   std::vector<IValue> stack = fmap<IValue>(inputs);
1733   k.run(stack);
1734   o = stack[0].toTensor();
1735   for (size_t i = 0; i < 100 * 16; i++) {
1736     TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1737   }
1738 #endif
1739 }
1740 
1741 // TODO: To vectorize loopnest for 100x3 case, we need to flatten loops first.
TEST_F(Kernel,DISABLED_FlattenVectorize)1742 TEST_F(Kernel, DISABLED_FlattenVectorize) {
1743 #ifdef TORCH_ENABLE_LLVM
1744   const auto graph_string = R"IR(
1745       graph(%0 : Float(100, 3, strides=[3, 1], device=cpu),
1746             %1 : Float(100, 3, strides=[3, 1], device=cpu)):
1747         %2 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %1)
1748         %3 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %2)
1749         return (%3))IR";
1750   auto graph = std::make_shared<Graph>();
1751   parseIR(graph_string, &*graph);
1752 
1753   auto a = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1754   auto b = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1755   auto o = at::zeros({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1756   auto ref = a * (a * b);
1757   TensorExprKernel k(graph);
1758   std::vector<at::Tensor> inputs = {a, b};
1759   StmtPtr s = k.getCodeGenStmt();
1760 
1761   std::ostringstream oss;
1762   oss << *s;
1763 
1764   // Check the IR we produced
1765   const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR";
1766   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1767 
1768   std::vector<IValue> stack = fmap<IValue>(inputs);
1769   k.run(stack);
1770   o = stack[0].toTensor();
1771   for (size_t i = 0; i < 100 * 3; i++) {
1772     TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1773   }
1774 #endif
1775 }
1776 
TEST_F(Kernel,Strided1dWithinBounds)1777 TEST_F(Kernel, Strided1dWithinBounds) {
1778   auto ir = R"IR(
1779     graph(%0 : Float(3, strides=[1], device=cpu),
1780           %1 : Float(3, strides=[2], device=cpu)):
1781         %2 : int = prim::Constant[value=1]()
1782         %3 : Float(3, strides=[1]) = aten::add(%0, %1, %2)
1783         return (%3))IR";
1784   auto graph = std::make_shared<Graph>();
1785   std::unordered_map<std::string, Value*> vmap;
1786   parseIR(ir, graph.get(), vmap);
1787   TensorExprKernel k(graph);
1788 
1789   auto a = at::rand({3}, TensorOptions(kCPU).dtype(at::kFloat));
1790   auto b = at::rand({6}, TensorOptions(kCPU).dtype(at::kFloat))
1791                .index({Slice(None, None, 2)});
1792   auto expect = a + b;
1793 
1794   std::vector<at::Tensor> inputs = {a, b};
1795 
1796   std::vector<IValue> stack = fmap<IValue>(inputs);
1797   k.run(stack);
1798 
1799   auto output = stack[0].toTensor();
1800 
1801   for (size_t i = 0; i < 3; ++i) {
1802     TORCH_CHECK_EQ(
1803         ((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]);
1804   }
1805 }
1806 
TEST_F(Kernel,InputAsOutput)1807 TEST_F(Kernel, InputAsOutput) {
1808   const auto graph_string = R"IR(
1809       graph(%x : Float(5, 3, strides=[3, 1], device=cpu),
1810             %y : Float(5, 3, strides=[1, 5], device=cpu)):
1811         return (%x, %y))IR";
1812   auto graph = std::make_shared<Graph>();
1813   parseIR(graph_string, &*graph);
1814 
1815   auto x = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1816   auto y =
1817       at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
1818   TensorExprKernel k(graph);
1819   std::vector<at::Tensor> inputs = {x, y};
1820 
1821   std::vector<IValue> stack = fmap<IValue>(inputs);
1822   k.run(stack);
1823   CHECK(at::allclose(x, stack[0].toTensor()));
1824   CHECK(at::allclose(y, stack[1].toTensor()));
1825 }
1826 
TEST_F(Kernel,ScalarOut)1827 TEST_F(Kernel, ScalarOut) {
1828   auto ir = R"IR(
1829 graph(%x : int, %y : int):
1830   %z : int = aten::mul(%x, %y)
1831   %r : int = aten::mul(%z, %x)
1832   return (%r, %z))IR";
1833   auto graph = std::make_shared<Graph>();
1834   std::unordered_map<std::string, Value*> vmap;
1835   parseIR(ir, graph.get(), vmap);
1836   TensorExprKernel k(graph);
1837 
1838   auto stmt = k.getCodeGenStmt();
1839   std::ostringstream oss;
1840   oss << *stmt;
1841 
1842   // Verify the generated IR. We expect to see a scalar variable (Let) followed
1843   // by a store to a 0-dim buffer.
1844   const std::string& verification_pattern = R"IR(
1845 # CHECK: int64_t
1846 # CHECK-NEXT: [0ll] =
1847 # CHECK-NEXT: int64_t
1848 # CHECK-NEXT: [0ll] =
1849 )IR";
1850   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1851 
1852   int64_t x = 2, y = 3, r = 0, z = 0;
1853 
1854   // Verify that TEK::runFast works correctly with scalar outputs
1855   std::vector<void*> inputs = {&x, &y};
1856   std::vector<void*> outputs = {&r, &z};
1857   k.runFast(inputs, outputs);
1858   TORCH_CHECK_EQ(z, x * y);
1859   TORCH_CHECK_EQ(r, z * x);
1860 
1861   // Verify that TEK::run works correctly with scalar outputs
1862   std::vector<IValue> stack = {x, y};
1863   k.run(stack);
1864   TORCH_CHECK_EQ(stack[0], x * y * x);
1865   TORCH_CHECK_EQ(stack[1], x * y);
1866 }
1867 
TEST_F(Kernel,ScalarTensorOut)1868 TEST_F(Kernel, ScalarTensorOut) {
1869   auto ir = R"IR(
1870 graph(%x : int,
1871       %xt : Long(3, strides=[1], device=cpu),
1872       %y : int,
1873       %yt : Long(3, strides=[1], device=cpu)):
1874   %z : int = aten::mul(%x, %y)
1875   %r : int = aten::mul(%z, %x)
1876   %zt : Long(3, strides=[1], device=cpu) = aten::mul(%xt, %y)
1877   %rt : Long(3, strides=[1], device=cpu) = aten::mul(%zt, %xt)
1878   return (%r, %rt, %z, %zt))IR";
1879   auto graph = std::make_shared<Graph>();
1880   std::unordered_map<std::string, Value*> vmap;
1881   parseIR(ir, graph.get(), vmap);
1882   TensorExprKernel k(graph);
1883   int64_t x = 2, y = 3, r = 0, z = 0;
1884   auto xt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 2;
1885   auto yt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 3;
1886   auto zt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong));
1887   auto rt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong));
1888 
1889   // Verify that TEK::runFast works correctly with mixed scalar and tensor
1890   // inputs/utputs
1891   std::vector<void*> inputs = {&x, xt.data_ptr(), &y, yt.data_ptr()};
1892   std::vector<void*> outputs = {&r, rt.data_ptr(), &z, zt.data_ptr()};
1893   k.runFast(inputs, outputs);
1894   TORCH_CHECK_EQ(z, x * y);
1895   TORCH_CHECK_EQ(r, z * x);
1896   ASSERT_TRUE(at::equal(zt, xt * yt));
1897   ASSERT_TRUE(at::equal(rt, zt * xt));
1898 
1899   // Verify that TEK::run works correctly with mixed scalar and tensor
1900   // inputs/utputs
1901   std::vector<IValue> stack = {x, xt, y, yt};
1902   k.run(stack);
1903   TORCH_CHECK_EQ(stack[0], x * y * x);
1904   ASSERT_TRUE(at::equal(stack[1].toTensor(), xt * yt * xt));
1905   TORCH_CHECK_EQ(stack[2], x * y);
1906   ASSERT_TRUE(at::equal(stack[3].toTensor(), xt * yt));
1907 }
1908 
TEST_F(Kernel,FuseLoopsWithVariableBounds)1909 TEST_F(Kernel, FuseLoopsWithVariableBounds) {
1910 #ifdef TORCH_ENABLE_LLVM
1911   bool old_cat_wo_conditionals = getCatWoConditionals();
1912   getCatWoConditionals() = true;
1913   const auto graph_string = R"IR(
1914       graph(%a : Float(SS(-2), 3, SS(-3), requires_grad=0, device=cpu),
1915             %b : Float(SS(-2), 7, SS(-3), requires_grad=0, device=cpu),
1916             %c : Float(SS(-2), 9, SS(-3), requires_grad=0, device=cpu),
1917             %SS_2 : int,
1918             %SS_3 : int):
1919         %dim : int = prim::Constant[value=1]()
1920         %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
1921         %r : Float(SS(-2), 19, SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim)               # new size: [5,19,2]
1922         return (%r))IR";
1923   std::shared_ptr<Graph> graph = std::make_shared<Graph>();
1924   torch::jit::parseIR(graph_string, graph.get());
1925 
1926   std::vector<int64_t> symbolic_shape_inputs = {-2, -3};
1927 
1928   std::vector<torch::jit::StrideInput> input_desc = {
1929       torch::jit::StrideInput::TENSOR_CONT};
1930   std::unordered_map<
1931       const torch::jit::Value*,
1932       std::vector<torch::jit::StrideInput>>
1933       symbolic_strides;
1934   symbolic_strides[graph->inputs().at(0)] = input_desc;
1935   symbolic_strides[graph->inputs().at(1)] = input_desc;
1936   symbolic_strides[graph->inputs().at(2)] = input_desc;
1937   symbolic_strides[graph->outputs().at(0)] = input_desc;
1938 
1939   TensorExprKernel kernel(
1940       graph, {}, symbolic_shape_inputs, false, symbolic_strides);
1941 
1942   std::ostringstream oss;
1943   oss << *kernel.getCodeGenStmt();
1944   const std::string& verification_pattern =
1945       R"IR(
1946 # CHECK: for (int64_t i
1947 # CHECK-NEXT: for (int64_t j
1948 # CHECK-NEXT: for (int64_t k
1949 # CHECK: for (int64_t j
1950 # CHECK-NEXT: for (int64_t k
1951 # CHECK: for (int64_t j
1952 # CHECK-NEXT: for (int64_t k
1953 # CHECK-NOT: for (int64_t i
1954       )IR";
1955   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1956 
1957   auto run_kernel = [&](int dim1, int dim2) {
1958     auto a =
1959         at::rand({dim1, 3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
1960     auto b =
1961         at::rand({dim1, 7, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
1962     auto c =
1963         at::rand({dim1, 9, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
1964 
1965     auto ref = at::cat({a, b, c}, 1);
1966 
1967     std::vector<IValue> stack =
1968         fmap<IValue>(std::vector<at::Tensor>({a, b, c}));
1969     stack.emplace_back(dim1);
1970     stack.emplace_back(dim2);
1971     kernel.run(stack);
1972 
1973     auto o = stack[0].toTensor();
1974     ASSERT_TRUE(at::allclose(o, ref));
1975   };
1976 
1977   run_kernel(10, 20);
1978   getCatWoConditionals() = old_cat_wo_conditionals;
1979 #endif
1980 }
1981 
TEST_F(Kernel,FuseLoopsWithVariableConcatDim)1982 TEST_F(Kernel, FuseLoopsWithVariableConcatDim) {
1983 #ifdef TORCH_ENABLE_LLVM
1984   bool old_cat_wo_conditionals = getCatWoConditionals();
1985   getCatWoConditionals() = true;
1986   const auto graph_string = R"IR(
1987       graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
1988             %b : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
1989             %c : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
1990             %SS_2 : int,
1991             %SS_3 : int,
1992             %SS_4 : int,
1993             %SS_5 : int):
1994         %dim : int = prim::Constant[value=1]()
1995         %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
1996         %r : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim)               # new size: [5,19,2]
1997         return (%r))IR";
1998   std::shared_ptr<Graph> graph = std::make_shared<Graph>();
1999   torch::jit::parseIR(graph_string, graph.get());
2000 
2001   std::vector<int64_t> symbolic_shape_inputs = {-2, -3, -4, -5};
2002 
2003   std::vector<torch::jit::StrideInput> input_desc = {
2004       torch::jit::StrideInput::TENSOR_CONT};
2005   std::unordered_map<
2006       const torch::jit::Value*,
2007       std::vector<torch::jit::StrideInput>>
2008       symbolic_strides;
2009   symbolic_strides[graph->inputs().at(0)] = input_desc;
2010   symbolic_strides[graph->inputs().at(1)] = input_desc;
2011   symbolic_strides[graph->inputs().at(2)] = input_desc;
2012   symbolic_strides[graph->outputs().at(0)] = input_desc;
2013 
2014   TensorExprKernel kernel(
2015       graph, {}, symbolic_shape_inputs, false, symbolic_strides);
2016 
2017   std::ostringstream oss;
2018   oss << *kernel.getCodeGenStmt();
2019   const std::string& verification_pattern =
2020       R"IR(
2021 # CHECK: for (int64_t i
2022 # CHECK-NEXT: for (int64_t j
2023 # CHECK-NEXT: for (int64_t k
2024 # CHECK: for (int64_t j
2025 # CHECK-NEXT: for (int64_t k
2026 # CHECK: for (int64_t j
2027 # CHECK-NEXT: for (int64_t k
2028 # CHECK-NOT: for (int64_t i
2029       )IR";
2030   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2031 
2032   auto run_kernel = [&](int dim1, int dim2, int dim3) {
2033     auto a =
2034         at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
2035     auto b =
2036         at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
2037     auto c =
2038         at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
2039 
2040     auto ref = at::cat({a, b, c}, 1);
2041 
2042     std::vector<IValue> stack =
2043         fmap<IValue>(std::vector<at::Tensor>({a, b, c}));
2044     stack.emplace_back(dim1);
2045     stack.emplace_back(dim2);
2046     stack.emplace_back(dim3);
2047     stack.emplace_back(3 * dim3);
2048     kernel.run(stack);
2049 
2050     auto o = stack[0].toTensor();
2051     ASSERT_TRUE(at::allclose(o, ref));
2052   };
2053 
2054   run_kernel(10, 20, 15);
2055   getCatWoConditionals() = old_cat_wo_conditionals;
2056 #endif
2057 }
2058 
TEST_F(Kernel,DoNotFuseLoopsWithMismatchingVariableDims)2059 TEST_F(Kernel, DoNotFuseLoopsWithMismatchingVariableDims) {
2060 #ifdef TORCH_ENABLE_LLVM
2061   bool old_cat_wo_conditionals = getCatWoConditionals();
2062   getCatWoConditionals() = true;
2063   const auto graph_string = R"IR(
2064       graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
2065             %b : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu),
2066             %SS_2 : int,
2067             %SS_3 : int,
2068             %SS_4 : int,
2069             %SS_5 : int,
2070             %SS_6 : int):
2071         %dim : int = prim::Constant[value=1]()
2072         %inputs : Tensor[] = prim::ListConstruct(%a, %b)
2073         %r : Float(SS(-2), SS(-6), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim)               # new size: [5,19,2]
2074         return (%r))IR";
2075   std::shared_ptr<Graph> graph = std::make_shared<Graph>();
2076   torch::jit::parseIR(graph_string, graph.get());
2077 
2078   std::vector<int64_t> symbolic_shape_inputs = {-2, -3, -4, -5, -6};
2079 
2080   std::vector<torch::jit::StrideInput> input_desc = {
2081       torch::jit::StrideInput::TENSOR_CONT};
2082   std::unordered_map<
2083       const torch::jit::Value*,
2084       std::vector<torch::jit::StrideInput>>
2085       symbolic_strides;
2086   symbolic_strides[graph->inputs().at(0)] = input_desc;
2087   symbolic_strides[graph->inputs().at(1)] = input_desc;
2088   symbolic_strides[graph->outputs().at(0)] = input_desc;
2089 
2090   TensorExprKernel kernel(
2091       graph, {}, symbolic_shape_inputs, false, symbolic_strides);
2092 
2093   std::ostringstream oss;
2094   oss << *kernel.getCodeGenStmt();
2095   const std::string& verification_pattern =
2096       R"IR(
2097 # CHECK: for (int64_t i
2098 # CHECK-NEXT: for (int64_t j
2099 # CHECK-NEXT: for (int64_t k
2100 # CHECK: for (int64_t j
2101 # CHECK-NEXT: for (int64_t k
2102 # CHECK-NOT: for (int64_t j
2103 # CHECK-NOT: for (int64_t i
2104       )IR";
2105   torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2106 
2107   auto run_kernel = [&](int dim2, int dim3, int dim4, int dim5) {
2108     auto a =
2109         at::rand({dim2, dim4, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat));
2110     auto b =
2111         at::rand({dim2, dim5, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat));
2112 
2113     auto ref = at::cat({a, b}, 1);
2114 
2115     std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
2116     stack.emplace_back(dim2);
2117     stack.emplace_back(dim3);
2118     stack.emplace_back(dim4);
2119     stack.emplace_back(dim5);
2120     stack.emplace_back(dim4 + dim5);
2121     kernel.run(stack);
2122 
2123     auto o = stack[0].toTensor();
2124     ASSERT_TRUE(at::allclose(o, ref));
2125   };
2126 
2127   run_kernel(10, 20, 15, 8);
2128   getCatWoConditionals() = old_cat_wo_conditionals;
2129 #endif
2130 }
2131 
2132 } // namespace jit
2133 } // namespace torch
2134