xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_ir.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/jit/test_utils.h>
4 #include <torch/csrc/jit/ir/irparser.h>
5 
6 namespace torch {
7 namespace jit {
8 
TEST(IRTest,Attributes)9 TEST(IRTest, Attributes) {
10   Graph g;
11   auto one = attr::alpha;
12   auto two = attr::device;
13   auto three = attr::end;
14   auto four = attr::perm;
15   Node* n = g.create(Symbol::fromQualString("foo::bar"));
16   Node& attr = *n;
17   attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what");
18   ASSERT_EQ(attr.f(one), 3.4);
19   ASSERT_EQ(attr.s(three), "what");
20   ASSERT_EQ(attr.i(two), 5);
21   attr.s_(one, "no");
22   ASSERT_EQ(attr.s(one), "no");
23   ASSERT_TRUE(attr.hasAttribute(three));
24   ASSERT_TRUE(!attr.hasAttribute(four));
25   attr.ss_(two, {"hi", "now"});
26   ASSERT_EQ(attr.ss(two).at(1), "now");
27 
28   Node* n2 = g.create(Symbol::fromQualString("foo::baz"));
29   Node& attr2 = *n2;
30   attr2.copyAttributes(attr);
31   ASSERT_EQ(attr2.s(one), "no");
32   attr2.f_(one, 5);
33   ASSERT_EQ(attr.s(one), "no");
34   ASSERT_EQ(attr2.f(one), 5);
35 }
36 
TEST(IRTest,Blocks)37 TEST(IRTest, Blocks) {
38   auto g = std::make_shared<Graph>();
39   const auto graph_string = R"IR(
40     graph(%a : Tensor,
41           %b : Tensor,
42           %c : Tensor):
43       %2 : int = prim::Constant[value=1]()
44       %3 : Tensor = aten::add(%a, %b, %2)
45       %5 : Tensor = prim::If(%c)
46         block0():
47           %6 : int = prim::Constant[value=1]()
48           %7 : Tensor = aten::add(%3, %3, %6)
49           -> (%7)
50         block1():
51           %8 : int = prim::Constant[value=1]()
52           %9 : Tensor = aten::add(%b, %3, %8)
53           %10 : int = prim::Constant[value=1]()
54           %11 : Tensor = aten::add(%9, %3, %10)
55           -> (%11)
56       %12 : int = prim::Constant[value=1]()
57       %13 : Tensor = aten::add(%5, %3, %12)
58       return (%13))IR";
59   torch::jit::parseIR(graph_string, g.get());
60 
61   g->lint();
62   testing::FileCheck()
63       .check("add")
64       ->check("prim::If")
65       ->check("block0")
66       ->check("aten::add")
67       ->check("block1")
68       ->check_count("aten::add", 3)
69       ->run(*g);
70 
71   // Removes block0 of the conditional
72   for (auto* node : g->block()->nodes()) {
73     if (node->kind() == prim::If) {
74       node->eraseBlock(0);
75       break;
76     }
77   }
78 
79   testing::FileCheck()
80       .check("add")
81       ->check("prim::If")
82       ->check("block0")
83       ->check_not("block")
84       ->run(*g);
85   g->lint();
86   // test recursive copy of blocks works
87   auto g2 = g->copy();
88   testing::FileCheck()
89       .check("add")
90       ->check("prim::If")
91       ->check("block0")
92       ->check_not("block")
93       ->run(*g2);
94 }
95 
TEST(IRTest,CommonAncestor)96 TEST(IRTest, CommonAncestor) {
97   std::string input_str = R"(
98 graph(%x : Tensor,
99       %a.1 : bool,
100       %b.1 : bool,
101       %c.1 : bool):
102   %4 : int = prim::If(%a.1)
103     block0():
104       %5 : int = prim::If(%b.1)
105         block0():
106           %6 : int = prim::Constant[value=2]()
107           -> (%6)
108         block1():
109           %7 : int = prim::Constant[value=3]()
110           -> (%7)
111       -> (%5)
112     block1():
113       %8 : int = prim::If(%c.1)
114         block0():
115           %9 : int = prim::Constant[value=4]()
116           -> (%9)
117         block1():
118           %10 : int = prim::Constant[value=5]()
119           -> (%10)
120       -> (%8)
121   return (%4)
122 )";
123 
124   torch::jit::Graph g;
125   std::unordered_map<std::string, torch::jit::Value*> name_to_value;
126   torch::jit::parseIR(input_str, &g, name_to_value);
127 
128   std::vector<std::string> value_names{"6", "7", "9", "10"};
129   std::unordered_set<std::string> value_names_set(
130       value_names.begin(), value_names.end());
131 
132   /* clang-format off */
133   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
134   int ref_blocks_from_graph[4][4] = {
135     /* (6, 6), (6, 7), (6, 9), (6, 10) */
136     {   2,     1,      0,      0        },
137     /* (7, 6), (7, 7), (7, 9), (7, 10) */
138     {   1,     2,      0,      0        },
139     /* (9, 6), (9, 7), (9, 9), (9, 10) */
140     {   0,     0,      2,      1,       },
141     /* (10, 6),(10, 7),(10, 9),(10, 10) */
142     {   0,     0,      1,      2        }
143   };
144   /* clang-format on */
145 
146   for (size_t i = 0; i < value_names.size(); ++i) {
147     Value* i_val = name_to_value[value_names[i]];
148     for (size_t j = 0; j < value_names.size(); ++j) {
149       Value* j_val = name_to_value[value_names[j]];
150       Block* common_ancestor =
151           i_val->node()->findCommonAncestorBlockWith(j_val->node());
152       int blocks_from_graph_block =
153           common_ancestor->param_node()->blocksFromGraphBlock();
154       ASSERT_EQ(blocks_from_graph_block, ref_blocks_from_graph[i][j]);
155     }
156   }
157 }
158 
TEST(IRTest,OperatorMap)159 TEST(IRTest, OperatorMap) {
160   OperatorMap<int> op_map;
161   const char* literal1 =
162       "aten::dropout(Tensor input, float p, bool train) -> Tensor";
163   const char* literal2 =
164       "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor";
165   const char* literal3 =
166       "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor";
167   const char* literal4 =
168       "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor";
169   const char* literal5 =
170       "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor";
171   const char* literal6 =
172       "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor";
173   std::shared_ptr<Operator> op1 = getOperatorForLiteral(literal1);
174   std::shared_ptr<Operator> op2 = getOperatorForLiteral(literal2);
175   std::shared_ptr<Operator> op3 = getOperatorForLiteral(literal3);
176   std::shared_ptr<Operator> op4 = getOperatorForLiteral(literal4);
177   std::shared_ptr<Operator> op5 = getOperatorForLiteral(literal5);
178   std::shared_ptr<Operator> op6 = getOperatorForLiteral(literal6);
179   op_map.insert(op1, 1);
180   op_map.insert({{op2, 2}, {op3, 3}});
181   op_map.insert({{op4, 4}, {op5, 5}});
182   op_map.insert(op6, 6);
183   ASSERT_TRUE(op_map.contains(*op1));
184   ASSERT_TRUE(op_map.contains(*op2));
185   ASSERT_TRUE(op_map.contains(*op3));
186   ASSERT_TRUE(op_map.contains(*op4));
187   ASSERT_TRUE(op_map.contains(*op5));
188   ASSERT_TRUE(op_map.contains(*op6));
189   op_map.erase(op6);
190   op_map.erase(op3);
191   op_map.erase(op1);
192   ASSERT_FALSE(op_map.contains(*op1));
193   ASSERT_FALSE(op_map.contains(*op3));
194   ASSERT_FALSE(op_map.contains(*op6));
195   op_map.insert(op1, 1);
196   ASSERT_TRUE(op_map.contains(*op1));
197   std::optional<int> o1 = op_map.find(*op1);
198   ASSERT_TRUE(o1.has_value());
199   std::optional<int> o2 = op_map.find(*op2);
200   ASSERT_TRUE(o2.has_value());
201   std::optional<int> o3 = op_map.find(*op3);
202   ASSERT_FALSE(o3.has_value());
203   std::optional<int> o4 = op_map.find(*op4);
204   ASSERT_TRUE(o4.has_value());
205   std::optional<int> o5 = op_map.find(*op5);
206   ASSERT_TRUE(o5.has_value());
207   std::optional<int> o6 = op_map.find(*op6);
208   ASSERT_FALSE(o6.has_value());
209 }
210 
211 } // namespace jit
212 } // namespace torch
213