xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_boundsinference.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <memory>
2 #include <sstream>
3 #include <stdexcept>
4 #include <unordered_map>
5 
6 #include <gtest/gtest.h>
7 
8 #include <c10/util/irange.h>
9 #include <test/cpp/tensorexpr/padded_buffer.h>
10 #include <torch/csrc/jit/tensorexpr/analysis.h>
11 #include <torch/csrc/jit/tensorexpr/bounds_inference.h>
12 #include <torch/csrc/jit/tensorexpr/eval.h>
13 #include <torch/csrc/jit/tensorexpr/ir.h>
14 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
15 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
16 #include <torch/csrc/jit/tensorexpr/loopnest.h>
17 #include <torch/csrc/jit/tensorexpr/tensor.h>
18 
19 namespace torch {
20 namespace jit {
21 
22 using namespace torch::jit::tensorexpr;
23 
verifyConstBounds(const TensorAccessBoundsInfo & access_info,const std::vector<std::pair<int,int>> & ref)24 static void verifyConstBounds(
25     const TensorAccessBoundsInfo& access_info,
26     const std::vector<std::pair<int, int>>& ref) {
27   size_t ndim = ref.size();
28   ASSERT_EQ(access_info.start.size(), ndim);
29   ASSERT_EQ(access_info.stop.size(), ndim);
30   for (const auto i : c10::irange(ndim)) {
31     if (ref[i].first >= 0) { // Negative values are used to skip the check
32       ASSERT_TRUE(access_info.start[i]->isConstant());
33       int start_i = immediateAs<int>(access_info.start[i]);
34       ASSERT_EQ(start_i, ref[i].first);
35     }
36     if (ref[i].second >= 0) {
37       ASSERT_TRUE(access_info.stop[i]->isConstant());
38       int stop_i = immediateAs<int>(access_info.stop[i]);
39       ASSERT_EQ(stop_i, ref[i].second);
40     }
41   }
42 }
43 
TEST(BoundsInference,_1)44 TEST(BoundsInference, _1) {
45   // Verify that bounds inference works for the following example:
46   // for i in 0..100:
47   //   b[i] = a[i]
48   // For this loop bounds inference should yield the following:
49   // {{b, kStore, 0, 99}, {a, kLoad, 0, 99}}
50   ExprHandle n(100);
51   BufHandle a("a", {n}, kFloat);
52   Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
53   LoopNest l({b});
54   auto bounds_info = inferBounds(l.root_stmt());
55 
56   // We should have two entries: one for 'b' and one for 'a'.
57   ASSERT_EQ(bounds_info.size(), 2);
58   ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
59   ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
60   verifyConstBounds(bounds_info.at(a.node())[0], {{0, 99}});
61 
62   ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
63   ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
64   verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}});
65 }
66 
TEST(BoundsInference,_2)67 TEST(BoundsInference, _2) {
68   // Verify that bounds inference works for the following example:
69   // for i in 0..n:
70   //   b[i] = a[i]
71   // For this loop bounds inference should yield the following:
72   // {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}}
73   VarHandle n("n", kInt);
74   BufHandle a("a", {n}, kFloat);
75   Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
76   LoopNest l({b});
77   auto bounds_info = inferBounds(l.root_stmt());
78 
79   // We should have two entries: one for 'b' and one for 'a'.
80   ASSERT_EQ(bounds_info.size(), 2);
81   ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
82   ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
83   verifyConstBounds(bounds_info.at(a.node())[0], {{0, -1}});
84 
85   ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
86   ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
87   verifyConstBounds(bounds_info.at(b.buf())[0], {{0, -1}});
88 }
89 
TEST(BoundsInference,_3)90 TEST(BoundsInference, _3) {
91   // Verify that bounds inference works for the following example:
92   // for i in 0..100:
93   //   b[i] = a[i] * a[i+10]
94   // For this loop bounds inference should yield the following:
95   // {{b, kStore, 0, 99}, {a, kLoad, 0, 109}}
96   ExprHandle n(100);
97   BufHandle a("a", {n + 10}, kFloat);
98   Tensor b = Compute(
99       "b", {n}, [&](const VarHandle& i) { return a.load(i) * a.load(i + 10); });
100   LoopNest l({b});
101   auto bounds_info = inferBounds(l.root_stmt());
102 
103   // We should have two entries: one for 'b' and one for 'a'.
104   ASSERT_EQ(bounds_info.size(), 2);
105   ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
106   ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
107   verifyConstBounds(bounds_info.at(a.node())[0], {{0, 109}});
108 
109   ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
110   ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
111   verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}});
112 }
113 
TEST(BoundsInference,_4)114 TEST(BoundsInference, _4) {
115   // Verify that bounds inference works for the following example:
116   //
117   // for y in 0..200:
118   //   for x in 0..320:
119   //     b[y,x] = x*y
120   // for y in 0..200:
121   //   for x in 0..320:
122   //     c[y,x] = a[y,x] * b[y,x]
123   ExprHandle W(320);
124   ExprHandle H(200);
125   BufHandle a("a", {H, W}, kFloat);
126   Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
127     return x * y;
128   });
129   Tensor c = Compute("c", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
130     return a.load(y, x) * b.load(y, x);
131   });
132   LoopNest l({c});
133   std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
134   StmtPtr body = l.getLoopBodyFor(c);
135   {
136     // Infer bounds on the top-level loop scope
137     auto bounds_info = inferBounds(loops[0]);
138     ASSERT_EQ(bounds_info.size(), 3);
139 
140     ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
141     ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
142     verifyConstBounds(bounds_info.at(a.node())[0], {{0, 199}, {0, 319}});
143 
144     ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
145     ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad);
146     verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 199}, {0, 319}});
147 
148     ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
149     ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
150     verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 199}, {0, 319}});
151   }
152   {
153     // Infer bounds on the inner loop scope
154     auto bounds_info = inferBounds(loops[1]);
155     ASSERT_EQ(bounds_info.size(), 3);
156 
157     ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
158     ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
159     verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {0, 319}});
160 
161     ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
162     ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad);
163     verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 319}});
164 
165     ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
166     ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
167     verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 319}});
168   }
169   {
170     // Infer bounds on the inner loop body's scope
171     auto bounds_info = inferBounds(body);
172     ASSERT_EQ(bounds_info.size(), 3);
173 
174     ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
175     ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
176     verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}});
177 
178     ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
179     ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad);
180     verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}});
181 
182     ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
183     ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
184     verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}});
185   }
186 }
187 
TEST(BoundsInference,_5)188 TEST(BoundsInference, _5) {
189   // Verify that bounds inference works for the following example:
190   // for i in 0..100:
191   //   b[i] = a[i]
192   //
193   // ==> split ==>
194   //
195   // for i_outer in 0..100/16:
196   //   for i_inner in 0..16:
197   //     b[i_outer * 16 + i_inner] = a[i_outer * 16 + i_inner]
198   // for i_tail in 0..100%16:
199   //   b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16];
200   ExprHandle n(100);
201   BufHandle a("a", {n}, kFloat);
202   Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); });
203   LoopNest l({b});
204 
205   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
206   ForPtr inner;
207   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
208   ForPtr tail;
209   std::vector<ForPtr> loops = l.getLoopStmtsFor(b);
210   LoopNest::splitWithTail(loops[0], 16, &inner, &tail);
211   ForPtr outer = loops[0];
212 
213   {
214     // Verify inferred bounds for the outer loop
215     auto bounds_info = inferBounds(outer);
216     ASSERT_EQ(bounds_info.size(), 2);
217 
218     ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
219     ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
220     verifyConstBounds(bounds_info.at(a.node())[0], {{0, 95}});
221 
222     ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
223     ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
224     verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 95}});
225   }
226   {
227     // Verify inferred bounds for the tail loop
228     auto bounds_info = inferBounds(tail);
229     ASSERT_EQ(bounds_info.size(), 2);
230 
231     ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
232     ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
233     verifyConstBounds(bounds_info.at(a.node())[0], {{96, 99}});
234 
235     ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
236     ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
237     verifyConstBounds(bounds_info.at(b.buf())[0], {{96, 99}});
238   }
239 }
240 
TEST(BoundsInference,_6)241 TEST(BoundsInference, _6) {
242   // Verify that bounds inference works for the following example:
243   //
244   // for y in 0..200:
245   //   for x in 0..320:
246   //     b[y,x] = x*y
247   // for y in 0..20:
248   //   for x in 0..32:
249   //     c[y,x] = a[y+100,x+100] * b[y*2,x*5]
250   ExprHandle W(320);
251   ExprHandle H(200);
252   ExprHandle CW(32);
253   ExprHandle CH(20);
254   BufHandle a("a", {H, W}, kFloat);
255   Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) {
256     return x * y;
257   });
258   Tensor c =
259       Compute("c", {CH, CW}, [&](const VarHandle& y, const VarHandle& x) {
260         return a.load(y + 100, x + 100) * b.load(y * 2, x * 5);
261       });
262   LoopNest l({c});
263   std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
264   StmtPtr body = l.getLoopBodyFor(c);
265   {
266     // Infer bounds on the top-level loop scope
267     auto bounds_info = inferBounds(loops[0]);
268     ASSERT_EQ(bounds_info.size(), 3);
269 
270     ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
271     ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
272     verifyConstBounds(bounds_info.at(a.node())[0], {{100, 119}, {100, 131}});
273 
274     ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
275     ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad);
276     verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 38}, {0, 155}});
277 
278     ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
279     ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
280     verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 19}, {0, 31}});
281   }
282   {
283     // Infer bounds on the inner loop scope
284     auto bounds_info = inferBounds(loops[1]);
285     ASSERT_EQ(bounds_info.size(), 3);
286 
287     ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
288     ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
289     verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {100, 131}});
290 
291     ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
292     ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad);
293     verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 155}});
294 
295     ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
296     ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
297     verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 31}});
298   }
299   {
300     // Infer bounds on the inner loop body's scope
301     auto bounds_info = inferBounds(body);
302     ASSERT_EQ(bounds_info.size(), 3);
303 
304     ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
305     ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
306     verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}});
307 
308     ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
309     ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad);
310     verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}});
311 
312     ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
313     ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
314     verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}});
315   }
316 }
317 
TEST(BoundsInference,Adjacent)318 TEST(BoundsInference, Adjacent) {
319   ExprHandle H(6);
320   BufHandle a("a", {20}, kFloat);
321   Tensor b = Compute("b", {H}, [&](const VarHandle& x) { return a.load(x); });
322   Tensor c =
323       Compute("c", {H}, [&](const VarHandle& x) { return a.load(x + H); });
324   LoopNest l({b, c});
325   std::vector<ForPtr> loops = NodeFinder<For>::find(l.root_stmt());
326 
327   {
328     // Infer bounds on the top-level loop scope
329     auto bounds_info = inferBounds(loops[0]);
330     ASSERT_EQ(bounds_info.size(), 2);
331 
332     // reads from a[0:5], writes to b[0:5]
333     ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
334     ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
335     verifyConstBounds(bounds_info.at(a.node())[0], {{0, 5}});
336 
337     ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
338     ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
339     verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}});
340   }
341   {
342     // Infer bounds on the inner loop scope
343     auto bounds_info = inferBounds(loops[1]);
344     ASSERT_EQ(bounds_info.size(), 2);
345 
346     // reads from a[0+6:5+6], writes to c[0:5]
347     ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
348     ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
349     verifyConstBounds(bounds_info.at(a.node())[0], {{6, 11}});
350 
351     ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
352     ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
353     verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}});
354   }
355   {
356     // Infer bounds on the high level program.
357     auto bounds_info = inferBounds(l.root_stmt());
358     ASSERT_EQ(bounds_info.size(), 3);
359 
360     // Should be union of above 2 bounds, but this time the bounds of A can be
361     // merged.
362     ASSERT_EQ(bounds_info.at(a.node()).size(), 1);
363     ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad);
364     verifyConstBounds(bounds_info.at(a.node())[0], {{0, 11}});
365 
366     ASSERT_EQ(bounds_info.at(b.buf()).size(), 1);
367     ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore);
368     verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}});
369 
370     ASSERT_EQ(bounds_info.at(c.buf()).size(), 1);
371     ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore);
372     verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}});
373   }
374 }
375 
TEST(BoundsInference,MultipleTopLoopLoad)376 TEST(BoundsInference, MultipleTopLoopLoad) {
377   BufHandle a("a", {100}, kFloat);
378   Tensor b = Compute("b", {64}, [&](const VarHandle& x) { return a.load(x); });
379   Tensor c =
380       Compute("c", {32}, [&](const VarHandle& x) { return a.load(x + 10); });
381   Tensor d =
382       Compute("d", {96}, [&](const VarHandle& x) { return a.load(x + 2); });
383   LoopNest l({b, c, d});
384 
385   auto bounds_info = inferBounds(l.root_stmt());
386 
387   ASSERT_EQ(bounds_info.size(), 4);
388 
389   // a only read.
390   {
391     auto bounds = bounds_info[a.node()];
392     ASSERT_EQ(bounds.size(), 1);
393     // One dimension.
394     auto bound = bounds[0];
395     ASSERT_EQ(bound.kind, TensorAccessKind::kLoad);
396     // Bounds:
397     // start: Min of the 3 load bounds = Min of loop starts + offset = 0+0 (b).
398     // stop: Max of the 3 load bounds = Max of loop stops + offset - 1 =
399     //       96 + 2 - 1 (d).
400     verifyConstBounds(bound, {{0, 97}});
401   }
402 
403   // b, c, d only written.
404   {
405     auto bounds = bounds_info[b.buf()];
406     ASSERT_EQ(bounds.size(), 1);
407     auto bound = bounds[0];
408     ASSERT_EQ(bound.kind, TensorAccessKind::kStore);
409     // Just the loop extents for b.
410     verifyConstBounds(bound, {{0, 63}});
411   }
412   {
413     auto bounds = bounds_info[c.buf()];
414     ASSERT_EQ(bounds.size(), 1);
415     auto bound = bounds[0];
416     ASSERT_EQ(bound.kind, TensorAccessKind::kStore);
417     // Just the loop extents for c.
418     verifyConstBounds(bound, {{0, 31}});
419   }
420   {
421     auto bounds = bounds_info[d.buf()];
422     ASSERT_EQ(bounds.size(), 1);
423     auto bound = bounds[0];
424     ASSERT_EQ(bound.kind, TensorAccessKind::kStore);
425     // Just the loop extents for d.
426     verifyConstBounds(bound, {{0, 95}});
427   }
428 }
429 
TEST(BoundsInference,MultipleTopLoopStore)430 TEST(BoundsInference, MultipleTopLoopStore) {
431   BufHandle a("a", {100}, kFloat);
432   BufHandle b("b", {100}, kFloat);
433   BufHandle c("c", {100}, kFloat);
434   BufHandle d("d", {100}, kFloat);
435   VarHandle x("x", kInt);
436 
437   // Same as above but the offsets are on the Store now.
438   // Can't do this through ComputeAPI without transforms we don't have yet.
439   StmtPtr stmt = Block::make(
440       {For::make(x, 0, 64, Store::make(b, {x}, Load::make(a, {x}))),
441        For::make(x, 0, 32, Store::make(c, {x + 10}, Load::make(a, {x}))),
442        For::make(x, 0, 96, Store::make(d, {x + 2}, Load::make(a, {x})))});
443 
444   auto bounds_info = inferBounds(stmt);
445 
446   ASSERT_EQ(bounds_info.size(), 4);
447 
448   // a only read.
449   {
450     auto bounds = bounds_info[a.node()];
451     ASSERT_EQ(bounds.size(), 1);
452     // One dimension.
453     auto bound = bounds[0];
454     ASSERT_EQ(bound.kind, TensorAccessKind::kLoad);
455     // Bounds: there are no offsets, so this is just the max loop bounds.
456     verifyConstBounds(bound, {{0, 95}});
457   }
458 
459   // b, c, d only written.
460   {
461     auto bounds = bounds_info[b.node()];
462     ASSERT_EQ(bounds.size(), 1);
463     auto bound = bounds[0];
464     ASSERT_EQ(bound.kind, TensorAccessKind::kStore);
465     // This should be equivalent to {offset, extent + offset} for the b loop.
466     // b loop has no offset, so just the loop extents.
467     verifyConstBounds(bound, {{0, 63}});
468   }
469   {
470     auto bounds = bounds_info[c.node()];
471     ASSERT_EQ(bounds.size(), 1);
472     auto bound = bounds[0];
473     ASSERT_EQ(bound.kind, TensorAccessKind::kStore);
474     // This should be equivalent to {offset, extent + offset} for the c loop.
475     // Offset is 10, extent is 32-1.
476     verifyConstBounds(bound, {{10, 41}});
477   }
478   {
479     auto bounds = bounds_info[d.node()];
480     ASSERT_EQ(bounds.size(), 1);
481     auto bound = bounds[0];
482     ASSERT_EQ(bound.kind, TensorAccessKind::kStore);
483     // This should be equivalent to {offset, extent + offset} for the d loop.
484     // Offset is 2, extent is 96-1.
485     verifyConstBounds(bound, {{2, 97}});
486   }
487 }
488 
TEST(BoundsInference,CacheReads)489 TEST(BoundsInference, CacheReads) {
490   Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
491     return i * j;
492   });
493   Tensor B =
494       Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
495         return A.load(i + 30, j + 3);
496       });
497   Tensor C =
498       Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) {
499         return A.load(i + 10, j + 20) + A.load(i + 30, j + 40);
500       });
501 
502   LoopNest l({B, C});
503   auto bounds_info_before = inferBounds(l.root_stmt());
504 
505   StmtPtr j_loop = l.getLoopStmtsFor(B)[1];
506   LoopNest::cacheAccesses(A.buf(), "A_local", j_loop);
507 
508   auto bounds_info_after = inferBounds(l.root_stmt());
509 
510   // CacheAccesses should not change existing bounds, but add a new one for the
511   // cache.
512   for (auto& pair : bounds_info_after) {
513     auto beforeIt = bounds_info_before.find(pair.first);
514     if (beforeIt != bounds_info_before.end()) {
515       // Same number of TensorAccessBoundInfos.
516       ASSERT_EQ(pair.second.size(), beforeIt->second.size());
517 
518       for (const auto i : c10::irange(pair.second.size())) {
519         TensorAccessBoundsInfo& after = pair.second[i];
520         TensorAccessBoundsInfo& before = beforeIt->second[i];
521         // Same number of dimensions.
522         ASSERT_EQ(before.start.size(), after.start.size());
523 
524         // Bounds are equal.
525         for (const auto j : c10::irange(before.start.size())) {
526           ASSERT_TRUE(exprEquals(before.start[j], after.start[j]));
527           ASSERT_TRUE(exprEquals(before.stop[j], after.stop[j]));
528         }
529       }
530     } else {
531       // This should be the cache.
532       ASSERT_EQ(pair.first->name_hint(), "A_local");
533       // Should have both a load and a store.
534       ASSERT_EQ(pair.second.size(), 2);
535       TensorAccessBoundsInfo& first = pair.second[0];
536       TensorAccessBoundsInfo& second = pair.second[1];
537 
538       ASSERT_NE(first.kind, second.kind);
539       // 2 dimensions.
540       ASSERT_EQ(first.start.size(), second.start.size());
541       ASSERT_EQ(first.start.size(), 2);
542 
543       // bounds for load and store are equal.
544       for (const auto j : c10::irange(first.start.size())) {
545         ASSERT_TRUE(exprEquals(first.start[j], second.start[j]));
546         ASSERT_TRUE(exprEquals(first.stop[j], second.stop[j]));
547       }
548     }
549   }
550 }
551 
TEST(BoundsInference,Flattened)552 TEST(BoundsInference, Flattened) {
553   Tensor b = Compute(
554       "b",
555       {3, 4, 5},
556       [&](const VarHandle& z, const VarHandle& y, const VarHandle& x) {
557         return x * y + z;
558       });
559 
560   LoopNest l({b});
561   // Flatten indices.
562   l.prepareForCodegen();
563   auto bounds_info = inferBounds(l.root_stmt());
564 
565   // There's only one buffer.
566   ASSERT_EQ(bounds_info.size(), 1);
567   auto& TABI = bounds_info[b.buf()][0];
568   ASSERT_EQ(TABI.kind, TensorAccessKind::kStore);
569   // Flattened bounds should have a single dimension.
570   ASSERT_EQ(TABI.start.size(), 1);
571   ASSERT_EQ(TABI.stop.size(), 1);
572 
573   // Bounds should be 0 -> (3*4*5)-1
574   ASSERT_TRUE(exprEquals(TABI.start[0], alloc<IntImm>(0)));
575   ASSERT_TRUE(exprEquals(TABI.stop[0], alloc<IntImm>(3 * 4 * 5 - 1)));
576 }
577 
TEST(BoundsInference,GetPotentialHazards)578 TEST(BoundsInference, GetPotentialHazards) {
579   BufHandle a("A", {5}, kInt);
580   BufHandle b("B", {5}, kInt);
581   BufHandle c("C", {5}, kInt);
582   VarHandle x("x", kInt);
583   VarHandle y("y", kInt);
584 
585   using namespace analysis;
586 
587   {
588     /*
589      * A[0] = B[0];
590      * B[0] = 3;      WAR on B
591      * A[0] = B[0];   WAW on A, RAW on B
592      * C[0] = 5;
593      */
594 
595     StorePtr store1 = Store::make(a, {0}, Load::make(b, {0}));
596     StorePtr store2 = Store::make(b, {0}, 3);
597     StorePtr store3 = Store::make(a, {0}, Load::make(b, {0}));
598     StorePtr store4 = Store::make(c, {0}, 5);
599     StmtPtr stmt = Block::make({store1, store2, store3, store4});
600 
601     MemDependencyChecker analyzer;
602     stmt->accept(&analyzer);
603 
604     ASSERT_EQ(
605         HazardKind::WriteAfterRead,
606         getPotentialHazards(analyzer, store1, store2));
607 
608     ASSERT_EQ(
609         HazardKind::ReadAfterWrite,
610         getPotentialHazards(analyzer, store2, store3));
611 
612     ASSERT_EQ(
613         HazardKind::WriteAfterWrite,
614         getPotentialHazards(analyzer, store1, store3));
615 
616     // Fourth store has no dependencies
617     ASSERT_EQ(
618         HazardKind::NoDependency,
619         getPotentialHazards(analyzer, store1, store4));
620     ASSERT_EQ(
621         HazardKind::NoDependency,
622         getPotentialHazards(analyzer, store2, store4));
623     ASSERT_EQ(
624         HazardKind::NoDependency,
625         getPotentialHazards(analyzer, store3, store4));
626   }
627 }
628 
TEST(BoundsInference,GetPotentialHazardsLoopNoHazard)629 TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) {
630   Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
631     return i * j;
632   });
633   Tensor B = Compute("B", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
634     return (i + 1) * (j + 1);
635   });
636 
637   LoopNest l({A, B});
638 
639   using namespace analysis;
640 
641   MemDependencyChecker analyzer;
642   l.root_stmt()->accept(&analyzer);
643 
644   ForPtr loopRootA = l.getLoopStmtsFor(A)[0];
645   ForPtr loopRootB = l.getLoopStmtsFor(B)[0];
646 
647   // No dependencies between loops.
648   ASSERT_EQ(
649       HazardKind::NoDependency,
650       getPotentialHazards(analyzer, loopRootA, loopRootB));
651 }
652 
TEST(BoundsInference,GetPotentialHazardsLoopCall)653 TEST(BoundsInference, GetPotentialHazardsLoopCall) {
654   Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
655     return i * j;
656   });
657   Tensor B =
658       Compute("B", {64, 64}, [&](const VarHandle& i, const VarHandle& j) {
659         return A.load(i, j) + 5;
660       });
661 
662   LoopNest l({A, B});
663 
664   using namespace analysis;
665 
666   MemDependencyChecker analyzer;
667   l.root_stmt()->accept(&analyzer);
668 
669   ForPtr loopRootA = l.getLoopStmtsFor(A)[0];
670   ForPtr loopRootB = l.getLoopStmtsFor(B)[0];
671 
672   ASSERT_EQ(
673       HazardKind::ReadAfterWrite,
674       getPotentialHazards(analyzer, loopRootA, loopRootB));
675 }
676 
TEST(BoundsInference,GetPotentialHazardsLoopSplit)677 TEST(BoundsInference, GetPotentialHazardsLoopSplit) {
678   Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) {
679     return i * j;
680   });
681 
682   LoopNest l({A});
683   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
684   ForPtr inner, tail;
685 
686   // Splitting with tail by something offset creates a tail which also writes to
687   // A.
688   ForPtr outer = l.getLoopStmtsFor(A)[0];
689   // `outer` loop get transformed to the outer loop after splitting.
690   LoopNest::splitWithTail(outer, 5, &inner, &tail);
691 
692   using namespace analysis;
693 
694   MemDependencyChecker analyzer;
695   l.root_stmt()->accept(&analyzer);
696 
697   ASSERT_EQ(
698       HazardKind::WriteAfterWrite, getPotentialHazards(analyzer, outer, tail));
699 }
700 
TEST(BoundsInference,HasConflictingOverlapSameBufferWithPartialOverlap)701 TEST(BoundsInference, HasConflictingOverlapSameBufferWithPartialOverlap) {
702   // Input IR:
703   //   for (const auto j : c10::irange(10, 100)) {
704   //     A[j] = 10 * j;
705   //   }
706   //   for (const auto k : c10::irange(10, 100)) {
707   //     A[k-1] = 20 * k;
708   //   }
709   BufHandle a_buf("A", {200}, kInt);
710   VarHandle j("j", kInt);
711   VarHandle k("k", kInt);
712   auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
713   auto forK =
714       For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k)));
715   auto par = Block::make({forJ, forK});
716 
717   tensorexpr::analysis::MemDependencyChecker analyzer;
718   par->accept(&analyzer);
719   ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK));
720   ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ));
721 }
722 
TEST(BoundsInference,HasConflictingOverlapSameBufferWithFullOverlap)723 TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlap) {
724   // Input IR:
725   //   for (const auto j : c10::irange(10, 100)) {
726   //     A[j] = 10 * j;
727   //   }
728   //   for (const auto k : c10::irange(10, 100)) {
729   //     A[k] = 20 * k;
730   //   }
731   BufHandle a_buf("A", {200}, kInt);
732   VarHandle j("j", kInt);
733   VarHandle k("k", kInt);
734   auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
735   auto forK = For::make(k, 10, 100, Store::make(a_buf, {k}, Mul::make(20, k)));
736   auto par = Block::make({forJ, forK});
737 
738   tensorexpr::analysis::MemDependencyChecker analyzer;
739   par->accept(&analyzer);
740   ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK));
741   ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ));
742 }
743 
TEST(BoundsInference,HasConflictingOverlapSameBufferWithFullOverlapRAW)744 TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlapRAW) {
745   // Input IR:
746   //   for (const auto j : c10::irange(10, 100)) {
747   //     A[j] = 10 * j;
748   //   }
749   //   for (const auto k : c10::irange(10, 100)) {
750   //     B[k] = A[k];
751   //   }
752   BufHandle a_buf("A", {200}, kInt);
753   BufHandle b_buf("B", {200}, kInt);
754   VarHandle j("j", kInt);
755   VarHandle k("k", kInt);
756   auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
757   auto forK =
758       For::make(k, 10, 100, Store::make(b_buf, {k}, Load::make(a_buf, {k})));
759   auto par = Block::make({forJ, forK});
760 
761   tensorexpr::analysis::MemDependencyChecker analyzer;
762   par->accept(&analyzer);
763   ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK));
764   ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ));
765 }
766 
TEST(BoundsInference,HasConflictingOverlapSameBufferNotOverlapping)767 TEST(BoundsInference, HasConflictingOverlapSameBufferNotOverlapping) {
768   // Input IR:
769   //   for (const auto j : c10::irange(10, 100)) {
770   //     A[j] = 10 * j;
771   //   }
772   //   for (const auto k : c10::irange(10, 100)) {
773   //     A[k+100] = 20 * k;
774   //   }
775   BufHandle a_buf("A", {200}, kInt);
776   VarHandle j("j", kInt);
777   VarHandle k("k", kInt);
778   auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
779   auto forK =
780       For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(20, k)));
781   auto par = Block::make({forJ, forK});
782 
783   tensorexpr::analysis::MemDependencyChecker analyzer;
784   par->accept(&analyzer);
785   ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK));
786   ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ));
787 }
788 
TEST(BoundsInference,HasConflictingOverlap2DBufferWithOverlap)789 TEST(BoundsInference, HasConflictingOverlap2DBufferWithOverlap) {
790   // Input IR:
791   //   for (const auto i : c10::irange(20)) {
792   //     for (const auto j : c10::irange(100)) {
793   //       A[i,j] = i * j * 500;
794   //     }
795   //   }
796   //   for (const auto m : c10::irange(20)) {
797   //     for (const auto n : c10::irange(50)) {
798   //       A[m+1,n] = m + n * 100;
799   //     }
800   //   }
801   BufHandle a_buf("A", {20, 100}, kInt);
802   BufHandle b_buf("B", {20, 50}, kInt);
803   VarHandle i("i", kInt);
804   VarHandle j("j", kInt);
805   VarHandle m("m", kInt);
806   VarHandle n("n", kInt);
807   auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500));
808   auto forJ = For::make(j, 0, 100, storeA1);
809   auto forI = For::make(i, 0, 20, forJ);
810   auto storeA2 =
811       Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100)));
812   auto forN = For::make(n, 0, 50, storeA2);
813   auto forM = For::make(m, 0, 20, forN);
814   auto par = Block::make({forI, forM});
815 
816   tensorexpr::analysis::MemDependencyChecker analyzer;
817   par->accept(&analyzer);
818   ASSERT_TRUE(hasConflictingOverlap(analyzer, forI, forM));
819   ASSERT_TRUE(hasConflictingOverlap(analyzer, forM, forI));
820   ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forN));
821   ASSERT_TRUE(hasConflictingOverlap(analyzer, forN, forJ));
822   ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, storeA2));
823   ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA2, storeA1));
824   ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, storeA2));
825   ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, forM));
826 }
827 
TEST(BoundsInference,HasConflictingOverlap2DBufferWithNoOverlap)828 TEST(BoundsInference, HasConflictingOverlap2DBufferWithNoOverlap) {
829   // Input IR:
830   //   for (const auto i : c10::irange(20)) {
831   //     for (const auto j : c10::irange(100)) {
832   //       A[i,j] = i * j * 500;
833   //     }
834   //   }
835   //   for (const auto m : c10::irange(20)) {
836   //     for (const auto n : c10::irange(50)) {
837   //       A[m+20,n+100] = m + n * 100;
838   //     }
839   //   }
840   BufHandle a_buf("A", {20, 100}, kInt);
841   BufHandle b_buf("B", {20, 50}, kInt);
842   VarHandle i("i", kInt);
843   VarHandle j("j", kInt);
844   VarHandle m("m", kInt);
845   VarHandle n("n", kInt);
846   auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500));
847   auto forJ = For::make(j, 0, 100, storeA1);
848   auto forI = For::make(i, 0, 20, forJ);
849   auto storeA2 =
850       Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100)));
851   auto forN = For::make(n, 0, 50, storeA2);
852   auto forM = For::make(m, 0, 20, forN);
853   auto par = Block::make({forI, forM});
854 
855   tensorexpr::analysis::MemDependencyChecker analyzer;
856   par->accept(&analyzer);
857   ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM));
858   ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI));
859   ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN));
860   ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ));
861   ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2));
862   ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1));
863   ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2));
864   ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM));
865 }
866 
TEST(BoundsInference,HasConflictingOverlapDifferentBuffers)867 TEST(BoundsInference, HasConflictingOverlapDifferentBuffers) {
868   // Input IR:
869   //   for (const auto i : c10::irange(20)) {
870   //     for (const auto j : c10::irange(100)) {
871   //       A[i,j] = i * j * 500;
872   //     }
873   //   }
874   //   for (const auto m : c10::irange(20)) {
875   //     for (const auto n : c10::irange(50)) {
876   //       B[m,n] = m + n * 100;
877   //     }
878   //   }
879   BufHandle a_buf("A", {20, 100}, kInt);
880   BufHandle b_buf("B", {20, 50}, kInt);
881   VarHandle i("i", kInt);
882   VarHandle j("j", kInt);
883   VarHandle m("m", kInt);
884   VarHandle n("n", kInt);
885   auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500));
886   auto forJ = For::make(j, 0, 100, storeA1);
887   auto forI = For::make(i, 0, 20, forJ);
888   auto storeA2 = Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100)));
889   auto forN = For::make(n, 0, 50, storeA2);
890   auto forM = For::make(m, 0, 20, forN);
891   auto par = Block::make({forI, forM});
892 
893   tensorexpr::analysis::MemDependencyChecker analyzer;
894   par->accept(&analyzer);
895   ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM));
896   ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI));
897   ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN));
898   ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ));
899   ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2));
900   ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1));
901   ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2));
902   ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM));
903 }
904 
TEST(BoundsInference,HasConflictingOverlapDueToRAWDependence)905 TEST(BoundsInference, HasConflictingOverlapDueToRAWDependence) {
906   // Input IR:
907   //   for (const auto j : c10::irange(100)) {
908   //     A[j] = 10 * j;
909   //   }
910   //   for (const auto k : c10::irange(100)) {
911   //     B[k] = 20 * A[99-k];
912   //   }
913   BufHandle a_buf("A", {100}, kInt);
914   BufHandle b_buf("B", {100}, kInt);
915   VarHandle j("j", kInt);
916   VarHandle k("k", kInt);
917   auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
918   auto forK = For::make(
919       k,
920       0,
921       100,
922       Store::make(
923           b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k}))));
924   auto par = Block::make({forJ, forK});
925 
926   tensorexpr::analysis::MemDependencyChecker analyzer;
927   par->accept(&analyzer);
928   ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK));
929   ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ));
930 }
931 
TEST(BoundsInference,HasConflictingOverlapDueToWARDependence)932 TEST(BoundsInference, HasConflictingOverlapDueToWARDependence) {
933   // Input IR:
934   //   for (const auto k : c10::irange(100)) {
935   //     B[k] = 20 * A[99-k];
936   //   }
937   //   for (const auto j : c10::irange(100)) {
938   //     A[j] = 10 * j;
939   //   }
940   BufHandle a_buf("A", {100}, kInt);
941   BufHandle b_buf("B", {100}, kInt);
942   VarHandle j("j", kInt);
943   VarHandle k("k", kInt);
944   auto forK = For::make(
945       k,
946       0,
947       100,
948       Store::make(
949           b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k}))));
950   auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j)));
951   auto par = Block::make({forK, forJ});
952 
953   tensorexpr::analysis::MemDependencyChecker analyzer;
954   par->accept(&analyzer);
955   ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK));
956   ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ));
957 }
958 
TEST(BoundsInference,HasConflictingOverlapWithLoads)959 TEST(BoundsInference, HasConflictingOverlapWithLoads) {
960   // Input IR:
961   //   for (const auto k : c10::irange(10, 100)) {
962   //     B[k] = 20 * A[99-k];
963   //   }
964   //   for (const auto j : c10::irange(10, 100)) {
965   //     C[j] = 10 * A[j];
966   //   }
967   BufHandle a_buf("A", {100}, kInt);
968   BufHandle b_buf("B", {100}, kInt);
969   BufHandle c_buf("C", {100}, kInt);
970   VarHandle j("j", kInt);
971   VarHandle k("k", kInt);
972   auto forK = For::make(
973       k,
974       10,
975       100,
976       Store::make(
977           b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k}))));
978   auto forJ = For::make(
979       j,
980       10,
981       100,
982       Store::make(c_buf, {j}, Mul::make(10, Load::make(a_buf, {j}))));
983   auto par = Block::make({forK, forJ});
984 
985   tensorexpr::analysis::MemDependencyChecker analyzer;
986   par->accept(&analyzer);
987   ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK));
988   ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ));
989 }
990 
TEST(BoundsInference,IsOverlapping)991 TEST(BoundsInference, IsOverlapping) {
992   // Input IR:
993   //   for (const auto i : c10::irange(100)) {
994   //     A[i] = i * 10;               // storeA1
995   //     B[i] = A[99-i] * 20;         // loadA1
996   //     C[i] = A[i + 100] * 10;      // loadA2
997   //     A[i + 50] = i * 50;          // storeA2
998   //     A[i + 150] = i * 150;        // storeA3
999   //   }
1000   BufHandle a_buf("A", {300}, kInt);
1001   BufHandle b_buf("B", {100}, kInt);
1002   BufHandle c_buf("C", {100}, kInt);
1003   VarHandle i("i", kInt);
1004   auto storeA1 = Store::make(a_buf, {i}, i * 10);
1005   auto loadA1 = Load::make(a_buf, {ExprHandle(99) - i});
1006   auto storeB = Store::make(b_buf, {i}, Mul::make(loadA1, 20));
1007   auto loadA2 = Load::make(a_buf, {i + 100});
1008   auto storeC = Store::make(c_buf, {i}, Mul::make(loadA2, 10));
1009   auto storeA2 = Store::make(a_buf, {i + 50}, i * 50);
1010   auto storeA3 = Store::make(a_buf, {i + 150}, i * 150);
1011   auto forI = For::make(
1012       i, 0, 100, Block::make({storeA1, storeB, storeC, storeA2, storeA3}));
1013   tensorexpr::analysis::MemDependencyChecker analyzer;
1014   forI->accept(&analyzer);
1015   ASSERT_TRUE(isOverlapping(analyzer, storeA1, to<Load>(loadA1.node())));
1016   ASSERT_FALSE(isOverlapping(analyzer, storeA1, to<Load>(loadA2.node())));
1017   ASSERT_TRUE(isOverlapping(analyzer, storeA1, storeA2));
1018   ASSERT_FALSE(isOverlapping(analyzer, storeA1, storeA3));
1019 }
1020 
1021 } // namespace jit
1022 } // namespace torch
1023