xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_memory_dag.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/passes/utils/memory_dag.h>
5 
6 namespace torch {
7 namespace jit {
8 
TEST(MemoryDAGTest,Basic)9 TEST(MemoryDAGTest, Basic) {
10   auto graph = std::make_shared<Graph>();
11   const Value* aValue = graph->addInput();
12   const Value* bValue = graph->addInput();
13   const Value* cValue = graph->addInput();
14   const Value* dValue = graph->addInput();
15   const Value* eValue = graph->addInput();
16   const Value* fValue = graph->addInput();
17   const Value* gValue = graph->addInput();
18 
19   {
20     // a <- b <- c
21     //      b <- d
22     // a <- e
23     // f <- e
24     // g is by itself
25     auto t = std::make_unique<MemoryDAGBuilder>();
26     auto a = t->makeFreshValue(aValue);
27     auto b = t->makeFreshValue(bValue);
28     auto c = t->makeFreshValue(cValue);
29     auto d = t->makeFreshValue(dValue);
30     auto e = t->makeFreshValue(eValue);
31     auto f = t->makeFreshValue(fValue);
32     auto g = t->makeFreshValue(gValue);
33     t->makePointerTo(b, a);
34     t->makePointerTo(c, b);
35     t->makePointerTo(d, b);
36     t->makePointerTo(e, a);
37     t->makePointerTo(e, f);
38 
39     auto dag = std::move(*t).createMemoryDAG();
40 
41     /**
42      * Test mayAlias()
43      */
44     // Values should alias themselves
45     EXPECT_TRUE(dag->mayAlias(a, a));
46     EXPECT_TRUE(dag->mayAlias(g, g));
47 
48     // Values that point to the same location should alias
49     EXPECT_TRUE(dag->mayAlias(a, b));
50     EXPECT_TRUE(dag->mayAlias(a, c));
51     EXPECT_TRUE(dag->mayAlias(c, d));
52 
53     // e may point to a OR f
54     EXPECT_TRUE(dag->mayAlias(e, a));
55     EXPECT_TRUE(dag->mayAlias(e, f));
56     // But a and f don't alias
57     EXPECT_FALSE(dag->mayAlias(a, f));
58   }
59   {
60     // x(y) -> x contains y
61 
62     // b(a)
63     // c(a)
64     auto t = std::make_unique<MemoryDAGBuilder>();
65     auto a = t->makeFreshValue(aValue);
66     auto b = t->makeFreshValue(bValue);
67     t->addToContainedElements(a, b);
68 
69     auto c = t->makeFreshValue(cValue);
70     t->addToContainedElements(a, c);
71 
72     auto dag = std::move(*t).createMemoryDAG();
73     EXPECT_TRUE(dag->mayContainAlias(a, b));
74     EXPECT_TRUE(dag->mayContainAlias(b, a));
75 
76     EXPECT_TRUE(dag->mayContainAlias(a, c));
77     EXPECT_TRUE(dag->mayContainAlias(c, a));
78 
79     EXPECT_TRUE(dag->mayContainAlias(b, c));
80     EXPECT_TRUE(dag->mayContainAlias(c, b));
81 
82     // containers contain an element in themselves
83     EXPECT_TRUE(dag->mayContainAlias(b, b));
84     EXPECT_TRUE(dag->mayContainAlias(c, c));
85     EXPECT_TRUE(dag->mayContainAlias(a, a));
86   }
87   {
88     // b(a)
89     // c(a)
90     // d(b(a))
91     auto t = std::make_unique<MemoryDAGBuilder>();
92     auto a = t->makeFreshValue(aValue);
93     auto b = t->makeFreshValue(bValue);
94     t->addToContainedElements(a, b);
95 
96     auto c = t->makeFreshValue(cValue);
97     t->addToContainedElements(a, c);
98 
99     auto d = t->makeFreshValue(dValue);
100     t->addToContainedElements(b, d);
101 
102     auto dag = std::move(*t).createMemoryDAG();
103     EXPECT_TRUE(dag->mayContainAlias(b, d));
104     EXPECT_TRUE(dag->mayContainAlias(d, b));
105 
106     EXPECT_TRUE(dag->mayContainAlias(c, d));
107     EXPECT_TRUE(dag->mayContainAlias(d, c));
108 
109     EXPECT_TRUE(dag->mayContainAlias(a, d));
110   }
111   {
112     // f(e)
113     auto t = std::make_unique<MemoryDAGBuilder>();
114     auto a = t->makeFreshValue(aValue);
115     auto b = t->makeFreshValue(bValue);
116     t->addToContainedElements(a, b);
117 
118     auto c = t->makeFreshValue(cValue);
119     t->addToContainedElements(a, c);
120 
121     auto d = t->makeFreshValue(dValue);
122     t->addToContainedElements(b, d);
123 
124     auto f = t->makeFreshValue(aValue);
125     auto e = t->makeFreshValue(bValue);
126 
127     t->addToContainedElements(f, e);
128 
129     auto dag = std::move(*t).createMemoryDAG();
130     for (auto elem : {a, b, c, d}) {
131       EXPECT_FALSE(dag->mayContainAlias(f, elem));
132       EXPECT_FALSE(dag->mayContainAlias(e, elem));
133     }
134   }
135 }
136 
137 } // namespace jit
138 } // namespace torch
139