1 #ifdef TORCH_ENABLE_LLVM
2 #include <gtest/gtest.h>
3
4 #include <test/cpp/tensorexpr/test_base.h>
5
6 #include <c10/util/irange.h>
7 #include <test/cpp/tensorexpr/padded_buffer.h>
8 #include <test/cpp/tensorexpr/test_utils.h>
9 #include <torch/csrc/jit/tensorexpr/eval.h>
10 #include <torch/csrc/jit/tensorexpr/ir.h>
11 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
12 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
13 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
14 #include <torch/csrc/jit/tensorexpr/loopnest.h>
15 #include <torch/csrc/jit/tensorexpr/tensor.h>
16 #include <torch/csrc/jit/testing/file_check.h>
17
18 #include <cmath>
19 #include <numeric>
20
21 namespace torch {
22 namespace jit {
23 using namespace torch::jit::tensorexpr;
24
25 using LLVMExprEval = ExprEval<LLVMCodeGen>;
26
27 // Typed tests, can't use gtest params here due to the way we instantiate tests.
28 #define TEST_LLVM_SCALAR_TYPES(_) \
29 _(uint8_t, Byte, 24) \
30 _(int8_t, Char, -20) \
31 _(int16_t, Short, 3332) \
32 _(int, Int, 123456) \
33 _(int64_t, Long, 2631563121321) \
34 _(float, Float, 0.122) \
35 _(double, Double, 0.21312) \
36 _(at::Half, Half, 0.128f)
37
38 #define IMM_TEST(Type, Name, Val) \
39 TEST(LLVM, Name##ImmTest) { \
40 auto a = Name##Imm::make(Val); \
41 LLVMExprEval cg(a); \
42 if (std::is_floating_point<decltype(Val)>()) { \
43 ASSERT_NEAR(cg.value<Type>(), Val, 0.1); \
44 } else { \
45 ASSERT_EQ(cg.value<Type>(), Val); \
46 } \
47 }
48 TEST_LLVM_SCALAR_TYPES(IMM_TEST)
49 #undef IMM_TEST
50
51 #define ADD_TEST(Type, Name, Val) \
52 TEST(LLVM, Name##AddTest) { \
53 auto a = Name##Imm::make(Val); \
54 auto b = Name##Imm::make(Val * 2); \
55 auto c = Add::make(a, b); \
56 LLVMExprEval cg(c); \
57 if (std::is_floating_point<decltype(Val)>()) { \
58 ASSERT_NEAR(cg.value<Type>(), Val * 3, 0.1); \
59 } else { \
60 ASSERT_EQ(cg.value<Type>(), Val * 3); \
61 } \
62 }
TEST_LLVM_SCALAR_TYPES(ADD_TEST)63 TEST_LLVM_SCALAR_TYPES(ADD_TEST)
64 #undef ADD_TEST
65
66 #define SUB_TEST(Type, Name, Val) \
67 TEST(LLVM, Name##SubTest) { \
68 auto a = Name##Imm::make(Val * 2); \
69 auto b = Name##Imm::make(Val); \
70 auto c = Sub::make(a, b); \
71 LLVMExprEval cg(c); \
72 if (std::is_floating_point<decltype(Val)>()) { \
73 ASSERT_NEAR(cg.value<Type>(), Val, 0.1); \
74 } else { \
75 ASSERT_EQ(cg.value<Type>(), Val); \
76 } \
77 }
78 TEST_LLVM_SCALAR_TYPES(SUB_TEST)
79 #undef SUB_TEST
80
81 #define MUL_TEST(Type, Name, Val) \
82 TEST(LLVM, Name##MulTest) { \
83 auto a = Name##Imm::make(Val); \
84 auto b = Name##Imm::make((Type)4); \
85 auto c = Mul::make(a, b); \
86 LLVMExprEval cg(c); \
87 if (std::is_floating_point<decltype(Val)>()) { \
88 ASSERT_NEAR(cg.value<Type>(), Val * 4, 0.1); \
89 } else { \
90 ASSERT_EQ(cg.value<Type>(), Val * 4); \
91 } \
92 }
93 TEST_LLVM_SCALAR_TYPES(MUL_TEST)
94 #undef MUL_TEST
95
96 #define DIV_TEST(Type, Name, Val) \
97 TEST(LLVM, Name##DivTest) { \
98 auto a = Name##Imm::make((Type)6); \
99 auto b = Name##Imm::make((Type)3); \
100 auto c = Div::make(a, b); \
101 LLVMExprEval cg(c); \
102 if (std::is_floating_point<decltype(Val)>()) { \
103 ASSERT_NEAR(cg.value<Type>(), 2, 0.1); \
104 } else { \
105 ASSERT_EQ(cg.value<Type>(), 2); \
106 } \
107 }
108 TEST_LLVM_SCALAR_TYPES(DIV_TEST)
109 #undef DIV_TEST
110
111 TEST(LLVM, IntToFloatCastTest) {
112 auto a = IntImm::make(2);
113 auto b = Cast::make(kFloat, a);
114 LLVMExprEval cg(b, {});
115 ASSERT_EQ(cg.value<float>(), 2.0);
116 }
117
TEST(LLVM,FloatToIntCastTest)118 TEST(LLVM, FloatToIntCastTest) {
119 auto a = FloatImm::make(2.0);
120 auto b = Cast::make(kInt, a);
121 LLVMExprEval cg(b);
122 ASSERT_EQ(cg.value<int>(), 2);
123 }
124
TEST(LLVM,IntToLongCastTest)125 TEST(LLVM, IntToLongCastTest) {
126 auto a = IntImm::make(12345);
127 auto b = Cast::make(kLong, a);
128 LLVMExprEval cg(b);
129 ASSERT_EQ(cg.value<int64_t>(), 12345);
130 }
131
TEST(LLVM,ByteToCharCastTest)132 TEST(LLVM, ByteToCharCastTest) {
133 auto a = ByteImm::make(250);
134 auto b = Cast::make(kChar, a);
135 LLVMExprEval cg(b);
136 ASSERT_EQ(cg.value<int8_t>(), (int8_t)250);
137 }
138
TEST(LLVM,HalfToLongCastTest)139 TEST(LLVM, HalfToLongCastTest) {
140 auto a = HalfImm::make(2.0);
141 auto b = Cast::make(kLong, a);
142 LLVMExprEval cg(b);
143 ASSERT_EQ(cg.value<int64_t>(), 2);
144 }
145
TEST(LLVM,ByteToDoubleCastTest)146 TEST(LLVM, ByteToDoubleCastTest) {
147 auto a = ByteImm::make(2);
148 auto b = Cast::make(kDouble, a);
149 LLVMExprEval cg(b);
150 ASSERT_EQ(cg.value<double>(), 2);
151 }
152
TEST(LLVM,FloatToByteCastTest)153 TEST(LLVM, FloatToByteCastTest) {
154 auto a = FloatImm::make(254.0);
155 auto b = Cast::make(kByte, a);
156 LLVMExprEval cg(b);
157 ASSERT_EQ(cg.value<uint8_t>(), 254);
158 }
159
TEST(LLVM,FloatToCharCastTest)160 TEST(LLVM, FloatToCharCastTest) {
161 auto a = FloatImm::make(-2.0);
162 auto b = Cast::make(kChar, a);
163 LLVMExprEval cg(b);
164 ASSERT_EQ(cg.value<int8_t>(), -2);
165 }
166
TEST(LLVM,ByteToFloatCastTest)167 TEST(LLVM, ByteToFloatCastTest) {
168 auto a = ByteImm::make(254);
169 auto b = Cast::make(kFloat, a);
170 LLVMExprEval cg(b);
171 ASSERT_EQ(cg.value<float>(), 254.0);
172 }
173
TEST(LLVM,CharToFloatCastTest)174 TEST(LLVM, CharToFloatCastTest) {
175 auto a = CharImm::make(-2);
176 auto b = Cast::make(kFloat, a);
177 LLVMExprEval cg(b);
178 ASSERT_EQ(cg.value<float>(), -2.0);
179 }
180
TEST(LLVM,BitCast)181 TEST(LLVM, BitCast) {
182 /* constexpr int16_t ref16 = 1337; */
183 constexpr int32_t ref32 = 1337;
184 constexpr int64_t ref64 = 1337;
185 constexpr float reff32 = 1337.0f;
186 constexpr double reff64 = 1337.0f;
187
188 // this is broken
189 /*{
190 at::Half k_;
191 at::Half* k = &k_;
192 *reinterpret_cast<int16_t*>(k) = ref16;
193 auto a = HalfImm::make(k);
194 auto b = BitCast::make(kShort, a);
195 LLVMExprEval cg(b);
196 ASSERT_EQ(cg.value<int16_t>(), ref16);
197 }*/
198
199 {
200 float k = raw_bitcast<float>(ref32);
201 auto a = FloatImm::make(k);
202 auto b = BitCast::make(kInt, a);
203 LLVMExprEval cg(b);
204 ASSERT_EQ(cg.value<int32_t>(), ref32);
205 }
206
207 {
208 double k = raw_bitcast<double>(ref64);
209 auto a = DoubleImm::make(k);
210 auto b = BitCast::make(kLong, a);
211 LLVMExprEval cg(b);
212 ASSERT_EQ(cg.value<int64_t>(), ref64);
213 }
214
215 {
216 int64_t k = raw_bitcast<int64_t>(reff64);
217 auto a = LongImm::make(k);
218 auto b = BitCast::make(kDouble, a);
219 LLVMExprEval cg(b);
220 ASSERT_EQ(cg.value<double>(), reff64);
221 }
222
223 {
224 int32_t k = raw_bitcast<int32_t>(reff32);
225 auto a = IntImm::make(k);
226 auto b = BitCast::make(kFloat, a);
227 LLVMExprEval cg(b);
228 ASSERT_EQ(cg.value<float>(), reff32);
229 }
230 }
231
TEST(LLVM,fastLogFloat)232 TEST(LLVM, fastLogFloat) {
233 const int kTotalSize = 128 * 128;
234 BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
235 BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
236
237 VarHandle index = VarHandle("index", kInt);
238 ExprHandle load_a = a_buf.load(index);
239 StmtPtr store_b = b_buf.store({index}, fast_log(load_a));
240 StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
241
242 PaddedBuffer<float> a_v(kTotalSize);
243 PaddedBuffer<float> b_v(kTotalSize);
244
245 for (const auto i : c10::irange(kTotalSize)) {
246 a_v(i) = at::randn({1}).item().to<float>();
247 }
248
249 LLVMCodeGen ir_eval(stmt, {a_buf, b_buf});
250 ir_eval.call({a_v, b_v});
251
252 for (const auto i : c10::irange(kTotalSize)) {
253 auto test = b_v(i);
254 auto ref = std::log(a_v(i));
255 if (std::isnan(ref)) {
256 ASSERT_EQ(std::isnan(test), true);
257 } else {
258 ASSERT_FLOAT_EQ(test, ref);
259 }
260 }
261 }
262
TEST(LLVM,LetTest01)263 TEST(LLVM, LetTest01) {
264 BufHandle a("A", {1}, kFloat);
265 std::vector<float> v = {1, 0};
266 std::vector<void*> args({v.data()});
267 VarHandle x("x", kFloat);
268 auto block = Block::make({
269 Let::make(x, 3.f),
270 a.store({0}, ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f))),
271 });
272
273 LLVMCodeGen cg(block, {a});
274 ASSERT_EQ(cg.value<int>(args), 0);
275 ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 4.f);
276 }
277
TEST(LLVM,LetTest02)278 TEST(LLVM, LetTest02) {
279 BufHandle a("A", {1}, kFloat);
280 std::vector<float> v = {1, 0};
281 std::vector<void*> args({v.data()});
282 VarHandle x("x", kFloat);
283 VarHandle y("y", kFloat);
284 auto block = Block::make(
285 {Let::make(x, 3.f),
286 Let::make(y, 6.f),
287 a.store(
288 {IntImm::make(0)},
289 ExprHandle(2.f) + (x * ExprHandle(3.f) + y * ExprHandle(4.f)))});
290
291 LLVMCodeGen cg(block, {a});
292 ASSERT_EQ(cg.value<int>(args), 0);
293 ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 6.f * 4.f);
294 }
295
TEST(LLVM,LetTestMultitype)296 TEST(LLVM, LetTestMultitype) {
297 BufHandle a("A", {1}, kDouble);
298 std::vector<double> v = {1, 0};
299 std::vector<void*> args({v.data()});
300 VarHandle x("x", kByte);
301 VarHandle y("y", kHalf);
302 auto block = Block::make(
303 {Let::make(x, 3),
304 Let::make(y, 6.f),
305 a.store(
306 {0},
307 Cast::make(
308 kDouble,
309 ExprHandle(2.f) +
310 (x * ExprHandle(3.f) + y * ExprHandle(4.f))))});
311
312 LLVMCodeGen cg(block, {a});
313 ASSERT_EQ(cg.value<int>(args), 0);
314 ASSERT_EQ(v[0], 2.f + 3 * 3.f + 6.f * 4.f);
315 }
316
TEST(LLVM,BufferTest)317 TEST(LLVM, BufferTest) {
318 BufHandle a("A", {32}, kFloat);
319 std::vector<int32_t> v(5);
320 std::vector<void*> args({v.data()});
321 auto rv = IntImm::make(0);
322 LLVMExprEval cg(rv, {a});
323 ASSERT_EQ(cg.value<int>(args), 0);
324 }
325
TEST(LLVM,BlockTest)326 TEST(LLVM, BlockTest) {
327 BufHandle a("A", {32}, kInt);
328 std::vector<int32_t> v = {1, 2};
329 std::vector<void*> args({v.data()});
330
331 auto block = Block::make({
332 a.store({0}, 3),
333 a.store({1}, 4),
334 a.store({0}, 4),
335 });
336
337 LLVMCodeGen cg(block, {a});
338 ASSERT_EQ(cg.value<int>(args), 0);
339 ASSERT_EQ(v[0], 4);
340 ASSERT_EQ(v[1], 4);
341 }
342
TEST(LLVM,LoadStoreTest)343 TEST(LLVM, LoadStoreTest) {
344 BufHandle a("A", {1}, kInt);
345 BufHandle b("B", {1}, kInt);
346 std::vector<int32_t> a_buffer = {42};
347 std::vector<int32_t> b_buffer = {-11};
348
349 auto store = b.store({0}, a.load(0));
350 LLVMCodeGen cg(store, {a, b});
351 std::vector<void*> args({a_buffer.data(), b_buffer.data()});
352 ASSERT_EQ(cg.value<int>(args), 0);
353 ASSERT_EQ(a_buffer[0], 42);
354 ASSERT_EQ(b_buffer[0], 42);
355 }
356
TEST(LLVM,IfThenElseTest)357 TEST(LLVM, IfThenElseTest) {
358 BufHandle a("A", {1}, kInt);
359 BufHandle b("B", {1}, kInt);
360 BufHandle c("C", {1}, kInt);
361 std::vector<int32_t> a_buffer = {42};
362 std::vector<int32_t> b_buffer = {-11};
363 std::vector<int32_t> c_buffer = {1};
364
365 auto store = b.store({0}, IfThenElse::make(c.load(0), a.load(0), 0));
366 LLVMCodeGen cg(store, {a, b, c});
367 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
368 ASSERT_EQ(cg.value<int>(args), 0);
369 ASSERT_EQ(a_buffer[0], 42);
370 ASSERT_EQ(b_buffer[0], 42);
371 }
372
373 // if (x < 10) x = x + 1
TEST(LLVM,CondNoFalseBlockTest)374 TEST(LLVM, CondNoFalseBlockTest) {
375 BufHandle x("X", {1}, kInt);
376 auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);
377 auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), nullptr);
378
379 for (int32_t x_value : {0, 10, 20}) {
380 std::vector<int32_t> x_buffer = {x_value};
381 std::vector<void*> args({x_buffer.data()});
382 LLVMCodeGen cg(cond, {x});
383 ASSERT_EQ(cg.value<int>(args), 0);
384 if (x_value < 10) {
385 ASSERT_EQ(x_buffer[0], x_value + 1);
386 } else {
387 ASSERT_EQ(x_buffer[0], x_value);
388 }
389 }
390 }
391
392 // if (x < 10) {
393 // x = x + 1;
394 // } else {
395 // x = x - 1;
396 // }
TEST(LLVM,CondTest)397 TEST(LLVM, CondTest) {
398 BufHandle x("X", {1}, kInt);
399 auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);
400 auto cond =
401 Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1));
402 auto block = Block::make({
403 cond,
404 x.store({0}, x.load(0) * 2),
405 });
406
407 for (int32_t x_value : {0, 10, 20}) {
408 std::vector<int32_t> x_buffer = {x_value};
409 std::vector<void*> args({x_buffer.data()});
410 LLVMCodeGen cg(block, {x});
411 ASSERT_EQ(cg.value<int>(args), 0);
412 if (x_value < 10) {
413 ASSERT_EQ(x_buffer[0], (x_value + 1) * 2);
414 } else {
415 ASSERT_EQ(x_buffer[0], (x_value - 1) * 2);
416 }
417 }
418 }
419
420 // if (x < 10) {
421 // if (x > 5) {
422 // x = x + 1;
423 // } else {
424 // x = x - 1;
425 // }
426 // } else {
427 // if (x <= 15) {
428 // x = x + 2;
429 // } else {
430 // x = x - 2;
431 // }
432 // }
TEST(LLVM,CondNestedTest)433 TEST(LLVM, CondNestedTest) {
434 BufHandle x("X", {1}, kInt);
435 auto true_cmp =
436 CompareSelect::make(x.load(0), 5, CompareSelectOperation::kGT);
437 auto true_cond = Cond::make(
438 true_cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1));
439 auto false_cmp =
440 CompareSelect::make(x.load(0), 15, CompareSelectOperation::kLE);
441 auto false_cond = Cond::make(
442 false_cmp, x.store({0}, x.load(0) + 2), x.store({0}, x.load(0) - 2));
443 auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);
444 auto cond = Cond::make(cmp, true_cond, false_cond);
445
446 for (int32_t x_value : {0, 8, 15, 20}) {
447 std::vector<int32_t> x_buffer = {x_value};
448 std::vector<void*> args({x_buffer.data()});
449 LLVMCodeGen cg(cond, {x});
450 ASSERT_EQ(cg.value<int>(args), 0);
451 if (x_value < 10) {
452 if (x_value > 5) {
453 ASSERT_EQ(x_buffer[0], x_value + 1);
454 } else {
455 ASSERT_EQ(x_buffer[0], x_value - 1);
456 }
457 } else {
458 if (x_value <= 15) {
459 ASSERT_EQ(x_buffer[0], x_value + 2);
460 } else {
461 ASSERT_EQ(x_buffer[0], x_value - 2);
462 }
463 }
464 }
465 }
466
TEST(LLVM,DirectVectorization)467 TEST(LLVM, DirectVectorization) {
468 constexpr int M = 3;
469 constexpr int N = 64;
470 BufHandle a("a", {M, N}, kFloat);
471 BufHandle b("b", {M, N}, kFloat);
472 BufHandle c("c", {M, N}, kFloat);
473 VarHandle m("m", kInt);
474 VarHandle n("n", kInt);
475 StmtPtr s = For::make(
476 m,
477 0,
478 M,
479 Store::make(
480 c,
481 {Ramp::make(m * 64, 1, 64)},
482 Load::make({kFloat, 64}, a, {Ramp::make(m * 64, 1, 64)}) *
483 Load::make({kFloat, 64}, b, {Ramp::make(m * 64, 1, 64)})));
484 LLVMCodeGen cg(s, {a, b, c});
485 }
486
TEST(LLVM,VecLoadStoreTest)487 TEST(LLVM, VecLoadStoreTest) {
488 BufHandle a("A", {1}, kInt);
489 BufHandle b("B", {1}, kInt);
490 std::vector<int32_t> a_buffer = {1, 1, 1, 1};
491 std::vector<int32_t> b_buffer = {2, 2, 2, 2};
492
493 auto store = b.store({Ramp::make(0, 1, 4)}, a.load({Ramp::make(0, 1, 4)}));
494 LLVMCodeGen cg(store, {a, b});
495 std::vector<void*> args({a_buffer.data(), b_buffer.data()});
496 ASSERT_EQ(cg.value<int>(args), 0);
497 ASSERT_EQ(a_buffer[0], 1);
498 ASSERT_EQ(a_buffer[1], 1);
499 ASSERT_EQ(a_buffer[2], 1);
500 ASSERT_EQ(a_buffer[3], 1);
501 ASSERT_EQ(b_buffer[0], 1);
502 ASSERT_EQ(b_buffer[1], 1);
503 ASSERT_EQ(b_buffer[2], 1);
504 ASSERT_EQ(b_buffer[3], 1);
505 }
506
507 #define FLOAT_INTRINSICS_TEST(Name, Lanes) \
508 TEST(LLVM, VecFloat_##Name##Lane##Lanes##Test) { \
509 BufHandle a("A", {1}, kFloat); \
510 BufHandle b("B", {1}, kFloat); \
511 float val = 0.5f; \
512 std::vector<float> a_buffer(Lanes, val); \
513 std::vector<float> b_buffer(Lanes, val); \
514 auto store = b.store( \
515 {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \
516 LLVMCodeGen cg(store, {a, b}); \
517 std::vector<void*> args({a_buffer.data(), b_buffer.data()}); \
518 ASSERT_EQ(cg.value<int>(args), 0); \
519 for (const auto i : c10::irange(Lanes)) { \
520 ASSERT_FLOAT_EQ(a_buffer[i], val); \
521 } \
522 } // namespace jit
523 FLOAT_INTRINSICS_TEST(erf, 4)
524 FLOAT_INTRINSICS_TEST(erfc, 4)
525 FLOAT_INTRINSICS_TEST(acos, 4)
526 FLOAT_INTRINSICS_TEST(asin, 4)
527 FLOAT_INTRINSICS_TEST(atan, 4)
528 FLOAT_INTRINSICS_TEST(cosh, 4)
529 FLOAT_INTRINSICS_TEST(sinh, 4)
530 FLOAT_INTRINSICS_TEST(tanh, 4)
531 FLOAT_INTRINSICS_TEST(expm1, 4)
532 FLOAT_INTRINSICS_TEST(lgamma, 4)
533 FLOAT_INTRINSICS_TEST(erf, 8)
534 FLOAT_INTRINSICS_TEST(erfc, 8)
535 FLOAT_INTRINSICS_TEST(acos, 8)
536 FLOAT_INTRINSICS_TEST(asin, 8)
537 FLOAT_INTRINSICS_TEST(atan, 8)
538 FLOAT_INTRINSICS_TEST(cosh, 8)
539 FLOAT_INTRINSICS_TEST(sinh, 8)
540 FLOAT_INTRINSICS_TEST(tanh, 8)
541 FLOAT_INTRINSICS_TEST(expm1, 8)
542 FLOAT_INTRINSICS_TEST(lgamma, 8)
543 #undef FLOAT_INTRINSICS_TEST
544
545 #define DOUBLE_INTRINSICS_TEST(Name, Lanes) \
546 TEST(LLVM, VecDouble_##Name##Lane##Lanes##Test) { \
547 BufHandle a("A", {1}, kDouble); \
548 BufHandle b("B", {1}, kDouble); \
549 float val = 0.5f; \
550 std::vector<double> a_buffer(Lanes, val); \
551 std::vector<double> b_buffer(Lanes, val); \
552 auto store = b.store( \
553 {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \
554 LLVMCodeGen cg(store, {a, b}); \
555 std::vector<void*> args({a_buffer.data(), b_buffer.data()}); \
556 ASSERT_EQ(cg.value<int>(args), 0); \
557 for (const auto i : c10::irange(Lanes)) { \
558 ASSERT_FLOAT_EQ(a_buffer[i], val); \
559 } \
560 } // namespace jit
561 DOUBLE_INTRINSICS_TEST(erf, 2)
562 DOUBLE_INTRINSICS_TEST(erfc, 2)
563 DOUBLE_INTRINSICS_TEST(acos, 2)
564 DOUBLE_INTRINSICS_TEST(asin, 2)
565 DOUBLE_INTRINSICS_TEST(atan, 2)
566 DOUBLE_INTRINSICS_TEST(cosh, 2)
567 DOUBLE_INTRINSICS_TEST(sinh, 2)
568 DOUBLE_INTRINSICS_TEST(tanh, 2)
569 DOUBLE_INTRINSICS_TEST(expm1, 2)
570 DOUBLE_INTRINSICS_TEST(lgamma, 2)
571 DOUBLE_INTRINSICS_TEST(erf, 4)
572 DOUBLE_INTRINSICS_TEST(erfc, 4)
573 DOUBLE_INTRINSICS_TEST(acos, 4)
574 DOUBLE_INTRINSICS_TEST(asin, 4)
575 DOUBLE_INTRINSICS_TEST(atan, 4)
576 DOUBLE_INTRINSICS_TEST(cosh, 4)
577 DOUBLE_INTRINSICS_TEST(sinh, 4)
578 DOUBLE_INTRINSICS_TEST(tanh, 4)
579 DOUBLE_INTRINSICS_TEST(expm1, 4)
580 DOUBLE_INTRINSICS_TEST(lgamma, 4)
581 #undef DOUBLE_INTRINSICS_TEST
582
TEST(LLVM,VectorizerLoadStoreTest)583 TEST(LLVM, VectorizerLoadStoreTest) {
584 BufHandle a("A", {1}, kInt);
585
586 Tensor c = Compute("c", {4}, [&](const VarHandle& i) { return a.load(i); });
587
588 BufHandle c_buf(c.buf());
589 LoopNest l({c});
590 StmtPtr s = l.root_stmt();
591 ASSERT_TRUE(LoopNest::vectorize(to<For>(to<Block>(s)->front())));
592
593 ASSERT_TRUE(to<For>(to<Block>(s)->front()) == nullptr);
594
595 LLVMCodeGen cg(s, {a, c_buf});
596
597 std::vector<int> a_vec(4, 21);
598 std::vector<int> c_vec(4, 0);
599 std::vector<void*> args({a_vec.data(), c_vec.data()});
600 ASSERT_EQ(cg.value<int>(args), 0);
601 assertAllEqual(c_vec, 21);
602 }
603
TEST(LLVM,VectorizeBitCast)604 TEST(LLVM, VectorizeBitCast) {
605 BufHandle a("A", {128}, kInt);
606
607 Tensor c = Compute("c", {128}, [&](const VarHandle& i) {
608 return bitcast<float>(a.load(i));
609 });
610
611 BufHandle c_buf(c.buf());
612 LoopNest l({c});
613 StmtPtr s = l.root_stmt();
614 ASSERT_TRUE(LoopNest::vectorize(to<For>(to<Block>(s)->front())));
615 ASSERT_TRUE(to<For>(to<Block>(s)->front()) == nullptr);
616
617 LLVMCodeGen cg(s, {a, c_buf});
618
619 std::vector<int> a_vec(128);
620 std::vector<float> c_vec(128);
621 for (const auto i : c10::irange(128)) {
622 a_vec[i] = raw_bitcast<int>(1337.f);
623 }
624 std::vector<void*> args({a_vec.data(), c_vec.data()});
625 ASSERT_EQ(cg.value<int>(args), 0);
626 assertAllEqual(c_vec, 1337.f);
627 }
628
TEST(LLVM,MemcpyTest)629 TEST(LLVM, MemcpyTest) {
630 constexpr int N = 32;
631 BufHandle a("A", {N}, kInt);
632 BufHandle b("B", {N}, kInt);
633 std::vector<int32_t> a_buffer(N, 42);
634 std::vector<int32_t> b_buffer(N, 0);
635
636 VarHandle i("i", kInt);
637 auto expr = For::make(i, 0, N, b.store({i}, a.load(i)));
638
639 LLVMCodeGen cg(expr, {a, b});
640
641 std::vector<void*> args({a_buffer.data(), b_buffer.data()});
642 ASSERT_EQ(cg.value<int>(args), 0);
643
644 ASSERT_EQ(a_buffer.size(), N);
645 ASSERT_EQ(b_buffer.size(), N);
646 assertAllEqual(a_buffer, 42);
647 assertAllEqual(b_buffer, 42);
648 }
649
TEST(LLVM,BzeroTest)650 TEST(LLVM, BzeroTest) {
651 constexpr int N = 32;
652 BufHandle b("B", {N}, kInt);
653 std::vector<int32_t> b_buffer(N, 11);
654
655 VarHandle i("i", kInt);
656 auto expr = For::make(i, 0, N, b.store({i}, 0));
657
658 LLVMCodeGen cg(expr, {b});
659
660 std::vector<void*> args({b_buffer.data()});
661 ASSERT_EQ(cg.value<int>(args), 0);
662
663 ASSERT_EQ(b_buffer.size(), N);
664 assertAllEqual(b_buffer, 0);
665 }
666
TEST(LLVM,ElemwiseAdd)667 TEST(LLVM, ElemwiseAdd) {
668 constexpr int N = 1024;
669 BufHandle a("A", {N}, kInt);
670 BufHandle b("B", {N}, kInt);
671 BufHandle c("C", {N}, kInt);
672 std::vector<int32_t> a_buffer(N, 41);
673 std::vector<int32_t> b_buffer(N, 1);
674 std::vector<int32_t> c_buffer(N, 1);
675
676 VarHandle i("i", kInt);
677 auto expr = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i))));
678
679 LLVMCodeGen cg(expr, {a, b, c});
680
681 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
682 ASSERT_EQ(cg.value<int>(args), 0);
683
684 ASSERT_EQ(a_buffer.size(), N);
685 ASSERT_EQ(b_buffer.size(), N);
686 ASSERT_EQ(c_buffer.size(), N);
687 assertAllEqual(a_buffer, 41);
688 assertAllEqual(b_buffer, 1);
689 assertAllEqual(c_buffer, 42);
690 }
691
TEST(LLVM,ElemwiseAddFloat)692 TEST(LLVM, ElemwiseAddFloat) {
693 constexpr int N = 1024;
694 BufHandle a("A", {N}, kFloat);
695 BufHandle b("B", {N}, kFloat);
696 BufHandle c("C", {N}, kFloat);
697 std::vector<float> a_buffer(N, 41);
698 std::vector<float> b_buffer(N, 1);
699 std::vector<float> c_buffer(N, 1);
700
701 VarHandle i("i", kInt);
702 auto expr = For::make(i, 0, N, c.store({i}, a.load(i) + b.load(i)));
703
704 LLVMCodeGen cg(expr, {a, b, c});
705
706 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
707 ASSERT_EQ(cg.value<int>(args), 0);
708
709 ASSERT_EQ(a_buffer.size(), N);
710 ASSERT_EQ(b_buffer.size(), N);
711 ASSERT_EQ(c_buffer.size(), N);
712 assertAllEqual(a_buffer, 41.0f);
713 assertAllEqual(b_buffer, 1.0f);
714 assertAllEqual(c_buffer, 42.0f);
715 }
716
TEST(LLVM,ElemwiseLog10Float)717 TEST(LLVM, ElemwiseLog10Float) {
718 constexpr int N = 1024;
719 BufHandle a("A", {N}, kFloat);
720 BufHandle b("B", {N}, kFloat);
721 std::vector<float> a_buffer(N, 10.0f);
722 std::vector<float> b_buffer(N, 2.0f);
723
724 VarHandle i("i", kInt);
725 auto expr = For::make(
726 i,
727 0,
728 N / 4,
729 b.store(
730 {Ramp::make(i * 4, 1, 4)}, log10(a.load({Ramp::make(i * 4, 1, 4)}))));
731
732 LLVMCodeGen cg(expr, {a, b});
733
734 std::vector<void*> args({a_buffer.data(), b_buffer.data()});
735 ASSERT_EQ(cg.value<int>(args), 0);
736
737 ASSERT_EQ(a_buffer.size(), N);
738 ASSERT_EQ(b_buffer.size(), N);
739 assertAllEqual(a_buffer, 10.0f);
740 assertAllEqual(b_buffer, 1.0f);
741 }
742
TEST(LLVM,ElemwiseLog1pFloat)743 TEST(LLVM, ElemwiseLog1pFloat) {
744 constexpr int N = 1024;
745 BufHandle a("A", {N}, kFloat);
746 BufHandle b("B", {N}, kFloat);
747 std::vector<float> a_buffer(N, expf(3.0f) - 1);
748 std::vector<float> b_buffer(N, 42.0f);
749
750 VarHandle i("i", kInt);
751 auto expr = For::make(
752 i,
753 0,
754 N / 4,
755 b.store(
756 {Ramp::make(i * 4, 1, 4)}, log1p(a.load({Ramp::make(i * 4, 1, 4)}))));
757
758 LLVMCodeGen cg(expr, {a, b});
759
760 std::vector<void*> args({a_buffer.data(), b_buffer.data()});
761 ASSERT_EQ(cg.value<int>(args), 0);
762
763 ASSERT_EQ(a_buffer.size(), N);
764 ASSERT_EQ(b_buffer.size(), N);
765 assertAllEqual(a_buffer, expf(3.0f) - 1);
766 ExpectAllNear(b_buffer, 3.0f, 1e-5f);
767 }
768
TEST(LLVM,ElemwiseMaxInt)769 TEST(LLVM, ElemwiseMaxInt) {
770 constexpr int N = 1024;
771 BufHandle a("A", {N}, kInt);
772 BufHandle b("B", {N}, kInt);
773 BufHandle c("C", {N}, kInt);
774 std::vector<int> a_buffer(N, 41);
775 std::vector<int> b_buffer(N, 1);
776 std::vector<int> c_buffer(N, 1);
777
778 VarHandle i("i", kInt);
779 auto expr =
780 For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false)));
781
782 LLVMCodeGen cg(expr, {a, b, c});
783
784 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
785 ASSERT_EQ(cg.value<int>(args), 0);
786
787 ASSERT_EQ(a_buffer.size(), N);
788 ASSERT_EQ(b_buffer.size(), N);
789 ASSERT_EQ(c_buffer.size(), N);
790 assertAllEqual(a_buffer, 41);
791 assertAllEqual(b_buffer, 1);
792 assertAllEqual(c_buffer, 41);
793 }
794
TEST(LLVM,ElemwiseMinInt)795 TEST(LLVM, ElemwiseMinInt) {
796 constexpr int N = 1024;
797 BufHandle a("A", {N}, kInt);
798 BufHandle b("B", {N}, kInt);
799 BufHandle c("C", {N}, kInt);
800 std::vector<int> a_buffer(N, 41);
801 std::vector<int> b_buffer(N, 1);
802 std::vector<int> c_buffer(N, 1);
803
804 VarHandle i("i", kInt);
805 auto expr =
806 For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false)));
807
808 LLVMCodeGen cg(expr, {a, b, c});
809
810 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
811 ASSERT_EQ(cg.value<int>(args), 0);
812
813 ASSERT_EQ(a_buffer.size(), N);
814 ASSERT_EQ(b_buffer.size(), N);
815 ASSERT_EQ(c_buffer.size(), N);
816 assertAllEqual(a_buffer, 41);
817 assertAllEqual(b_buffer, 1);
818 assertAllEqual(c_buffer, 1);
819 }
820
TEST(LLVM,ElemwiseMaxFloat)821 TEST(LLVM, ElemwiseMaxFloat) {
822 constexpr int N = 1024;
823 BufHandle a("A", {N}, kFloat);
824 BufHandle b("B", {N}, kFloat);
825 BufHandle c("C", {N}, kFloat);
826 std::vector<float> a_buffer(N, 41);
827 std::vector<float> b_buffer(N, 1);
828 std::vector<float> c_buffer(N, 1);
829
830 VarHandle i("i", kInt);
831 auto expr =
832 For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false)));
833
834 LLVMCodeGen cg(expr, {a, b, c});
835
836 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
837 ASSERT_EQ(cg.value<int>(args), 0);
838
839 ASSERT_EQ(a_buffer.size(), N);
840 ASSERT_EQ(b_buffer.size(), N);
841 ASSERT_EQ(c_buffer.size(), N);
842 assertAllEqual(a_buffer, 41.0f);
843 assertAllEqual(b_buffer, 1.0f);
844 assertAllEqual(c_buffer, 41.0f);
845 }
846
TEST(LLVM,ElemwiseMaxNaNFloat)847 TEST(LLVM, ElemwiseMaxNaNFloat) {
848 constexpr int N = 1024;
849 BufHandle a("A", {N}, kFloat);
850 BufHandle b("B", {N}, kFloat);
851 BufHandle c("C", {N}, kFloat);
852 std::vector<float> a_buffer(N, NAN);
853 std::vector<float> b_buffer(N, 1);
854 std::vector<float> c_buffer(N, 1);
855
856 VarHandle i("i", kInt);
857 auto expr =
858 For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false)));
859
860 LLVMCodeGen cg(expr, {a, b, c});
861
862 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
863 ASSERT_EQ(cg.value<int>(args), 0);
864
865 ASSERT_EQ(a_buffer.size(), N);
866 ASSERT_EQ(b_buffer.size(), N);
867 ASSERT_EQ(c_buffer.size(), N);
868 assertAllEqual(b_buffer, 1.0f);
869 for (auto const& elt : c_buffer) {
870 ASSERT_TRUE(std::isnan(elt));
871 }
872 }
873
TEST(LLVM,ElemwiseMinFloat)874 TEST(LLVM, ElemwiseMinFloat) {
875 constexpr int N = 1024;
876 BufHandle a("A", {N}, kFloat);
877 BufHandle b("B", {N}, kFloat);
878 BufHandle c("C", {N}, kFloat);
879 std::vector<float> a_buffer(N, 41);
880 std::vector<float> b_buffer(N, 1);
881 std::vector<float> c_buffer(N, 1);
882
883 VarHandle i("i", kInt);
884 auto expr =
885 For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false)));
886
887 LLVMCodeGen cg(expr, {a, b, c});
888
889 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
890 ASSERT_EQ(cg.value<int>(args), 0);
891
892 ASSERT_EQ(a_buffer.size(), N);
893 ASSERT_EQ(b_buffer.size(), N);
894 ASSERT_EQ(c_buffer.size(), N);
895 assertAllEqual(a_buffer, 41.0f);
896 assertAllEqual(b_buffer, 1.0f);
897 assertAllEqual(c_buffer, 1.0f);
898 }
899
TEST(LLVM,ElemwiseMinNaNFloat)900 TEST(LLVM, ElemwiseMinNaNFloat) {
901 constexpr int N = 1024;
902 BufHandle a("A", {N}, kFloat);
903 BufHandle b("B", {N}, kFloat);
904 BufHandle c("C", {N}, kFloat);
905 std::vector<float> a_buffer(N, NAN);
906 std::vector<float> b_buffer(N, 1);
907 std::vector<float> c_buffer(N, 1);
908
909 VarHandle i("i", kInt);
910 auto expr =
911 For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false)));
912
913 LLVMCodeGen cg(expr, {a, b, c});
914
915 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
916 ASSERT_EQ(cg.value<int>(args), 0);
917
918 ASSERT_EQ(a_buffer.size(), N);
919 ASSERT_EQ(b_buffer.size(), N);
920 ASSERT_EQ(c_buffer.size(), N);
921 assertAllEqual(b_buffer, 1.0f);
922 for (auto const& elt : c_buffer) {
923 ASSERT_TRUE(std::isnan(elt));
924 }
925 }
926
TEST(LLVM,ElemwiseMod)927 TEST(LLVM, ElemwiseMod) {
928 constexpr int N = 1024;
929 BufHandle a("A", {N}, kInt);
930 BufHandle b("B", {N}, kInt);
931 BufHandle c("C", {N}, kInt);
932 std::vector<int32_t> a_buffer(N, 41);
933 std::vector<int32_t> b_buffer(N, 23);
934 std::vector<int32_t> c_buffer(N, 18);
935
936 VarHandle i("i", kInt);
937 auto expr = For::make(i, 0, N, c.store({i}, Mod::make(a.load(i), b.load(i))));
938
939 LLVMCodeGen cg(expr, {a, b, c});
940
941 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
942 ASSERT_EQ(cg.value<int>(args), 0);
943
944 ASSERT_EQ(a_buffer.size(), N);
945 ASSERT_EQ(b_buffer.size(), N);
946 ASSERT_EQ(c_buffer.size(), N);
947 assertAllEqual(a_buffer, 41);
948 assertAllEqual(b_buffer, 23);
949 assertAllEqual(c_buffer, 18);
950 }
951
TEST(LLVM,CompareSelectIntEQ)952 TEST(LLVM, CompareSelectIntEQ) {
953 constexpr int N = 1024;
954 BufHandle a("A", {N}, kInt);
955 BufHandle b("B", {N}, kInt);
956 BufHandle c("C", {N}, kInt);
957 std::vector<int> a_buffer(N, 1);
958 std::vector<int> b_buffer(N, 1);
959 std::vector<int> c_buffer(N, 0);
960 std::vector<int> c_ref(N, 1);
961
962 for (int i = 0; i < N / 2; i++) {
963 b_buffer[i] = 0;
964 c_ref[i] = 0;
965 }
966
967 VarHandle i("i", kInt);
968 auto expr = For::make(
969 i,
970 0,
971 N,
972 c.store(
973 {i},
974 CompareSelect::make(
975 a.load(i), b.load(i), CompareSelectOperation::kEQ)));
976
977 LLVMCodeGen cg(expr, {a, b, c});
978
979 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
980 ASSERT_EQ(cg.value<int>(args), 0);
981
982 ASSERT_EQ(a_buffer.size(), N);
983 ASSERT_EQ(b_buffer.size(), N);
984 ASSERT_EQ(c_buffer.size(), N);
985
986 assertAllEqual(a_buffer, 1);
987 for (const auto i : c10::irange(N)) {
988 ASSERT_EQ(c_ref[i], c_buffer[i]);
989 }
990 }
991
TEST(LLVM,CompareSelectFloatEQ)992 TEST(LLVM, CompareSelectFloatEQ) {
993 constexpr int N = 1024;
994 BufHandle a("A", {N}, kFloat);
995 BufHandle b("B", {N}, kFloat);
996 BufHandle c("C", {N}, kInt);
997 std::vector<float> a_buffer(N, 1.0f);
998 std::vector<float> b_buffer(N, 1.0f);
999 std::vector<int> c_buffer(N, 0);
1000
1001 VarHandle i("i", kInt);
1002 auto expr = For::make(
1003 i,
1004 0,
1005 N,
1006 c.store(
1007 {i},
1008 CompareSelect::make(
1009 a.load(i), b.load(i), CompareSelectOperation::kEQ)));
1010
1011 LLVMCodeGen cg(expr, {a, b, c});
1012
1013 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1014 ASSERT_EQ(cg.value<int>(args), 0);
1015
1016 ASSERT_EQ(a_buffer.size(), N);
1017 ASSERT_EQ(b_buffer.size(), N);
1018 ASSERT_EQ(c_buffer.size(), N);
1019
1020 assertAllEqual(a_buffer, 1.0f);
1021 assertAllEqual(b_buffer, 1.0f);
1022 assertAllEqual(c_buffer, 1);
1023 }
1024
TEST(LLVM,CompareSelectByteGT)1025 TEST(LLVM, CompareSelectByteGT) {
1026 constexpr int N = 1024;
1027 BufHandle a("A", {N}, kByte);
1028 BufHandle b("B", {N}, kByte);
1029 BufHandle c("C", {N}, kInt);
1030 std::vector<uint8_t> a_buffer(N, 0);
1031 std::vector<uint8_t> b_buffer(N, 0);
1032 std::vector<int> c_buffer(N, 0);
1033 std::vector<int> c_ref(N, 0);
1034
1035 for (int i = 0; i < N / 2; i++) {
1036 a_buffer[i] = 128;
1037 c_ref[i] = 1;
1038 }
1039
1040 VarHandle i("i", kInt);
1041 auto expr = For::make(
1042 i,
1043 0,
1044 N,
1045 c.store(
1046 {i},
1047 CompareSelect::make(
1048 a.load(i), b.load(i), CompareSelectOperation::kGT)));
1049
1050 LLVMCodeGen cg(expr, {a, b, c});
1051
1052 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1053 ASSERT_EQ(cg.value<int>(args), 0);
1054
1055 ASSERT_EQ(a_buffer.size(), N);
1056 ASSERT_EQ(b_buffer.size(), N);
1057 ASSERT_EQ(c_buffer.size(), N);
1058
1059 assertAllEqual(b_buffer, uint8_t(0));
1060 for (const auto i : c10::irange(N)) {
1061 ASSERT_EQ(c_ref[i], c_buffer[i]);
1062 }
1063 }
1064
TEST(LLVM,CompareSelectByteGE)1065 TEST(LLVM, CompareSelectByteGE) {
1066 constexpr int N = 1024;
1067 BufHandle a("A", {N}, kByte);
1068 BufHandle b("B", {N}, kByte);
1069 BufHandle c("C", {N}, kInt);
1070 std::vector<uint8_t> a_buffer(N, 0);
1071 std::vector<uint8_t> b_buffer(N, 0);
1072 std::vector<int> c_buffer(N, 0);
1073 std::vector<int> c_ref(N, 1);
1074
1075 VarHandle i("i", kInt);
1076 auto expr = For::make(
1077 i,
1078 0,
1079 N,
1080 c.store(
1081 {i},
1082 CompareSelect::make(
1083 a.load(i), b.load(i), CompareSelectOperation::kGE)));
1084
1085 LLVMCodeGen cg(expr, {a, b, c});
1086
1087 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1088 ASSERT_EQ(cg.value<int>(args), 0);
1089
1090 ASSERT_EQ(a_buffer.size(), N);
1091 ASSERT_EQ(b_buffer.size(), N);
1092 ASSERT_EQ(c_buffer.size(), N);
1093
1094 assertAllEqual(b_buffer, uint8_t(0));
1095 for (const auto i : c10::irange(N)) {
1096 ASSERT_EQ(c_ref[i], c_buffer[i]);
1097 }
1098 }
1099
TEST(LLVM,CompareSelectByteLT)1100 TEST(LLVM, CompareSelectByteLT) {
1101 constexpr int N = 1024;
1102 BufHandle a("A", {N}, kByte);
1103 BufHandle b("B", {N}, kByte);
1104 BufHandle c("C", {N}, kInt);
1105 std::vector<uint8_t> a_buffer(N, 0);
1106 std::vector<uint8_t> b_buffer(N, 128);
1107 std::vector<int> c_buffer(N, 0);
1108 std::vector<int> c_ref(N, 1);
1109
1110 for (int i = 0; i < N / 2; i++) {
1111 a_buffer[i] = 128;
1112 c_ref[i] = 0;
1113 }
1114
1115 VarHandle i("i", kInt);
1116 auto expr = For::make(
1117 i,
1118 0,
1119 N,
1120 c.store(
1121 {i},
1122 CompareSelect::make(
1123 a.load(i), b.load(i), CompareSelectOperation::kLT)));
1124
1125 LLVMCodeGen cg(expr, {a, b, c});
1126
1127 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1128 ASSERT_EQ(cg.value<int>(args), 0);
1129
1130 ASSERT_EQ(a_buffer.size(), N);
1131 ASSERT_EQ(b_buffer.size(), N);
1132 ASSERT_EQ(c_buffer.size(), N);
1133
1134 assertAllEqual(b_buffer, uint8_t(128));
1135 for (const auto i : c10::irange(N)) {
1136 ASSERT_EQ(c_ref[i], c_buffer[i]);
1137 }
1138 }
1139
TEST(LLVM,CompareSelectByteLE)1140 TEST(LLVM, CompareSelectByteLE) {
1141 constexpr int N = 1024;
1142 BufHandle a("A", {N}, kByte);
1143 BufHandle b("B", {N}, kByte);
1144 BufHandle c("C", {N}, kInt);
1145 std::vector<uint8_t> a_buffer(N, 0);
1146 std::vector<uint8_t> b_buffer(N, 128);
1147 std::vector<int> c_buffer(N, 0);
1148 std::vector<int> c_ref(N, 1);
1149
1150 VarHandle i("i", kInt);
1151 auto expr = For::make(
1152 i,
1153 0,
1154 N,
1155 c.store(
1156 {i},
1157 CompareSelect::make(
1158 a.load(i), b.load(i), CompareSelectOperation::kLE)));
1159
1160 LLVMCodeGen cg(expr, {a, b, c});
1161
1162 std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1163 ASSERT_EQ(cg.value<int>(args), 0);
1164
1165 ASSERT_EQ(a_buffer.size(), N);
1166 ASSERT_EQ(b_buffer.size(), N);
1167 ASSERT_EQ(c_buffer.size(), N);
1168
1169 assertAllEqual(b_buffer, uint8_t(128));
1170 for (const auto i : c10::irange(N)) {
1171 ASSERT_EQ(c_ref[i], c_buffer[i]);
1172 }
1173 }
1174
TEST(LLVM,StoreFloat)1175 TEST(LLVM, StoreFloat) {
1176 BufHandle result("result", {1}, kFloat);
1177 std::vector<float> result_buffer = {0.0f};
1178 auto expr = result.store({0}, FloatImm::make(3.14f));
1179 LLVMCodeGen cg(expr, {result});
1180 std::vector<void*> args({result_buffer.data()});
1181 ASSERT_EQ(cg.value<int>(args), 0);
1182 ASSERT_EQ(result_buffer[0], 3.14f);
1183 }
1184
TEST(LLVM,SimpleMath01)1185 TEST(LLVM, SimpleMath01) {
1186 const int N = 1024;
1187 Tensor tensor = Compute(
1188 "f", {N}, [](const VarHandle& i) { return cast<float>(i * i + 1); });
1189 LoopNest l({tensor});
1190 StmtPtr stmt = l.root_stmt();
1191 BufHandle f_buf(tensor.buf());
1192 LLVMCodeGen cg(stmt, {f_buf});
1193
1194 PaddedBuffer<float> f_v(N, "f_v");
1195 std::vector<void*> args({f_v.data()});
1196 int value = cg.value<int>(args);
1197 ASSERT_EQ(value, 0);
1198 PaddedBuffer<float> f_ref(N, "f_ref");
1199 for (const auto i : c10::irange(N)) {
1200 f_ref(i) = i * i + 1;
1201 }
1202 ExpectAllNear(f_v, f_ref, 1e-5);
1203 }
1204
TEST(LLVM,ComputeMul)1205 TEST(LLVM, ComputeMul) {
1206 const int N = 1024;
1207 BufHandle a("a", {N}, kFloat);
1208 BufHandle b("b", {N}, kFloat);
1209 Tensor c = Compute(
1210 "c", {N}, [&](const VarHandle& i) { return a.load(i) * b.load(i); });
1211
1212 BufHandle c_buf(c.buf());
1213 LoopNest l({c});
1214 StmtPtr s = l.root_stmt();
1215
1216 LLVMCodeGen cg(s, {a, b, c_buf});
1217
1218 std::vector<float> a_vec(N, 21.0f);
1219 std::vector<float> b_vec(N, 2.0f);
1220 std::vector<float> c_vec(N, 0.0f);
1221 std::vector<void*> args({a_vec.data(), b_vec.data(), c_vec.data()});
1222 ASSERT_EQ(cg.value<int>(args), 0);
1223 assertAllEqual(c_vec, 42.0f);
1224 }
1225
TEST(LLVM,BroadcastAdd)1226 TEST(LLVM, BroadcastAdd) {
1227 const int M = 32;
1228 const int N = 1024;
1229 BufHandle a("a", {M, N}, kFloat);
1230 BufHandle b("b", {N}, kFloat);
1231 Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) {
1232 return a.load(i, j) + b.load(j);
1233 });
1234
1235 BufHandle c_buf(c.buf());
1236 LoopNest l({c});
1237 l.prepareForCodegen();
1238 StmtPtr s = l.root_stmt();
1239
1240 LLVMCodeGen cg(s, {a, b, c_buf});
1241
1242 std::vector<float> av(M * N);
1243 std::iota(av.begin(), av.end(), 0);
1244 std::vector<float> bv(N);
1245 std::iota(bv.begin(), bv.end(), 0);
1246 std::vector<float> cv(M * N, 0);
1247 std::vector<void*> args({av.data(), bv.data(), cv.data()});
1248 ASSERT_EQ(cg.value<int>(args), 0);
1249
1250 for (const auto i : c10::irange(M)) {
1251 for (const auto j : c10::irange(N)) {
1252 ASSERT_EQ(cv[i * N + j], av[i * N + j] + bv[j]);
1253 }
1254 }
1255 }
1256
TEST(LLVM,BitwiseOps)1257 TEST(LLVM, BitwiseOps) {
1258 auto a = IntImm::make(59);
1259 auto b = IntImm::make(11);
1260 auto c = IntImm::make(101);
1261 auto d = IntImm::make(2);
1262
1263 ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;
1264 LLVMExprEval cg(f);
1265
1266 ASSERT_EQ(cg.value<int>(), 11);
1267 }
1268
TEST(LLVM,ArithmeticRightShift)1269 TEST(LLVM, ArithmeticRightShift) {
1270 auto a = CharImm::make(-4);
1271 auto b = CharImm::make(1);
1272 ExprHandle f = a >> b;
1273 LLVMExprEval cg(f);
1274 ASSERT_EQ(cg.value<int8_t>(), -2);
1275 }
1276
TEST(LLVM,LogicalRightShift)1277 TEST(LLVM, LogicalRightShift) {
1278 auto a = ByteImm::make(0xfc);
1279 auto b = ByteImm::make(1);
1280 ExprHandle f = a >> b;
1281 LLVMExprEval cg(f);
1282 ASSERT_EQ(cg.value<uint8_t>(), 0x7e);
1283 }
1284
TEST(LLVM,DynamicShapeAdd)1285 TEST(LLVM, DynamicShapeAdd) {
1286 auto testWithSize = [](int32_t size) {
1287 VarHandle n("n", kInt);
1288 BufHandle a("a", {n}, kFloat);
1289 BufHandle b("b", {n}, kFloat);
1290 BufHandle c("c", {n}, kFloat);
1291 VarHandle i("i", kInt);
1292 StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
1293 std::vector<float> aData(size, 1.0f);
1294 std::vector<float> bData(size, 2.0f);
1295 std::vector<float> cData(size, 0.0f);
1296 LLVMCodeGen cg(s, {a, b, c, n});
1297 std::vector<void*> args({aData.data(), bData.data(), cData.data(), &size});
1298 cg.value<float>(args);
1299 ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
1300 };
1301 testWithSize(1);
1302 testWithSize(16);
1303 testWithSize(37);
1304 }
1305
TEST(LLVM,BindDynamicShapeAdd)1306 TEST(LLVM, BindDynamicShapeAdd) {
1307 auto testWithSize = [](int32_t size) {
1308 VarHandle n("n", kInt);
1309 BufHandle a("a", {n}, kFloat);
1310 BufHandle b("b", {n}, kFloat);
1311 BufHandle c("c", {n}, kFloat);
1312 VarHandle i("i", kInt);
1313 StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
1314 std::vector<float> aData(size, 1.0f);
1315 std::vector<float> bData(size, 2.0f);
1316 std::vector<float> cData(size, 0.0f);
1317 LLVMCodeGen cg(s, {a, b, c, n});
1318 cg.call({aData, bData, cData, size});
1319 ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
1320 };
1321 testWithSize(1);
1322 testWithSize(16);
1323 testWithSize(37);
1324 }
1325
TEST(LLVM,TensorDynamicShapeAdd)1326 TEST(LLVM, TensorDynamicShapeAdd) {
1327 auto testWithSize = [](int32_t size) {
1328 VarHandle n("n", kInt);
1329 BufHandle a("a", {n}, kFloat);
1330 BufHandle b("b", {n}, kFloat);
1331 Tensor c = Compute(
1332 "c", {n}, [&](const VarHandle& i) { return a.load(i) + b.load(i); });
1333 LoopNest l({c});
1334 StmtPtr s = l.root_stmt();
1335 LLVMCodeGen cg(s, {a, b, c, n});
1336 std::vector<float> aData(size, 1.0f);
1337 std::vector<float> bData(size, 2.0f);
1338 std::vector<float> cData(size, 0.0f);
1339 cg.call({aData, bData, cData, size});
1340 ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
1341 };
1342 testWithSize(1);
1343 testWithSize(16);
1344 testWithSize(37);
1345 }
1346
TEST(LLVM,DynamicShape2D)1347 TEST(LLVM, DynamicShape2D) {
1348 auto testWithSize = [](int32_t M, int32_t N) {
1349 VarHandle m("m", kInt);
1350 VarHandle n("n", kInt);
1351 BufHandle a("a", {m, n}, kFloat);
1352 BufHandle b("b", {m, n}, kFloat);
1353 Tensor c =
1354 Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) {
1355 return a.load(i, j) + b.load(i, j);
1356 });
1357 LoopNest l({c});
1358 l.prepareForCodegen();
1359 StmtPtr s = l.root_stmt();
1360 LLVMCodeGen cg(s, {a, b, c, m, n});
1361 std::vector<float> aData(M * N, 1.0f);
1362 std::vector<float> bData(M * N, 2.0f);
1363 std::vector<float> cData(M * N, 0.0f);
1364 cg.call({aData, bData, cData, M, N});
1365 ExpectAllNear(cData, std::vector<float>(M * N, 3.0f), 1e-7);
1366 };
1367 testWithSize(1, 8);
1368 testWithSize(16, 32);
1369 testWithSize(37, 11);
1370 }
1371
TEST(LLVM,EmptyStmt)1372 TEST(LLVM, EmptyStmt) {
1373 StmtPtr s = alloc<Block>(std::vector<StmtPtr>({}));
1374
1375 LLVMCodeGen cg(s, {});
1376 cg.call({});
1377 // Just don't crash.
1378 }
1379
TEST(LLVM,EliminatedStmt)1380 TEST(LLVM, EliminatedStmt) {
1381 BufHandle a("a", {1}, kFloat);
1382
1383 Tensor c = Compute("c", {0}, [&](const VarHandle& m) { return m; });
1384
1385 LoopNest l({c});
1386 l.prepareForCodegen();
1387 StmtPtr s = l.root_stmt();
1388 s = IRSimplifier::simplify(s);
1389 LLVMCodeGen cg(s, {a, c});
1390 std::vector<float> aData(1, 1.0f);
1391 std::vector<float> cData(0, 0.0f);
1392 cg.call({aData, cData});
1393 }
1394
TEST(LLVM,SimpleReduction)1395 TEST(LLVM, SimpleReduction) {
1396 int M = 128;
1397 int N = 64;
1398
1399 BufHandle a("a", {1, M, N}, kFloat);
1400
1401 Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
1402 LoopNest loop({b});
1403
1404 loop.prepareForCodegen();
1405 StmtPtr s = loop.root_stmt();
1406 s = IRSimplifier::simplify(s);
1407
1408 LLVMCodeGen cg(s, {a, b});
1409
1410 PaddedBuffer<float> a_v(1, M, N, "a_v");
1411 PaddedBuffer<float> b_v(1, "b_v");
1412 PaddedBuffer<float> b_ref(1, "b_ref");
1413
1414 b_ref(0) = 0;
1415 for (const auto i : c10::irange(M)) {
1416 for (const auto j : c10::irange(N)) {
1417 int v = i + j;
1418 a_v(0, i, j) = v;
1419 b_ref(0) += v;
1420 }
1421 }
1422
1423 cg.call({a_v, b_v});
1424
1425 ExpectAllNear(b_v, b_ref, 1e-5);
1426 }
1427
TEST(LLVM,RFactorReduction)1428 TEST(LLVM, RFactorReduction) {
1429 int M = 128;
1430 int N = 64;
1431
1432 BufHandle a("a", {1, M, N}, kFloat);
1433
1434 Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
1435 LoopNest loop({b});
1436
1437 std::vector<ForPtr> loops = loop.getLoopStmtsFor(b);
1438 ForPtr loop_m = loops.at(1);
1439 ForPtr loop_n = loops.at(2);
1440 loop.reorderAxis(loop_m, loop_n);
1441
1442 loops = loop.getLoopStmtsFor(b);
1443 loop_m = loops.at(2);
1444 loop_n = loops.at(1);
1445 auto b_body = loop.getAllWritesToBuf(b.buf())[1];
1446 ASSERT_TRUE(loop.rfactor(b_body, loop_n));
1447
1448 loop.prepareForCodegen();
1449 StmtPtr s = loop.root_stmt();
1450 s = IRSimplifier::simplify(s);
1451
1452 LLVMCodeGen cg(s, {a, b});
1453
1454 PaddedBuffer<float> a_v(1, M, N, "a_v");
1455 PaddedBuffer<float> b_v(1, "b_v");
1456 PaddedBuffer<float> b_ref(1, "b_ref");
1457
1458 b_ref(0) = 0;
1459 for (const auto i : c10::irange(M)) {
1460 for (const auto j : c10::irange(N)) {
1461 int v = i + j;
1462 a_v(0, i, j) = v;
1463 b_ref(0) += v;
1464 }
1465 }
1466
1467 cg.call({a_v, b_v});
1468
1469 ExpectAllNear(b_v, b_ref, 1e-5);
1470 }
1471
TEST(LLVM,RFactorVectorizedReduction)1472 TEST(LLVM, RFactorVectorizedReduction) {
1473 int M = 128;
1474 int N = 64;
1475
1476 BufHandle a("a", {1, M, N}, kFloat);
1477
1478 Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
1479 LoopNest loopnest({b});
1480 std::vector<ForPtr> loops = loopnest.getLoopStmtsFor(b);
1481 // Reorder n and m loops
1482 loopnest.reorderAxis(loops.at(1), loops.at(2));
1483 auto b_body = loopnest.getAllWritesToBuf(b.buf()).at(1);
1484 auto all_loops = loopnest.getAllLoopNestsWritingToBuf(b.buf());
1485 ASSERT_TRUE(all_loops.size() == 2 && all_loops[1].size() == 3);
1486 ASSERT_TRUE(loopnest.rfactor(b_body, all_loops[1][1]));
1487 auto distributed_loops = loopnest.distributeLoop(all_loops[1][1]);
1488
1489 // Vectorize initializer of rfac_buf
1490 ASSERT_TRUE(LoopNest::vectorize(distributed_loops[0]));
1491 // Vectorize producer of rfac_buf
1492 ASSERT_TRUE(LoopNest::vectorize(distributed_loops[1]));
1493 loopnest.simplify();
1494
1495 loopnest.prepareForCodegen();
1496
1497 StmtPtr s = IRSimplifier::simplify(loopnest.root_stmt());
1498 LLVMCodeGen cg(s, {a, b});
1499
1500 PaddedBuffer<float> a_v(1, M, N, "a_v");
1501 PaddedBuffer<float> b_v(1, "b_v");
1502 PaddedBuffer<float> b_ref(1, "b_ref");
1503
1504 b_ref(0) = 0;
1505 for (const auto i : c10::irange(M)) {
1506 for (const auto j : c10::irange(N)) {
1507 int v = i + j;
1508 a_v(0, i, j) = v;
1509 b_ref(0) += v;
1510 }
1511 }
1512
1513 cg.call({a_v, b_v});
1514
1515 ExpectAllNear(b_v, b_ref, 1e-5);
1516 }
1517
1518 template <bool outer, bool inner>
testSimpleParallel()1519 static void testSimpleParallel() {
1520 // Compute a simple operation, and try all loop-axis combination to be
1521 // parallel or sequential.
1522 const int M = 4;
1523 const int N = 6;
1524 Tensor f = Compute("f", {M, N}, [](const VarHandle& m, const VarHandle& n) {
1525 return cast<float>(m + n);
1526 });
1527 LoopNest loop_nest({f});
1528 auto const& loops = loop_nest.getLoopStmtsFor(f);
1529 ForPtr m = loops[0];
1530 ForPtr n = loops[1];
1531 if (outer) {
1532 m->set_parallel();
1533 }
1534 if (inner) {
1535 n->set_parallel();
1536 }
1537 loop_nest.prepareForCodegen();
1538 StmtPtr stmt = loop_nest.root_stmt();
1539 LLVMCodeGen cg(stmt, {f});
1540
1541 PaddedBuffer<float> f_v(M, N, "f_v");
1542 std::vector<void*> args({f_v.data()});
1543 int value = cg.value<int>(args);
1544 ASSERT_EQ(value, 0);
1545 PaddedBuffer<float> f_ref(M, N, "f_ref");
1546 for (const auto m : c10::irange(M)) {
1547 for (const auto n : c10::irange(N)) {
1548 f_ref(m, n) = m + n;
1549 }
1550 }
1551 ExpectAllNear(f_v, f_ref, 1e-5);
1552 }
1553
TEST(LLVM,SimpleParallelSS)1554 TEST(LLVM, SimpleParallelSS) {
1555 testSimpleParallel<false, false>();
1556 }
TEST(LLVM,SimpleParallelSP)1557 TEST(LLVM, SimpleParallelSP) {
1558 testSimpleParallel<false, true>();
1559 }
TEST(LLVM,SimpleParallelPS)1560 TEST(LLVM, SimpleParallelPS) {
1561 testSimpleParallel<true, false>();
1562 }
TEST(LLVM,SimpleParallelPP)1563 TEST(LLVM, SimpleParallelPP) {
1564 testSimpleParallel<true, true>();
1565 }
1566
TEST(LLVM,CompositeParallel)1567 TEST(LLVM, CompositeParallel) {
1568 int loop_count = 6;
1569 int test_count = 1 << loop_count;
1570 // Compute a composite operation, and try all loop-axis combination to be
1571 // parallel or sequential.
1572 for (const auto test_cfg : c10::irange(test_count)) {
1573 int M = 5;
1574 int N = 7;
1575 Tensor t1 = Compute("t1", {M}, [](const VarHandle& m) { return m + 1.f; });
1576 Tensor t2 = Compute("t2", {N}, [](const VarHandle& n) { return n + 2.f; });
1577 Tensor t3 =
1578 Compute("t3", {M, N}, [=](const VarHandle& m, const VarHandle& n) {
1579 return t1.load(m) * t2.load(n);
1580 });
1581 Tensor t4 =
1582 Compute("t4", {M, N}, [=](const VarHandle& m, const VarHandle& n) {
1583 return t3.load(m, n) + m + n;
1584 });
1585 LoopNest loop_nest({t4}, {t1, t2, t3, t4});
1586 std::vector<ForPtr> loop_list;
1587 {
1588 auto const& loops = loop_nest.getLoopStmtsFor(t1);
1589 loop_list.push_back(loops[0]);
1590 }
1591 {
1592 auto const& loops = loop_nest.getLoopStmtsFor(t2);
1593 loop_list.push_back(loops[0]);
1594 }
1595 {
1596 auto const& loops = loop_nest.getLoopStmtsFor(t3);
1597 loop_list.push_back(loops[0]);
1598 loop_list.push_back(loops[1]);
1599 }
1600 {
1601 auto const& loops = loop_nest.getLoopStmtsFor(t4);
1602 loop_list.push_back(loops[0]);
1603 loop_list.push_back(loops[1]);
1604 }
1605 ASSERT_EQ(loop_list.size(), loop_count);
1606 for (const auto i : c10::irange(loop_count)) {
1607 if (test_cfg & (1 << i)) {
1608 loop_list[i]->set_parallel();
1609 }
1610 }
1611 loop_nest.prepareForCodegen();
1612 StmtPtr stmt = loop_nest.root_stmt();
1613 LLVMCodeGen cg(stmt, {t4});
1614
1615 PaddedBuffer<float> t4_v(M, N, "t4_v");
1616 std::vector<void*> args({t4_v.data()});
1617 int value = cg.value<int>(args);
1618 ASSERT_EQ(value, 0);
1619 PaddedBuffer<float> t4_ref(M, N, "t4_ref");
1620 for (const auto m : c10::irange(M)) {
1621 for (const auto n : c10::irange(N)) {
1622 t4_ref(m, n) = (m + 1) * (n + 2) + m + n;
1623 }
1624 }
1625 ExpectAllNear(t4_v, t4_ref, 1e-5);
1626 }
1627 }
1628
TEST(LLVM,VectorizedGEMM)1629 TEST(LLVM, VectorizedGEMM) {
1630 int M = 32;
1631 int N = 32;
1632 int K = 48;
1633
1634 BufHandle AP("A", {M, K}, kFloat);
1635 BufHandle BP("B", {K, N}, kFloat);
1636 Tensor CT = Reduce(
1637 "gemm",
1638 {M, N},
1639 Sum(),
1640 [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
1641 return AP.load(m, k) * BP.load(k, n);
1642 },
1643 {K});
1644 LoopNest loop({CT});
1645
1646 {
1647 auto const& loops = loop.getLoopStmtsFor(CT);
1648 ForPtr m = loops[0];
1649 loop.splitWithMask(m, 16);
1650 }
1651 {
1652 auto const& loops = loop.getLoopStmtsFor(CT);
1653 ForPtr n = loops[2];
1654 loop.splitWithMask(n, 16);
1655 }
1656 // mo, mi, no, ni, k ->
1657 // mo, no, mi, ni, k
1658 {
1659 auto const& loops = loop.getLoopStmtsFor(CT);
1660 ForPtr mi = loops[1];
1661 ForPtr no = loops[2];
1662 loop.reorderAxis(mi, no);
1663 }
1664 // mo, no, mi, ni, k ->
1665 // mo, no, mi, k, ni
1666 {
1667 auto const& loops = loop.getLoopStmtsFor(CT);
1668 ForPtr ni = loops[3];
1669 ForPtr k = loops[4];
1670 loop.reorderAxis(ni, k);
1671 }
1672 // mo, no, mi, k, ni ->
1673 // mo, no, k, mi, ni
1674 {
1675 auto const& loops = loop.getLoopStmtsFor(CT);
1676 ForPtr mi = loops[2];
1677 ForPtr k = loops[3];
1678 loop.reorderAxis(mi, k);
1679 }
1680 {
1681 auto loops = NodeFinder<For>::find(loop.root_stmt());
1682 ASSERT_TRUE(LoopNest::vectorize(loops[3]));
1683 ASSERT_TRUE(LoopNest::vectorize(loops.back()));
1684 }
1685
1686 loop.prepareForCodegen();
1687
1688 StmtPtr s = loop.root_stmt();
1689 s = IRSimplifier::simplify(s);
1690 LLVMCodeGen cg(s, {AP, BP, CT});
1691
1692 PaddedBuffer<float> a_v(M, K, "a_v");
1693 PaddedBuffer<float> b_v(K, N, "b_v");
1694 PaddedBuffer<float> c_v(M, N, "c_v");
1695 PaddedBuffer<float> c_ref(M, N, "c_ref");
1696
1697 for (const auto m : c10::irange(M)) {
1698 for (const auto n : c10::irange(N)) {
1699 c_ref(m, n) = 0.f;
1700 for (const auto k : c10::irange(K)) {
1701 c_ref(m, n) += a_v(m, k) * b_v(k, n);
1702 }
1703 }
1704 }
1705
1706 cg.call({a_v, b_v, c_v});
1707
1708 ExpectAllNear(c_v, c_ref, 1e-5);
1709 }
1710
TEST(LLVM,CallRaw)1711 TEST(LLVM, CallRaw) {
1712 const int M = 32;
1713 VarHandle N("N", kInt);
1714 BufHandle a("a", {M, N}, kFloat);
1715 BufHandle b("b", {N}, kFloat);
1716 Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) {
1717 return a.load(i, j) + b.load(j);
1718 });
1719
1720 LoopNest l({c});
1721 l.prepareForCodegen();
1722 StmtPtr s = l.root_stmt();
1723
1724 int32_t N_value = 1024;
1725 std::vector<float> av(M * N_value);
1726 std::iota(av.begin(), av.end(), 0);
1727 std::vector<float> bv(N_value);
1728 std::iota(bv.begin(), bv.end(), 0);
1729 std::vector<float> cv(M * N_value, 0);
1730 std::vector<void*> args({av.data(), bv.data(), cv.data(), &N_value});
1731
1732 LLVMCodeGen cg(s, {a, b, BufHandle(c.buf()), N});
1733 cg.call_raw(args);
1734
1735 for (const auto i : c10::irange(M)) {
1736 for (const auto j : c10::irange(N_value)) {
1737 ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]);
1738 }
1739 }
1740
1741 SimpleIREvaluator eval(s, {a, b, BufHandle(c.buf()), N});
1742 eval.call_raw(args);
1743
1744 for (const auto i : c10::irange(M)) {
1745 for (const auto j : c10::irange(N_value)) {
1746 ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]);
1747 }
1748 }
1749 }
1750
TEST(LLVM,CustomTarget)1751 TEST(LLVM, CustomTarget) {
1752 constexpr int M = 16;
1753 BufHandle a("a", {M}, kFloat);
1754 BufHandle b("b", {M}, kFloat);
1755 BufHandle c("c", {M}, kFloat);
1756 Tensor d = Compute("d", {M}, [&](const VarHandle& m) {
1757 return a.load(m) * b.load(m) + c.load(m);
1758 });
1759 LoopNest nest({d});
1760 nest.prepareForCodegen();
1761 auto cg = LLVMCodeGenBuilder(nest.root_stmt(), {a, b, c, d})
1762 .triple("i686-elf")
1763 .cpu("i386")
1764 .build();
1765 std::ostringstream ss;
1766 ss << cg->getCodeText("asm");
1767 torch::jit::testing::FileCheck()
1768 .check("fadds")
1769 ->check("fmuls")
1770 ->check_not("vfmadd")
1771 ->run(ss.str());
1772 }
1773
TEST(LLVM,CodeGenKernelFuncName)1774 TEST(LLVM, CodeGenKernelFuncName) {
1775 BufHandle a("A", {1}, kInt);
1776 BufHandle b("B", {1}, kInt);
1777 std::vector<int32_t> a_buffer = {42};
1778 std::vector<int32_t> b_buffer = {-11};
1779 auto store = b.store({0}, a.load(0));
1780
1781 {
1782 LLVMCodeGen cg(store, {a, b});
1783 // Check that the kernel function name used by LLVMCodeGen
1784 // is not empty.
1785 ASSERT_NE(cg.kernel_func_name(), "");
1786 }
1787
1788 {
1789 LLVMCodeGen cg(store, {a, b}, at::kCPU, "new_func");
1790 // Check that the kernel function name used by LLVMCodeGen
1791 // is the one that was given above.
1792 ASSERT_EQ(cg.kernel_func_name(), "new_func");
1793 }
1794 }
1795
1796 } // namespace jit
1797 } // namespace torch
1798
1799 #endif // TORCH_ENABLE_LLVM
1800