xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_llvm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef TORCH_ENABLE_LLVM
2 #include <gtest/gtest.h>
3 
4 #include <test/cpp/tensorexpr/test_base.h>
5 
6 #include <c10/util/irange.h>
7 #include <test/cpp/tensorexpr/padded_buffer.h>
8 #include <test/cpp/tensorexpr/test_utils.h>
9 #include <torch/csrc/jit/tensorexpr/eval.h>
10 #include <torch/csrc/jit/tensorexpr/ir.h>
11 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
12 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
13 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
14 #include <torch/csrc/jit/tensorexpr/loopnest.h>
15 #include <torch/csrc/jit/tensorexpr/tensor.h>
16 #include <torch/csrc/jit/testing/file_check.h>
17 
18 #include <cmath>
19 #include <numeric>
20 
21 namespace torch {
22 namespace jit {
23 using namespace torch::jit::tensorexpr;
24 
25 using LLVMExprEval = ExprEval<LLVMCodeGen>;
26 
27 // Typed tests, can't use gtest params here due to the way we instantiate tests.
28 #define TEST_LLVM_SCALAR_TYPES(_) \
29   _(uint8_t, Byte, 24)            \
30   _(int8_t, Char, -20)            \
31   _(int16_t, Short, 3332)         \
32   _(int, Int, 123456)             \
33   _(int64_t, Long, 2631563121321) \
34   _(float, Float, 0.122)          \
35   _(double, Double, 0.21312)      \
36   _(at::Half, Half, 0.128f)
37 
38 #define IMM_TEST(Type, Name, Val)                  \
39   TEST(LLVM, Name##ImmTest) {                      \
40     auto a = Name##Imm::make(Val);                 \
41     LLVMExprEval cg(a);                            \
42     if (std::is_floating_point<decltype(Val)>()) { \
43       ASSERT_NEAR(cg.value<Type>(), Val, 0.1);     \
44     } else {                                       \
45       ASSERT_EQ(cg.value<Type>(), Val);            \
46     }                                              \
47   }
48 TEST_LLVM_SCALAR_TYPES(IMM_TEST)
49 #undef IMM_TEST
50 
51 #define ADD_TEST(Type, Name, Val)                  \
52   TEST(LLVM, Name##AddTest) {                      \
53     auto a = Name##Imm::make(Val);                 \
54     auto b = Name##Imm::make(Val * 2);             \
55     auto c = Add::make(a, b);                      \
56     LLVMExprEval cg(c);                            \
57     if (std::is_floating_point<decltype(Val)>()) { \
58       ASSERT_NEAR(cg.value<Type>(), Val * 3, 0.1); \
59     } else {                                       \
60       ASSERT_EQ(cg.value<Type>(), Val * 3);        \
61     }                                              \
62   }
TEST_LLVM_SCALAR_TYPES(ADD_TEST)63 TEST_LLVM_SCALAR_TYPES(ADD_TEST)
64 #undef ADD_TEST
65 
66 #define SUB_TEST(Type, Name, Val)                  \
67   TEST(LLVM, Name##SubTest) {                      \
68     auto a = Name##Imm::make(Val * 2);             \
69     auto b = Name##Imm::make(Val);                 \
70     auto c = Sub::make(a, b);                      \
71     LLVMExprEval cg(c);                            \
72     if (std::is_floating_point<decltype(Val)>()) { \
73       ASSERT_NEAR(cg.value<Type>(), Val, 0.1);     \
74     } else {                                       \
75       ASSERT_EQ(cg.value<Type>(), Val);            \
76     }                                              \
77   }
78 TEST_LLVM_SCALAR_TYPES(SUB_TEST)
79 #undef SUB_TEST
80 
81 #define MUL_TEST(Type, Name, Val)                  \
82   TEST(LLVM, Name##MulTest) {                      \
83     auto a = Name##Imm::make(Val);                 \
84     auto b = Name##Imm::make((Type)4);             \
85     auto c = Mul::make(a, b);                      \
86     LLVMExprEval cg(c);                            \
87     if (std::is_floating_point<decltype(Val)>()) { \
88       ASSERT_NEAR(cg.value<Type>(), Val * 4, 0.1); \
89     } else {                                       \
90       ASSERT_EQ(cg.value<Type>(), Val * 4);        \
91     }                                              \
92   }
93 TEST_LLVM_SCALAR_TYPES(MUL_TEST)
94 #undef MUL_TEST
95 
96 #define DIV_TEST(Type, Name, Val)                  \
97   TEST(LLVM, Name##DivTest) {                      \
98     auto a = Name##Imm::make((Type)6);             \
99     auto b = Name##Imm::make((Type)3);             \
100     auto c = Div::make(a, b);                      \
101     LLVMExprEval cg(c);                            \
102     if (std::is_floating_point<decltype(Val)>()) { \
103       ASSERT_NEAR(cg.value<Type>(), 2, 0.1);       \
104     } else {                                       \
105       ASSERT_EQ(cg.value<Type>(), 2);              \
106     }                                              \
107   }
108 TEST_LLVM_SCALAR_TYPES(DIV_TEST)
109 #undef DIV_TEST
110 
111 TEST(LLVM, IntToFloatCastTest) {
112   auto a = IntImm::make(2);
113   auto b = Cast::make(kFloat, a);
114   LLVMExprEval cg(b, {});
115   ASSERT_EQ(cg.value<float>(), 2.0);
116 }
117 
TEST(LLVM,FloatToIntCastTest)118 TEST(LLVM, FloatToIntCastTest) {
119   auto a = FloatImm::make(2.0);
120   auto b = Cast::make(kInt, a);
121   LLVMExprEval cg(b);
122   ASSERT_EQ(cg.value<int>(), 2);
123 }
124 
TEST(LLVM,IntToLongCastTest)125 TEST(LLVM, IntToLongCastTest) {
126   auto a = IntImm::make(12345);
127   auto b = Cast::make(kLong, a);
128   LLVMExprEval cg(b);
129   ASSERT_EQ(cg.value<int64_t>(), 12345);
130 }
131 
TEST(LLVM,ByteToCharCastTest)132 TEST(LLVM, ByteToCharCastTest) {
133   auto a = ByteImm::make(250);
134   auto b = Cast::make(kChar, a);
135   LLVMExprEval cg(b);
136   ASSERT_EQ(cg.value<int8_t>(), (int8_t)250);
137 }
138 
TEST(LLVM,HalfToLongCastTest)139 TEST(LLVM, HalfToLongCastTest) {
140   auto a = HalfImm::make(2.0);
141   auto b = Cast::make(kLong, a);
142   LLVMExprEval cg(b);
143   ASSERT_EQ(cg.value<int64_t>(), 2);
144 }
145 
TEST(LLVM,ByteToDoubleCastTest)146 TEST(LLVM, ByteToDoubleCastTest) {
147   auto a = ByteImm::make(2);
148   auto b = Cast::make(kDouble, a);
149   LLVMExprEval cg(b);
150   ASSERT_EQ(cg.value<double>(), 2);
151 }
152 
TEST(LLVM,FloatToByteCastTest)153 TEST(LLVM, FloatToByteCastTest) {
154   auto a = FloatImm::make(254.0);
155   auto b = Cast::make(kByte, a);
156   LLVMExprEval cg(b);
157   ASSERT_EQ(cg.value<uint8_t>(), 254);
158 }
159 
TEST(LLVM,FloatToCharCastTest)160 TEST(LLVM, FloatToCharCastTest) {
161   auto a = FloatImm::make(-2.0);
162   auto b = Cast::make(kChar, a);
163   LLVMExprEval cg(b);
164   ASSERT_EQ(cg.value<int8_t>(), -2);
165 }
166 
TEST(LLVM,ByteToFloatCastTest)167 TEST(LLVM, ByteToFloatCastTest) {
168   auto a = ByteImm::make(254);
169   auto b = Cast::make(kFloat, a);
170   LLVMExprEval cg(b);
171   ASSERT_EQ(cg.value<float>(), 254.0);
172 }
173 
TEST(LLVM,CharToFloatCastTest)174 TEST(LLVM, CharToFloatCastTest) {
175   auto a = CharImm::make(-2);
176   auto b = Cast::make(kFloat, a);
177   LLVMExprEval cg(b);
178   ASSERT_EQ(cg.value<float>(), -2.0);
179 }
180 
TEST(LLVM,BitCast)181 TEST(LLVM, BitCast) {
182   /* constexpr int16_t ref16 = 1337; */
183   constexpr int32_t ref32 = 1337;
184   constexpr int64_t ref64 = 1337;
185   constexpr float reff32 = 1337.0f;
186   constexpr double reff64 = 1337.0f;
187 
188   // this is broken
189   /*{
190     at::Half k_;
191     at::Half* k = &k_;
192     *reinterpret_cast<int16_t*>(k) = ref16;
193     auto a = HalfImm::make(k);
194     auto b = BitCast::make(kShort, a);
195     LLVMExprEval cg(b);
196     ASSERT_EQ(cg.value<int16_t>(), ref16);
197   }*/
198 
199   {
200     float k = raw_bitcast<float>(ref32);
201     auto a = FloatImm::make(k);
202     auto b = BitCast::make(kInt, a);
203     LLVMExprEval cg(b);
204     ASSERT_EQ(cg.value<int32_t>(), ref32);
205   }
206 
207   {
208     double k = raw_bitcast<double>(ref64);
209     auto a = DoubleImm::make(k);
210     auto b = BitCast::make(kLong, a);
211     LLVMExprEval cg(b);
212     ASSERT_EQ(cg.value<int64_t>(), ref64);
213   }
214 
215   {
216     int64_t k = raw_bitcast<int64_t>(reff64);
217     auto a = LongImm::make(k);
218     auto b = BitCast::make(kDouble, a);
219     LLVMExprEval cg(b);
220     ASSERT_EQ(cg.value<double>(), reff64);
221   }
222 
223   {
224     int32_t k = raw_bitcast<int32_t>(reff32);
225     auto a = IntImm::make(k);
226     auto b = BitCast::make(kFloat, a);
227     LLVMExprEval cg(b);
228     ASSERT_EQ(cg.value<float>(), reff32);
229   }
230 }
231 
TEST(LLVM,fastLogFloat)232 TEST(LLVM, fastLogFloat) {
233   const int kTotalSize = 128 * 128;
234   BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
235   BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
236 
237   VarHandle index = VarHandle("index", kInt);
238   ExprHandle load_a = a_buf.load(index);
239   StmtPtr store_b = b_buf.store({index}, fast_log(load_a));
240   StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
241 
242   PaddedBuffer<float> a_v(kTotalSize);
243   PaddedBuffer<float> b_v(kTotalSize);
244 
245   for (const auto i : c10::irange(kTotalSize)) {
246     a_v(i) = at::randn({1}).item().to<float>();
247   }
248 
249   LLVMCodeGen ir_eval(stmt, {a_buf, b_buf});
250   ir_eval.call({a_v, b_v});
251 
252   for (const auto i : c10::irange(kTotalSize)) {
253     auto test = b_v(i);
254     auto ref = std::log(a_v(i));
255     if (std::isnan(ref)) {
256       ASSERT_EQ(std::isnan(test), true);
257     } else {
258       ASSERT_FLOAT_EQ(test, ref);
259     }
260   }
261 }
262 
TEST(LLVM,LetTest01)263 TEST(LLVM, LetTest01) {
264   BufHandle a("A", {1}, kFloat);
265   std::vector<float> v = {1, 0};
266   std::vector<void*> args({v.data()});
267   VarHandle x("x", kFloat);
268   auto block = Block::make({
269       Let::make(x, 3.f),
270       a.store({0}, ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f))),
271   });
272 
273   LLVMCodeGen cg(block, {a});
274   ASSERT_EQ(cg.value<int>(args), 0);
275   ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 4.f);
276 }
277 
TEST(LLVM,LetTest02)278 TEST(LLVM, LetTest02) {
279   BufHandle a("A", {1}, kFloat);
280   std::vector<float> v = {1, 0};
281   std::vector<void*> args({v.data()});
282   VarHandle x("x", kFloat);
283   VarHandle y("y", kFloat);
284   auto block = Block::make(
285       {Let::make(x, 3.f),
286        Let::make(y, 6.f),
287        a.store(
288            {IntImm::make(0)},
289            ExprHandle(2.f) + (x * ExprHandle(3.f) + y * ExprHandle(4.f)))});
290 
291   LLVMCodeGen cg(block, {a});
292   ASSERT_EQ(cg.value<int>(args), 0);
293   ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 6.f * 4.f);
294 }
295 
TEST(LLVM,LetTestMultitype)296 TEST(LLVM, LetTestMultitype) {
297   BufHandle a("A", {1}, kDouble);
298   std::vector<double> v = {1, 0};
299   std::vector<void*> args({v.data()});
300   VarHandle x("x", kByte);
301   VarHandle y("y", kHalf);
302   auto block = Block::make(
303       {Let::make(x, 3),
304        Let::make(y, 6.f),
305        a.store(
306            {0},
307            Cast::make(
308                kDouble,
309                ExprHandle(2.f) +
310                    (x * ExprHandle(3.f) + y * ExprHandle(4.f))))});
311 
312   LLVMCodeGen cg(block, {a});
313   ASSERT_EQ(cg.value<int>(args), 0);
314   ASSERT_EQ(v[0], 2.f + 3 * 3.f + 6.f * 4.f);
315 }
316 
TEST(LLVM,BufferTest)317 TEST(LLVM, BufferTest) {
318   BufHandle a("A", {32}, kFloat);
319   std::vector<int32_t> v(5);
320   std::vector<void*> args({v.data()});
321   auto rv = IntImm::make(0);
322   LLVMExprEval cg(rv, {a});
323   ASSERT_EQ(cg.value<int>(args), 0);
324 }
325 
TEST(LLVM,BlockTest)326 TEST(LLVM, BlockTest) {
327   BufHandle a("A", {32}, kInt);
328   std::vector<int32_t> v = {1, 2};
329   std::vector<void*> args({v.data()});
330 
331   auto block = Block::make({
332       a.store({0}, 3),
333       a.store({1}, 4),
334       a.store({0}, 4),
335   });
336 
337   LLVMCodeGen cg(block, {a});
338   ASSERT_EQ(cg.value<int>(args), 0);
339   ASSERT_EQ(v[0], 4);
340   ASSERT_EQ(v[1], 4);
341 }
342 
TEST(LLVM,LoadStoreTest)343 TEST(LLVM, LoadStoreTest) {
344   BufHandle a("A", {1}, kInt);
345   BufHandle b("B", {1}, kInt);
346   std::vector<int32_t> a_buffer = {42};
347   std::vector<int32_t> b_buffer = {-11};
348 
349   auto store = b.store({0}, a.load(0));
350   LLVMCodeGen cg(store, {a, b});
351   std::vector<void*> args({a_buffer.data(), b_buffer.data()});
352   ASSERT_EQ(cg.value<int>(args), 0);
353   ASSERT_EQ(a_buffer[0], 42);
354   ASSERT_EQ(b_buffer[0], 42);
355 }
356 
TEST(LLVM,IfThenElseTest)357 TEST(LLVM, IfThenElseTest) {
358   BufHandle a("A", {1}, kInt);
359   BufHandle b("B", {1}, kInt);
360   BufHandle c("C", {1}, kInt);
361   std::vector<int32_t> a_buffer = {42};
362   std::vector<int32_t> b_buffer = {-11};
363   std::vector<int32_t> c_buffer = {1};
364 
365   auto store = b.store({0}, IfThenElse::make(c.load(0), a.load(0), 0));
366   LLVMCodeGen cg(store, {a, b, c});
367   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
368   ASSERT_EQ(cg.value<int>(args), 0);
369   ASSERT_EQ(a_buffer[0], 42);
370   ASSERT_EQ(b_buffer[0], 42);
371 }
372 
373 // if (x < 10) x = x + 1
TEST(LLVM,CondNoFalseBlockTest)374 TEST(LLVM, CondNoFalseBlockTest) {
375   BufHandle x("X", {1}, kInt);
376   auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);
377   auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), nullptr);
378 
379   for (int32_t x_value : {0, 10, 20}) {
380     std::vector<int32_t> x_buffer = {x_value};
381     std::vector<void*> args({x_buffer.data()});
382     LLVMCodeGen cg(cond, {x});
383     ASSERT_EQ(cg.value<int>(args), 0);
384     if (x_value < 10) {
385       ASSERT_EQ(x_buffer[0], x_value + 1);
386     } else {
387       ASSERT_EQ(x_buffer[0], x_value);
388     }
389   }
390 }
391 
392 // if (x < 10) {
393 //   x = x + 1;
394 // } else {
395 //   x = x - 1;
396 // }
TEST(LLVM,CondTest)397 TEST(LLVM, CondTest) {
398   BufHandle x("X", {1}, kInt);
399   auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);
400   auto cond =
401       Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1));
402   auto block = Block::make({
403       cond,
404       x.store({0}, x.load(0) * 2),
405   });
406 
407   for (int32_t x_value : {0, 10, 20}) {
408     std::vector<int32_t> x_buffer = {x_value};
409     std::vector<void*> args({x_buffer.data()});
410     LLVMCodeGen cg(block, {x});
411     ASSERT_EQ(cg.value<int>(args), 0);
412     if (x_value < 10) {
413       ASSERT_EQ(x_buffer[0], (x_value + 1) * 2);
414     } else {
415       ASSERT_EQ(x_buffer[0], (x_value - 1) * 2);
416     }
417   }
418 }
419 
420 // if (x < 10) {
421 //   if (x > 5) {
422 //     x = x + 1;
423 //   } else {
424 //     x = x - 1;
425 //   }
426 // } else {
427 //   if (x <= 15) {
428 //     x = x + 2;
429 //   } else {
430 //     x = x - 2;
431 //   }
432 // }
TEST(LLVM,CondNestedTest)433 TEST(LLVM, CondNestedTest) {
434   BufHandle x("X", {1}, kInt);
435   auto true_cmp =
436       CompareSelect::make(x.load(0), 5, CompareSelectOperation::kGT);
437   auto true_cond = Cond::make(
438       true_cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1));
439   auto false_cmp =
440       CompareSelect::make(x.load(0), 15, CompareSelectOperation::kLE);
441   auto false_cond = Cond::make(
442       false_cmp, x.store({0}, x.load(0) + 2), x.store({0}, x.load(0) - 2));
443   auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);
444   auto cond = Cond::make(cmp, true_cond, false_cond);
445 
446   for (int32_t x_value : {0, 8, 15, 20}) {
447     std::vector<int32_t> x_buffer = {x_value};
448     std::vector<void*> args({x_buffer.data()});
449     LLVMCodeGen cg(cond, {x});
450     ASSERT_EQ(cg.value<int>(args), 0);
451     if (x_value < 10) {
452       if (x_value > 5) {
453         ASSERT_EQ(x_buffer[0], x_value + 1);
454       } else {
455         ASSERT_EQ(x_buffer[0], x_value - 1);
456       }
457     } else {
458       if (x_value <= 15) {
459         ASSERT_EQ(x_buffer[0], x_value + 2);
460       } else {
461         ASSERT_EQ(x_buffer[0], x_value - 2);
462       }
463     }
464   }
465 }
466 
TEST(LLVM,DirectVectorization)467 TEST(LLVM, DirectVectorization) {
468   constexpr int M = 3;
469   constexpr int N = 64;
470   BufHandle a("a", {M, N}, kFloat);
471   BufHandle b("b", {M, N}, kFloat);
472   BufHandle c("c", {M, N}, kFloat);
473   VarHandle m("m", kInt);
474   VarHandle n("n", kInt);
475   StmtPtr s = For::make(
476       m,
477       0,
478       M,
479       Store::make(
480           c,
481           {Ramp::make(m * 64, 1, 64)},
482           Load::make({kFloat, 64}, a, {Ramp::make(m * 64, 1, 64)}) *
483               Load::make({kFloat, 64}, b, {Ramp::make(m * 64, 1, 64)})));
484   LLVMCodeGen cg(s, {a, b, c});
485 }
486 
TEST(LLVM,VecLoadStoreTest)487 TEST(LLVM, VecLoadStoreTest) {
488   BufHandle a("A", {1}, kInt);
489   BufHandle b("B", {1}, kInt);
490   std::vector<int32_t> a_buffer = {1, 1, 1, 1};
491   std::vector<int32_t> b_buffer = {2, 2, 2, 2};
492 
493   auto store = b.store({Ramp::make(0, 1, 4)}, a.load({Ramp::make(0, 1, 4)}));
494   LLVMCodeGen cg(store, {a, b});
495   std::vector<void*> args({a_buffer.data(), b_buffer.data()});
496   ASSERT_EQ(cg.value<int>(args), 0);
497   ASSERT_EQ(a_buffer[0], 1);
498   ASSERT_EQ(a_buffer[1], 1);
499   ASSERT_EQ(a_buffer[2], 1);
500   ASSERT_EQ(a_buffer[3], 1);
501   ASSERT_EQ(b_buffer[0], 1);
502   ASSERT_EQ(b_buffer[1], 1);
503   ASSERT_EQ(b_buffer[2], 1);
504   ASSERT_EQ(b_buffer[3], 1);
505 }
506 
507 #define FLOAT_INTRINSICS_TEST(Name, Lanes)                                   \
508   TEST(LLVM, VecFloat_##Name##Lane##Lanes##Test) {                           \
509     BufHandle a("A", {1}, kFloat);                                           \
510     BufHandle b("B", {1}, kFloat);                                           \
511     float val = 0.5f;                                                        \
512     std::vector<float> a_buffer(Lanes, val);                                 \
513     std::vector<float> b_buffer(Lanes, val);                                 \
514     auto store = b.store(                                                    \
515         {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \
516     LLVMCodeGen cg(store, {a, b});                                           \
517     std::vector<void*> args({a_buffer.data(), b_buffer.data()});             \
518     ASSERT_EQ(cg.value<int>(args), 0);                                       \
519     for (const auto i : c10::irange(Lanes)) {                                \
520       ASSERT_FLOAT_EQ(a_buffer[i], val);                                     \
521     }                                                                        \
522   } // namespace jit
523 FLOAT_INTRINSICS_TEST(erf, 4)
524 FLOAT_INTRINSICS_TEST(erfc, 4)
525 FLOAT_INTRINSICS_TEST(acos, 4)
526 FLOAT_INTRINSICS_TEST(asin, 4)
527 FLOAT_INTRINSICS_TEST(atan, 4)
528 FLOAT_INTRINSICS_TEST(cosh, 4)
529 FLOAT_INTRINSICS_TEST(sinh, 4)
530 FLOAT_INTRINSICS_TEST(tanh, 4)
531 FLOAT_INTRINSICS_TEST(expm1, 4)
532 FLOAT_INTRINSICS_TEST(lgamma, 4)
533 FLOAT_INTRINSICS_TEST(erf, 8)
534 FLOAT_INTRINSICS_TEST(erfc, 8)
535 FLOAT_INTRINSICS_TEST(acos, 8)
536 FLOAT_INTRINSICS_TEST(asin, 8)
537 FLOAT_INTRINSICS_TEST(atan, 8)
538 FLOAT_INTRINSICS_TEST(cosh, 8)
539 FLOAT_INTRINSICS_TEST(sinh, 8)
540 FLOAT_INTRINSICS_TEST(tanh, 8)
541 FLOAT_INTRINSICS_TEST(expm1, 8)
542 FLOAT_INTRINSICS_TEST(lgamma, 8)
543 #undef FLOAT_INTRINSICS_TEST
544 
545 #define DOUBLE_INTRINSICS_TEST(Name, Lanes)                                  \
546   TEST(LLVM, VecDouble_##Name##Lane##Lanes##Test) {                          \
547     BufHandle a("A", {1}, kDouble);                                          \
548     BufHandle b("B", {1}, kDouble);                                          \
549     float val = 0.5f;                                                        \
550     std::vector<double> a_buffer(Lanes, val);                                \
551     std::vector<double> b_buffer(Lanes, val);                                \
552     auto store = b.store(                                                    \
553         {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \
554     LLVMCodeGen cg(store, {a, b});                                           \
555     std::vector<void*> args({a_buffer.data(), b_buffer.data()});             \
556     ASSERT_EQ(cg.value<int>(args), 0);                                       \
557     for (const auto i : c10::irange(Lanes)) {                                \
558       ASSERT_FLOAT_EQ(a_buffer[i], val);                                     \
559     }                                                                        \
560   } // namespace jit
561 DOUBLE_INTRINSICS_TEST(erf, 2)
562 DOUBLE_INTRINSICS_TEST(erfc, 2)
563 DOUBLE_INTRINSICS_TEST(acos, 2)
564 DOUBLE_INTRINSICS_TEST(asin, 2)
565 DOUBLE_INTRINSICS_TEST(atan, 2)
566 DOUBLE_INTRINSICS_TEST(cosh, 2)
567 DOUBLE_INTRINSICS_TEST(sinh, 2)
568 DOUBLE_INTRINSICS_TEST(tanh, 2)
569 DOUBLE_INTRINSICS_TEST(expm1, 2)
570 DOUBLE_INTRINSICS_TEST(lgamma, 2)
571 DOUBLE_INTRINSICS_TEST(erf, 4)
572 DOUBLE_INTRINSICS_TEST(erfc, 4)
573 DOUBLE_INTRINSICS_TEST(acos, 4)
574 DOUBLE_INTRINSICS_TEST(asin, 4)
575 DOUBLE_INTRINSICS_TEST(atan, 4)
576 DOUBLE_INTRINSICS_TEST(cosh, 4)
577 DOUBLE_INTRINSICS_TEST(sinh, 4)
578 DOUBLE_INTRINSICS_TEST(tanh, 4)
579 DOUBLE_INTRINSICS_TEST(expm1, 4)
580 DOUBLE_INTRINSICS_TEST(lgamma, 4)
581 #undef DOUBLE_INTRINSICS_TEST
582 
TEST(LLVM,VectorizerLoadStoreTest)583 TEST(LLVM, VectorizerLoadStoreTest) {
584   BufHandle a("A", {1}, kInt);
585 
586   Tensor c = Compute("c", {4}, [&](const VarHandle& i) { return a.load(i); });
587 
588   BufHandle c_buf(c.buf());
589   LoopNest l({c});
590   StmtPtr s = l.root_stmt();
591   ASSERT_TRUE(LoopNest::vectorize(to<For>(to<Block>(s)->front())));
592 
593   ASSERT_TRUE(to<For>(to<Block>(s)->front()) == nullptr);
594 
595   LLVMCodeGen cg(s, {a, c_buf});
596 
597   std::vector<int> a_vec(4, 21);
598   std::vector<int> c_vec(4, 0);
599   std::vector<void*> args({a_vec.data(), c_vec.data()});
600   ASSERT_EQ(cg.value<int>(args), 0);
601   assertAllEqual(c_vec, 21);
602 }
603 
TEST(LLVM,VectorizeBitCast)604 TEST(LLVM, VectorizeBitCast) {
605   BufHandle a("A", {128}, kInt);
606 
607   Tensor c = Compute("c", {128}, [&](const VarHandle& i) {
608     return bitcast<float>(a.load(i));
609   });
610 
611   BufHandle c_buf(c.buf());
612   LoopNest l({c});
613   StmtPtr s = l.root_stmt();
614   ASSERT_TRUE(LoopNest::vectorize(to<For>(to<Block>(s)->front())));
615   ASSERT_TRUE(to<For>(to<Block>(s)->front()) == nullptr);
616 
617   LLVMCodeGen cg(s, {a, c_buf});
618 
619   std::vector<int> a_vec(128);
620   std::vector<float> c_vec(128);
621   for (const auto i : c10::irange(128)) {
622     a_vec[i] = raw_bitcast<int>(1337.f);
623   }
624   std::vector<void*> args({a_vec.data(), c_vec.data()});
625   ASSERT_EQ(cg.value<int>(args), 0);
626   assertAllEqual(c_vec, 1337.f);
627 }
628 
TEST(LLVM,MemcpyTest)629 TEST(LLVM, MemcpyTest) {
630   constexpr int N = 32;
631   BufHandle a("A", {N}, kInt);
632   BufHandle b("B", {N}, kInt);
633   std::vector<int32_t> a_buffer(N, 42);
634   std::vector<int32_t> b_buffer(N, 0);
635 
636   VarHandle i("i", kInt);
637   auto expr = For::make(i, 0, N, b.store({i}, a.load(i)));
638 
639   LLVMCodeGen cg(expr, {a, b});
640 
641   std::vector<void*> args({a_buffer.data(), b_buffer.data()});
642   ASSERT_EQ(cg.value<int>(args), 0);
643 
644   ASSERT_EQ(a_buffer.size(), N);
645   ASSERT_EQ(b_buffer.size(), N);
646   assertAllEqual(a_buffer, 42);
647   assertAllEqual(b_buffer, 42);
648 }
649 
TEST(LLVM,BzeroTest)650 TEST(LLVM, BzeroTest) {
651   constexpr int N = 32;
652   BufHandle b("B", {N}, kInt);
653   std::vector<int32_t> b_buffer(N, 11);
654 
655   VarHandle i("i", kInt);
656   auto expr = For::make(i, 0, N, b.store({i}, 0));
657 
658   LLVMCodeGen cg(expr, {b});
659 
660   std::vector<void*> args({b_buffer.data()});
661   ASSERT_EQ(cg.value<int>(args), 0);
662 
663   ASSERT_EQ(b_buffer.size(), N);
664   assertAllEqual(b_buffer, 0);
665 }
666 
TEST(LLVM,ElemwiseAdd)667 TEST(LLVM, ElemwiseAdd) {
668   constexpr int N = 1024;
669   BufHandle a("A", {N}, kInt);
670   BufHandle b("B", {N}, kInt);
671   BufHandle c("C", {N}, kInt);
672   std::vector<int32_t> a_buffer(N, 41);
673   std::vector<int32_t> b_buffer(N, 1);
674   std::vector<int32_t> c_buffer(N, 1);
675 
676   VarHandle i("i", kInt);
677   auto expr = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i))));
678 
679   LLVMCodeGen cg(expr, {a, b, c});
680 
681   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
682   ASSERT_EQ(cg.value<int>(args), 0);
683 
684   ASSERT_EQ(a_buffer.size(), N);
685   ASSERT_EQ(b_buffer.size(), N);
686   ASSERT_EQ(c_buffer.size(), N);
687   assertAllEqual(a_buffer, 41);
688   assertAllEqual(b_buffer, 1);
689   assertAllEqual(c_buffer, 42);
690 }
691 
TEST(LLVM,ElemwiseAddFloat)692 TEST(LLVM, ElemwiseAddFloat) {
693   constexpr int N = 1024;
694   BufHandle a("A", {N}, kFloat);
695   BufHandle b("B", {N}, kFloat);
696   BufHandle c("C", {N}, kFloat);
697   std::vector<float> a_buffer(N, 41);
698   std::vector<float> b_buffer(N, 1);
699   std::vector<float> c_buffer(N, 1);
700 
701   VarHandle i("i", kInt);
702   auto expr = For::make(i, 0, N, c.store({i}, a.load(i) + b.load(i)));
703 
704   LLVMCodeGen cg(expr, {a, b, c});
705 
706   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
707   ASSERT_EQ(cg.value<int>(args), 0);
708 
709   ASSERT_EQ(a_buffer.size(), N);
710   ASSERT_EQ(b_buffer.size(), N);
711   ASSERT_EQ(c_buffer.size(), N);
712   assertAllEqual(a_buffer, 41.0f);
713   assertAllEqual(b_buffer, 1.0f);
714   assertAllEqual(c_buffer, 42.0f);
715 }
716 
TEST(LLVM,ElemwiseLog10Float)717 TEST(LLVM, ElemwiseLog10Float) {
718   constexpr int N = 1024;
719   BufHandle a("A", {N}, kFloat);
720   BufHandle b("B", {N}, kFloat);
721   std::vector<float> a_buffer(N, 10.0f);
722   std::vector<float> b_buffer(N, 2.0f);
723 
724   VarHandle i("i", kInt);
725   auto expr = For::make(
726       i,
727       0,
728       N / 4,
729       b.store(
730           {Ramp::make(i * 4, 1, 4)}, log10(a.load({Ramp::make(i * 4, 1, 4)}))));
731 
732   LLVMCodeGen cg(expr, {a, b});
733 
734   std::vector<void*> args({a_buffer.data(), b_buffer.data()});
735   ASSERT_EQ(cg.value<int>(args), 0);
736 
737   ASSERT_EQ(a_buffer.size(), N);
738   ASSERT_EQ(b_buffer.size(), N);
739   assertAllEqual(a_buffer, 10.0f);
740   assertAllEqual(b_buffer, 1.0f);
741 }
742 
TEST(LLVM,ElemwiseLog1pFloat)743 TEST(LLVM, ElemwiseLog1pFloat) {
744   constexpr int N = 1024;
745   BufHandle a("A", {N}, kFloat);
746   BufHandle b("B", {N}, kFloat);
747   std::vector<float> a_buffer(N, expf(3.0f) - 1);
748   std::vector<float> b_buffer(N, 42.0f);
749 
750   VarHandle i("i", kInt);
751   auto expr = For::make(
752       i,
753       0,
754       N / 4,
755       b.store(
756           {Ramp::make(i * 4, 1, 4)}, log1p(a.load({Ramp::make(i * 4, 1, 4)}))));
757 
758   LLVMCodeGen cg(expr, {a, b});
759 
760   std::vector<void*> args({a_buffer.data(), b_buffer.data()});
761   ASSERT_EQ(cg.value<int>(args), 0);
762 
763   ASSERT_EQ(a_buffer.size(), N);
764   ASSERT_EQ(b_buffer.size(), N);
765   assertAllEqual(a_buffer, expf(3.0f) - 1);
766   ExpectAllNear(b_buffer, 3.0f, 1e-5f);
767 }
768 
TEST(LLVM,ElemwiseMaxInt)769 TEST(LLVM, ElemwiseMaxInt) {
770   constexpr int N = 1024;
771   BufHandle a("A", {N}, kInt);
772   BufHandle b("B", {N}, kInt);
773   BufHandle c("C", {N}, kInt);
774   std::vector<int> a_buffer(N, 41);
775   std::vector<int> b_buffer(N, 1);
776   std::vector<int> c_buffer(N, 1);
777 
778   VarHandle i("i", kInt);
779   auto expr =
780       For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false)));
781 
782   LLVMCodeGen cg(expr, {a, b, c});
783 
784   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
785   ASSERT_EQ(cg.value<int>(args), 0);
786 
787   ASSERT_EQ(a_buffer.size(), N);
788   ASSERT_EQ(b_buffer.size(), N);
789   ASSERT_EQ(c_buffer.size(), N);
790   assertAllEqual(a_buffer, 41);
791   assertAllEqual(b_buffer, 1);
792   assertAllEqual(c_buffer, 41);
793 }
794 
TEST(LLVM,ElemwiseMinInt)795 TEST(LLVM, ElemwiseMinInt) {
796   constexpr int N = 1024;
797   BufHandle a("A", {N}, kInt);
798   BufHandle b("B", {N}, kInt);
799   BufHandle c("C", {N}, kInt);
800   std::vector<int> a_buffer(N, 41);
801   std::vector<int> b_buffer(N, 1);
802   std::vector<int> c_buffer(N, 1);
803 
804   VarHandle i("i", kInt);
805   auto expr =
806       For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false)));
807 
808   LLVMCodeGen cg(expr, {a, b, c});
809 
810   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
811   ASSERT_EQ(cg.value<int>(args), 0);
812 
813   ASSERT_EQ(a_buffer.size(), N);
814   ASSERT_EQ(b_buffer.size(), N);
815   ASSERT_EQ(c_buffer.size(), N);
816   assertAllEqual(a_buffer, 41);
817   assertAllEqual(b_buffer, 1);
818   assertAllEqual(c_buffer, 1);
819 }
820 
TEST(LLVM,ElemwiseMaxFloat)821 TEST(LLVM, ElemwiseMaxFloat) {
822   constexpr int N = 1024;
823   BufHandle a("A", {N}, kFloat);
824   BufHandle b("B", {N}, kFloat);
825   BufHandle c("C", {N}, kFloat);
826   std::vector<float> a_buffer(N, 41);
827   std::vector<float> b_buffer(N, 1);
828   std::vector<float> c_buffer(N, 1);
829 
830   VarHandle i("i", kInt);
831   auto expr =
832       For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false)));
833 
834   LLVMCodeGen cg(expr, {a, b, c});
835 
836   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
837   ASSERT_EQ(cg.value<int>(args), 0);
838 
839   ASSERT_EQ(a_buffer.size(), N);
840   ASSERT_EQ(b_buffer.size(), N);
841   ASSERT_EQ(c_buffer.size(), N);
842   assertAllEqual(a_buffer, 41.0f);
843   assertAllEqual(b_buffer, 1.0f);
844   assertAllEqual(c_buffer, 41.0f);
845 }
846 
TEST(LLVM,ElemwiseMaxNaNFloat)847 TEST(LLVM, ElemwiseMaxNaNFloat) {
848   constexpr int N = 1024;
849   BufHandle a("A", {N}, kFloat);
850   BufHandle b("B", {N}, kFloat);
851   BufHandle c("C", {N}, kFloat);
852   std::vector<float> a_buffer(N, NAN);
853   std::vector<float> b_buffer(N, 1);
854   std::vector<float> c_buffer(N, 1);
855 
856   VarHandle i("i", kInt);
857   auto expr =
858       For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false)));
859 
860   LLVMCodeGen cg(expr, {a, b, c});
861 
862   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
863   ASSERT_EQ(cg.value<int>(args), 0);
864 
865   ASSERT_EQ(a_buffer.size(), N);
866   ASSERT_EQ(b_buffer.size(), N);
867   ASSERT_EQ(c_buffer.size(), N);
868   assertAllEqual(b_buffer, 1.0f);
869   for (auto const& elt : c_buffer) {
870     ASSERT_TRUE(std::isnan(elt));
871   }
872 }
873 
TEST(LLVM,ElemwiseMinFloat)874 TEST(LLVM, ElemwiseMinFloat) {
875   constexpr int N = 1024;
876   BufHandle a("A", {N}, kFloat);
877   BufHandle b("B", {N}, kFloat);
878   BufHandle c("C", {N}, kFloat);
879   std::vector<float> a_buffer(N, 41);
880   std::vector<float> b_buffer(N, 1);
881   std::vector<float> c_buffer(N, 1);
882 
883   VarHandle i("i", kInt);
884   auto expr =
885       For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false)));
886 
887   LLVMCodeGen cg(expr, {a, b, c});
888 
889   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
890   ASSERT_EQ(cg.value<int>(args), 0);
891 
892   ASSERT_EQ(a_buffer.size(), N);
893   ASSERT_EQ(b_buffer.size(), N);
894   ASSERT_EQ(c_buffer.size(), N);
895   assertAllEqual(a_buffer, 41.0f);
896   assertAllEqual(b_buffer, 1.0f);
897   assertAllEqual(c_buffer, 1.0f);
898 }
899 
TEST(LLVM,ElemwiseMinNaNFloat)900 TEST(LLVM, ElemwiseMinNaNFloat) {
901   constexpr int N = 1024;
902   BufHandle a("A", {N}, kFloat);
903   BufHandle b("B", {N}, kFloat);
904   BufHandle c("C", {N}, kFloat);
905   std::vector<float> a_buffer(N, NAN);
906   std::vector<float> b_buffer(N, 1);
907   std::vector<float> c_buffer(N, 1);
908 
909   VarHandle i("i", kInt);
910   auto expr =
911       For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false)));
912 
913   LLVMCodeGen cg(expr, {a, b, c});
914 
915   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
916   ASSERT_EQ(cg.value<int>(args), 0);
917 
918   ASSERT_EQ(a_buffer.size(), N);
919   ASSERT_EQ(b_buffer.size(), N);
920   ASSERT_EQ(c_buffer.size(), N);
921   assertAllEqual(b_buffer, 1.0f);
922   for (auto const& elt : c_buffer) {
923     ASSERT_TRUE(std::isnan(elt));
924   }
925 }
926 
TEST(LLVM,ElemwiseMod)927 TEST(LLVM, ElemwiseMod) {
928   constexpr int N = 1024;
929   BufHandle a("A", {N}, kInt);
930   BufHandle b("B", {N}, kInt);
931   BufHandle c("C", {N}, kInt);
932   std::vector<int32_t> a_buffer(N, 41);
933   std::vector<int32_t> b_buffer(N, 23);
934   std::vector<int32_t> c_buffer(N, 18);
935 
936   VarHandle i("i", kInt);
937   auto expr = For::make(i, 0, N, c.store({i}, Mod::make(a.load(i), b.load(i))));
938 
939   LLVMCodeGen cg(expr, {a, b, c});
940 
941   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
942   ASSERT_EQ(cg.value<int>(args), 0);
943 
944   ASSERT_EQ(a_buffer.size(), N);
945   ASSERT_EQ(b_buffer.size(), N);
946   ASSERT_EQ(c_buffer.size(), N);
947   assertAllEqual(a_buffer, 41);
948   assertAllEqual(b_buffer, 23);
949   assertAllEqual(c_buffer, 18);
950 }
951 
TEST(LLVM,CompareSelectIntEQ)952 TEST(LLVM, CompareSelectIntEQ) {
953   constexpr int N = 1024;
954   BufHandle a("A", {N}, kInt);
955   BufHandle b("B", {N}, kInt);
956   BufHandle c("C", {N}, kInt);
957   std::vector<int> a_buffer(N, 1);
958   std::vector<int> b_buffer(N, 1);
959   std::vector<int> c_buffer(N, 0);
960   std::vector<int> c_ref(N, 1);
961 
962   for (int i = 0; i < N / 2; i++) {
963     b_buffer[i] = 0;
964     c_ref[i] = 0;
965   }
966 
967   VarHandle i("i", kInt);
968   auto expr = For::make(
969       i,
970       0,
971       N,
972       c.store(
973           {i},
974           CompareSelect::make(
975               a.load(i), b.load(i), CompareSelectOperation::kEQ)));
976 
977   LLVMCodeGen cg(expr, {a, b, c});
978 
979   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
980   ASSERT_EQ(cg.value<int>(args), 0);
981 
982   ASSERT_EQ(a_buffer.size(), N);
983   ASSERT_EQ(b_buffer.size(), N);
984   ASSERT_EQ(c_buffer.size(), N);
985 
986   assertAllEqual(a_buffer, 1);
987   for (const auto i : c10::irange(N)) {
988     ASSERT_EQ(c_ref[i], c_buffer[i]);
989   }
990 }
991 
TEST(LLVM,CompareSelectFloatEQ)992 TEST(LLVM, CompareSelectFloatEQ) {
993   constexpr int N = 1024;
994   BufHandle a("A", {N}, kFloat);
995   BufHandle b("B", {N}, kFloat);
996   BufHandle c("C", {N}, kInt);
997   std::vector<float> a_buffer(N, 1.0f);
998   std::vector<float> b_buffer(N, 1.0f);
999   std::vector<int> c_buffer(N, 0);
1000 
1001   VarHandle i("i", kInt);
1002   auto expr = For::make(
1003       i,
1004       0,
1005       N,
1006       c.store(
1007           {i},
1008           CompareSelect::make(
1009               a.load(i), b.load(i), CompareSelectOperation::kEQ)));
1010 
1011   LLVMCodeGen cg(expr, {a, b, c});
1012 
1013   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1014   ASSERT_EQ(cg.value<int>(args), 0);
1015 
1016   ASSERT_EQ(a_buffer.size(), N);
1017   ASSERT_EQ(b_buffer.size(), N);
1018   ASSERT_EQ(c_buffer.size(), N);
1019 
1020   assertAllEqual(a_buffer, 1.0f);
1021   assertAllEqual(b_buffer, 1.0f);
1022   assertAllEqual(c_buffer, 1);
1023 }
1024 
TEST(LLVM,CompareSelectByteGT)1025 TEST(LLVM, CompareSelectByteGT) {
1026   constexpr int N = 1024;
1027   BufHandle a("A", {N}, kByte);
1028   BufHandle b("B", {N}, kByte);
1029   BufHandle c("C", {N}, kInt);
1030   std::vector<uint8_t> a_buffer(N, 0);
1031   std::vector<uint8_t> b_buffer(N, 0);
1032   std::vector<int> c_buffer(N, 0);
1033   std::vector<int> c_ref(N, 0);
1034 
1035   for (int i = 0; i < N / 2; i++) {
1036     a_buffer[i] = 128;
1037     c_ref[i] = 1;
1038   }
1039 
1040   VarHandle i("i", kInt);
1041   auto expr = For::make(
1042       i,
1043       0,
1044       N,
1045       c.store(
1046           {i},
1047           CompareSelect::make(
1048               a.load(i), b.load(i), CompareSelectOperation::kGT)));
1049 
1050   LLVMCodeGen cg(expr, {a, b, c});
1051 
1052   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1053   ASSERT_EQ(cg.value<int>(args), 0);
1054 
1055   ASSERT_EQ(a_buffer.size(), N);
1056   ASSERT_EQ(b_buffer.size(), N);
1057   ASSERT_EQ(c_buffer.size(), N);
1058 
1059   assertAllEqual(b_buffer, uint8_t(0));
1060   for (const auto i : c10::irange(N)) {
1061     ASSERT_EQ(c_ref[i], c_buffer[i]);
1062   }
1063 }
1064 
TEST(LLVM,CompareSelectByteGE)1065 TEST(LLVM, CompareSelectByteGE) {
1066   constexpr int N = 1024;
1067   BufHandle a("A", {N}, kByte);
1068   BufHandle b("B", {N}, kByte);
1069   BufHandle c("C", {N}, kInt);
1070   std::vector<uint8_t> a_buffer(N, 0);
1071   std::vector<uint8_t> b_buffer(N, 0);
1072   std::vector<int> c_buffer(N, 0);
1073   std::vector<int> c_ref(N, 1);
1074 
1075   VarHandle i("i", kInt);
1076   auto expr = For::make(
1077       i,
1078       0,
1079       N,
1080       c.store(
1081           {i},
1082           CompareSelect::make(
1083               a.load(i), b.load(i), CompareSelectOperation::kGE)));
1084 
1085   LLVMCodeGen cg(expr, {a, b, c});
1086 
1087   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1088   ASSERT_EQ(cg.value<int>(args), 0);
1089 
1090   ASSERT_EQ(a_buffer.size(), N);
1091   ASSERT_EQ(b_buffer.size(), N);
1092   ASSERT_EQ(c_buffer.size(), N);
1093 
1094   assertAllEqual(b_buffer, uint8_t(0));
1095   for (const auto i : c10::irange(N)) {
1096     ASSERT_EQ(c_ref[i], c_buffer[i]);
1097   }
1098 }
1099 
TEST(LLVM,CompareSelectByteLT)1100 TEST(LLVM, CompareSelectByteLT) {
1101   constexpr int N = 1024;
1102   BufHandle a("A", {N}, kByte);
1103   BufHandle b("B", {N}, kByte);
1104   BufHandle c("C", {N}, kInt);
1105   std::vector<uint8_t> a_buffer(N, 0);
1106   std::vector<uint8_t> b_buffer(N, 128);
1107   std::vector<int> c_buffer(N, 0);
1108   std::vector<int> c_ref(N, 1);
1109 
1110   for (int i = 0; i < N / 2; i++) {
1111     a_buffer[i] = 128;
1112     c_ref[i] = 0;
1113   }
1114 
1115   VarHandle i("i", kInt);
1116   auto expr = For::make(
1117       i,
1118       0,
1119       N,
1120       c.store(
1121           {i},
1122           CompareSelect::make(
1123               a.load(i), b.load(i), CompareSelectOperation::kLT)));
1124 
1125   LLVMCodeGen cg(expr, {a, b, c});
1126 
1127   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1128   ASSERT_EQ(cg.value<int>(args), 0);
1129 
1130   ASSERT_EQ(a_buffer.size(), N);
1131   ASSERT_EQ(b_buffer.size(), N);
1132   ASSERT_EQ(c_buffer.size(), N);
1133 
1134   assertAllEqual(b_buffer, uint8_t(128));
1135   for (const auto i : c10::irange(N)) {
1136     ASSERT_EQ(c_ref[i], c_buffer[i]);
1137   }
1138 }
1139 
TEST(LLVM,CompareSelectByteLE)1140 TEST(LLVM, CompareSelectByteLE) {
1141   constexpr int N = 1024;
1142   BufHandle a("A", {N}, kByte);
1143   BufHandle b("B", {N}, kByte);
1144   BufHandle c("C", {N}, kInt);
1145   std::vector<uint8_t> a_buffer(N, 0);
1146   std::vector<uint8_t> b_buffer(N, 128);
1147   std::vector<int> c_buffer(N, 0);
1148   std::vector<int> c_ref(N, 1);
1149 
1150   VarHandle i("i", kInt);
1151   auto expr = For::make(
1152       i,
1153       0,
1154       N,
1155       c.store(
1156           {i},
1157           CompareSelect::make(
1158               a.load(i), b.load(i), CompareSelectOperation::kLE)));
1159 
1160   LLVMCodeGen cg(expr, {a, b, c});
1161 
1162   std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1163   ASSERT_EQ(cg.value<int>(args), 0);
1164 
1165   ASSERT_EQ(a_buffer.size(), N);
1166   ASSERT_EQ(b_buffer.size(), N);
1167   ASSERT_EQ(c_buffer.size(), N);
1168 
1169   assertAllEqual(b_buffer, uint8_t(128));
1170   for (const auto i : c10::irange(N)) {
1171     ASSERT_EQ(c_ref[i], c_buffer[i]);
1172   }
1173 }
1174 
TEST(LLVM,StoreFloat)1175 TEST(LLVM, StoreFloat) {
1176   BufHandle result("result", {1}, kFloat);
1177   std::vector<float> result_buffer = {0.0f};
1178   auto expr = result.store({0}, FloatImm::make(3.14f));
1179   LLVMCodeGen cg(expr, {result});
1180   std::vector<void*> args({result_buffer.data()});
1181   ASSERT_EQ(cg.value<int>(args), 0);
1182   ASSERT_EQ(result_buffer[0], 3.14f);
1183 }
1184 
TEST(LLVM,SimpleMath01)1185 TEST(LLVM, SimpleMath01) {
1186   const int N = 1024;
1187   Tensor tensor = Compute(
1188       "f", {N}, [](const VarHandle& i) { return cast<float>(i * i + 1); });
1189   LoopNest l({tensor});
1190   StmtPtr stmt = l.root_stmt();
1191   BufHandle f_buf(tensor.buf());
1192   LLVMCodeGen cg(stmt, {f_buf});
1193 
1194   PaddedBuffer<float> f_v(N, "f_v");
1195   std::vector<void*> args({f_v.data()});
1196   int value = cg.value<int>(args);
1197   ASSERT_EQ(value, 0);
1198   PaddedBuffer<float> f_ref(N, "f_ref");
1199   for (const auto i : c10::irange(N)) {
1200     f_ref(i) = i * i + 1;
1201   }
1202   ExpectAllNear(f_v, f_ref, 1e-5);
1203 }
1204 
TEST(LLVM,ComputeMul)1205 TEST(LLVM, ComputeMul) {
1206   const int N = 1024;
1207   BufHandle a("a", {N}, kFloat);
1208   BufHandle b("b", {N}, kFloat);
1209   Tensor c = Compute(
1210       "c", {N}, [&](const VarHandle& i) { return a.load(i) * b.load(i); });
1211 
1212   BufHandle c_buf(c.buf());
1213   LoopNest l({c});
1214   StmtPtr s = l.root_stmt();
1215 
1216   LLVMCodeGen cg(s, {a, b, c_buf});
1217 
1218   std::vector<float> a_vec(N, 21.0f);
1219   std::vector<float> b_vec(N, 2.0f);
1220   std::vector<float> c_vec(N, 0.0f);
1221   std::vector<void*> args({a_vec.data(), b_vec.data(), c_vec.data()});
1222   ASSERT_EQ(cg.value<int>(args), 0);
1223   assertAllEqual(c_vec, 42.0f);
1224 }
1225 
TEST(LLVM,BroadcastAdd)1226 TEST(LLVM, BroadcastAdd) {
1227   const int M = 32;
1228   const int N = 1024;
1229   BufHandle a("a", {M, N}, kFloat);
1230   BufHandle b("b", {N}, kFloat);
1231   Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) {
1232     return a.load(i, j) + b.load(j);
1233   });
1234 
1235   BufHandle c_buf(c.buf());
1236   LoopNest l({c});
1237   l.prepareForCodegen();
1238   StmtPtr s = l.root_stmt();
1239 
1240   LLVMCodeGen cg(s, {a, b, c_buf});
1241 
1242   std::vector<float> av(M * N);
1243   std::iota(av.begin(), av.end(), 0);
1244   std::vector<float> bv(N);
1245   std::iota(bv.begin(), bv.end(), 0);
1246   std::vector<float> cv(M * N, 0);
1247   std::vector<void*> args({av.data(), bv.data(), cv.data()});
1248   ASSERT_EQ(cg.value<int>(args), 0);
1249 
1250   for (const auto i : c10::irange(M)) {
1251     for (const auto j : c10::irange(N)) {
1252       ASSERT_EQ(cv[i * N + j], av[i * N + j] + bv[j]);
1253     }
1254   }
1255 }
1256 
TEST(LLVM,BitwiseOps)1257 TEST(LLVM, BitwiseOps) {
1258   auto a = IntImm::make(59);
1259   auto b = IntImm::make(11);
1260   auto c = IntImm::make(101);
1261   auto d = IntImm::make(2);
1262 
1263   ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;
1264   LLVMExprEval cg(f);
1265 
1266   ASSERT_EQ(cg.value<int>(), 11);
1267 }
1268 
TEST(LLVM,ArithmeticRightShift)1269 TEST(LLVM, ArithmeticRightShift) {
1270   auto a = CharImm::make(-4);
1271   auto b = CharImm::make(1);
1272   ExprHandle f = a >> b;
1273   LLVMExprEval cg(f);
1274   ASSERT_EQ(cg.value<int8_t>(), -2);
1275 }
1276 
TEST(LLVM,LogicalRightShift)1277 TEST(LLVM, LogicalRightShift) {
1278   auto a = ByteImm::make(0xfc);
1279   auto b = ByteImm::make(1);
1280   ExprHandle f = a >> b;
1281   LLVMExprEval cg(f);
1282   ASSERT_EQ(cg.value<uint8_t>(), 0x7e);
1283 }
1284 
TEST(LLVM,DynamicShapeAdd)1285 TEST(LLVM, DynamicShapeAdd) {
1286   auto testWithSize = [](int32_t size) {
1287     VarHandle n("n", kInt);
1288     BufHandle a("a", {n}, kFloat);
1289     BufHandle b("b", {n}, kFloat);
1290     BufHandle c("c", {n}, kFloat);
1291     VarHandle i("i", kInt);
1292     StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
1293     std::vector<float> aData(size, 1.0f);
1294     std::vector<float> bData(size, 2.0f);
1295     std::vector<float> cData(size, 0.0f);
1296     LLVMCodeGen cg(s, {a, b, c, n});
1297     std::vector<void*> args({aData.data(), bData.data(), cData.data(), &size});
1298     cg.value<float>(args);
1299     ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
1300   };
1301   testWithSize(1);
1302   testWithSize(16);
1303   testWithSize(37);
1304 }
1305 
TEST(LLVM,BindDynamicShapeAdd)1306 TEST(LLVM, BindDynamicShapeAdd) {
1307   auto testWithSize = [](int32_t size) {
1308     VarHandle n("n", kInt);
1309     BufHandle a("a", {n}, kFloat);
1310     BufHandle b("b", {n}, kFloat);
1311     BufHandle c("c", {n}, kFloat);
1312     VarHandle i("i", kInt);
1313     StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
1314     std::vector<float> aData(size, 1.0f);
1315     std::vector<float> bData(size, 2.0f);
1316     std::vector<float> cData(size, 0.0f);
1317     LLVMCodeGen cg(s, {a, b, c, n});
1318     cg.call({aData, bData, cData, size});
1319     ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
1320   };
1321   testWithSize(1);
1322   testWithSize(16);
1323   testWithSize(37);
1324 }
1325 
TEST(LLVM,TensorDynamicShapeAdd)1326 TEST(LLVM, TensorDynamicShapeAdd) {
1327   auto testWithSize = [](int32_t size) {
1328     VarHandle n("n", kInt);
1329     BufHandle a("a", {n}, kFloat);
1330     BufHandle b("b", {n}, kFloat);
1331     Tensor c = Compute(
1332         "c", {n}, [&](const VarHandle& i) { return a.load(i) + b.load(i); });
1333     LoopNest l({c});
1334     StmtPtr s = l.root_stmt();
1335     LLVMCodeGen cg(s, {a, b, c, n});
1336     std::vector<float> aData(size, 1.0f);
1337     std::vector<float> bData(size, 2.0f);
1338     std::vector<float> cData(size, 0.0f);
1339     cg.call({aData, bData, cData, size});
1340     ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
1341   };
1342   testWithSize(1);
1343   testWithSize(16);
1344   testWithSize(37);
1345 }
1346 
TEST(LLVM,DynamicShape2D)1347 TEST(LLVM, DynamicShape2D) {
1348   auto testWithSize = [](int32_t M, int32_t N) {
1349     VarHandle m("m", kInt);
1350     VarHandle n("n", kInt);
1351     BufHandle a("a", {m, n}, kFloat);
1352     BufHandle b("b", {m, n}, kFloat);
1353     Tensor c =
1354         Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) {
1355           return a.load(i, j) + b.load(i, j);
1356         });
1357     LoopNest l({c});
1358     l.prepareForCodegen();
1359     StmtPtr s = l.root_stmt();
1360     LLVMCodeGen cg(s, {a, b, c, m, n});
1361     std::vector<float> aData(M * N, 1.0f);
1362     std::vector<float> bData(M * N, 2.0f);
1363     std::vector<float> cData(M * N, 0.0f);
1364     cg.call({aData, bData, cData, M, N});
1365     ExpectAllNear(cData, std::vector<float>(M * N, 3.0f), 1e-7);
1366   };
1367   testWithSize(1, 8);
1368   testWithSize(16, 32);
1369   testWithSize(37, 11);
1370 }
1371 
TEST(LLVM,EmptyStmt)1372 TEST(LLVM, EmptyStmt) {
1373   StmtPtr s = alloc<Block>(std::vector<StmtPtr>({}));
1374 
1375   LLVMCodeGen cg(s, {});
1376   cg.call({});
1377   // Just don't crash.
1378 }
1379 
TEST(LLVM,EliminatedStmt)1380 TEST(LLVM, EliminatedStmt) {
1381   BufHandle a("a", {1}, kFloat);
1382 
1383   Tensor c = Compute("c", {0}, [&](const VarHandle& m) { return m; });
1384 
1385   LoopNest l({c});
1386   l.prepareForCodegen();
1387   StmtPtr s = l.root_stmt();
1388   s = IRSimplifier::simplify(s);
1389   LLVMCodeGen cg(s, {a, c});
1390   std::vector<float> aData(1, 1.0f);
1391   std::vector<float> cData(0, 0.0f);
1392   cg.call({aData, cData});
1393 }
1394 
TEST(LLVM,SimpleReduction)1395 TEST(LLVM, SimpleReduction) {
1396   int M = 128;
1397   int N = 64;
1398 
1399   BufHandle a("a", {1, M, N}, kFloat);
1400 
1401   Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
1402   LoopNest loop({b});
1403 
1404   loop.prepareForCodegen();
1405   StmtPtr s = loop.root_stmt();
1406   s = IRSimplifier::simplify(s);
1407 
1408   LLVMCodeGen cg(s, {a, b});
1409 
1410   PaddedBuffer<float> a_v(1, M, N, "a_v");
1411   PaddedBuffer<float> b_v(1, "b_v");
1412   PaddedBuffer<float> b_ref(1, "b_ref");
1413 
1414   b_ref(0) = 0;
1415   for (const auto i : c10::irange(M)) {
1416     for (const auto j : c10::irange(N)) {
1417       int v = i + j;
1418       a_v(0, i, j) = v;
1419       b_ref(0) += v;
1420     }
1421   }
1422 
1423   cg.call({a_v, b_v});
1424 
1425   ExpectAllNear(b_v, b_ref, 1e-5);
1426 }
1427 
TEST(LLVM,RFactorReduction)1428 TEST(LLVM, RFactorReduction) {
1429   int M = 128;
1430   int N = 64;
1431 
1432   BufHandle a("a", {1, M, N}, kFloat);
1433 
1434   Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
1435   LoopNest loop({b});
1436 
1437   std::vector<ForPtr> loops = loop.getLoopStmtsFor(b);
1438   ForPtr loop_m = loops.at(1);
1439   ForPtr loop_n = loops.at(2);
1440   loop.reorderAxis(loop_m, loop_n);
1441 
1442   loops = loop.getLoopStmtsFor(b);
1443   loop_m = loops.at(2);
1444   loop_n = loops.at(1);
1445   auto b_body = loop.getAllWritesToBuf(b.buf())[1];
1446   ASSERT_TRUE(loop.rfactor(b_body, loop_n));
1447 
1448   loop.prepareForCodegen();
1449   StmtPtr s = loop.root_stmt();
1450   s = IRSimplifier::simplify(s);
1451 
1452   LLVMCodeGen cg(s, {a, b});
1453 
1454   PaddedBuffer<float> a_v(1, M, N, "a_v");
1455   PaddedBuffer<float> b_v(1, "b_v");
1456   PaddedBuffer<float> b_ref(1, "b_ref");
1457 
1458   b_ref(0) = 0;
1459   for (const auto i : c10::irange(M)) {
1460     for (const auto j : c10::irange(N)) {
1461       int v = i + j;
1462       a_v(0, i, j) = v;
1463       b_ref(0) += v;
1464     }
1465   }
1466 
1467   cg.call({a_v, b_v});
1468 
1469   ExpectAllNear(b_v, b_ref, 1e-5);
1470 }
1471 
TEST(LLVM,RFactorVectorizedReduction)1472 TEST(LLVM, RFactorVectorizedReduction) {
1473   int M = 128;
1474   int N = 64;
1475 
1476   BufHandle a("a", {1, M, N}, kFloat);
1477 
1478   Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
1479   LoopNest loopnest({b});
1480   std::vector<ForPtr> loops = loopnest.getLoopStmtsFor(b);
1481   // Reorder n and m loops
1482   loopnest.reorderAxis(loops.at(1), loops.at(2));
1483   auto b_body = loopnest.getAllWritesToBuf(b.buf()).at(1);
1484   auto all_loops = loopnest.getAllLoopNestsWritingToBuf(b.buf());
1485   ASSERT_TRUE(all_loops.size() == 2 && all_loops[1].size() == 3);
1486   ASSERT_TRUE(loopnest.rfactor(b_body, all_loops[1][1]));
1487   auto distributed_loops = loopnest.distributeLoop(all_loops[1][1]);
1488 
1489   // Vectorize initializer of rfac_buf
1490   ASSERT_TRUE(LoopNest::vectorize(distributed_loops[0]));
1491   // Vectorize producer of rfac_buf
1492   ASSERT_TRUE(LoopNest::vectorize(distributed_loops[1]));
1493   loopnest.simplify();
1494 
1495   loopnest.prepareForCodegen();
1496 
1497   StmtPtr s = IRSimplifier::simplify(loopnest.root_stmt());
1498   LLVMCodeGen cg(s, {a, b});
1499 
1500   PaddedBuffer<float> a_v(1, M, N, "a_v");
1501   PaddedBuffer<float> b_v(1, "b_v");
1502   PaddedBuffer<float> b_ref(1, "b_ref");
1503 
1504   b_ref(0) = 0;
1505   for (const auto i : c10::irange(M)) {
1506     for (const auto j : c10::irange(N)) {
1507       int v = i + j;
1508       a_v(0, i, j) = v;
1509       b_ref(0) += v;
1510     }
1511   }
1512 
1513   cg.call({a_v, b_v});
1514 
1515   ExpectAllNear(b_v, b_ref, 1e-5);
1516 }
1517 
1518 template <bool outer, bool inner>
testSimpleParallel()1519 static void testSimpleParallel() {
1520   // Compute a simple operation, and try all loop-axis combination to be
1521   // parallel or sequential.
1522   const int M = 4;
1523   const int N = 6;
1524   Tensor f = Compute("f", {M, N}, [](const VarHandle& m, const VarHandle& n) {
1525     return cast<float>(m + n);
1526   });
1527   LoopNest loop_nest({f});
1528   auto const& loops = loop_nest.getLoopStmtsFor(f);
1529   ForPtr m = loops[0];
1530   ForPtr n = loops[1];
1531   if (outer) {
1532     m->set_parallel();
1533   }
1534   if (inner) {
1535     n->set_parallel();
1536   }
1537   loop_nest.prepareForCodegen();
1538   StmtPtr stmt = loop_nest.root_stmt();
1539   LLVMCodeGen cg(stmt, {f});
1540 
1541   PaddedBuffer<float> f_v(M, N, "f_v");
1542   std::vector<void*> args({f_v.data()});
1543   int value = cg.value<int>(args);
1544   ASSERT_EQ(value, 0);
1545   PaddedBuffer<float> f_ref(M, N, "f_ref");
1546   for (const auto m : c10::irange(M)) {
1547     for (const auto n : c10::irange(N)) {
1548       f_ref(m, n) = m + n;
1549     }
1550   }
1551   ExpectAllNear(f_v, f_ref, 1e-5);
1552 }
1553 
TEST(LLVM,SimpleParallelSS)1554 TEST(LLVM, SimpleParallelSS) {
1555   testSimpleParallel<false, false>();
1556 }
TEST(LLVM,SimpleParallelSP)1557 TEST(LLVM, SimpleParallelSP) {
1558   testSimpleParallel<false, true>();
1559 }
TEST(LLVM,SimpleParallelPS)1560 TEST(LLVM, SimpleParallelPS) {
1561   testSimpleParallel<true, false>();
1562 }
TEST(LLVM,SimpleParallelPP)1563 TEST(LLVM, SimpleParallelPP) {
1564   testSimpleParallel<true, true>();
1565 }
1566 
TEST(LLVM,CompositeParallel)1567 TEST(LLVM, CompositeParallel) {
1568   int loop_count = 6;
1569   int test_count = 1 << loop_count;
1570   // Compute a composite operation, and try all loop-axis combination to be
1571   // parallel or sequential.
1572   for (const auto test_cfg : c10::irange(test_count)) {
1573     int M = 5;
1574     int N = 7;
1575     Tensor t1 = Compute("t1", {M}, [](const VarHandle& m) { return m + 1.f; });
1576     Tensor t2 = Compute("t2", {N}, [](const VarHandle& n) { return n + 2.f; });
1577     Tensor t3 =
1578         Compute("t3", {M, N}, [=](const VarHandle& m, const VarHandle& n) {
1579           return t1.load(m) * t2.load(n);
1580         });
1581     Tensor t4 =
1582         Compute("t4", {M, N}, [=](const VarHandle& m, const VarHandle& n) {
1583           return t3.load(m, n) + m + n;
1584         });
1585     LoopNest loop_nest({t4}, {t1, t2, t3, t4});
1586     std::vector<ForPtr> loop_list;
1587     {
1588       auto const& loops = loop_nest.getLoopStmtsFor(t1);
1589       loop_list.push_back(loops[0]);
1590     }
1591     {
1592       auto const& loops = loop_nest.getLoopStmtsFor(t2);
1593       loop_list.push_back(loops[0]);
1594     }
1595     {
1596       auto const& loops = loop_nest.getLoopStmtsFor(t3);
1597       loop_list.push_back(loops[0]);
1598       loop_list.push_back(loops[1]);
1599     }
1600     {
1601       auto const& loops = loop_nest.getLoopStmtsFor(t4);
1602       loop_list.push_back(loops[0]);
1603       loop_list.push_back(loops[1]);
1604     }
1605     ASSERT_EQ(loop_list.size(), loop_count);
1606     for (const auto i : c10::irange(loop_count)) {
1607       if (test_cfg & (1 << i)) {
1608         loop_list[i]->set_parallel();
1609       }
1610     }
1611     loop_nest.prepareForCodegen();
1612     StmtPtr stmt = loop_nest.root_stmt();
1613     LLVMCodeGen cg(stmt, {t4});
1614 
1615     PaddedBuffer<float> t4_v(M, N, "t4_v");
1616     std::vector<void*> args({t4_v.data()});
1617     int value = cg.value<int>(args);
1618     ASSERT_EQ(value, 0);
1619     PaddedBuffer<float> t4_ref(M, N, "t4_ref");
1620     for (const auto m : c10::irange(M)) {
1621       for (const auto n : c10::irange(N)) {
1622         t4_ref(m, n) = (m + 1) * (n + 2) + m + n;
1623       }
1624     }
1625     ExpectAllNear(t4_v, t4_ref, 1e-5);
1626   }
1627 }
1628 
TEST(LLVM,VectorizedGEMM)1629 TEST(LLVM, VectorizedGEMM) {
1630   int M = 32;
1631   int N = 32;
1632   int K = 48;
1633 
1634   BufHandle AP("A", {M, K}, kFloat);
1635   BufHandle BP("B", {K, N}, kFloat);
1636   Tensor CT = Reduce(
1637       "gemm",
1638       {M, N},
1639       Sum(),
1640       [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
1641         return AP.load(m, k) * BP.load(k, n);
1642       },
1643       {K});
1644   LoopNest loop({CT});
1645 
1646   {
1647     auto const& loops = loop.getLoopStmtsFor(CT);
1648     ForPtr m = loops[0];
1649     loop.splitWithMask(m, 16);
1650   }
1651   {
1652     auto const& loops = loop.getLoopStmtsFor(CT);
1653     ForPtr n = loops[2];
1654     loop.splitWithMask(n, 16);
1655   }
1656   // mo, mi, no, ni, k ->
1657   // mo, no, mi, ni, k
1658   {
1659     auto const& loops = loop.getLoopStmtsFor(CT);
1660     ForPtr mi = loops[1];
1661     ForPtr no = loops[2];
1662     loop.reorderAxis(mi, no);
1663   }
1664   // mo, no, mi, ni, k ->
1665   // mo, no, mi, k, ni
1666   {
1667     auto const& loops = loop.getLoopStmtsFor(CT);
1668     ForPtr ni = loops[3];
1669     ForPtr k = loops[4];
1670     loop.reorderAxis(ni, k);
1671   }
1672   // mo, no, mi, k, ni ->
1673   // mo, no, k, mi, ni
1674   {
1675     auto const& loops = loop.getLoopStmtsFor(CT);
1676     ForPtr mi = loops[2];
1677     ForPtr k = loops[3];
1678     loop.reorderAxis(mi, k);
1679   }
1680   {
1681     auto loops = NodeFinder<For>::find(loop.root_stmt());
1682     ASSERT_TRUE(LoopNest::vectorize(loops[3]));
1683     ASSERT_TRUE(LoopNest::vectorize(loops.back()));
1684   }
1685 
1686   loop.prepareForCodegen();
1687 
1688   StmtPtr s = loop.root_stmt();
1689   s = IRSimplifier::simplify(s);
1690   LLVMCodeGen cg(s, {AP, BP, CT});
1691 
1692   PaddedBuffer<float> a_v(M, K, "a_v");
1693   PaddedBuffer<float> b_v(K, N, "b_v");
1694   PaddedBuffer<float> c_v(M, N, "c_v");
1695   PaddedBuffer<float> c_ref(M, N, "c_ref");
1696 
1697   for (const auto m : c10::irange(M)) {
1698     for (const auto n : c10::irange(N)) {
1699       c_ref(m, n) = 0.f;
1700       for (const auto k : c10::irange(K)) {
1701         c_ref(m, n) += a_v(m, k) * b_v(k, n);
1702       }
1703     }
1704   }
1705 
1706   cg.call({a_v, b_v, c_v});
1707 
1708   ExpectAllNear(c_v, c_ref, 1e-5);
1709 }
1710 
TEST(LLVM,CallRaw)1711 TEST(LLVM, CallRaw) {
1712   const int M = 32;
1713   VarHandle N("N", kInt);
1714   BufHandle a("a", {M, N}, kFloat);
1715   BufHandle b("b", {N}, kFloat);
1716   Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) {
1717     return a.load(i, j) + b.load(j);
1718   });
1719 
1720   LoopNest l({c});
1721   l.prepareForCodegen();
1722   StmtPtr s = l.root_stmt();
1723 
1724   int32_t N_value = 1024;
1725   std::vector<float> av(M * N_value);
1726   std::iota(av.begin(), av.end(), 0);
1727   std::vector<float> bv(N_value);
1728   std::iota(bv.begin(), bv.end(), 0);
1729   std::vector<float> cv(M * N_value, 0);
1730   std::vector<void*> args({av.data(), bv.data(), cv.data(), &N_value});
1731 
1732   LLVMCodeGen cg(s, {a, b, BufHandle(c.buf()), N});
1733   cg.call_raw(args);
1734 
1735   for (const auto i : c10::irange(M)) {
1736     for (const auto j : c10::irange(N_value)) {
1737       ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]);
1738     }
1739   }
1740 
1741   SimpleIREvaluator eval(s, {a, b, BufHandle(c.buf()), N});
1742   eval.call_raw(args);
1743 
1744   for (const auto i : c10::irange(M)) {
1745     for (const auto j : c10::irange(N_value)) {
1746       ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]);
1747     }
1748   }
1749 }
1750 
TEST(LLVM,CustomTarget)1751 TEST(LLVM, CustomTarget) {
1752   constexpr int M = 16;
1753   BufHandle a("a", {M}, kFloat);
1754   BufHandle b("b", {M}, kFloat);
1755   BufHandle c("c", {M}, kFloat);
1756   Tensor d = Compute("d", {M}, [&](const VarHandle& m) {
1757     return a.load(m) * b.load(m) + c.load(m);
1758   });
1759   LoopNest nest({d});
1760   nest.prepareForCodegen();
1761   auto cg = LLVMCodeGenBuilder(nest.root_stmt(), {a, b, c, d})
1762                 .triple("i686-elf")
1763                 .cpu("i386")
1764                 .build();
1765   std::ostringstream ss;
1766   ss << cg->getCodeText("asm");
1767   torch::jit::testing::FileCheck()
1768       .check("fadds")
1769       ->check("fmuls")
1770       ->check_not("vfmadd")
1771       ->run(ss.str());
1772 }
1773 
TEST(LLVM,CodeGenKernelFuncName)1774 TEST(LLVM, CodeGenKernelFuncName) {
1775   BufHandle a("A", {1}, kInt);
1776   BufHandle b("B", {1}, kInt);
1777   std::vector<int32_t> a_buffer = {42};
1778   std::vector<int32_t> b_buffer = {-11};
1779   auto store = b.store({0}, a.load(0));
1780 
1781   {
1782     LLVMCodeGen cg(store, {a, b});
1783     // Check that the kernel function name used by LLVMCodeGen
1784     // is not empty.
1785     ASSERT_NE(cg.kernel_func_name(), "");
1786   }
1787 
1788   {
1789     LLVMCodeGen cg(store, {a, b}, at::kCPU, "new_func");
1790     // Check that the kernel function name used by LLVMCodeGen
1791     // is the one that was given above.
1792     ASSERT_EQ(cg.kernel_func_name(), "new_func");
1793   }
1794 }
1795 
1796 } // namespace jit
1797 } // namespace torch
1798 
1799 #endif // TORCH_ENABLE_LLVM
1800