xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_memdependency.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <test/cpp/tensorexpr/test_base.h>
3 
4 #include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
5 #include <torch/csrc/jit/tensorexpr/ir.h>
6 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
7 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
8 #include <torch/csrc/jit/tensorexpr/loopnest.h>
9 #include <torch/csrc/jit/tensorexpr/mem_dependency_checker.h>
10 #include <torch/csrc/jit/tensorexpr/tensor.h>
11 
12 namespace torch {
13 namespace jit {
14 
15 using namespace torch::jit::tensorexpr;
16 
17 // Test helper function used to determine if two regions of a buffer have an
18 // overlap. No Overlap & partial overlap is obvious. Contains means A is
19 // larger and fully encloses B, while ContainedOrEqual is the reverse. Equal
20 // ranges are ContainedOrEqual.
TEST(MemDependency,BoundOverlap)21 TEST(MemDependency, BoundOverlap) {
22   using namespace analysis;
23 
24   auto CB = [](int s, int e) {
25     return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
26   };
27 
28   // Sanity check 3 overlap cases.
29   ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0)));
30   ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 3), CB(2, 5)));
31   ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 0), CB(1, 1)));
32 
33   // Partial overlap works in either order.
34   ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 10), CB(7, 14)));
35   ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(7, 14), CB(0, 10)));
36 
37   // Total Overlap works when one bound encloses the other, and returns which.
38   ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(7, 9)));
39   ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(0, 16)));
40 
41   // Total overlap works when the bounds are an identical range, returns
42   // ContainedOrEqual.
43   ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(2, 15)));
44 
45   // Total overlap when only one end of the bound matches.
46   ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 10)));
47   ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(3, 15)));
48   ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(0, 10), CB(0, 9)));
49   ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 10), CB(2, 15)));
50   ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(3, 15), CB(2, 15)));
51 
52   // No overlap when a < b.
53   ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 2), CB(5, 10)));
54   ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 2), CB(3, 3)));
55   ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(100, 120), CB(130, 130)));
56 
57   // No overlap when a > b.
58   ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(5, 10), CB(0, 2)));
59   ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(3, 3), CB(2, 2)));
60   ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(130, 130), CB(100, 120)));
61 
62   // No overlap when adjacent.
63   ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 100), CB(101, 120)));
64   ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 3), CB(0, 1)));
65 
66   // Partial overlap when middle bounds match.
67   ASSERT_EQ(
68       OverlapKind::PartialOverlap, boundOverlap(CB(0, 100), CB(100, 120)));
69   ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 2), CB(2, 4)));
70   ASSERT_EQ(
71       OverlapKind::PartialOverlap, boundOverlap(CB(100, 120), CB(0, 100)));
72   ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(2, 3), CB(1, 2)));
73 
74   // Total overlap when one bound is single length over one end of the other.
75   ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(15, 15)));
76   ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 2)));
77   ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 2), CB(2, 15)));
78   ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(15, 15), CB(2, 15)));
79 }
80 
TEST(MemDependency,BoundComparison)81 TEST(MemDependency, BoundComparison) {
82   using namespace analysis;
83 
84   auto CB = [](int s, int e) {
85     return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
86   };
87 
88   ASSERT_EQ(
89       CmpEvalResult::NotDetermined,
90       compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kEQ));
91   ASSERT_EQ(
92       CmpEvalResult::True,
93       compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kEQ));
94   ASSERT_EQ(
95       CmpEvalResult::False,
96       compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kEQ));
97   ASSERT_EQ(
98       CmpEvalResult::False,
99       compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kEQ));
100   ASSERT_EQ(
101       CmpEvalResult::NotDetermined,
102       compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kEQ));
103   ASSERT_EQ(
104       CmpEvalResult::NotDetermined,
105       compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ));
106   ASSERT_EQ(
107       CmpEvalResult::NotDetermined,
108       compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kEQ));
109 
110   ASSERT_EQ(
111       CmpEvalResult::NotDetermined,
112       compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kNE));
113   ASSERT_EQ(
114       CmpEvalResult::False,
115       compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kNE));
116   ASSERT_EQ(
117       CmpEvalResult::True,
118       compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kNE));
119   ASSERT_EQ(
120       CmpEvalResult::True,
121       compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kNE));
122   ASSERT_EQ(
123       CmpEvalResult::NotDetermined,
124       compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kNE));
125   ASSERT_EQ(
126       CmpEvalResult::NotDetermined,
127       compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ));
128   ASSERT_EQ(
129       CmpEvalResult::NotDetermined,
130       compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kNE));
131 
132   ASSERT_EQ(
133       CmpEvalResult::True,
134       compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLT));
135   ASSERT_EQ(
136       CmpEvalResult::False,
137       compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLT));
138   ASSERT_EQ(
139       CmpEvalResult::False,
140       compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLT));
141   ASSERT_EQ(
142       CmpEvalResult::NotDetermined,
143       compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLT));
144   ASSERT_EQ(
145       CmpEvalResult::NotDetermined,
146       compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLT));
147   ASSERT_EQ(
148       CmpEvalResult::NotDetermined,
149       compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLT));
150 
151   ASSERT_EQ(
152       CmpEvalResult::False,
153       compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGE));
154   ASSERT_EQ(
155       CmpEvalResult::True,
156       compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGE));
157   ASSERT_EQ(
158       CmpEvalResult::True,
159       compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGE));
160   ASSERT_EQ(
161       CmpEvalResult::NotDetermined,
162       compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGE));
163   ASSERT_EQ(
164       CmpEvalResult::NotDetermined,
165       compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGE));
166   ASSERT_EQ(
167       CmpEvalResult::NotDetermined,
168       compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGE));
169 
170   ASSERT_EQ(
171       CmpEvalResult::False,
172       compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGT));
173   ASSERT_EQ(
174       CmpEvalResult::False,
175       compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGT));
176   ASSERT_EQ(
177       CmpEvalResult::NotDetermined,
178       compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGT));
179   ASSERT_EQ(
180       CmpEvalResult::True,
181       compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGT));
182   ASSERT_EQ(
183       CmpEvalResult::NotDetermined,
184       compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGT));
185   ASSERT_EQ(
186       CmpEvalResult::NotDetermined,
187       compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGT));
188 
189   ASSERT_EQ(
190       CmpEvalResult::True,
191       compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLE));
192   ASSERT_EQ(
193       CmpEvalResult::True,
194       compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLE));
195   ASSERT_EQ(
196       CmpEvalResult::NotDetermined,
197       compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLE));
198   ASSERT_EQ(
199       CmpEvalResult::False,
200       compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLE));
201   ASSERT_EQ(
202       CmpEvalResult::NotDetermined,
203       compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLE));
204   ASSERT_EQ(
205       CmpEvalResult::NotDetermined,
206       compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLE));
207 }
208 
TEST(MemDependency,BoundOverlapSymbolic)209 TEST(MemDependency, BoundOverlapSymbolic) {
210   VarHandle x("x", kInt);
211   VarHandle y("y", kInt);
212   VarHandle z("z", kInt);
213   VarHandle w("w", kInt);
214 
215   using namespace analysis;
216 
217   auto CB = [](ExprHandle s, ExprHandle e) {
218     return Bound(s.node(), e.node());
219   };
220 
221   // Sanity check cases where the start and end is symbolic but the diff is
222   // constant.
223   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
224   ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(x, x), CB(x, x)));
225   ASSERT_EQ(
226       OverlapKind::PartialOverlap,
227       boundOverlap(CB(x, x + 3), CB(x + 2, x + 5)));
228   ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(x, x), CB(x + 1, x + 1)));
229 
230   // We can't infer the sign of y, so cannot tell whether adding y is larger or
231   // smaller than y/2.
232   ASSERT_EQ(
233       OverlapKind::PartialOverlap,
234       boundOverlap(CB(x, x + y), CB(x, x + y / 2)));
235 
236   // No information about this bound, have to take the most conservative option:
237   // there may be an overlap.
238   ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(x, y), CB(z, w)));
239 
240   // Math on opaque terms works.
241   ASSERT_EQ(
242       OverlapKind::ContainedOrEqual,
243       boundOverlap(CB(x + w, y - z), CB(x + w, y - z)));
244   // Even requiring simplification.
245   ASSERT_EQ(
246       OverlapKind::ContainedOrEqual,
247       boundOverlap(CB(x - w - w, y), CB(x - w * 2, y)));
248 }
249 
250 // Tests the helper function for overlap of multi dimensional indices bounds.
251 // This uses boundOverlap on each dimension and return the "lowest" kind of
252 // overlap.
TEST(MemDependency,BoundOverlapMultiDim)253 TEST(MemDependency, BoundOverlapMultiDim) {
254   using namespace analysis;
255 
256   auto CB = [](int s, int e) {
257     return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
258   };
259 
260   // Sanity check one dimensional cases.
261   ASSERT_EQ(OverlapKind::ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)}));
262   ASSERT_EQ(OverlapKind::NoOverlap, overlaps({CB(0, 2)}, {CB(5, 10)}));
263   ASSERT_EQ(
264       OverlapKind::PartialOverlap, overlaps({CB(0, 100)}, {CB(100, 120)}));
265 
266   // Total overlap in 3 dims.
267   ASSERT_EQ(
268       OverlapKind::ContainedOrEqual,
269       overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 4)}));
270   ASSERT_EQ(
271       OverlapKind::ContainedOrEqual,
272       overlaps(
273           {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 10)}));
274 
275   // Total overlap in 2 dims, no overlap in another.
276   ASSERT_EQ(
277       OverlapKind::NoOverlap,
278       overlaps(
279           {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(5, 10)}));
280 
281   // Total overlap in 2 dims, partial overlap in another.
282   ASSERT_EQ(
283       OverlapKind::PartialOverlap,
284       overlaps(
285           {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(0, 5), CB(5, 10)}));
286   // This case is most important, so verify the overlap in any dim. (dim 2)
287   ASSERT_EQ(
288       OverlapKind::PartialOverlap,
289       overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(2, 6), CB(0, 5)}));
290   // Dim 1.
291   ASSERT_EQ(
292       OverlapKind::PartialOverlap,
293       overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(1, 3), CB(0, 5), CB(0, 5)}));
294   // Total overlap in 1 dim, partial in 2.
295   ASSERT_EQ(
296       OverlapKind::PartialOverlap,
297       overlaps(
298           {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(0, 5), CB(5, 10)}));
299   // Total overlap, partial overlap, no overlap.
300   ASSERT_EQ(
301       OverlapKind::NoOverlap,
302       overlaps(
303           {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(11, 15), CB(0, 5)}));
304 
305   // Total overlap (B) in 2 dims, total overlap (A) in another.
306   ASSERT_EQ(
307       OverlapKind::Contains,
308       overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 4)}));
309 
310   // Total overlap (A) in 2 dims, total overlap (B) in another.
311   ASSERT_EQ(
312       OverlapKind::Contains,
313       overlaps(
314           {CB(0, 12), CB(0, 15), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 14)}));
315 
316   // Total (B), No Overlap, Total (A).
317   ASSERT_EQ(
318       OverlapKind::NoOverlap,
319       overlaps(
320           {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 6), CB(11, 15), CB(1, 2)}));
321 }
322 
323 // Test the helper we use to subtract bounds: returns the regions(s) of A which
324 // remain after removing the region of B.
TEST(MemDependency,BoundSubtract)325 TEST(MemDependency, BoundSubtract) {
326   using namespace analysis;
327 
328   auto CB = [](int s, int e) {
329     return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
330   };
331   auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
332     return indexBoundsEquals(x, y);
333   };
334 
335   // One element subtract.
336   ASSERT_EQ(subtractBound(CB(0, 0), CB(0, 0)).size(), 0);
337   ASSERT_EQ(subtractBound(CB(5, 5), CB(5, 5)).size(), 0);
338 
339   // No Overlap.
340   ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(2, 2)), {CB(5, 5)}));
341   ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(0, 4)), {CB(5, 5)}));
342 
343   // one side overlap.
344   ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(4, 7)), {CB(1, 3)}));
345   ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(5, 7)), {CB(0, 4)}));
346   ASSERT_TRUE(EQ(subtractBound(CB(4, 5), CB(1, 4)), {CB(5, 5)}));
347   ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 4)), {CB(5, 5)}));
348 
349   // both sides overlap.
350   ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 7)), {}));
351   ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(5, 7)), {}));
352 
353   // internal overlap.
354   ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(2, 3)), {CB(1, 1), CB(4, 5)}));
355   ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(2, 4)), {CB(0, 1), CB(5, 5)}));
356 }
357 
TEST(MemDependency,BoundSubtractSymbolic)358 TEST(MemDependency, BoundSubtractSymbolic) {
359   VarHandle x("x", kInt);
360   VarHandle y("y", kInt);
361   VarHandle z("z", kInt);
362   VarHandle w("w", kInt);
363 
364   using namespace analysis;
365 
366   auto CB = [](ExprHandle s, ExprHandle e) {
367     return Bound(s.node(), e.node());
368   };
369   auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
370     return indexBoundsEquals(x, y);
371   };
372 
373   // One element subtract.
374   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
375   ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(x, x)), {}));
376   ASSERT_TRUE(EQ(subtractBound(CB(x + 1, x + 1), CB(x + 1, x + 1)), {}));
377   ASSERT_TRUE(EQ(subtractBound(CB(x * 2, x * 2), CB(x * 2, x * 2)), {}));
378 
379   // Subtract constant range low.
380   ASSERT_TRUE(
381       EQ(subtractBound(CB(x, x + 10), CB(x, x + 4)), {CB(x + 5, x + 10)}));
382   // Subtract constant range high.
383   ASSERT_TRUE(
384       EQ(subtractBound(CB(x, x + 10), CB(x + 6, x + 12)), {CB(x, x + 5)}));
385   // Subtract constant range total overlap.
386   ASSERT_TRUE(EQ(subtractBound(CB(x, x + 10), CB(x, x + 10)), {}));
387   ASSERT_TRUE(EQ(subtractBound(CB(x + 2, x + 10), CB(x, x + 12)), {}));
388   // Subtract constant range internal.
389   ASSERT_TRUE(
390       EQ(subtractBound(CB(x, x + 10), CB(x + 3, x + 7)),
391          {CB(x, x + 2), CB(x + 8, x + 10)}));
392 
393   // Size is inferable but not constant, only works with a single var.
394   ASSERT_TRUE(EQ(subtractBound(CB(0, x), CB(0, x * 2)), {}));
395   ASSERT_TRUE(EQ(subtractBound(CB(0, x * 2), CB(0, x - 1)), {CB(x, x * 2)}));
396 
397   // Size is not inferable.
398   ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(z, w)), {CB(x, y)}));
399   ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(x, z)), {CB(x, y)}));
400   ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(0, x)), {CB(x, y)}));
401   ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(0, 0)), {CB(x, x)}));
402 }
403 
404 // Tests the helper function that does subtraction, but for multi dimensional
405 // indices bounds.
TEST(MemDependency,BoundSubtractMultiDim)406 TEST(MemDependency, BoundSubtractMultiDim) {
407   using namespace analysis;
408 
409   auto CB = [](int s, int e) {
410     return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
411   };
412   auto EQ = [](std::vector<IndexBounds> x, std::vector<IndexBounds> y) {
413     if (x.size() != y.size()) {
414       return false;
415     }
416     for (auto i = 0U; i < x.size(); ++i) {
417       if (!indexBoundsEquals(x[i], y[i])) {
418         return false;
419       }
420     }
421     return true;
422   };
423 
424   // sanity check one dimension.
425   ASSERT_TRUE(EQ(subtractIndicesBounds({CB(0, 9)}, {CB(0, 9)}), {}));
426   ASSERT_TRUE(EQ(subtractIndicesBounds({CB(3, 9)}, {CB(0, 12)}), {}));
427   ASSERT_TRUE(
428       EQ(subtractIndicesBounds({CB(0, 12)}, {CB(0, 9)}), {{CB(10, 12)}}));
429   ASSERT_TRUE(
430       EQ(subtractIndicesBounds({CB(0, 12)}, {CB(3, 12)}), {{CB(0, 2)}}));
431   ASSERT_TRUE(EQ(
432       subtractIndicesBounds({CB(0, 9)}, {CB(1, 8)}), {{CB(0, 0)}, {CB(9, 9)}}));
433 
434   // Multi dim total overlap.
435   ASSERT_TRUE(EQ(
436       subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 9), CB(0, 2)}), {}));
437   ASSERT_TRUE(EQ(
438       subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {}));
439 
440   // Mutli dim one way partial in dim 1.
441   ASSERT_TRUE(
442       EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}),
443          {{CB(4, 9), CB(0, 2)}}));
444 
445   // Mutli dim one way partial in dim 2.
446   ASSERT_TRUE(
447       EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}),
448          {{CB(0, 9), CB(11, 20)}}));
449 
450   // Partial overlap in 2 dims.
451   ASSERT_TRUE(
452       EQ(subtractIndicesBounds({CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8)}),
453          {{CB(0, 1), CB(0, 5)}, {CB(2, 5), CB(0, 1)}}));
454 
455   // Partial overlap in 3 dims.
456   ASSERT_TRUE(
457       EQ(subtractIndicesBounds(
458              {CB(0, 5), CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8), CB(2, 8)}),
459          {{CB(0, 1), CB(0, 5), CB(0, 5)},
460           {CB(2, 5), CB(0, 1), CB(0, 5)},
461           {CB(2, 5), CB(2, 5), CB(0, 1)}}));
462 }
463 
464 // Tests the multi dimensional subtraction code for bounds that cannot be fully
465 // materialized.
TEST(MemDependency,BoundSubtractMultiDimSymbolic)466 TEST(MemDependency, BoundSubtractMultiDimSymbolic) {
467   VarHandle x("x", kInt);
468   VarHandle y("y", kInt);
469 
470   using namespace analysis;
471 
472   auto CB = [](ExprHandle s, ExprHandle e) {
473     return Bound(s.node(), e.node());
474   };
475 
476   auto EQ = [](std::vector<IndexBounds> x, std::vector<IndexBounds> y) {
477     if (x.size() != y.size()) {
478       return false;
479     }
480     for (auto i = 0U; i < x.size(); ++i) {
481       if (!indexBoundsEquals(x[i], y[i])) {
482         return false;
483       }
484     }
485     return true;
486   };
487 
488   // Cannot determine overlaps.
489   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
490   ASSERT_TRUE(EQ(subtractIndicesBounds({CB(x, x)}, {CB(0, 0)}), {{CB(x, x)}}));
491 
492   // Various total Overlaps.
493   ASSERT_TRUE(EQ(
494       subtractIndicesBounds({CB(x, x), CB(x, x)}, {CB(x, x), CB(x, x)}), {}));
495   ASSERT_TRUE(EQ(
496       subtractIndicesBounds({CB(x, y), CB(x, y)}, {CB(x, y), CB(x, y)}), {}));
497   ASSERT_TRUE(EQ(
498       subtractIndicesBounds({CB(x, x), CB(y, y)}, {CB(x, x), CB(y, y)}), {}));
499   ASSERT_TRUE(EQ(
500       subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(0, y)}), {}));
501 
502   // one-way overlap in first dim.
503   ASSERT_TRUE(
504       EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x - 5), CB(0, y)}),
505          {{CB(x - 4, x), CB(0, y)}}));
506   // second dim.
507   ASSERT_TRUE(
508       EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(5, y)}),
509          {{CB(0, x), CB(0, 4)}}));
510 
511   // Internal overlap in first dim.
512   ASSERT_TRUE(
513       EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(2, x - 5), CB(0, y)}),
514          {{CB(0, 1), CB(0, y)}, {CB(x - 4, x), CB(0, y)}}));
515   // second dim.
516   ASSERT_TRUE(EQ(
517       subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(10, y - 10)}),
518       {{CB(0, x), CB(0, 9)}, {CB(0, x), CB(y - 9, y)}}));
519 
520   // Overlap in both dimensions.
521   ASSERT_TRUE(
522       EQ(subtractIndicesBounds(
523              {CB(0, x), CB(0, y)}, {CB(5, x - 5), CB(10, y - 10)}),
524          {
525              {CB(0, 4), CB(0, y)},
526              {CB(x - 4, x), CB(0, y)},
527              {CB(0, x), CB(0, 9)},
528              {CB(0, x), CB(y - 9, y)},
529          }));
530 }
531 
532 // Simple check that the analyzer does anything at all...
TEST(MemDependency,MemDependencyCheckerSimple)533 TEST(MemDependency, MemDependencyCheckerSimple) {
534   BufHandle a("A", {1}, kInt);
535   BufHandle b("B", {1}, kInt);
536 
537   analysis::MemDependencyChecker analyzer;
538 
539   /*
540    * A[0] = 3;
541    * B[0] = A[0] + 1;
542    */
543 
544   StorePtr aStore = Store::make(a, {0}, 3);
545   StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
546 
547   StmtPtr stmt = Block::make({aStore, bStore});
548 
549   stmt->accept(&analyzer);
550 
551   ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
552   ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore));
553   // sanity check, but anything that depends directly must depend indirectly.
554   ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aStore));
555 }
556 
557 // Check that there is a difference between direct and indirect dependence.
TEST(MemDependency,MemDependencyCheckerMultiStmt)558 TEST(MemDependency, MemDependencyCheckerMultiStmt) {
559   BufHandle a("A", {1}, kInt);
560   BufHandle b("B", {1}, kInt);
561   BufHandle c("C", {1}, kInt);
562 
563   analysis::MemDependencyChecker analyzer;
564 
565   /*
566    * A[0] = 3;
567    * B[0] = A[0];
568    * C[0] = B[0] + 1;
569    */
570 
571   StorePtr aStore = Store::make(a, {0}, 3);
572   StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
573   StorePtr cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1));
574 
575   StmtPtr stmt = Block::make({aStore, bStore, cStore});
576 
577   stmt->accept(&analyzer);
578 
579   // C depends on A indirectly.
580   ASSERT_FALSE(analyzer.dependsDirectly(cStore, aStore));
581   ASSERT_TRUE(analyzer.dependsIndirectly(cStore, aStore));
582 
583   // C depends on B directly, which depends on A directly.
584   ASSERT_TRUE(analyzer.dependsDirectly(cStore, bStore));
585   ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
586 
587   // Dependency goes top to bottom only.
588   ASSERT_FALSE(analyzer.dependsIndirectly(bStore, cStore));
589   ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore));
590   ASSERT_FALSE(analyzer.dependsIndirectly(aStore, cStore));
591 }
592 
593 // Verify that we do filter writes that are totally overlapped by later writes.
TEST(MemDependency,MemDependencyCheckerOverlap)594 TEST(MemDependency, MemDependencyCheckerOverlap) {
595   BufHandle a("A", {1}, kInt);
596   BufHandle b("B", {1}, kInt);
597 
598   analysis::MemDependencyChecker analyzer;
599 
600   /*
601    * A[0] = 3;
602    * A[0] = 6;
603    * B[0] = A[0] + 1;
604    */
605 
606   StorePtr aStore = Store::make(a, {0}, 3);
607   StorePtr a2Store = Store::make(a, {0}, 6);
608   StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
609 
610   StmtPtr stmt = Block::make({aStore, a2Store, bStore});
611 
612   stmt->accept(&analyzer);
613 
614   // B store depends on second A store but not first since it is completely
615   // overlapped.
616   ASSERT_TRUE(analyzer.dependsIndirectly(bStore, a2Store));
617   ASSERT_FALSE(analyzer.dependsIndirectly(bStore, aStore));
618 
619   // No dependency between either A store.
620   ASSERT_FALSE(analyzer.dependsIndirectly(aStore, a2Store));
621   ASSERT_FALSE(analyzer.dependsIndirectly(a2Store, aStore));
622 }
623 
624 // Verify that bounds match loop iterations, and that dependencies progress
625 // across loop scopes.
TEST(MemDependency,MemDependencyCheckerLoop)626 TEST(MemDependency, MemDependencyCheckerLoop) {
627   BufHandle a("A", {1}, kInt);
628   BufHandle b("B", {1}, kInt);
629   VarHandle x("x", kInt);
630 
631   using namespace analysis;
632 
633   MemDependencyChecker analyzer;
634 
635   /*
636    * for (int x = 0; x < 10; ++x) {
637    *   A[x] = x;
638    * }
639    * B[0] = A[0] + 1;
640    */
641 
642   StorePtr aStore = Store::make(a, {x}, x);
643   StmtPtr loop = For::make(x, 0, 10, aStore);
644   StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1));
645 
646   StmtPtr stmt = Block::make({loop, bStore});
647 
648   stmt->accept(&analyzer);
649 
650   // Same A->B dependency.
651   ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
652 
653   // B depends on the loop.
654   ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
655   // A is in the loop but does not depend on any loop iteration.
656   ASSERT_FALSE(analyzer.dependsIndirectly(aStore, loop));
657 
658   auto aStoreAccess = analyzer.accessFor(aStore);
659   ASSERT_NE(aStoreAccess, nullptr);
660 
661   // It should have bounds covering the range of x: 0 <= x < 10.
662   ASSERT_TRUE(indexBoundsEquals(
663       aStoreAccess->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
664 }
665 
666 // Reductions should promote dependencies as well.
TEST(MemDependency,MemDependencyCheckerLoopReduce)667 TEST(MemDependency, MemDependencyCheckerLoopReduce) {
668   BufHandle a("A", {10}, kInt);
669   BufHandle b("B", {10}, kInt);
670   VarHandle x("x", kInt);
671 
672   using namespace analysis;
673 
674   MemDependencyChecker analyzer;
675 
676   /*
677    * A[0] = 0;
678    * for (int x = 0; x < 10; ++x) {
679    *   A[0] = A[x] + 1;
680    * }
681    * B[0] = A[0];
682    */
683 
684   StorePtr aInit = Store::make(a, {0}, 0);
685   ExprHandle reduce = Sum()(a, 1, {x}, {x});
686   StorePtr aReduce = Store::make(a, {0}, reduce);
687   StmtPtr loop = For::make(x, 0, 10, aReduce);
688   StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
689 
690   StmtPtr stmt = Block::make({aInit, loop, bStore});
691 
692   stmt->accept(&analyzer);
693 
694   // B -> A.
695   ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce));
696 
697   // B depends indirectly on the initializer of A, since the reduction depends
698   // on it.
699   ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit));
700   ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit));
701 
702   ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit));
703 
704   // B depends on the loop.
705   ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
706   // A is in the loop and depends on other iterations.
707   ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop));
708 
709   // The loop contents depend on the initializer too.
710   ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit));
711 
712   // Find loads within the reduction:
713   auto reduceLoads = NodeFinder<Load>::find(reduce.node());
714   // Pull out the access for the load inside the loop.
715   for (auto load : reduceLoads) {
716     auto loopLoad = analyzer.accessFor(load);
717     // It should have 10 element long bounds.
718     ASSERT_TRUE(indexBoundsEquals(
719         loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
720   }
721 }
722 
723 // Lowering a reduction doesn't affect dependency analysis.
TEST(MemDependency,MemDependencyCheckerLoopReduceExpanded)724 TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) {
725   BufHandle a("A", {10}, kInt);
726   BufHandle b("B", {10}, kInt);
727   VarHandle x("x", kInt);
728 
729   using namespace analysis;
730 
731   MemDependencyChecker analyzer;
732 
733   /*
734    * A[0] = 0;
735    * for (int x = 0; x < 10; ++x) {
736    *   A[0] = A[x] + 1;
737    * }
738    * B[0] = A[0];
739    */
740 
741   StorePtr aInit = Store::make(a, {0}, 0);
742   ExprHandle aLoad = Load::make(a, {x});
743   StorePtr aReduce = Store::make(a, {0}, Add::make(aLoad, 1));
744   StmtPtr loop = For::make(x, 0, 10, aReduce);
745   StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
746 
747   StmtPtr stmt = Block::make({aInit, loop, bStore});
748 
749   stmt->accept(&analyzer);
750 
751   // B -> A.
752   ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce));
753 
754   // B depends indirectly on the initializer of A, since the reduction depends
755   // on it.
756   ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit));
757   ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit));
758 
759   ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit));
760 
761   // B depends on the loop.
762   ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
763   // A is in the loop and depends on other iterations.
764   ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop));
765 
766   // The loop contents depend on the initializer too.
767   ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit));
768 
769   // Pull out the access for the store inside the loop.
770   auto loopLoad = analyzer.accessFor(aLoad.node());
771   // It should have 10 element long bounds.
772   ASSERT_TRUE(indexBoundsEquals(
773       loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
774 }
775 
776 // Can determine dependencies of outputs, through to inputs.
TEST(MemDependency,MemDependencyCheckerInputsOutputs)777 TEST(MemDependency, MemDependencyCheckerInputsOutputs) {
778   BufHandle a("A", {10}, kInt);
779   BufHandle b("B", {10}, kInt);
780   VarHandle x("x", kInt);
781 
782   // initialize analyzer with inputs and outputs.
783   analysis::MemDependencyChecker analyzer({a}, {b});
784 
785   // Here's a Relu.
786   /*
787    * for (int x = 0; x < 10; ++x) {
788    *   B[x] = Max(A[x], 0);
789    * }
790    */
791 
792   ExprHandle aLoad = Load::make(a, {x});
793   StorePtr bStore = Store::make(b, {x}, Max::make(aLoad, 0, true));
794   StmtPtr loop = For::make(x, 0, 10, bStore);
795 
796   StmtPtr stmt = Block::make({loop});
797 
798   stmt->accept(&analyzer);
799 
800   // Output depends indirectly on input.
801   ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
802   // aLoad depends directly on the input A.
803   ASSERT_TRUE(analyzer.dependsDirectly(aLoad.node(), a.node()));
804   // bStore therefore depends directly on the input A.
805   ASSERT_TRUE(analyzer.dependsDirectly(bStore, a.node()));
806   // The output depends directly on the store.
807   ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore));
808 
809   // Check AccessInfo based overloads.
810   auto input = analyzer.input(a.node());
811   auto output = analyzer.output(b.node());
812 
813   // Output depends indirectly on input.
814   ASSERT_TRUE(analyzer.dependsIndirectly(output, input));
815   // Not directly.
816   ASSERT_FALSE(analyzer.dependsDirectly(output, input));
817   // Not in reverse order.
818   ASSERT_FALSE(analyzer.dependsIndirectly(input, output));
819 
820   // output -> bStore -> bLoad -> input.
821   auto storeAccess = analyzer.accessFor(bStore);
822   auto loadAccess = analyzer.accessFor(aLoad.node());
823 
824   ASSERT_TRUE(analyzer.dependsDirectly(output, storeAccess));
825   ASSERT_TRUE(analyzer.dependsDirectly(loadAccess, input));
826 }
827 
828 // Can tell if an output does not depend on an input.
TEST(MemDependency,MemDependencyCheckerOutputDoesntDepend)829 TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) {
830   BufHandle a("A", {10}, kInt);
831   BufHandle b("B", {10}, kInt);
832   VarHandle x("x", kInt);
833 
834   // initialize analyzer with inputs and outputs.
835   analysis::MemDependencyChecker analyzer({a}, {b});
836 
837   // Here's a dumb Relu.
838   /*
839    * for (int x = 0; x < 10; ++x) {
840    *   B[x] = Max(x, 0);
841    * }
842    */
843 
844   StorePtr bStore = Store::make(b, {x}, Max::make(x, 0, true));
845   StmtPtr loop = For::make(x, 0, 10, bStore);
846 
847   StmtPtr stmt = Block::make({loop});
848 
849   stmt->accept(&analyzer);
850 
851   // Output does not depend indirectly on input.
852   ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), a.node()));
853 
854   // The output still depends directly on the store.
855   ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore));
856 
857   // Check AccessInfo based overloads.
858   auto input = analyzer.input(a.node());
859   auto output = analyzer.output(b.node());
860 
861   // Output does not depend indirectly on input.
862   ASSERT_FALSE(analyzer.dependsIndirectly(output, input));
863 }
864 
865 // Verify different loop extents produce accesses with different bounds, and
866 // that later accesses find dependencies that overlap their entire bound range.
TEST(MemDependency,MemDependencyCheckerLoopBounds)867 TEST(MemDependency, MemDependencyCheckerLoopBounds) {
868   BufHandle a("A", {10}, kInt);
869   BufHandle b("B", {10}, kInt);
870   BufHandle c("C", {10}, kInt);
871   VarHandle x("x", kInt);
872   using namespace analysis;
873 
874   MemDependencyChecker analyzer({a}, {c});
875 
876   // This enables using the execution order of the loops to determine if some
877   // loops are self dependent or not.
878   analyzer.allowLoopExecutionOrderAnalysis();
879 
880   /*
881    * for (int x = 1; x < 10; ++x) {
882    *   B[x] = A[x];
883    * }
884    * for (int x = 1; x < 9; ++x) {
885    *   B[x] = B[x] * 2;
886    * }
887    * for (int x = 3; x < 4; ++x) {
888    *   C[x] = A[x];
889    * }
890    * for (int x = 0; x < 10; ++x) {
891    *   C[x] = B[x];
892    * }
893    */
894 
895   std::vector<StmtPtr> stmts(
896       {For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}))),
897        For::make(
898            x, 1, 9, Store::make(b, {x}, Mul::make(Load::make(b, {x}), 2))),
899        For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}))),
900        For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))});
901 
902   StmtPtr stmt = Block::make(stmts);
903 
904   stmt->accept(&analyzer);
905 
906   auto input = analyzer.input(a.node());
907   auto output = analyzer.output(c.node());
908 
909   // sanity check Output -> Input.
910   ASSERT_TRUE(analyzer.dependsIndirectly(output, input));
911 
912   // Check the For loop dependencies:
913 
914   // Last write to C depends on both writes to B since they contain the last
915   // write to at least one element.
916   ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[1]));
917   ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[0]));
918 
919   // The last write to C does not depend on the other write to C.
920   ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2]));
921 
922   auto CB = [](int s, int e) {
923     return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
924   };
925   auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
926     return indexBoundsEquals(x, y);
927   };
928 
929   /*  0. Input: A[(0, 9)] - dependents: 1 5
930    *  1. Load: A[(1, 9)] - depends on: 0  - dependents: 2
931    *  2. Store: B[(1, 9)] - depends on: 1  - dependents: 3 7
932    *  3. Load: B[(1, 8)] - depends on: 2  - dependents: 4
933    *  4. Store: B[(1, 8)] - depends on: 3  - dependents: 7
934    *  5. Load: A[(3, 3)] - depends on: 0  - dependents: 6
935    *  6. Store: C[(3, 3)] - depends on: 5
936    *  7. Load: B[(0, 9)] - depends on: 2 4  - dependents: 8
937    *  8. Store: C[(0, 9)] - depends on: 7  - dependents: 9
938    *  9. Output: C[(0, 9)] - depends on: 8
939    */
940 
941   // Now let's look at the bounds of each access.
942   // There are 9 accesses in this Stmt, so this is exhaustive, we wont do this
943   // much.
944   auto history = analyzer.getHistory();
945   ASSERT_EQ(history.size(), 10);
946   VarPtr aVar = a.node()->base_handle();
947   VarPtr bVar = b.node()->base_handle();
948   VarPtr cVar = c.node()->base_handle();
949 
950   // The first access is the input A.
951   ASSERT_EQ(history[0]->type(), AccessType::Input);
952   ASSERT_EQ(history[0]->var(), aVar);
953   // It has the bounds of the producing Input.
954   ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)}));
955   // sanity check the input we retrieved earlier matches.
956   ASSERT_EQ(history[0], input);
957 
958   // The second access is the load of A in the first loop.
959   ASSERT_EQ(history[1]->type(), AccessType::Load);
960   ASSERT_EQ(history[1]->var(), aVar);
961   // It has the bounds of the loop, i.e. start == 1.
962   ASSERT_TRUE(EQ(history[1]->bounds(), {CB(1, 9)}));
963   // It reads from A, so it should have a dependency on the last write to this
964   // range - with is the input.
965   ASSERT_EQ(history[1]->dependencies().size(), 1);
966   ASSERT_TRUE(history[1]->hasDependency(history[0]));
967 
968   // The third access is the store into B in the first loop.
969   ASSERT_EQ(history[2]->type(), AccessType::Store);
970   ASSERT_EQ(history[2]->var(), bVar);
971   // It also has the bounds of the loop, i.e. start == 1.
972   ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)}));
973   // The previous load is in its RHS, so it depends on it.
974   ASSERT_EQ(history[2]->dependencies().size(), 1);
975   ASSERT_TRUE(history[2]->hasDependency(history[1]));
976 
977   // The third access is the load from B in the second loop.
978   ASSERT_EQ(history[3]->type(), AccessType::Load);
979   ASSERT_EQ(history[3]->var(), bVar);
980   // It has the bounds of the second loop, i.e. >= 1 < 9.
981   ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 8)}));
982   // It reads from B in a smaller range, so should depend on the previous
983   // store.
984   ASSERT_EQ(history[3]->dependencies().size(), 1);
985   ASSERT_TRUE(history[3]->hasDependency(history[2]));
986 
987   // The fourth: the store to B in the second loop.
988   ASSERT_EQ(history[4]->type(), AccessType::Store);
989   ASSERT_EQ(history[4]->var(), bVar);
990   // It also has the bounds of the second loop.
991   ASSERT_TRUE(EQ(history[4]->bounds(), {CB(1, 8)}));
992   // The previous load is in its RHS, so it depends on it as before.
993   ASSERT_EQ(history[4]->dependencies().size(), 1);
994   ASSERT_TRUE(history[4]->hasDependency(history[3]));
995 
996   // The fifth access is the load is from the 3rd loop, and skips previous B
997   // accesses.
998   ASSERT_EQ(history[5]->type(), AccessType::Load);
999   ASSERT_EQ(history[5]->var(), aVar);
1000   // It has the bounds of the third loop: >= 3 < 4.
1001   ASSERT_TRUE(EQ(history[5]->bounds(), {CB(3, 3)}));
1002   // It depends on the last thing to write to A, which is the A input.
1003   ASSERT_EQ(history[5]->dependencies().size(), 1);
1004   ASSERT_TRUE(history[5]->hasDependency(history[0]));
1005 
1006   // Sixth: the store into the output C.
1007   ASSERT_EQ(history[6]->type(), AccessType::Store);
1008   ASSERT_EQ(history[6]->var(), cVar);
1009   // It also has the bounds of the third loop.
1010   ASSERT_TRUE(EQ(history[6]->bounds(), {CB(3, 3)}));
1011   // The previous load is in its RHS, so it depends on it as always.
1012   ASSERT_EQ(history[6]->dependencies().size(), 1);
1013   ASSERT_TRUE(history[6]->hasDependency(history[5]));
1014 
1015   // The seventh access is the load of B in the fourth loop.
1016   ASSERT_EQ(history[7]->type(), AccessType::Load);
1017   ASSERT_EQ(history[7]->var(), bVar);
1018   // It has the bounds of the final loop, >= 0 < 10
1019   ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)}));
1020   // The bounds of this read are larger than the bounds of the previous write,
1021   // so it depends on both previous Stores to B.
1022   ASSERT_EQ(history[7]->dependencies().size(), 2);
1023   ASSERT_TRUE(history[7]->hasDependency(history[2]));
1024   ASSERT_TRUE(history[7]->hasDependency(history[4]));
1025 
1026   // Eight: the final store into the output C.
1027   ASSERT_EQ(history[8]->type(), AccessType::Store);
1028   ASSERT_EQ(history[8]->var(), cVar);
1029   // It also has the bounds of the final loop.
1030   ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)}));
1031   // The previous load is in its RHS, so it depends on it as always.
1032   ASSERT_EQ(history[8]->dependencies().size(), 1);
1033   ASSERT_TRUE(history[8]->hasDependency(history[7]));
1034 
1035   // The last access represents the output Buf.
1036   ASSERT_EQ(history[9]->type(), AccessType::Output);
1037   ASSERT_EQ(history[9]->var(), cVar);
1038   // It has the bounds of the output Buf.
1039   ASSERT_TRUE(EQ(history[9]->bounds(), {CB(0, 9)}));
1040   // sanity check the input we retrieved earlier matches.
1041   ASSERT_EQ(history[9], output);
1042   // It depends on the last write to C only.
1043   ASSERT_EQ(history[9]->dependencies().size(), 1);
1044   ASSERT_TRUE(history[9]->hasDependency(history[8]));
1045 }
1046 
1047 // Verify that we can still infer bounds when the loop var is offset.
TEST(MemDependency,MemDependencyCheckerLoopBoundsIndexShift)1048 TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) {
1049   BufHandle a("A", {10}, kInt);
1050   BufHandle b("B", {10}, kInt);
1051   VarHandle x("x", kInt);
1052 
1053   using namespace analysis;
1054 
1055   MemDependencyChecker analyzer({a}, {b});
1056 
1057   // This enables using the execution order of the loops to determine if some
1058   // loops are self dependent or not.
1059   analyzer.allowLoopExecutionOrderAnalysis();
1060 
1061   /*
1062    * for (int x = 1; x < 10; x++) {
1063    *   A[x] = A[x - 1];
1064    * }
1065    * for (int x = 0; x < 9; x++) {
1066    *   A[x] = A[x + 1];
1067    * }
1068    * for (int x = 0; x < 9; x++) {
1069    *   A[9 - x] = A[8 - x];
1070    * }
1071    * for (int x = 0; x < 10; x++) {
1072    *   A[x] = A[9 - x];
1073    * }
1074    * for (int x = 0; x < 10; x++) {
1075    *   B[x] = A[x];
1076    * }
1077    */
1078 
1079   StmtPtr stmt = Block::make(
1080       {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),
1081        For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}))),
1082        For::make(
1083            x,
1084            0,
1085            9,
1086            Store::make(
1087                a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))),
1088        For::make(
1089            x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))),
1090        For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})))});
1091 
1092   stmt->accept(&analyzer);
1093 
1094   // Sanity check output depends on Input.
1095   ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
1096 
1097   auto CB = [](int s, int e) {
1098     return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
1099   };
1100   auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
1101     return indexBoundsEquals(x, y);
1102   };
1103 
1104   /*  0. Input: A[(0, 9)] - dependents: 1
1105    *  1. Load: A[(0, 8)] - depends on: 0 2  - dependents: 2
1106    *  2. Store: A[(1, 9)] - depends on: 1  - dependents: 1 3
1107    *  3. Load: A[(1, 9)] - depends on: 2  - dependents: 4
1108    *  4. Store: A[(0, 8)] - depends on: 3  - dependents: 5 7
1109    *  5. Load: A[(0, 8)] - depends on: 4  - dependents: 6
1110    *  6. Store: A[(1, 9)] - depends on: 5  - dependents: 7
1111    *  7. Load: A[(0, 9)] - depends on: 4 6 8  - dependents: 8
1112    *  8. Store: A[(0, 9)] - depends on: 7  - dependents: 7 9
1113    *  9. Load: A[(0, 9)] - depends on: 8  - dependents: 10
1114    *  10. Store: B[(0, 9)] - depends on: 9  - dependents: 11
1115    *  11. Output: B[(0, 9)] - depends on: 10
1116    */
1117 
1118   // Now let's look at the bounds of each access.
1119   auto history = analyzer.getHistory();
1120   ASSERT_EQ(history.size(), 12);
1121   VarPtr aVar = a.node()->base_handle();
1122   VarPtr bVar = b.node()->base_handle();
1123 
1124   // The first access is the input A.
1125   ASSERT_EQ(history[0]->type(), AccessType::Input);
1126   ASSERT_EQ(history[0]->var(), aVar);
1127   // It has the bounds of the producing Input.
1128   ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)}));
1129 
1130   // The second access is the load A[x-1].
1131   ASSERT_EQ(history[1]->type(), AccessType::Load);
1132   ASSERT_EQ(history[1]->var(), aVar);
1133   // It has the bounds of the loop modified by the offset of each index, in
1134   // this case -1.
1135   ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)}));
1136   // It depends on the input, but also the store in the same loop, since
1137   // different interations of the loop depend on each other.
1138   ASSERT_EQ(history[1]->dependencies().size(), 2);
1139   ASSERT_TRUE(history[1]->hasDependency(history[0]));
1140   ASSERT_TRUE(history[1]->hasDependency(history[2]));
1141 
1142   // The third access is the Store to A[x] in the first loop.
1143   ASSERT_EQ(history[2]->type(), AccessType::Store);
1144   ASSERT_EQ(history[2]->var(), aVar);
1145   // It has no offset on x, so should have the same bounds as the loop.
1146   ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)}));
1147 
1148   // The fourth access is the load A[x+1] in the second loop.
1149   ASSERT_EQ(history[3]->type(), AccessType::Load);
1150   ASSERT_EQ(history[3]->var(), aVar);
1151   // It has the bounds of the loop (0 <= x < 9) modified by the offset of each
1152   // index, in this case 1.
1153   ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 9)}));
1154   // This load totally overlaps the previous write to A, so it depends only on
1155   // it and not the input.
1156   ASSERT_EQ(history[3]->dependencies().size(), 1);
1157   ASSERT_TRUE(history[3]->hasDependency(history[2]));
1158 
1159   // The fifth access is the store to A[x] in the second loop.
1160   ASSERT_EQ(history[4]->type(), AccessType::Store);
1161   ASSERT_EQ(history[4]->var(), aVar);
1162   // It has no offset on x, so should have the same bounds as the loop.
1163   ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, 8)}));
1164 
1165   // The sixth access is the load to A[8 - x] in the third loop.
1166   ASSERT_EQ(history[5]->type(), AccessType::Load);
1167   ASSERT_EQ(history[5]->var(), aVar);
1168   // It has the bounds of the loop (0 <= x < 9) modified by the offset of each
1169   // index, in this case 8 - x.
1170   // This access has a negative stride, which will be normalized.
1171   ASSERT_TRUE(EQ(history[5]->bounds(), {CB(0, 8)}));
1172   // This load totally overlaps the most recent write to A, so it depends only
1173   // on it and not the input or the first write to A.
1174   ASSERT_EQ(history[5]->dependencies().size(), 1);
1175   ASSERT_TRUE(history[5]->hasDependency(history[4]));
1176 
1177   // The seventh access is the store to A[9 - x] in the third loop.
1178   ASSERT_EQ(history[6]->type(), AccessType::Store);
1179   ASSERT_EQ(history[6]->var(), aVar);
1180   // This store has a negative stride on it's indices, but is normalized
1181   // internally.
1182   ASSERT_TRUE(EQ(history[6]->bounds(), {CB(1, 9)}));
1183 
1184   // The eighth access is the load A[9-x] in the second loop.
1185   ASSERT_EQ(history[7]->type(), AccessType::Load);
1186   ASSERT_EQ(history[7]->var(), aVar);
1187   // It has the bounds of the loop (0 <= x < 9), modified by the offset 9 - x,
1188   // which essentially traverses the loop backwards.
1189   ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)}));
1190   // This Load has three write dependencies:
1191   ASSERT_EQ(history[7]->dependencies().size(), 3);
1192   //  * The previous store (#6) for elements 1-9
1193   ASSERT_TRUE(history[7]->hasDependency(history[6]));
1194   //  * An earlier store (#4) covering element 0
1195   ASSERT_TRUE(history[7]->hasDependency(history[4]));
1196   //  * A future store inside this loop, since this loop modifies the buffer
1197   //  in a non distinct way (due to the load and store having different access
1198   //  strides).
1199   ASSERT_TRUE(history[7]->hasDependency(history[8]));
1200 
1201   // The ninth access is the store to A[x] in the fourth loop.
1202   ASSERT_EQ(history[8]->type(), AccessType::Store);
1203   ASSERT_EQ(history[8]->var(), aVar);
1204   // This store has a negative stride on it's indices, but is normalized
1205   // internally.
1206   ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)}));
1207 
1208   // The tenth and 11th accesses are the copy from A[x] to B[x].
1209   ASSERT_EQ(history[9]->type(), AccessType::Load);
1210   ASSERT_EQ(history[9]->var(), aVar);
1211   ASSERT_EQ(history[10]->type(), AccessType::Store);
1212   ASSERT_EQ(history[10]->var(), bVar);
1213 
1214   // The last access represents the output Buf.
1215   ASSERT_EQ(history[11]->type(), AccessType::Output);
1216   ASSERT_EQ(history[11]->var(), bVar);
1217   // It has the bounds of the output Buf.
1218   ASSERT_TRUE(EQ(history[11]->bounds(), {CB(0, 9)}));
1219   // It depends on the last write to B only.
1220   ASSERT_EQ(history[11]->dependencies().size(), 1);
1221   ASSERT_TRUE(history[11]->hasDependency(history[10]));
1222 
1223   // ok that's enough of that.
1224 }
1225 
1226 // Check many different cases of loop self dependency - when a load within a
1227 // loop is dependent on a Store later in the same loop but in different
1228 // iteration. This is affected by whether or not we can trust the execution
1229 // order of the loop.
TEST(MemDependency,MemDependencyCheckerLoopSelfDependency)1230 TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) {
1231   BufHandle a("A", {5}, kInt);
1232   BufHandle b("B", {5}, kInt);
1233   VarHandle x("x", kInt);
1234   VarHandle y("y", kInt);
1235   VarHandle z("z", kInt);
1236 
1237   using namespace analysis;
1238 
1239   // This check assumes that the Stmt has a single Store with a single Load on
1240   // the RHS.
1241   auto isSelfDependent =
1242       [](const std::vector<std::shared_ptr<AccessInfo>>& history) -> bool {
1243     return history.front()->hasDependency(history.back());
1244   };
1245 
1246   {
1247     /* for (int y = 0; y < 10; y++) {
1248      *   A[y] = (A[y]) + 1;
1249      * } */
1250 
1251     // Not self dependent since all loop iterations use a different y.
1252 
1253     MemDependencyChecker analyzer;
1254     StmtPtr stmt = For::make(
1255         y,
1256         0,
1257         10,
1258         Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), 1))}));
1259 
1260     stmt->accept(&analyzer);
1261 
1262     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1263   }
1264 
1265   {
1266     /* for (int y = 0; y < 10; y++) {
1267      *   A[y + 1] = (A[y + 1]) + 1;
1268      * }
1269      */
1270 
1271     // Not self dependent due to different y (with offset).
1272 
1273     MemDependencyChecker analyzer;
1274     StmtPtr stmt = For::make(
1275         y,
1276         0,
1277         10,
1278         Block::make(
1279             {Store::make(a, {y + 1}, Add::make(Load::make(a, {y + 1}), 1))}));
1280 
1281     stmt->accept(&analyzer);
1282 
1283     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1284   }
1285 
1286   {
1287     /* for (int x = 0; x < 10; x++) {
1288      *   A[0] = (A[0]) + x;
1289      * }
1290      */
1291 
1292     // Is self dependent since all loops use a common constant element of A.
1293 
1294     MemDependencyChecker analyzer;
1295     StmtPtr stmt = For::make(
1296         x,
1297         0,
1298         10,
1299         Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}));
1300     stmt->accept(&analyzer);
1301 
1302     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1303   }
1304 
1305   {
1306     /* for (int x = 0; x < 10; x++) {
1307      *   A[0] = (B[0]) + x;
1308      * }
1309      */
1310 
1311     // Is not self dependent because there is no store to the buffer that is
1312     // read.
1313 
1314     MemDependencyChecker analyzer;
1315     StmtPtr stmt = For::make(
1316         x,
1317         0,
1318         10,
1319         Block::make({Store::make(a, {0}, Add::make(Load::make(b, {0}), x))}));
1320     stmt->accept(&analyzer);
1321 
1322     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1323   }
1324 
1325   {
1326     /* for (int x = 0; x < 10; x++) {
1327      *   A[y] = (A[y]) + x;
1328      * }
1329      */
1330 
1331     // Is self dependent since all loops use a common symbolic element of A.
1332 
1333     MemDependencyChecker analyzer;
1334     StmtPtr stmt = For::make(
1335         x,
1336         0,
1337         10,
1338         Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), x))}));
1339     stmt->accept(&analyzer);
1340 
1341     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1342   }
1343 
1344   {
1345     /* for (int x = 0; x < 10; x++) {
1346      *   A[x] = A[x + 1];
1347      * }
1348      */
1349 
1350     // In this case it depends if we are considering execution order.
1351 
1352     MemDependencyChecker analyzer;
1353 
1354     StmtPtr stmt =
1355         For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1})));
1356     stmt->accept(&analyzer);
1357 
1358     // With analysis of order disabled, this is self dependent since the read
1359     // from X+1 and the write to X+1 could be in reverse order.
1360     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1361   }
1362 
1363   {
1364     /* for (int x = 0; x < 10; x++) {
1365      *   A[x] = A[x + 1];
1366      * }
1367      */
1368 
1369     MemDependencyChecker analyzer;
1370     analyzer.allowLoopExecutionOrderAnalysis();
1371 
1372     StmtPtr stmt =
1373         For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1})));
1374     stmt->accept(&analyzer);
1375 
1376     // If order analysis is enabled, this is not dependent since the read for
1377     // each element occurs before the write to that element.
1378     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1379   }
1380 
1381   {
1382     /* for (int x = 1; x < 10; x++) {
1383      *   A[x] = A[x - 1];
1384      * }
1385      */
1386 
1387     MemDependencyChecker analyzer;
1388 
1389     StmtPtr stmt =
1390         For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})));
1391     stmt->accept(&analyzer);
1392 
1393     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1394   }
1395 
1396   {
1397     /* for (int x = 1; x < 10; x++) {
1398      *   A[x] = A[x - 1];
1399      * }
1400      */
1401 
1402     MemDependencyChecker analyzer;
1403     analyzer.allowLoopExecutionOrderAnalysis();
1404 
1405     StmtPtr stmt =
1406         For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})));
1407     stmt->accept(&analyzer);
1408 
1409     // In this case, even with order analysis the Load is dependent on the
1410     // Store, since the write to X occurs before the read from X.
1411     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1412   }
1413 
1414   {
1415     /* for (int x = 0; x < 9; x++) {
1416      *   A[9 - x] = A[8 - x];
1417      * }
1418      */
1419 
1420     // Still works if the execution order is reversed, so long as the read
1421     // comes before the write.
1422 
1423     MemDependencyChecker analyzer;
1424     analyzer.allowLoopExecutionOrderAnalysis();
1425 
1426     StmtPtr stmt = For::make(
1427         x,
1428         3,
1429         10,
1430         Store::make(
1431             a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x})));
1432     stmt->accept(&analyzer);
1433 
1434     // However here was can determine the A store is earlier in the order than
1435     // the load.
1436     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1437   }
1438 
1439   {
1440     /* for (int x = 0; x < 9; x++) {
1441      *   A[8 - x] = A[9 - x];
1442      * }
1443      */
1444 
1445     // But not if it doesn't.
1446 
1447     MemDependencyChecker analyzer;
1448     analyzer.allowLoopExecutionOrderAnalysis();
1449 
1450     StmtPtr stmt = For::make(
1451         x,
1452         3,
1453         10,
1454         Store::make(
1455             a, {ExprHandle(8) - x}, Load::make(a, {ExprHandle(9) - x})));
1456     stmt->accept(&analyzer);
1457 
1458     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1459   }
1460 
1461   {
1462     /* for (int x = 0; x < 9; x++) {
1463      *   A[9 - x] = A[8 - x];
1464      * }
1465      */
1466 
1467     // And not if we're not relying on execution order.
1468 
1469     MemDependencyChecker analyzer;
1470 
1471     StmtPtr stmt = For::make(
1472         x,
1473         3,
1474         10,
1475         Store::make(
1476             a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x})));
1477     stmt->accept(&analyzer);
1478 
1479     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1480   }
1481 
1482   {
1483     /* for (int x = 3; x < 10; x++) {
1484      *   A[x - 2] = A[x - 1];
1485      * }
1486      */
1487 
1488     // Forward order but negative indices.
1489 
1490     MemDependencyChecker analyzer;
1491     analyzer.allowLoopExecutionOrderAnalysis();
1492 
1493     StmtPtr stmt =
1494         For::make(x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1})));
1495     stmt->accept(&analyzer);
1496 
1497     // However here was can determine the A store is earlier in the order than
1498     // the load.
1499     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1500   }
1501 
1502   {
1503     /* for (int x = 0; x < 10; x++) {
1504      *   A[x * 2] = A[x * 2];
1505      * }
1506      */
1507 
1508     // With an access stride.
1509 
1510     MemDependencyChecker analyzer;
1511     // Execution order doesn't matter since the read and the write are totally
1512     // distinct.
1513 
1514     StmtPtr stmt =
1515         For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2})));
1516     stmt->accept(&analyzer);
1517 
1518     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1519   }
1520 
1521   {
1522     /* for (int x = 0; x < 10; x++) {
1523      *   A[x * 2] = A[x * 2 + 1];
1524      * }
1525      */
1526 
1527     // Here we can use the common stride of the accesses to determine they are
1528     // distinct.
1529     // Note, this is the only place (loop self dependency) we use this stride
1530     // to avoid unnecessary dependence.
1531 
1532     MemDependencyChecker analyzer;
1533     // Execution order doesn't matter since the read and the write are totally
1534     // distinct.
1535 
1536     StmtPtr stmt = For::make(
1537         x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1})));
1538     stmt->accept(&analyzer);
1539 
1540     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1541   }
1542 
1543   {
1544     /* for (int x = 0; x < 10; x++) {
1545      *   A[x * 2] = A[x * 2 - 1];
1546      * }
1547      */
1548 
1549     // same if the read is behind the write so long as they are distinct.
1550 
1551     MemDependencyChecker analyzer;
1552     StmtPtr stmt = For::make(
1553         x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1})));
1554     stmt->accept(&analyzer);
1555 
1556     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1557   }
1558 
1559   {
1560     /* for (int x = 0; x < 10; x++) {
1561      *   A[x * 2] = A[x * 2 + 2];
1562      * }
1563      */
1564 
1565     // But not if the offset is in the stride.
1566 
1567     MemDependencyChecker analyzer;
1568     StmtPtr stmt = For::make(
1569         x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2})));
1570     stmt->accept(&analyzer);
1571 
1572     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1573   }
1574 
1575   {
1576     /* for (int x = 0; x < 10; x++) {
1577      *   A[x * 2] = A[x * 2 - 2];
1578      * }
1579      */
1580 
1581     // Works with negative offsets too.
1582 
1583     MemDependencyChecker analyzer;
1584     StmtPtr stmt = For::make(
1585         x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2})));
1586     stmt->accept(&analyzer);
1587 
1588     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1589   }
1590 
1591   {
1592     /* for (int x = 0; x < 10; x++) {
1593      *   A[x * 2] = A[x * 2 + 7];
1594      * }
1595      */
1596 
1597     // Detects accesses are distinct when offset is large but not a multiple
1598     // of stride.
1599     MemDependencyChecker analyzer;
1600     StmtPtr stmt = For::make(
1601         x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7})));
1602     stmt->accept(&analyzer);
1603 
1604     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1605   }
1606 
1607   {
1608     /* for (int x = 0; x < 10; x++) {
1609      *   A[x * 2] = A[x * 2 + 4];
1610      * }
1611      */
1612 
1613     // Works with offsets which are multiples of the stride.
1614     MemDependencyChecker analyzer;
1615     StmtPtr stmt = For::make(
1616         x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4})));
1617     stmt->accept(&analyzer);
1618 
1619     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1620   }
1621 
1622   {
1623     /* for (int x = 0; x < 10; x++) {
1624      *   A[x * 6] = A[x * 6 + 5];
1625      * }
1626      */
1627 
1628     // detects accesses are distinct with large strides when the offset is
1629     // within.
1630 
1631     MemDependencyChecker analyzer;
1632     StmtPtr stmt = For::make(
1633         x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5})));
1634     stmt->accept(&analyzer);
1635 
1636     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1637   }
1638 
1639   {
1640     /* for (int x = 0; x < 10; x++) {
1641      *   A[x * 2] = A[x * 6];
1642      * }
1643      */
1644 
1645     // detects accesses are overlapping when stride is different but a
1646     // multiple.
1647 
1648     MemDependencyChecker analyzer;
1649     StmtPtr stmt =
1650         For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6})));
1651     stmt->accept(&analyzer);
1652 
1653     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1654   }
1655 
1656   {
1657     /* for (int x = 0; x < 10; x++) {
1658      *   A[x * 4] = A[x * 2];
1659      * }
1660      */
1661 
1662     // still works when the read axis is the smaller stride.
1663 
1664     MemDependencyChecker analyzer;
1665     StmtPtr stmt =
1666         For::make(x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2})));
1667     stmt->accept(&analyzer);
1668 
1669     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1670   }
1671 
1672   {
1673     /* for (int x = 0; x < 10; x++) {
1674      *   A[x * 2] = A[x * 6 + 1];
1675      * }
1676      */
1677 
1678     // detects accesses are distinct when stride is different but a multiple
1679     // and there is an offset.
1680 
1681     MemDependencyChecker analyzer;
1682     StmtPtr stmt = For::make(
1683         x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1})));
1684     stmt->accept(&analyzer);
1685 
1686     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1687   }
1688 
1689   {
1690     /* for (int x = 0; x < 10; x++) {
1691      *   A[x * 2] = A[x * 6 + 4];
1692      * }
1693      */
1694 
1695     // The smaller stride determines whether there is overlap.
1696 
1697     MemDependencyChecker analyzer;
1698     StmtPtr stmt = For::make(
1699         x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4})));
1700     stmt->accept(&analyzer);
1701 
1702     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1703   }
1704 
1705   {
1706     /* for (int x = 0; x < 10; x++) {
1707      *   A[x * 2 + 3] = A[x * 6];
1708      * }
1709      */
1710 
1711     // The smaller stride determines whether there is overlap, not the larger.
1712 
1713     MemDependencyChecker analyzer;
1714     StmtPtr stmt = For::make(
1715         x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6})));
1716     stmt->accept(&analyzer);
1717 
1718     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1719   }
1720 
1721   {
1722     /* for (int x = 0; x < 10; x++) {
1723      *   A[x * 2] = A[x * 3 + 1];
1724      * }
1725      */
1726 
1727     // If they have strides with no common multiple > 1, they overlap.
1728     MemDependencyChecker analyzer;
1729     StmtPtr stmt = For::make(
1730         x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1})));
1731     stmt->accept(&analyzer);
1732 
1733     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1734   }
1735 
1736   {
1737     /* for (int x = 0; x < 10; x++) {
1738      *   A[x] = A[x + 10];
1739      * }
1740      */
1741 
1742     // If the offset is greater than the size of the loop, they can't overlap.
1743 
1744     MemDependencyChecker analyzer;
1745     StmtPtr stmt =
1746         For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10})));
1747     stmt->accept(&analyzer);
1748 
1749     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1750   }
1751 
1752   {
1753     /* for (int x = 0; x < 10; x++) {
1754      *   A[x] = A[9 - x];
1755      * }
1756      */
1757 
1758     // If they have different execution orders they may overlap.
1759     MemDependencyChecker analyzer;
1760     StmtPtr stmt = For::make(
1761         x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x})));
1762     stmt->accept(&analyzer);
1763 
1764     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1765   }
1766 
1767   {
1768     /* for (int x = 0; x < 10; x++) {
1769      *   A[x * 2] = A[19 - x * 2];
1770      * }
1771      */
1772 
1773     // Or they may not, depending on their start offset and strides.
1774     MemDependencyChecker analyzer;
1775     StmtPtr stmt = For::make(
1776         x,
1777         0,
1778         10,
1779         Store::make(a, {x * 2}, Load::make(a, {ExprHandle(19) - x * 2})));
1780     stmt->accept(&analyzer);
1781 
1782     ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1783   }
1784 
1785   {
1786     /* for (int x = 0; x < 10; x++) {
1787      *   A[x / 2] = A[x / 2];
1788      * }
1789      */
1790 
1791     // If the stride is not monotonic, they overlap.
1792 
1793     MemDependencyChecker analyzer;
1794     StmtPtr stmt =
1795         For::make(x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2})));
1796     stmt->accept(&analyzer);
1797 
1798     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1799   }
1800 
1801   {
1802     /* for (int x = 0; x < 10; x++) {
1803      *   A[x / 2] = A[x / 2] + 1;
1804      * }
1805      */
1806 
1807     // If the stride is not monotonic, they overlap - even with an offset.
1808     MemDependencyChecker analyzer;
1809     StmtPtr stmt = For::make(
1810         x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1})));
1811     stmt->accept(&analyzer);
1812 
1813     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1814   }
1815 
1816   {
1817     /* for (int x = 0; x < 10; x++) {
1818      *   A[x % 2] = A[x % 2];
1819      * }
1820      */
1821 
1822     // Mod too...
1823 
1824     analysis::MemDependencyChecker analyzer;
1825     StmtPtr stmt = For::make(
1826         x,
1827         0,
1828         10,
1829         Store::make(a, {Mod::make(x, 2)}, Load::make(a, {Mod::make(x, 2)})));
1830     stmt->accept(&analyzer);
1831 
1832     ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1833   }
1834 
1835   {
1836     /* for (int x = y; x < z; x++) {
1837      *   A[x] = A[x + 1];
1838      * }
1839      */
1840 
1841     // Still works with symbolic loop extents.
1842 
1843     {
1844       MemDependencyChecker analyzer;
1845       StmtPtr stmt =
1846           For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1})));
1847       stmt->accept(&analyzer);
1848 
1849       ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1850     }
1851 
1852     {
1853       MemDependencyChecker analyzer;
1854       analyzer.allowLoopExecutionOrderAnalysis();
1855       StmtPtr stmt =
1856           For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1})));
1857       stmt->accept(&analyzer);
1858 
1859       ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1860     }
1861   }
1862 }
1863 
1864 // Verify that a strided access still works.
1865 // TODO: actually this only works because of the size of the ranges, revisit
1866 // this test after strided overlap is implemented.
TEST(MemDependency,MemDependencyCheckerLoopDistinctStrides)1867 TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) {
1868   BufHandle a("A", {20}, kInt);
1869   BufHandle b("B", {20}, kInt);
1870   VarHandle x("x", kInt);
1871   VarHandle y("y", kInt);
1872 
1873   using namespace analysis;
1874   MemDependencyChecker analyzer({a.node()}, {b.node()});
1875   StmtPtr stmt = Block::make(
1876       {For::make(
1877            x, 0, 10, Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))),
1878        For::make(x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2})))
1879 
1880       });
1881   stmt->accept(&analyzer);
1882 
1883   // Sanity check output depends on input.
1884   ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
1885 
1886   // Output has 2 dependencies... the store in each loop.
1887   auto outputAccess = analyzer.output(b.node());
1888   ASSERT_EQ(outputAccess->dependencies().size(), 2);
1889 }
1890 
1891 /* TODO(nickg) - this test will fail due to the lack of stride math in Bound
1892 TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) {
1893   BufHandle a("A", {20}, kInt);
1894   BufHandle b("B", {20}, kInt);
1895   BufHandle c("C", {10}, kInt);
1896   VarHandle x("x", kInt);
1897   VarHandle y("y", kInt);
1898 
1899   {
1900     analysis::MemDependencyChecker analyzer({a.node()}, {c.node()});
1901     StmtPtr stmt = Block::make(
1902         {For::make(
1903              x,
1904              0,
1905              10,
1906              Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))),
1907          For::make(
1908              x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))),
1909          For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))
1910 
1911         });
1912     stmt->accept(&analyzer);
1913 
1914     std::cout << *stmt << "\n";
1915     for (auto& wi : analyzer.getHistory()) {
1916       wi->print();
1917     }
1918   }
1919 }*/
1920 
1921 // analysis on Stmts using Cond.
TEST(MemDependency,MemDependencyCheckerLoopBoundsCond)1922 TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) {
1923   BufHandle a("A", {10}, kInt);
1924   BufHandle b("B", {10}, kInt);
1925   BufHandle c("C", {10}, kInt);
1926   VarHandle x("x", kInt);
1927   VarHandle y("y", kInt);
1928 
1929   using namespace analysis;
1930 
1931   {
1932     /* for (int x = 0; x < 10; x++) {
1933      *   C[x] = A[x];
1934      * }
1935      * if (y<5 ? 1 : 0) {
1936      *   C[0] = (B[0]) + 1;
1937      * } else {
1938      *   C[0] = (B[1]) + 1;
1939      * }
1940      */
1941 
1942     // Future usages may depend on accesses in both branches of a condition.
1943 
1944     MemDependencyChecker analyzer({a, b}, {c});
1945     StmtPtr stmt = Block::make(
1946         {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
1947          Cond::make(
1948              CompareSelect::make(y, 5, CompareSelectOperation::kLT),
1949              Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)),
1950              Store::make(c, {0}, Add::make(Load::make(b, {1}), 1)))});
1951 
1952     stmt->accept(&analyzer);
1953 
1954     // Output C should have 3 dependencies, each of the three stores.
1955     auto outputAccess = analyzer.output(c.node());
1956     ASSERT_NE(outputAccess, nullptr);
1957     ASSERT_EQ(outputAccess->dependencies().size(), 3);
1958 
1959     // C depends indirectly on A and B.
1960     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
1961     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
1962   }
1963 
1964   {
1965     /* for (int x = 0; x < 10; x++) {
1966      *   C[x] = A[x];
1967      * }
1968      * if (y<5 ? 1 : 0) {
1969      *   for (int x = 0; x < 10; x++) {
1970      *     C[x] = B[x];
1971      *   }
1972      * } else {
1973      *   for (int x = 0; x < 10; x++) {
1974      *     C[x] = (B[x]) + 1;
1975      *   }
1976      * }
1977      */
1978 
1979     // Future usages may depend on accesses in both branches of a condition.
1980 
1981     MemDependencyChecker analyzer({a, b}, {c});
1982     StmtPtr stmt = Block::make(
1983         {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
1984          Cond::make(
1985              CompareSelect::make(y, 5, CompareSelectOperation::kLT),
1986              For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))),
1987              For::make(
1988                  x,
1989                  0,
1990                  10,
1991                  Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))});
1992 
1993     stmt->accept(&analyzer);
1994 
1995     // Output C should have 3 dependencies, each of the three stores.
1996     auto outputAccess = analyzer.output(c.node());
1997     ASSERT_NE(outputAccess, nullptr);
1998     ASSERT_EQ(outputAccess->dependencies().size(), 3);
1999 
2000     // TODO(nickg): actually since the true and false branch cover the total
2001     // range of the first store this should have 2 dependencies, but we don't
2002     // do that yet.
2003 
2004     // C depends indirectly on A and B.
2005     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2006     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2007   }
2008 
2009   {
2010     /* for (int x = 0; x < 10; x++) {
2011      *   C[x] = A[x];
2012      * }
2013      * if (y<5 ? 1 : 0) {
2014      *   for (int x = 0; x < 10; x++) {
2015      *     C[x] = (B[x]) + 1;
2016      *   }
2017      * }
2018      */
2019 
2020     // Only has true branch.
2021 
2022     MemDependencyChecker analyzer({a, b}, {c});
2023     StmtPtr stmt = Block::make(
2024         {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2025          Cond::make(
2026              CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2027              For::make(
2028                  x,
2029                  0,
2030                  10,
2031                  Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))),
2032              nullptr)});
2033 
2034     stmt->accept(&analyzer);
2035 
2036     // Output C should have 3 dependencies, each of the three stores.
2037     auto outputAccess = analyzer.output(c.node());
2038     ASSERT_NE(outputAccess, nullptr);
2039     ASSERT_EQ(outputAccess->dependencies().size(), 2);
2040 
2041     // C depends indirectly on A and B.
2042     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2043     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2044   }
2045 
2046   {
2047     /* for (int x = 0; x < 10; x++) {
2048      *   C[x] = A[x];
2049      * }
2050      * if (y<5 ? 1 : 0) {
2051      * } else {
2052      *   for (int x = 0; x < 10; x++) {
2053      *     C[x] = (B[x]) + 1;
2054      *   }
2055      * }
2056      */
2057 
2058     // Only has false branch.
2059 
2060     MemDependencyChecker analyzer({a, b}, {c});
2061     StmtPtr stmt = Block::make(
2062         {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2063          Cond::make(
2064              CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2065              nullptr,
2066              For::make(
2067                  x,
2068                  0,
2069                  10,
2070                  Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))});
2071 
2072     stmt->accept(&analyzer);
2073 
2074     // Output C should have 3 dependencies, each of the three stores.
2075     auto outputAccess = analyzer.output(c.node());
2076     ASSERT_NE(outputAccess, nullptr);
2077     ASSERT_EQ(outputAccess->dependencies().size(), 2);
2078 
2079     // C depends indirectly on A and B.
2080     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2081     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2082   }
2083 
2084   {
2085     /* for (int x = 0; x < 10; x++) {
2086      *   C[x] = A[x];
2087      * }
2088      * if (C[0]<5 ? 1 : 0) {
2089      *   C[0] = 5;
2090      * }
2091      */
2092 
2093     // Cond's Condition depends on a previous access.
2094 
2095     MemDependencyChecker analyzer({a}, {c});
2096     StorePtr initStore = Store::make(c, {x}, Load::make(a, {x}));
2097     ExprHandle conditionalLoad = Load::make(c, {0});
2098     StmtPtr stmt = Block::make(
2099         {For::make(x, 0, 10, initStore),
2100          Cond::make(
2101              CompareSelect::make(
2102                  conditionalLoad, 5, CompareSelectOperation::kLT),
2103              Store::make(c, {0}, 5),
2104              nullptr)});
2105 
2106     stmt->accept(&analyzer);
2107 
2108     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2109 
2110     ASSERT_TRUE(analyzer.dependsDirectly(conditionalLoad.node(), initStore));
2111     ASSERT_FALSE(analyzer.dependsDirectly(conditionalLoad.node(), a.node()));
2112     ASSERT_TRUE(analyzer.dependsIndirectly(conditionalLoad.node(), a.node()));
2113   }
2114 }
2115 
2116 // Stmts using IfThenElse.
TEST(MemDependency,MemDependencyCheckerIfThenElse)2117 TEST(MemDependency, MemDependencyCheckerIfThenElse) {
2118   BufHandle a("A", {10}, kInt);
2119   BufHandle b("B", {10}, kInt);
2120   BufHandle c("C", {10}, kInt);
2121   VarHandle x("x", kInt);
2122   VarHandle y("y", kInt);
2123 
2124   using namespace analysis;
2125 
2126   {
2127     /* for (int x = 0; x < 10; x++) {
2128      *   C[x] = A[x];
2129      * }
2130      * C[0] = (y < 5 ? (B[0]) + 1 : (B[1]) + 1;
2131      */
2132 
2133     // Future usages may depend on accesses in both branches of a condition.
2134 
2135     MemDependencyChecker analyzer({a, b}, {c});
2136     StorePtr ifStore = Store::make(
2137         c,
2138         {0},
2139         IfThenElse::make(
2140             CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2141             Add::make(Load::make(b, {0}), 1),
2142             Add::make(Load::make(b, {1}), 1)));
2143     StmtPtr stmt = Block::make(
2144         {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2145          ifStore});
2146 
2147     stmt->accept(&analyzer);
2148 
2149     // Output C should have 2 dependencies, each of the two stores.
2150     auto outputAccess = analyzer.output(c.node());
2151     ASSERT_NE(outputAccess, nullptr);
2152     ASSERT_EQ(outputAccess->dependencies().size(), 2);
2153 
2154     // Now we need to check the Store containing the IfThenElse.
2155     auto ifStoreAccess = analyzer.accessFor(ifStore);
2156 
2157     // It should have 2 dependencies.
2158     ASSERT_EQ(ifStoreAccess->dependencies().size(), 2);
2159 
2160     // C depends indirectly on A and B.
2161     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2162     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2163   }
2164 
2165   {
2166     /* for (int x = 0; x < 10; x++) {
2167      *   C[x] = A[x];
2168      * }
2169      * C[0] = (y < 5 ? (B[0]) + 1 : 42;
2170      */
2171 
2172     // If the load appears in only one side of an IfThenElse the output may be
2173     // dependent on it.
2174 
2175     MemDependencyChecker analyzer({a, b}, {c});
2176     StorePtr ifStore = Store::make(
2177         c,
2178         {0},
2179         IfThenElse::make(
2180             CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2181             Add::make(Load::make(b, {0}), 1),
2182             42));
2183     StmtPtr stmt = Block::make(
2184         {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2185          ifStore});
2186 
2187     stmt->accept(&analyzer);
2188 
2189     // C depends indirectly on A and B.
2190     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2191     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2192   }
2193 
2194   {
2195     /* for (int x = 0; x < 10; x++) {
2196      *   C[x] = (x < 5 ? B[x] : A[x];
2197      * }
2198      */
2199 
2200     // In this case C is dependent on both A and B.
2201 
2202     // TODO: in cases like this it would be possible to split the range of B
2203     // into two bounds, one dependent on A and one dependent on B. We'd need to
2204     // examine conditions relative to previously encountered loop variables. I'm
2205     // uncertain if this would be helpful.
2206 
2207     MemDependencyChecker analyzer({a, b}, {c});
2208     StorePtr ifStore = Store::make(
2209         c,
2210         {0},
2211         IfThenElse::make(
2212             CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2213             Load::make(b, {x}),
2214             Load::make(a, {x})));
2215     StmtPtr stmt = Block::make({For::make(x, 0, 10, ifStore)});
2216 
2217     stmt->accept(&analyzer);
2218 
2219     // C depends indirectly on A and B.
2220     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2221     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2222   }
2223 }
2224 
2225 // Cutting a loop with single elem writes
TEST(MemDependency,MemDependencyCheckerCutLoop)2226 TEST(MemDependency, MemDependencyCheckerCutLoop) {
2227   BufHandle a("A", {10}, kInt);
2228   BufHandle b("B", {10}, kInt);
2229   VarHandle x("x", kInt);
2230 
2231   using namespace analysis;
2232 
2233   {
2234     /* for (int x = 0; x < 10; x++) {
2235      *   B[x] = A[x];
2236      * }
2237      * B[5] = 100;
2238      */
2239 
2240     // Cutting a loop with single element writes.
2241 
2242     MemDependencyChecker analyzer({a}, {b});
2243     StmtPtr stmt = Block::make(
2244         {For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))),
2245          Store::make(b, {5}, 100)});
2246 
2247     stmt->accept(&analyzer);
2248 
2249     // Output depends on input.
2250     ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2251 
2252     // Output has 2 dependencies.
2253     auto outputAccess = analyzer.output(b.node());
2254     ASSERT_NE(outputAccess, nullptr);
2255     ASSERT_EQ(outputAccess->dependencies().size(), 2);
2256   }
2257 
2258   {
2259     /* for (int x = 0; x < 10; x++) {
2260      *   B[x] = A[x];
2261      * }
2262      * for (int x = 4; x < 7; x++) {
2263      *   B[x] = B[x] + 3;
2264      * }
2265      * B[5] = 100;
2266      * B[6] = 101;
2267      * B[7] = 102;
2268      */
2269 
2270     // Cutting a loop with a smaller loop but then totally overlap that second
2271     // loop with one element writes.
2272 
2273     MemDependencyChecker analyzer({a}, {b});
2274     ForPtr firstLoop =
2275         For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})));
2276     StorePtr secondStore =
2277         Store::make(b, {x}, Add::make(Load::make(b, {x}), 1));
2278     ForPtr secondLoop = For::make(x, 4, 7, secondStore);
2279 
2280     StmtPtr stmt = Block::make(
2281         {firstLoop,
2282          secondLoop,
2283          Store::make(b, {4}, 100),
2284          Store::make(b, {5}, 101),
2285          Store::make(b, {6}, 102)});
2286 
2287     stmt->accept(&analyzer);
2288 
2289     // Output depends on input.
2290     ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2291 
2292     // Output has 4 dependencies.
2293     auto outputAccess = analyzer.output(b.node());
2294     ASSERT_NE(outputAccess, nullptr);
2295     ASSERT_EQ(outputAccess->dependencies().size(), 4);
2296 
2297     // Second loop depends on first loop.
2298     ASSERT_TRUE(analyzer.dependsDirectly(secondLoop, firstLoop));
2299 
2300     // Output does not depend on second loop or store.
2301     ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondLoop));
2302     ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondStore));
2303   }
2304 }
2305 
2306 // Dynamic shapes (load in indices).
TEST(MemDependency,MemDependencyCheckerDynamicShapes)2307 TEST(MemDependency, MemDependencyCheckerDynamicShapes) {
2308   BufHandle a("A", {100}, kInt);
2309   BufHandle b("B", {100}, kInt);
2310   BufHandle c("C", {100}, kInt);
2311   VarHandle x("x", kInt);
2312 
2313   using namespace analysis;
2314 
2315   auto CB = [](ExprHandle s, ExprHandle e) {
2316     return Bound(s.node(), e.node());
2317   };
2318 
2319   auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
2320     return indexBoundsEquals(x, y);
2321   };
2322 
2323   {
2324     /* for (int x = 0; x < B[0]; x++) {
2325      *   C[x] = A[x];
2326      * }
2327      */
2328     MemDependencyChecker analyzer({a, b}, {c});
2329     StmtPtr stmt = Block::make({For::make(
2330         x, 0, Load::make(b, {0}), Store::make(c, {x}, Load::make(a, {x})))});
2331 
2332     stmt->accept(&analyzer);
2333 
2334     /*  0. Input: B[(0, 99)] - dependents: 2
2335      *  1. Input: A[(0, 99)] - dependents: 3
2336      *  2. Load: B[(0, 0)] - depends on: 0  - dependents: 3 4
2337      *  3. Load: A[(0, (B[0]) - 1)] - depends on: 1 2  - dependents: 4
2338      *  4. Store: C[(0, (B[0]) - 1)] - depends on: 2 3  - dependents: 5
2339      *  5. Output: C[(0, 99)] - depends on: 4
2340      */
2341 
2342     // Output dependent on A input.
2343     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2344     // Also dependent on B input to determine the size of the region written.
2345     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2346 
2347     auto history = analyzer.getHistory();
2348     ASSERT_EQ(history.size(), 6);
2349 
2350     // The accesses in the loop depend on the load in the stop condition.
2351     ASSERT_TRUE(history[4]->hasDependency(history[2]));
2352     ASSERT_TRUE(history[3]->hasDependency(history[2]));
2353 
2354     // Make a load from B to compare against.
2355     ExprHandle loadFromB = Load::make(b, {0});
2356 
2357     ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, loadFromB - 1)}));
2358     ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, loadFromB - 1)}));
2359   }
2360 
2361   {
2362     /* for (int x = B[0]; x < B[1]; x++) {
2363      *   C[x] = A[x];
2364      * }
2365      */
2366     MemDependencyChecker analyzer({a, b}, {c});
2367     StmtPtr stmt = Block::make({For::make(
2368         x,
2369         Load::make(b, {0}),
2370         Load::make(b, {1}),
2371         Store::make(c, {x}, Load::make(a, {x})))});
2372 
2373     stmt->accept(&analyzer);
2374 
2375     /*  0. Input: B[(0, 99)] - dependents: 2 3
2376      *  1. Input: A[(0, 99)] - dependents: 4
2377      *  2. Load: B[(0, 0)] - depends on: 0  - dependents: 4 5
2378      *  3. Load: B[(1, 1)] - depends on: 0  - dependents: 4 5
2379      *  4. Load: A[(B[0], (B[1]) - 1)] - depends on: 1 2 3  - dependents: 5
2380      *  5. Store: C[(B[0], (B[1]) - 1)] - depends on: 2 3 4  - dependents: 6
2381      *  6. Output: C[(0, 99)] - depends on: 5
2382      */
2383 
2384     // Sanity check output depends on input.
2385     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2386     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2387 
2388     auto history = analyzer.getHistory();
2389     ASSERT_EQ(history.size(), 7);
2390 
2391     // The accesses in the loop depend on the load in the start condition.
2392     ASSERT_TRUE(history[5]->hasDependency(history[2]));
2393     ASSERT_TRUE(history[4]->hasDependency(history[2]));
2394 
2395     // also the stop condition.
2396     ASSERT_TRUE(history[5]->hasDependency(history[3]));
2397     ASSERT_TRUE(history[4]->hasDependency(history[3]));
2398 
2399     // Make loads from B to compare against.
2400     ExprHandle loadFromB0 = Load::make(b, {0});
2401     ExprHandle loadFromB1 = Load::make(b, {1});
2402     ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromB0, loadFromB1 - 1)}));
2403     ASSERT_TRUE(EQ(history[5]->bounds(), {CB(loadFromB0, loadFromB1 - 1)}));
2404   }
2405 
2406   {
2407     /* for (int x = 0; x < 10; x++) {
2408      *   C[x] = A[B[x]];
2409      * }
2410      */
2411     MemDependencyChecker analyzer({a, b}, {c});
2412     StmtPtr stmt = Block::make({For::make(
2413         x, 0, 10, Store::make(c, {x}, Load::make(a, {Load::make(b, {x})})))});
2414 
2415     stmt->accept(&analyzer);
2416 
2417     /*  0. Input: B[(0, 99)] - dependents: 2
2418      *  1. Input: A[(0, 99)] - dependents: 3
2419      *  2. Load: B[(0, 9)] - depends on: 0  - dependents: 3 4
2420      *  3. Load: A[(B[0], B[9])] - depends on: 1 2  - dependents: 4
2421      *  4. Store: C[(0, 9)] - depends on: 2 3  - dependents: 5
2422      *  5. Output: C[(0, 99)] - depends on: 4
2423      */
2424 
2425     // Sanity check output depends on input.
2426     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2427     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2428 
2429     auto history = analyzer.getHistory();
2430     ASSERT_EQ(history.size(), 6);
2431 
2432     // The store depends on both loads, the load of A depends on the load of B.
2433     ASSERT_TRUE(history[4]->hasDependency(history[2]));
2434     ASSERT_TRUE(history[4]->hasDependency(history[3]));
2435 
2436     ASSERT_TRUE(history[3]->hasDependency(history[2]));
2437 
2438     // The loads in the indices depend on the relevant input buffer.
2439     ASSERT_TRUE(history[3]->hasDependency(history[1]));
2440     ASSERT_TRUE(history[2]->hasDependency(history[0]));
2441 
2442     // The load from B has the loop bounds.
2443     ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
2444 
2445     // The load from A has bounds B[0] to B[9].
2446     ExprHandle loadFromB0 = Load::make(b, {0});
2447     ExprHandle loadFromB9 = Load::make(b, {9});
2448     ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromB0, loadFromB9)}));
2449   }
2450 
2451   {
2452     /* for (int x = 0; x < 10; x++) {
2453      *   C[B[x]] = A[x];
2454      * }
2455      */
2456     MemDependencyChecker analyzer({a, b}, {c});
2457     StmtPtr stmt = Block::make({For::make(
2458         x, 0, 10, Store::make(c, {Load::make(b, {x})}, Load::make(a, {x})))});
2459 
2460     stmt->accept(&analyzer);
2461 
2462     /*  0. Input: B[(0, 99)] - dependents: 3
2463      *  1. Input: A[(0, 99)] - dependents: 2
2464      *  2. Load: A[(0, 9)] - depends on: 1  - dependents: 4
2465      *  3. Load: B[(0, 9)] - depends on: 0  - dependents: 4
2466      *  4. Store: C[(B[0], B[9])] - depends on: 2 3  - dependents: 5
2467      *  5. Output: C[(0, 99)] - depends on: 4
2468      */
2469     // Sanity check output depends on input.
2470     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2471     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2472 
2473     auto history = analyzer.getHistory();
2474     ASSERT_EQ(history.size(), 6);
2475 
2476     // The store depends on both loads, neither load is dependent.
2477     ASSERT_TRUE(history[4]->hasDependency(history[2]));
2478     ASSERT_TRUE(history[4]->hasDependency(history[3]));
2479 
2480     ASSERT_FALSE(history[3]->hasDependency(history[2]));
2481     ASSERT_FALSE(history[2]->hasDependency(history[3]));
2482 
2483     // The loads each depend on their relevant input. (but accesses are in a
2484     // different order than the last case).
2485     ASSERT_TRUE(history[3]->hasDependency(history[0]));
2486     ASSERT_TRUE(history[2]->hasDependency(history[1]));
2487 
2488     // The load from B has the loop bounds.
2489     ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, 9)}));
2490 
2491     // And so does the load from A.
2492     ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
2493   }
2494 
2495   {
2496     /* for (int x = 0; x < 10; x++) {
2497      *   C[B[A[x]]] = x;
2498      * }
2499      */
2500     MemDependencyChecker analyzer({a, b}, {c});
2501     StmtPtr stmt = Block::make({For::make(
2502         x, 0, 10, Store::make(c, {Load::make(b, {Load::make(a, {x})})}, x))});
2503 
2504     stmt->accept(&analyzer);
2505 
2506     /*  0. Input: B[(0, 99)] - dependents: 3
2507      *  1. Input: A[(0, 99)] - dependents: 2
2508      *  2. Load: A[(0, 9)] - depends on: 1  - dependents: 3 4
2509      *  3. Load: B[(A[0], A[9])] - depends on: 0 2  - dependents: 4
2510      *  4. Store: C[(B[A[0]], B[A[9]])] - depends on: 2 3  - dependents: 5
2511      *  5. Output: C[(0, 99)] - depends on: 4
2512      */
2513 
2514     // Sanity check output depends on input.
2515     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2516     ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2517 
2518     auto history = analyzer.getHistory();
2519     ASSERT_EQ(history.size(), 6);
2520 
2521     // The store depends on both loads.
2522     ASSERT_TRUE(history[4]->hasDependency(history[2]));
2523     ASSERT_TRUE(history[4]->hasDependency(history[3]));
2524 
2525     // The outer load depends on the inner.
2526     ASSERT_TRUE(history[3]->hasDependency(history[2]));
2527 
2528     // The loads each depend on their relevant input. (but accesses are in a
2529     // different order than the last case).
2530     ASSERT_TRUE(history[3]->hasDependency(history[0]));
2531     ASSERT_TRUE(history[2]->hasDependency(history[1]));
2532 
2533     // The load from A has the loop bounds.
2534     ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
2535     // The load from B as bounds A[0] to A[9].
2536     ExprHandle loadFromA0 = Load::make(a, {0});
2537     ExprHandle loadFromA9 = Load::make(a, {9});
2538     ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromA0, loadFromA9)}));
2539 
2540     // The store has bounds of B[A[0]] to B[A[9]].
2541     ExprHandle loadFromBA0 = Load::make(b, {loadFromA0});
2542     ExprHandle loadFromBA9 = Load::make(b, {loadFromA9});
2543     ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromBA0, loadFromBA9)}));
2544   }
2545 }
2546 
2547 // Verify multi dimensional bounds work.
TEST(MemDependency,MemDependencyCheckerMultiDim)2548 TEST(MemDependency, MemDependencyCheckerMultiDim) {
2549   int M = 10, N = 9, K = 12;
2550   BufHandle a("A", {M, N, K}, kInt);
2551   BufHandle b("B", {M, N, K}, kInt);
2552   BufHandle c("C", {M, K}, kInt);
2553   VarHandle x("x", kInt);
2554   VarHandle y("y", kInt);
2555   VarHandle z("z", kInt);
2556 
2557   using namespace analysis;
2558 
2559   auto CB = [](ExprHandle s, ExprHandle e) {
2560     return Bound(s.node(), e.node());
2561   };
2562 
2563   auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
2564     return indexBoundsEquals(x, y);
2565   };
2566 
2567   {
2568     /* for (int x = 0; x < 10; x++) {
2569      *   for (int y = 0; y < 9; y++) {
2570      *     for (int z = 0; z < 12; z++) {
2571      *       B[x, y, z] = A[x, y, z];
2572      *     }
2573      *   }
2574      * }
2575      */
2576     // Full range.
2577 
2578     MemDependencyChecker analyzer({a}, {b});
2579     StmtPtr stmt = Block::make({For::make(
2580         x,
2581         0,
2582         M,
2583         For::make(
2584             y,
2585             0,
2586             N,
2587             For::make(
2588                 z,
2589                 0,
2590                 K,
2591                 Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))});
2592 
2593     stmt->accept(&analyzer);
2594 
2595     // Sanity test: Output depends on input.
2596     ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2597 
2598     // 4 accesses: input, load, store, output.
2599     auto history = analyzer.getHistory();
2600     ASSERT_EQ(history.size(), 4);
2601 
2602     // Simple chain from input to output.
2603     ASSERT_TRUE(history[3]->hasDependency(history[2]));
2604     ASSERT_TRUE(history[2]->hasDependency(history[1]));
2605     ASSERT_TRUE(history[1]->hasDependency(history[0]));
2606 
2607     ASSERT_TRUE(
2608         EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2609     ASSERT_TRUE(
2610         EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2611   }
2612 
2613   {
2614     /* for (int x = 0; x < 5; x++) {
2615      *   for (int y = 0; y < 5; y++) {
2616      *     for (int z = 0; z < 5; z++) {
2617      *       B[x, y, z] = A[x, y, z];
2618      *     }
2619      *   }
2620      * }
2621      */
2622     // Partial range.
2623 
2624     MemDependencyChecker analyzer({a}, {b});
2625     StmtPtr stmt = Block::make({For::make(
2626         x,
2627         0,
2628         5,
2629         For::make(
2630             y,
2631             0,
2632             5,
2633             For::make(
2634                 z,
2635                 0,
2636                 5,
2637                 Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))});
2638 
2639     stmt->accept(&analyzer);
2640 
2641     // Sanity test: Output depends on input.
2642     ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2643 
2644     // 4 accesses: input, load, store, output.
2645     auto history = analyzer.getHistory();
2646     ASSERT_EQ(history.size(), 4);
2647 
2648     // Simple chain from input to output.
2649     ASSERT_TRUE(history[3]->hasDependency(history[2]));
2650     ASSERT_TRUE(history[2]->hasDependency(history[1]));
2651     ASSERT_TRUE(history[1]->hasDependency(history[0]));
2652 
2653     ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)}));
2654     ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)}));
2655   }
2656 
2657   {
2658     /* for (int x = 0; x < 10; x++) {
2659      *   for (int y = 0; y < 12; y++) {
2660      *     B[x, 0, y] = A[x, 0, y];
2661      *   }
2662      * }
2663      */
2664 
2665     // Partial loops.
2666 
2667     MemDependencyChecker analyzer({a}, {b});
2668     StmtPtr stmt = Block::make({For::make(
2669         x,
2670         0,
2671         N,
2672         For::make(
2673             y, 0, K, Store::make(b, {x, 0, y}, Load::make(a, {x, 0, y}))))});
2674 
2675     stmt->accept(&analyzer);
2676 
2677     // Sanity test: Output depends on input.
2678     ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2679 
2680     // 4 accesses: input, load, store, output.
2681     auto history = analyzer.getHistory();
2682     ASSERT_EQ(history.size(), 4);
2683 
2684     // Simple chain from input to output.
2685     ASSERT_TRUE(history[3]->hasDependency(history[2]));
2686     ASSERT_TRUE(history[2]->hasDependency(history[1]));
2687     ASSERT_TRUE(history[1]->hasDependency(history[0]));
2688 
2689     ASSERT_TRUE(
2690         EQ(history[1]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)}));
2691     ASSERT_TRUE(
2692         EQ(history[2]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)}));
2693   }
2694 
2695   {
2696     /* for (int x = 0; x < 10; x++) {
2697      *   for (int y = 0; y < 100; y++) {
2698      *     for (int z = 0; z < 12; z++) {
2699      *       B[x, 0, z] = (A[x, 0, z]) + (C[x, z]);
2700      *     }
2701      *   }
2702      * }
2703      */
2704 
2705     // Loops that don't correspond to an index, bufs with different
2706     // dimensionality.
2707 
2708     MemDependencyChecker analyzer({a, c}, {b});
2709     StmtPtr stmt = Block::make({For::make(
2710         x,
2711         0,
2712         M,
2713         For::make(
2714             y,
2715             0,
2716             100,
2717             For::make(
2718                 z,
2719                 0,
2720                 K,
2721                 Store::make(
2722                     b,
2723                     {x, 0, z},
2724                     Add::make(
2725                         Load::make(a, {x, 0, z}), Load::make(c, {x, z}))))))});
2726 
2727     stmt->accept(&analyzer);
2728 
2729     // Sanity test: Output depends on both inputs.
2730     ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2731     ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), c.node()));
2732 
2733     // 6 accesses: 2 inputs, 2 loads, store, output.
2734     auto history = analyzer.getHistory();
2735     ASSERT_EQ(history.size(), 6);
2736 
2737     // Simple chain from input to output over the A buf.
2738     // history[0] is the C input, history[3] is the load from C.
2739     ASSERT_TRUE(history[5]->hasDependency(history[4]));
2740     ASSERT_TRUE(history[4]->hasDependency(history[2]));
2741     ASSERT_TRUE(history[2]->hasDependency(history[1]));
2742     // The store also depends on the load from the C input.
2743     ASSERT_TRUE(history[4]->hasDependency(history[3]));
2744     ASSERT_TRUE(history[3]->hasDependency(history[0]));
2745 
2746     // A Buf accesses.
2747     ASSERT_TRUE(
2748         EQ(history[4]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)}));
2749     ASSERT_TRUE(
2750         EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)}));
2751 
2752     // C buf access.
2753     ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, K - 1)}));
2754   }
2755 
2756   {
2757     /* for (int x = 0; x < 9; x++) {
2758      *   for (int y = 0; y < 10; y++) {
2759      *     for (int z = 0; z < 12; z++) {
2760      *       B[x, 0, 0] = (B[x, y, z]) + (A[x, y, z]);
2761      *     }
2762      *   }
2763      * }
2764      */
2765     // Multi-dim reductions.
2766 
2767     MemDependencyChecker analyzer({a}, {b});
2768     StmtPtr stmt = Block::make({For::make(
2769         x,
2770         0,
2771         M,
2772         For::make(
2773             y,
2774             0,
2775             N,
2776             For::make(
2777                 z,
2778                 0,
2779                 K,
2780                 Store::make(
2781                     b,
2782                     {x, 0, 0},
2783                     Add::make(
2784                         Load::make(b, {x, y, z}),
2785                         Load::make(a, {x, y, z}))))))});
2786 
2787     stmt->accept(&analyzer);
2788 
2789     // Sanity test: Output depends on input.
2790     ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2791 
2792     // 4 accesses: input, 2 loads, store, output.
2793     auto history = analyzer.getHistory();
2794     ASSERT_EQ(history.size(), 5);
2795 
2796     // Simple chain from input to output.
2797     ASSERT_TRUE(history[4]->hasDependency(history[3]));
2798     ASSERT_TRUE(history[3]->hasDependency(history[2]));
2799     ASSERT_TRUE(history[3]->hasDependency(history[1]));
2800     ASSERT_TRUE(history[2]->hasDependency(history[0]));
2801 
2802     // The load from B depends on the store to B.
2803     ASSERT_TRUE(history[1]->hasDependency(history[3]));
2804 
2805     ASSERT_TRUE(
2806         EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2807     ASSERT_TRUE(
2808         EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2809     ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, 0)}));
2810   }
2811 }
2812 
2813 // Various tests using the external Compute/Reduce API.
TEST(MemDependency,MemDependencyCheckerComputeAPI)2814 TEST(MemDependency, MemDependencyCheckerComputeAPI) {
2815   using namespace analysis;
2816 
2817   /* for (int m = 0; m < 4; m++) {
2818    *   for (int n = 0; n < 5; n++) {
2819    *     for (int k = 0; k < 6; k++) {
2820    *       broadcast_add[m, n, k] = (a[m, n]) + (b[n, k]);
2821    *     }
2822    *   }
2823    * }
2824    * for (int m_1 = 0; m_1 < 4; m_1++) {
2825    *   for (int n_1 = 0; n_1 < 5; n_1++) {
2826    *     for (int k_1 = 0; k_1 < 6; k_1++) {
2827    *       d[m_1, n_1, k_1] = (broadcast_add(m_1, n_1, k_1)) + float(1);
2828    *     }
2829    *   }
2830    * }
2831    */
2832 
2833   // Can determine if 2 loops created by Compute are dependent.
2834   BufHandle a_buf("a", {4, 5}, kFloat);
2835   BufHandle b_buf("b", {5, 6}, kFloat);
2836   Tensor c = Compute(
2837       "broadcast_add",
2838       {4, 5, 6},
2839       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2840         return a_buf.load(m, n) + b_buf.load(n, k);
2841       });
2842   Tensor d = Compute(
2843       "d",
2844       {4, 5, 6},
2845       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2846         return c.load(m, n, k) + 1;
2847       });
2848 
2849   LoopNest l({d}, {c, d});
2850 
2851   MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()});
2852 
2853   l.root_stmt()->accept(&analyzer);
2854 
2855   // Sanity test: Output depends on input.
2856   ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node()));
2857   ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node()));
2858 
2859   // Second loop depends on first loop.
2860   auto c_loop = l.getLoopStmtsFor(c)[0];
2861   auto d_loop = l.getLoopStmtsFor(d)[0];
2862   ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop));
2863 }
2864 
TEST(MemDependency,MemDependencyCheckerComputeInline)2865 TEST(MemDependency, MemDependencyCheckerComputeInline) {
2866   using namespace analysis;
2867 
2868   /* for (int m = 0; m < 4; m++) {
2869    *   for (int n = 0; n < 5; n++) {
2870    *     for (int k = 0; k < 6; k++) {
2871    *       d[m, n, k] = ((a[m, n]) + (b[n, k])) + float(1);
2872    *     }
2873    *   }
2874    * }
2875    */
2876 
2877   // Check inlining affects the number of accesses returned.
2878 
2879   BufHandle a_buf("a", {4, 5}, kFloat);
2880   BufHandle b_buf("b", {5, 6}, kFloat);
2881   Tensor c = Compute(
2882       "broadcast_add",
2883       {4, 5, 6},
2884       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2885         return a_buf.load(m, n) + b_buf.load(n, k);
2886       });
2887   Tensor d = Compute(
2888       "d",
2889       {4, 5, 6},
2890       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2891         return c.load(m, n, k) + 1;
2892       });
2893 
2894   LoopNest l({d}, {c, d});
2895   l.computeInline(c.buf());
2896 
2897   MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()});
2898   l.root_stmt()->accept(&analyzer);
2899 
2900   // Sanity test: Output depends on input.
2901   ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node()));
2902   ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node()));
2903 
2904   // broadcast_add tensor should not appear in trace at all.
2905   for (auto& wi : analyzer.getHistory()) {
2906     ASSERT_NE(wi->var(), c.buf()->base_handle());
2907   }
2908 }
2909 
TEST(MemDependency,MemDependencyCheckerComputeSplit)2910 TEST(MemDependency, MemDependencyCheckerComputeSplit) {
2911   using namespace analysis;
2912   // Split an axis, so the number of loops != the number of dimensions.
2913 
2914   BufHandle a_buf("a", {4, 5}, kFloat);
2915   BufHandle b_buf("b", {5, 6}, kFloat);
2916   Tensor c = Compute(
2917       "broadcast_add",
2918       {4, 5, 6},
2919       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2920         return a_buf.load(m, n) + b_buf.load(n, k);
2921       });
2922 
2923   LoopNest l({c});
2924 
2925   MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()});
2926   l.root_stmt()->accept(&analyzer_before);
2927 
2928   l.splitWithTail(l.getLoopStmtsFor(c)[0], 2);
2929 
2930   MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()});
2931   StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
2932   stmt->accept(&analyzer_after);
2933 
2934   // Splitting should not change accesses at all.
2935   auto history_before = analyzer_before.getHistory();
2936   auto history_after = analyzer_after.getHistory();
2937 
2938   ASSERT_EQ(history_before.size(), history_after.size());
2939 
2940   for (size_t i = 0; i < history_before.size(); ++i) {
2941     ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
2942     ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
2943     ASSERT_EQ(
2944         history_before[i]->bounds().size(), history_after[i]->bounds().size());
2945     ASSERT_TRUE(indexBoundsEquals(
2946         history_before[i]->bounds(), history_after[i]->bounds()));
2947     ASSERT_EQ(
2948         history_before[i]->dependencies().size(),
2949         history_after[i]->dependencies().size());
2950     ASSERT_EQ(
2951         history_before[i]->dependents().size(),
2952         history_after[i]->dependents().size());
2953   }
2954 }
2955 
TEST(MemDependency,MemDependencyCheckerComputeReorder)2956 TEST(MemDependency, MemDependencyCheckerComputeReorder) {
2957   using namespace analysis;
2958   // Reorder an axis, so the loop order doesn't match the indexing order.
2959 
2960   BufHandle a_buf("a", {4, 5}, kFloat);
2961   BufHandle b_buf("b", {5, 6}, kFloat);
2962   Tensor c = Compute(
2963       "broadcast_add",
2964       {4, 5, 6},
2965       [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2966         return a_buf.load(m, n) + b_buf.load(n, k);
2967       });
2968 
2969   LoopNest l({c});
2970 
2971   MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()});
2972   l.root_stmt()->accept(&analyzer_before);
2973 
2974   auto loops = l.getLoopStmtsFor(c);
2975   l.reorderAxis(loops[0], loops[1]);
2976 
2977   MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()});
2978   StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
2979   stmt->accept(&analyzer_after);
2980 
2981   // Reordering should not change accesses at all.
2982   auto history_before = analyzer_before.getHistory();
2983   auto history_after = analyzer_after.getHistory();
2984 
2985   ASSERT_EQ(history_before.size(), history_after.size());
2986 
2987   for (size_t i = 0; i < history_before.size(); ++i) {
2988     ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
2989     ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
2990     ASSERT_EQ(
2991         history_before[i]->bounds().size(), history_after[i]->bounds().size());
2992     ASSERT_TRUE(indexBoundsEquals(
2993         history_before[i]->bounds(), history_after[i]->bounds()));
2994     ASSERT_EQ(
2995         history_before[i]->dependencies().size(),
2996         history_after[i]->dependencies().size());
2997     ASSERT_EQ(
2998         history_before[i]->dependents().size(),
2999         history_after[i]->dependents().size());
3000   }
3001 }
3002 
TEST(MemDependency,MemDependencyCheckerComputeReduce)3003 TEST(MemDependency, MemDependencyCheckerComputeReduce) {
3004   using namespace analysis;
3005   /* for (int l2 = 0; l2 < 2; l2++) {
3006    *   for (int n1 = 0; n1 < 3; n1++) {
3007    *     for (int m1 = 0; m1 < 6; m1++) {
3008    *       scale[l2, n1, m1] = (b[l2, n1, m1]) * (a[l2, n1, m1]);
3009    *     }
3010    *   }
3011    * }
3012    * for (int l1 = 0; l1 < 2; l1++) {
3013    *   sum[l1] = float(0);
3014    *   for (int n1_1 = 0; n1_1 < 3; n1_1++) {
3015    *     for (int m1_1 = 0; m1_1 < 6; m1_1++) {
3016    *       sum[l1] = ReduceOp(sum, (sum[l1]) + (scale(l1, n1_1, m1_1)),
3017    *                    out_args={l1}, reduce_args={n1, m1});
3018    *     }
3019    *   }
3020    * }
3021    */
3022 
3023   // Can determine dependencies of a Reduction.
3024 
3025   BufHandle a("a", {2, 3, 6}, kFloat);
3026   BufHandle b("b", {2, 3, 6}, kFloat);
3027 
3028   Tensor c = Compute(
3029       "scale",
3030       {2, 3, 6},
3031       [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
3032         return b.load(l, n, m) * a.load(l, n, m);
3033       });
3034   Tensor d = Reduce("sum", {2}, Sum(), c, {3, 6});
3035   LoopNest l({d}, {c, d});
3036 
3037   MemDependencyChecker analyzer({a.node(), b.node()}, {d.buf()});
3038 
3039   l.root_stmt()->accept(&analyzer);
3040 
3041   // Sanity test: Output depends on input.
3042   ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a.node()));
3043   ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b.node()));
3044 
3045   // Second loop depends on first loop.
3046   auto c_loop = l.getLoopStmtsFor(c)[0];
3047   auto d_loop = l.getLoopStmtsFor(d)[0];
3048   ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop));
3049 
3050   // Reduction depends on both inputs.
3051   auto reduces = NodeFinder<ReduceOp>::find(l.root_stmt());
3052   ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], a.node()));
3053   ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], b.node()));
3054 }
3055 
TEST(MemDependency,MemDependencyCheckerComputeGEMM)3056 TEST(MemDependency, MemDependencyCheckerComputeGEMM) {
3057   int M = 1024;
3058   int N = 1024;
3059   int K = 2048;
3060   using namespace analysis;
3061 
3062   BufHandle AP("A", {M, K}, kFloat);
3063   BufHandle BP("B", {K, N}, kFloat);
3064   Tensor CT = Reduce(
3065       "gemm",
3066       {M, N},
3067       Sum(),
3068       [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
3069         return AP.load(m, k) * BP.load(k, n);
3070       },
3071       {K});
3072   LoopNest loop({CT});
3073 
3074   {
3075     auto const& loops = loop.getLoopStmtsFor(CT);
3076     ForPtr m = loops[0];
3077     loop.splitWithMask(m, 4);
3078   }
3079   {
3080     auto const& loops = loop.getLoopStmtsFor(CT);
3081     ForPtr n = loops[2];
3082     loop.splitWithMask(n, 16);
3083   }
3084   // mo, mi, no, ni, k ->
3085   // mo, no, mi, ni, k
3086   {
3087     auto const& loops = loop.getLoopStmtsFor(CT);
3088     ForPtr mi = loops[1];
3089     ForPtr no = loops[2];
3090     loop.reorderAxis(mi, no);
3091   }
3092   // mo, no, mi, ni, k ->
3093   // mo, no, mi, k, ni
3094   {
3095     auto const& loops = loop.getLoopStmtsFor(CT);
3096     ForPtr ni = loops[3];
3097     ForPtr k = loops[4];
3098     loop.reorderAxis(ni, k);
3099   }
3100   // mo, no, mi, k, ni ->
3101   // mo, no, k, mi, ni
3102   {
3103     auto const& loops = loop.getLoopStmtsFor(CT);
3104     ForPtr mi = loops[2];
3105     ForPtr k = loops[3];
3106     loop.reorderAxis(mi, k);
3107   }
3108   {
3109     auto const& loops = loop.getLoopStmtsFor(CT);
3110     loop.cacheAccesses(CT.buf(), "C_regs", loops[2]);
3111   }
3112 
3113   MemDependencyChecker analyzer_unlowered(
3114       loop.getInputBufs(), loop.getOutputBufs());
3115 
3116   MemDependencyChecker analyzer_lowered(
3117       loop.getInputBufs(), loop.getOutputBufs());
3118 
3119   // Test both unlowered and lowered form.
3120   {
3121     StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt());
3122     stmt->accept(&analyzer_unlowered);
3123 
3124     // Outputs depend on inputs.
3125     ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), AP.node()));
3126     ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), BP.node()));
3127 
3128     // The last write to gemm should cover the total bound of the output.
3129     std::shared_ptr<AccessInfo> outputAccess =
3130         analyzer_unlowered.output(CT.buf());
3131     // A single dependency.
3132     ASSERT_EQ(outputAccess->dependencies().size(), 1);
3133 
3134     // dependencies is a set with 1 element, so can just deref begin().
3135     std::shared_ptr<AccessInfo> gemmStore =
3136         outputAccess->dependencies().begin()->second;
3137     // Check its a store.
3138     ASSERT_EQ(gemmStore->type(), AccessType::Store);
3139 
3140     ASSERT_TRUE(indexBoundsEquals(outputAccess->bounds(), gemmStore->bounds()));
3141 
3142     // Likewise the first read from each input cover the entire range of the
3143     // input.
3144     auto aInput = analyzer_unlowered.input(AP.node());
3145     auto bInput = analyzer_unlowered.input(BP.node());
3146 
3147     // A single dependent each.
3148     ASSERT_EQ(aInput->dependents().size(), 1);
3149     ASSERT_EQ(bInput->dependents().size(), 1);
3150 
3151     // They're both loads.
3152     std::shared_ptr<AccessInfo> aLoad = aInput->dependents().begin()->second;
3153     std::shared_ptr<AccessInfo> bLoad = bInput->dependents().begin()->second;
3154     ASSERT_EQ(aLoad->type(), AccessType::Load);
3155     ASSERT_EQ(bLoad->type(), AccessType::Load);
3156 
3157     ASSERT_TRUE(indexBoundsEquals(aInput->bounds(), aLoad->bounds()));
3158     ASSERT_TRUE(indexBoundsEquals(bInput->bounds(), bLoad->bounds()));
3159   }
3160 
3161   loop.prepareForCodegen();
3162   SimpleIREvaluator cg(loop.root_stmt(), {AP, BP, CT});
3163 
3164   // now check lowered dependency graph.
3165   {
3166     StmtPtr stmt = IRSimplifier::simplify(cg.stmt());
3167     stmt->accept(&analyzer_lowered);
3168 
3169     // Lowering will change the dimensionality of all bounds due to index
3170     // flattening and will insert Allocates and Frees.
3171 
3172     auto history_before = analyzer_unlowered.getHistory();
3173     auto history_after = analyzer_lowered.getHistory();
3174 
3175     ASSERT_EQ(history_before.size() + 2, history_after.size());
3176 
3177     // Filter out the alloc/free;
3178     auto isAllocFree = [](const auto& info) {
3179       return info->type() == AccessType::Alloc ||
3180           info->type() == AccessType::Free;
3181     };
3182     history_after.erase(
3183         std::remove_if(history_after.begin(), history_after.end(), isAllocFree),
3184         history_after.end());
3185 
3186     ASSERT_EQ(history_before.size(), history_after.size());
3187 
3188     for (size_t i = 0; i < history_before.size(); ++i) {
3189       ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
3190       ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
3191 
3192       if (history_before[i]->dependencies().size() !=
3193           history_after[i]->dependencies().size()) {
3194         // Must depend on an Alloc.
3195         ASSERT_TRUE(std::any_of(
3196             history_after[i]->dependencies().begin(),
3197             history_after[i]->dependencies().end(),
3198             [](const auto& pair) {
3199               return pair.second->type() == AccessType::Alloc;
3200             }));
3201 
3202         ASSERT_EQ(
3203             history_before[i]->dependencies().size() + 1,
3204             history_after[i]->dependencies().size());
3205       }
3206 
3207       if (history_before[i]->dependents().size() !=
3208           history_after[i]->dependents().size()) {
3209         // Must depend on an Free.
3210         ASSERT_TRUE(std::any_of(
3211             history_after[i]->dependents().begin(),
3212             history_after[i]->dependents().end(),
3213             [](const auto& pair) {
3214               return pair.second->type() == AccessType::Free;
3215             }));
3216 
3217         ASSERT_EQ(
3218             history_before[i]->dependents().size() + 1,
3219             history_after[i]->dependents().size());
3220       }
3221 
3222       // Inputs and outputs are not flattened, only accesses.
3223       if (history_before[i]->type() == AccessType::Input ||
3224           history_before[i]->type() == AccessType::Output) {
3225         ASSERT_EQ(
3226             history_before[i]->bounds().size(),
3227             history_after[i]->bounds().size());
3228         ASSERT_TRUE(indexBoundsEquals(
3229             history_before[i]->bounds(), history_after[i]->bounds()));
3230       } else {
3231         ASSERT_EQ(history_after[i]->bounds().size(), 1);
3232         ExprPtr flat_bounds = alloc<IntImm>(1);
3233 
3234         for (auto& b : history_before[i]->bounds()) {
3235           flat_bounds =
3236               alloc<Mul>(flat_bounds, alloc<Add>(b.end, alloc<IntImm>(1)));
3237 
3238           // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
3239           ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start));
3240         }
3241 
3242         flat_bounds = IRSimplifier::simplify(flat_bounds);
3243         ExprPtr after_bounds = IRSimplifier::simplify(
3244             alloc<Add>(history_after[i]->bounds()[0].end, alloc<IntImm>(1)));
3245         ASSERT_TRUE(exprEquals(flat_bounds, after_bounds));
3246       }
3247     }
3248   }
3249 }
3250 
3251 } // namespace jit
3252 } // namespace torch
3253