xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_memplanning.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <test/cpp/tensorexpr/test_base.h>
3 
4 #include <c10/util/irange.h>
5 #include <test/cpp/tensorexpr/padded_buffer.h>
6 #include <torch/csrc/jit/tensorexpr/ir.h>
7 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
8 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
9 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
10 #include <torch/csrc/jit/tensorexpr/loopnest.h>
11 #include <torch/csrc/jit/tensorexpr/tensor.h>
12 
13 namespace torch {
14 namespace jit {
15 
16 using namespace torch::jit::tensorexpr;
17 
18 extern void checkIR(StmtPtr s, const std::string& pattern);
19 
TEST(BufLiveRange,SingleRangeLine)20 TEST(BufLiveRange, SingleRangeLine) {
21   VarHandle i("i", kInt), j("j", kInt);
22   BufHandle a("a", {32}, kFloat);
23   BufHandle b("b", {32, 32}, kFloat);
24 
25   // Construct Stmt:
26   // {
27   //   for (int i = 0; i < 32; i++) {
28   //     a[i] = 0;
29   //     for (int j = 0; j < 32; j++) {
30   //       a[i] = (a[i]) + (b[i, j]);
31   //     }
32   //   }
33   // }
34 
35   StorePtr aInit = Store::make(a, {i}, 0);
36   ExprHandle reduce = a.load({i}) + b.load({i, j});
37   StorePtr aReduce = Store::make(a, {i}, reduce);
38   StmtPtr loop =
39       For::make(i, 0, 32, Block::make({aInit, For::make(j, 0, 32, aReduce)}));
40 
41   StmtPtr stmt = Block::make({loop});
42 
43   auto range = BufLiveRange::liveRange(stmt, a.node());
44   ASSERT_TRUE(std::get<0>(range) == 0);
45   ASSERT_TRUE(std::get<1>(range) == 0);
46 }
47 
TEST(BufLiveRange,MulRangeLine)48 TEST(BufLiveRange, MulRangeLine) {
49   VarHandle i("i", kInt);
50   BufHandle a("a", {32}, kFloat);
51   BufHandle b("b", {32}, kFloat);
52 
53   // Construct Stmt:
54   // {
55   //   for (int i = 0; i < 32; i++) {
56   //     if (i<10 ? 1 : 0) {
57   //       a[i] = i + i;
58   //       b[i] = i * i;
59   //     }
60   //   }
61   //   for (int i = 0; i < 32; i++) {
62   //     if (i>10 ? 1 : 0) {
63   //       a[i] = i * i;
64   //       b[i] = i + i;
65   //     }
66   //   }
67   // }
68 
69   StorePtr aStore_1 = Store::make(a, {i}, i + i);
70   StorePtr bStore_1 = Store::make(b, {i}, i * i);
71   StmtPtr loop_1 = For::make(
72       i, 0, 32, Cond::make(i < 10, Block::make({aStore_1, bStore_1}), NULL));
73 
74   StorePtr aStore_2 = Store::make(a, {i}, i * i);
75   StorePtr bStore_2 = Store::make(b, {i}, i + i);
76   StmtPtr loop_2 = For::make(
77       i, 0, 32, Cond::make(i > 10, Block::make({aStore_2, bStore_2}), NULL));
78 
79   StmtPtr stmt = Block::make({loop_1, loop_2});
80 
81   auto range_a = BufLiveRange::liveRange(stmt, a.node());
82   ASSERT_TRUE(std::get<0>(range_a) == 0);
83   ASSERT_TRUE(std::get<1>(range_a) == 1);
84 
85   auto range_b = BufLiveRange::liveRange(stmt, b.node());
86   ASSERT_TRUE(std::get<0>(range_b) == 0);
87   ASSERT_TRUE(std::get<1>(range_b) == 1);
88 }
89 
TEST(MemPlanning,MemReuseWithTypeCast)90 TEST(MemPlanning, MemReuseWithTypeCast) {
91   int M = 4;
92   int N = 4;
93   int K = 4;
94 
95   BufHandle AP("A", {M, K}, kFloat);
96   BufHandle BP("B", {K, N}, kFloat);
97 
98   Tensor CT = Reduce(
99       "gemm",
100       {M, N},
101       Sum(),
102       [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
103         return AP.load(m, k) * BP.load(k, n);
104       },
105       {K});
106   Tensor DT =
107       Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
108         return CompareSelect::make(
109             CT.load(m, n), 0.0f, 0.0f, CT.load(m, n), kLT);
110       });
111   Tensor ET =
112       Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
113         return Cast::make(kQUInt8, DT.load(m, n) + DT.load(m, n));
114       });
115   Tensor FT =
116       Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
117         return ET.load(m, n);
118       });
119   StmtPtr stmt =
120       tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
121 
122   // Constructed stmt:
123   // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
124   // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are
125   // different: 'E' type quint8 < 'gemm' type float. We'll reuse 'gemm' for 'E'
126   // with typecasting.
127   //{
128   //  for (int i = 0; i < 4; i++) {
129   //    for (int i_1 = 0; i_1 < 4; i_1++) {
130   //      gemm[i, i_1] = float(0);
131   //      for (int i_2 = 0; i_2 < 4; i_2++) {
132   //        gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2,
133   //        i_1]), reduce_args={i_2});
134   //      }
135   //    }
136   //  }
137   //  for (int i_3 = 0; i_3 < 4; i_3++) {
138   //    for (int i_4 = 0; i_4 < 4; i_4++) {
139   //      relu[i_3, i_4] = (gemm[i_3, i_4])<0.f ? 0.f : (gemm[i_3, i_4]);
140   //    }
141   //  }
142   //  for (int i_5 = 0; i_5 < 4; i_5++) {
143   //    for (int i_6 = 0; i_6 < 4; i_6++) {
144   //      E[i_5, i_6] = quint8((relu[i_5, i_6]) + (relu[i_5, i_6]));
145   //    }
146   //  }
147   //  for (int i_7 = 0; i_7 < 4; i_7++) {
148   //    for (int i_8 = 0; i_8 < 4; i_8++) {
149   //      F[i_7, i_8] = E[i_7, i_8];
150   //    }
151   //  }
152   //}
153 
154   LoopNest l(stmt, {FT.buf()});
155   l.prepareForCodegen();
156   SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT});
157 
158   checkIR(cg.stmt(), R"IR(
159 # CHECK: Allocate(gemm); // dtype=float, dims=[4, 4]
160 # CHECK: Allocate(relu); // dtype=float, dims=[4, 4]
161 # CHECK: Alias(E,gemm);
162 # CHECK: Free(relu);
163 # CHECK: Free(gemm))IR");
164 
165   PaddedBuffer<float> a_v(M, K, "a");
166   PaddedBuffer<float> b_v(K, N, "b");
167   PaddedBuffer<uint8_t> o1(M, N, "e_before");
168   PaddedBuffer<uint8_t> o2(M, N, "e_after");
169 
170   for (const auto m : c10::irange(M)) {
171     for (const auto k : c10::irange(K)) {
172       a_v(m, k) = at::randn({1}).item().to<float>();
173     }
174   }
175 
176   for (const auto k : c10::irange(K)) {
177     for (const auto n : c10::irange(N)) {
178       b_v(k, n) = at::randn({1}).item().to<float>();
179     }
180   }
181 
182   cg.call({a_v, b_v, o1});
183 
184 #ifdef TORCH_ENABLE_LLVM
185   LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT});
186 
187   checkIR(cg_llvm.stmt(), R"IR(
188 # CHECK: Allocate(gemm); // dtype=float, dims=[4, 4]
189 # CHECK: Allocate(relu); // dtype=float, dims=[4, 4]
190 # CHECK: Alias(E,gemm);
191 # CHECK: Free(relu);
192 # CHECK: Free(gemm))IR");
193 
194   cg_llvm.call({a_v, b_v, o2});
195 
196   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
197   ExpectAllNear(o1, o2, 1e-5);
198 #endif
199 }
200 
TEST(MemPlanning,NoMemReuseForLargerType)201 TEST(MemPlanning, NoMemReuseForLargerType) {
202   int M = 4;
203   int N = 4;
204   int K = 4;
205 
206   BufHandle AP("A", {M, K}, kShort);
207   BufHandle BP("B", {K, N}, kShort);
208 
209   Tensor CT = Reduce(
210       "gemm",
211       {M, N},
212       Sum(),
213       [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
214         return AP.load(m, k) * BP.load(k, n);
215       },
216       {K});
217   auto zero = Cast::make(CT.buf()->dtype(), 0);
218   Tensor DT =
219       Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
220         return CompareSelect::make(
221             CT.load(m, n), zero, zero, CT.load(m, n), kLT);
222       });
223   Tensor ET =
224       Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
225         return Cast::make(kFloat, DT.load(m, n) + DT.load(m, n));
226       });
227   Tensor FT =
228       Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
229         return ET.load(m, n);
230       });
231   StmtPtr stmt =
232       tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
233 
234   // Constructed stmt:
235   // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
236   // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are
237   // different: 'E' type float > 'gemm' type int16. We won't reuse 'gemm' for
238   // 'E'.
239   //{
240   //  for (int i = 0; i < 4; i++) {
241   //    for (int i_1 = 0; i_1 < 4; i_1++) {
242   //      gemm[i, i_1] = int16_t(0);
243   //      for (int i_2 = 0; i_2 < 4; i_2++) {
244   //        gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2,
245   //        i_1]), reduce_args={i_2});
246   //      }
247   //    }
248   //  }
249   //  for (int i_3 = 0; i_3 < 4; i_3++) {
250   //    for (int i_4 = 0; i_4 < 4; i_4++) {
251   //      relu[i_3, i_4] = (gemm[i_3, i_4])<int16_t(0) ? int16_t(0) : (gemm[i_3,
252   //      i_4]);
253   //    }
254   //  }
255   //  for (int i_5 = 0; i_5 < 4; i_5++) {
256   //    for (int i_6 = 0; i_6 < 4; i_6++) {
257   //      E[i_5, i_6] = float((relu[i_5, i_6]) + (relu[i_5, i_6]));
258   //    }
259   //  }
260   //  for (int i_7 = 0; i_7 < 4; i_7++) {
261   //    for (int i_8 = 0; i_8 < 4; i_8++) {
262   //      F[i_7, i_8] = E[i_7, i_8];
263   //    }
264   //  }
265   //}
266 
267   LoopNest l(stmt, {FT.buf()});
268   l.prepareForCodegen();
269   SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT.buf()});
270 
271   checkIR(cg.stmt(), R"IR(
272 # CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4]
273 # CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4]
274 # CHECK: Allocate(E); // dtype=float, dims=[4, 4]
275 # CHECK: Free(E);
276 # CHECK: Free(relu);
277 # CHECK: Free(gemm))IR");
278 
279   PaddedBuffer<short> a_v(M, K, "a");
280   PaddedBuffer<short> b_v(K, N, "b");
281   PaddedBuffer<float> o1(M, N, "e_before");
282   PaddedBuffer<float> o2(M, N, "e_after");
283 
284   for (const auto m : c10::irange(M)) {
285     for (const auto k : c10::irange(K)) {
286       a_v(m, k) = at::randn({1}).item().to<float>();
287     }
288   }
289 
290   for (const auto k : c10::irange(K)) {
291     for (const auto n : c10::irange(N)) {
292       b_v(k, n) = at::randn({1}).item().to<float>();
293     }
294   }
295 
296   cg.call({a_v, b_v, o1});
297 
298 #ifdef TORCH_ENABLE_LLVM
299   LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT});
300 
301   checkIR(cg_llvm.stmt(), R"IR(
302 # CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4]
303 # CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4]
304 # CHECK: Allocate(E); // dtype=float, dims=[4, 4]
305 # CHECK: Free(E);
306 # CHECK: Free(relu);
307 # CHECK: Free(gemm))IR");
308 
309   cg_llvm.call({a_v, b_v, o2});
310 
311   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
312   ExpectAllNear(o1, o2, 1e-5);
313 #endif
314 }
315 
TEST(MemPlanning,SameBufSizeMemReuse)316 TEST(MemPlanning, SameBufSizeMemReuse) {
317   int M = 1024;
318   int N = 1024;
319   int K = 2048;
320 
321   BufHandle AP("A", {M, K}, kFloat);
322   BufHandle BP("B", {K, N}, kFloat);
323 
324   Tensor CT = Reduce(
325       "gemm",
326       {M, N},
327       Sum(),
328       [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
329         return AP.load(m, k) * BP.load(k, n);
330       },
331       {K});
332   Tensor DT =
333       Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
334         auto zero = Cast::make(CT.buf()->dtype(), 0);
335         return CompareSelect::make(
336             CT.load(m, n), zero, zero, CT.load(m, n), kLT);
337       });
338   Tensor ET =
339       Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
340         return DT.load(m, n) + DT.load(m, n);
341       });
342   Tensor FT =
343       Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
344         return ET.load(m, n) * ET.load(m, n);
345       });
346   auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
347 
348   // Constructed stmt:
349   // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
350   // add [2, 3] Buffer 'gemm' and 'add' are the same size; we'll reuse 'gemm'
351   // for 'add'.
352   //{
353   //  for (int M = 0; M < 1024; M++) {
354   //    for (int N = 0; N < 1024; N++) {
355   //      gemm[M, N] = float(0);
356   //      for (int K = 0; K < 2048; K++) {
357   //        gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]),
358   //        reduce_args={K});
359   //      }
360   //    }
361   //  }
362   //  for (int M_1 = 0; M_1 < 1024; M_1++) {
363   //    for (int N_1 = 0; N_1 < 1024; N_1++) {
364   //      relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1,
365   //      N_1]);
366   //    }
367   //  }
368   //  for (int M_2 = 0; M_2 < 1024; M_2++) {
369   //    for (int N_2 = 0; N_2 < 1024; N_2++) {
370   //      add[M_2, N_2] = (relu[M_2, N_2]) + (relu[M_2, N_2]);
371   //    }
372   //  }
373   //  for (int M_3 = 0; M_3 < 1024; M_3++) {
374   //    for (int N_3 = 0; N_3 < 1024; N_3++) {
375   //      mul[M_3, N_3] = (add[M_3, N_3]) * (add[M_3, N_3]);
376   //    }
377   //  }
378   //}
379 
380   SimpleIREvaluator cg(stmt, {AP, BP, FT});
381 
382   checkIR(cg.stmt(), R"IR(
383 # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
384 # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
385 # CHECK: Alias(add,gemm);
386 # CHECK: Free(relu);
387 # CHECK: Free(gemm))IR");
388 
389 #ifdef TORCH_ENABLE_LLVM
390   LoopNest loop(Stmt::clone(stmt), {FT.buf()});
391   loop.prepareForCodegen();
392   LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT});
393 
394   checkIR(cg_llvm.stmt(), R"IR(
395 # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
396 # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
397 # CHECK: Alias(add,gemm);
398 # CHECK: Free(relu);
399 # CHECK: Free(gemm))IR");
400 #endif
401 }
402 
TEST(MemPlanning,SameBufSizeMultiMemReuses)403 TEST(MemPlanning, SameBufSizeMultiMemReuses) {
404   int M = 1024;
405   int N = 1024;
406   int K = 2048;
407 
408   BufHandle AP("A", {M, K}, kFloat);
409   BufHandle BP("B", {K, N}, kFloat);
410 
411   Tensor CT = Reduce(
412       "gemm",
413       {M, N},
414       Sum(),
415       [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
416         return AP.load(m, k) * BP.load(k, n);
417       },
418       {K});
419   Tensor DT =
420       Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
421         auto zero = Cast::make(CT.buf()->dtype(), 0);
422         return CompareSelect::make(
423             CT.load(m, n), zero, zero, CT.load(m, n), kLT);
424       });
425   Tensor ET =
426       Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
427         return DT.load(m, n) + DT.load(m, n);
428       });
429   Tensor FT =
430       Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
431         return ET.load(m, n) * ET.load(m, n);
432       });
433   Tensor GT =
434       Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
435         return FT.load(m, n) - ET.load(m, n);
436       });
437 
438   auto stmt =
439       Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt()});
440 
441   // Constructed stmt:
442   // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
443   // add [2, 3], mul [3, 4] Buffer 'gemm', 'relu, ''add' and 'mul' are the same
444   // size; we'll reuse 'gemm' for 'add', and reuse 'relu' for 'mul'
445   //{
446   //  for (int M = 0; M < 1024; M++) {
447   //    for (int N = 0; N < 1024; N++) {
448   //      gemm[M, N] = float(0);
449   //      for (int K = 0; K < 2048; K++) {
450   //        gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]),
451   //        reduce_args={K});
452   //      }
453   //    }
454   //  }
455   //  for (int M_1 = 0; M_1 < 1024; M_1++) {
456   //    for (int N_1 = 0; N_1 < 1024; N_1++) {
457   //      relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1,
458   //      N_1]);
459   //    }
460   //  }
461   //  for (int M_2 = 0; M_2 < 1024; M_2++) {
462   //    for (int N_2 = 0; N_2 < 1024; N_2++) {
463   //      add[M_2, N_2] = (relu[M_2, N_2]) + (relu[M_2, N_2]);
464   //    }
465   //  }
466   //  for (int M_3 = 0; M_3 < 1024; M_3++) {
467   //    for (int N_3 = 0; N_3 < 1024; N_3++) {
468   //      mul[M_3, N_3] = (add[M_3, N_3]) * (add[M_3, N_3]);
469   //    }
470   //  }
471   //  for (int M_4 = 0; M_4 < 1024; M_4++) {
472   //    for (int N_4 = 0; N_4 < 1024; N_4++) {
473   //      sub[M_4, N_4] = (mul[M_4, N_4]) - (add[M_4, N_4]);
474   //    }
475   //  }
476   //}
477 
478   SimpleIREvaluator cg(stmt, {AP, BP, GT});
479 
480   checkIR(cg.stmt(), R"IR(
481 # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
482 # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
483 # CHECK: Alias(add,gemm);
484 # CHECK: Alias(mul,relu);
485 # CHECK: Free(relu);
486 # CHECK: Free(gemm))IR");
487 
488 #ifdef TORCH_ENABLE_LLVM
489   LoopNest loop(Stmt::clone(stmt), {FT.buf()});
490   loop.prepareForCodegen();
491   LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT});
492 
493   checkIR(cg_llvm.stmt(), R"IR(
494 # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
495 # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
496 # CHECK: Alias(add,gemm);
497 # CHECK: Alias(mul,relu);
498 # CHECK: Free(relu);
499 # CHECK: Free(gemm))IR");
500 #endif
501 }
502 
TEST(MemPlanning,SameBufSizeMultiMemReusesOfOneBuf)503 TEST(MemPlanning, SameBufSizeMultiMemReusesOfOneBuf) {
504   int M = 1024;
505   int N = 1024;
506   int K = 2048;
507 
508   BufHandle AP("A", {M, K}, kFloat);
509   BufHandle BP("B", {K, N}, kFloat);
510 
511   Tensor CT = Reduce(
512       "gemm",
513       {M, N},
514       Sum(),
515       [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
516         return AP.load(m, k) * BP.load(k, n);
517       },
518       {K});
519   Tensor DT =
520       Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
521         auto zero = Cast::make(CT.buf()->dtype(), 0);
522         return CompareSelect::make(
523             CT.load(m, n), zero, zero, CT.load(m, n), kLT);
524       });
525   Tensor ET =
526       Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
527         return DT.load(m, n) + DT.load(m, n);
528       });
529   Tensor FT =
530       Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
531         return ET.load(m, n) * ET.load(m, n);
532       });
533   Tensor GT =
534       Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
535         return FT.load(m, n) - 1;
536       });
537   Tensor HT =
538       Compute("div", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
539         return GT.load(m, n) / 2;
540       });
541 
542   auto stmt = Block::make(
543       {CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt(), HT.stmt()});
544 
545   // Constructed stmt:
546   // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
547   // add [2, 3], mul [3, 4], sub [4, 5] Buffer 'gemm', 'relu, ''add', 'mul' and
548   // 'sub' are the same size; we'll reuse 'gemm' for 'add', reuse 'relu' for
549   // 'mul', and reuse 'gemm' for 'sub'.
550   //{
551   //  for (int M = 0; M < 1024; M++) {
552   //    for (int N = 0; N < 1024; N++) {
553   //      gemm[M, N] = float(0);
554   //      for (int K = 0; K < 2048; K++) {
555   //        gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]),
556   //        reduce_args={K});
557   //      }
558   //    }
559   //  }
560   //  for (int M_1 = 0; M_1 < 1024; M_1++) {
561   //    for (int N_1 = 0; N_1 < 1024; N_1++) {
562   //      relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1,
563   //      N_1]);
564   //    }
565   //  }
566   //  for (int M_2 = 0; M_2 < 1024; M_2++) {
567   //    for (int N_2 = 0; N_2 < 1024; N_2++) {
568   //      add[M_2, N_2] = (relu[M_2, N_2]) + (relu[M_2, N_2]);
569   //    }
570   //  }
571   //  for (int M_3 = 0; M_3 < 1024; M_3++) {
572   //    for (int N_3 = 0; N_3 < 1024; N_3++) {
573   //      mul[M_3, N_3] = (add[M_3, N_3]) * (add[M_3, N_3]);
574   //    }
575   //  }
576   //  for (int M_4 = 0; M_4 < 1024; M_4++) {
577   //    for (int N_4 = 0; N_4 < 1024; N_4++) {
578   //      sub[M_4, N_4] = (mul[M_4, N_4]) - float(1);
579   //    }
580   //  }
581   //  for (int M_5 = 0; M_5 < 1024; M_5++) {
582   //    for (int N_5 = 0; N_5 < 1024; N_5++) {
583   //      div[M_5, N_5] = (sub[M_5, N_5]) / float(2);
584   //    }
585   //  }
586   //}
587 
588   SimpleIREvaluator cg(stmt, {AP, BP, HT});
589 
590   checkIR(cg.stmt(), R"IR(
591 # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
592 # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
593 # CHECK: Alias(add,gemm);
594 # CHECK: Alias(mul,relu);
595 # CHECK: Alias(sub,gemm);
596 # CHECK: Free(relu);
597 # CHECK: Free(gemm))IR");
598 
599 #ifdef TORCH_ENABLE_LLVM
600   LoopNest loop(Stmt::clone(stmt), {FT.buf()});
601   loop.prepareForCodegen();
602   LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT});
603 
604   checkIR(cg_llvm.stmt(), R"IR(
605 # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
606 # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
607 # CHECK: Alias(add,gemm);
608 # CHECK: Alias(mul,relu);
609 # CHECK: Alias(sub,gemm);
610 # CHECK: Free(relu);
611 # CHECK: Free(gemm))IR");
612 #endif
613 }
614 
TEST(MemPlanning,SmallerBufSizeNonMemReuse)615 TEST(MemPlanning, SmallerBufSizeNonMemReuse) {
616   int M = 1024;
617   int N = 1024;
618   int K = 2048;
619 
620   BufHandle AP("A", {M, K}, kFloat);
621   BufHandle BP("B", {K, N}, kFloat);
622 
623   Tensor CT = Reduce(
624       "gemm",
625       {M, N},
626       Sum(),
627       [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
628         return AP.load(m, k) * BP.load(k, n);
629       },
630       {K});
631   Tensor DT =
632       Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
633         auto zero = Cast::make(CT.buf()->dtype(), 0);
634         return CompareSelect::make(
635             CT.load(m, n), zero, zero, CT.load(m, n), kLT);
636       });
637   Tensor ET = Compute(
638       "add", {M * 2, N * 2}, [&](const ExprHandle& em, const ExprHandle& en) {
639         return DT.load(em / 2, en / 2) + DT.load(em / 2, en / 2);
640       });
641   Tensor FT = Compute(
642       "mul", {M * 2, N * 2}, [&](const ExprHandle& fm, const ExprHandle& fn) {
643         return ET.load(fm, fn) * ET.load(fm, fn);
644       });
645   auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
646 
647   // Constructed stmt:
648   // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
649   // add [2, 3] We do not reuse buffer 'gemm' for 'add' because the size of
650   // buffer 'gemm' is smaller.
651   //{
652   //  for (int M = 0; M < 1024; M++) {
653   //    for (int N = 0; N < 1024; N++) {
654   //      gemm[M, N] = float(0);
655   //      for (int K = 0; K < 2048; K++) {
656   //        gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]),
657   //        reduce_args={K});
658   //      }
659   //    }
660   //  }
661   //  for (int M_1 = 0; M_1 < 1024; M_1++) {
662   //    for (int N_1 = 0; N_1 < 1024; N_1++) {
663   //      relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1,
664   //      N_1]);
665   //    }
666   //  }
667   //  for (int EM = 0; EM < 2048; EM++) {
668   //    for (int EN = 0; EN < 2048; EN++) {
669   //      add[EM, EN] = (relu[EM / 2, EN / 2]) + (relu[EM / 2, EN / 2]);
670   //    }
671   //  }
672   //  for (int FM = 0; FM < 2048; FM++) {
673   //    for (int FN = 0; FN < 2048; FN++) {
674   //      mul[FM, FN] = (add[FM, FN]) * (add[FM, FN]);
675   //    }
676   //  }
677   //}
678   //
679 
680   SimpleIREvaluator cg(stmt, {AP, BP, FT});
681 
682   checkIR(cg.stmt(), R"IR(
683 # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
684 # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
685 # CHECK-NOT: Alias(add,gemm);
686 # CHECK: Allocate(add); // dtype=float, dims=[2048, 2048]
687 # CHECK: Free(add);
688 # CHECK: Free(relu);
689 # CHECK: Free(gemm))IR");
690 
691 #ifdef TORCH_ENABLE_LLVM
692   LoopNest loop(Stmt::clone(stmt), {FT.buf()});
693   loop.prepareForCodegen();
694   LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT});
695 
696   checkIR(cg_llvm.stmt(), R"IR(
697 # CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
698 # CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
699 # CHECK-NOT: Alias(add,gemm);
700 # CHECK: Allocate(add); // dtype=float, dims=[2048, 2048]
701 # CHECK: Free(add);
702 # CHECK: Free(relu);
703 # CHECK: Free(gemm))IR");
704 #endif
705 }
706 
707 } // namespace jit
708 } // namespace torch
709