xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_subgraph_rewriter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/jit/test_utils.h>
4 #include <torch/csrc/jit/ir/subgraph_matcher.h>
5 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
6 #include <torch/csrc/jit/testing/file_check.h>
7 
8 namespace torch {
9 namespace jit {
10 using namespace testing;
11 
TEST(SubgraphRewriterTest,FilterMatch)12 TEST(SubgraphRewriterTest, FilterMatch) {
13   auto graph = std::make_shared<Graph>();
14 
15   parseIR(
16       R"IR(
17 graph(%0):
18   %a = a::aaa(%0)
19   %b : int = prim::Constant[value=1]()
20   %c = c::ccc(%a, %b)
21   return (%c))IR",
22       graph.get());
23 
24   std::string pattern = R"IR(
25 graph(%a, %b):
26   %c = c::ccc(%a, %b)
27   return (%c))IR";
28   Graph pattern_graph;
29   std::unordered_map<std::string, Value*> vmap;
30 
31   parseIR(pattern, &pattern_graph, vmap);
32 
33   auto b_is_constant = [](const Match& match,
34                           const std::unordered_map<std::string, Value*>& vmap) {
35     const auto& match_vmap = match.values_map;
36     auto b_node = match_vmap.at(vmap.at("b"))->node();
37     return b_node->kind() == prim::Constant;
38   };
39 
40   auto b_is_one = [](const Match& match,
41                      const std::unordered_map<std::string, Value*>& vmap) {
42     const auto& match_vmap = match.values_map;
43     auto b_val = toIValue(match_vmap.at(vmap.at("b")));
44     return b_val && b_val->isInt() && b_val->toInt() == 1;
45   };
46 
47   auto b_is_two = [](const Match& match,
48                      const std::unordered_map<std::string, Value*>& vmap) {
49     const auto& match_vmap = match.values_map;
50     auto b_val = toIValue(match_vmap.at(vmap.at("b")));
51     return b_val && b_val->isInt() && b_val->toInt() == 2;
52   };
53 
54   std::string replacement = R"IR(
55 graph(%a, %b):
56   %d = d::ddd(%a, %b)
57   return (%d))IR";
58 
59   SubgraphRewriter rewriter;
60   rewriter.RegisterRewritePattern(pattern, replacement);
61 
62   // b is constant, so the match will succeed
63   {
64     auto g = graph->copy();
65     rewriter.runOnGraph(g, b_is_constant);
66     FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g);
67   }
68 
69   // b is constant and the value is one, the match will succeed
70   {
71     auto g = graph->copy();
72     rewriter.runOnGraph(g, {b_is_constant, b_is_one});
73     FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g);
74   }
75 
76   // b is constant but the value is not two, the match will fail
77   {
78     auto g = graph->copy();
79     rewriter.runOnGraph(g, {b_is_constant, b_is_two});
80     FileCheck().check("c::ccc")->check_not("d::ddd")->run(*g);
81   }
82 }
83 
TEST(SubgraphRewriterTest,FilterNoMatch)84 TEST(SubgraphRewriterTest, FilterNoMatch) {
85   auto graph = std::make_shared<Graph>();
86   parseIR(
87       R"IR(
88 graph(%0):
89   %a = a::aaa(%0)
90   %b = prim::Constant[value=1]()
91   %c = c::ccc(%a, %b)
92   return (%c))IR",
93       graph.get());
94 
95   std::string pattern = R"IR(
96 graph(%a, %b):
97   %c = c::ccc(%a, %b)
98   return (%c))IR";
99   Graph pattern_graph;
100   std::unordered_map<std::string, Value*> vmap;
101 
102   parseIR(pattern, &pattern_graph, vmap);
103 
104   auto filter = [](const Match& match,
105                    const std::unordered_map<std::string, Value*>& vmap) {
106     const auto& match_vmap = match.values_map;
107     auto b_node = match_vmap.at(vmap.at("b"))->node();
108     // b_node is not prim::Assign, so this won't match and we'll skip the
109     // rewrite
110     return b_node->kind() == prim::Assign;
111   };
112 
113   std::string replacement = R"IR(
114 graph(%a, %b):
115   %d = d::ddd(%a, %b)
116   return (%d))IR";
117 
118   SubgraphRewriter rewriter;
119   rewriter.RegisterRewritePattern(pattern, replacement);
120   rewriter.runOnGraph(graph, filter);
121 
122   FileCheck().check("c::ccc")->check_not("d::ddd")->run(*graph);
123 }
124 
TEST(SubgraphRewriterTest,MultiOutput)125 TEST(SubgraphRewriterTest, MultiOutput) {
126   {
127     auto graph = std::make_shared<Graph>();
128 
129     // Basic multi-output pattern rewriting
130     parseIR(
131         R"IR(
132 graph(%0, %1):
133   %a1, %a2 = a::aaa(%0, %1)
134   %b = b::bbb(%a1)
135   %c = c::ccc(%b)
136 
137   %x1, %x2 = a::aaa(%c, %a2)
138   %y = b::bbb(%x1)
139   %z = d::ddd(%y)
140   return (%z))IR",
141         graph.get());
142 
143     std::string pattern = R"IR(
144 graph(%0, %1):
145   %a1, %a2 = a::aaa(%0, %1)
146   %b = b::bbb(%a1)
147   return (%b, %a2))IR";
148 
149     std::string replacement = R"IR(
150 graph(%a, %b):
151   %x, %y = ab::ababab(%a, %b)
152   return (%x, %y))IR";
153 
154     SubgraphRewriter rewriter;
155     rewriter.RegisterRewritePattern(pattern, replacement);
156 
157     auto g = graph->copy();
158     rewriter.runOnGraph(g);
159     FileCheck().check("ab::ababab")->check("ab::ababab")->run(*g);
160   }
161   {
162     auto graph = std::make_shared<Graph>();
163 
164     // Mimic a real model case
165     parseIR(
166         R"IR(
167     graph(%k, %m, %x1, %x2, %x3, %x4, %y1, %y2, %y3, %y4):
168       %a1 = aa::aaa(%x1, %k)
169       %b1_1, %b1_2 = bb::bbb(%y1, %a1)
170       %a2 = aa::aaa(%x2, %k)
171       %b2_1, %b2_2 = bb::bbb(%y2, %a2)
172       %a3 = aa::aaa(%x3, %k)
173       %b3_1, %b3_2 = bb::bbb(%y3, %a3)
174       %a4 = aa::aaa(%x4, %k)
175       %b4_1, %b4_2 = bb::bbb(%y4, %a4)
176       %c = cc::ccc(%b4_1)
177       %d1 = dd::ddd(%b1_2, %m)
178       %e1 = ee::eee(%b1_1, %d1)
179       %d2 = dd::ddd(%b2_2, %m)
180       %e2 = ee::eee(%b2_1, %d2)
181       %d3 = dd::ddd(%b3_2, %m)
182       %e3 = ee::eee(%b3_1, %d3)
183       %d4 = dd::ddd(%b4_2, %m)
184       %e4 = ee::eee(%b4_1, %d4)
185       return (%d1, %d2, %d3, %d4, %e1, %e2, %e3, %e4)
186       )IR",
187         graph.get());
188 
189     std::string pattern = R"IR(
190     graph(%a, %b, %c, %d):
191         %y0 = aa::aaa(%b, %c)
192         %y1, %y2 = bb::bbb(%a, %y0)
193         %y3 = dd::ddd(%y2, %d)
194         return (%y3, %y1))IR";
195 
196     std::string replacement = R"IR(
197     graph(%a, %b, %c, %d):
198       %x, %y = ab::ababab(%a, %b, %c, %d)
199       return (%x, %y))IR";
200 
201     SubgraphRewriter rewriter;
202     rewriter.RegisterRewritePattern(pattern, replacement);
203 
204     auto g = graph->copy();
205     rewriter.runOnGraph(g);
206     FileCheck().check("ab::ababab")->check("ab::ababab")->run(*g);
207   }
208   {
209     auto graph = std::make_shared<Graph>();
210 
211     // A case where no rewriting should occur due to data dependencies
212     parseIR(
213         R"IR(
214     graph(%x, %y):
215       %a = aa::aaa(%x)
216       %b = bb::bbb(%a)
217       %e = ee::eee(%b)
218       %c = cc::ccc(%y)
219       %d = dd::ddd(%b, %c)
220       %f = ff::fff(%b, %d)
221       return (%f)
222       )IR",
223         graph.get());
224 
225     std::string pattern = R"IR(
226     graph(%a, %c):
227         %b = bb::bbb(%a)
228         %d = dd::ddd(%b, %c)
229         return (%d, %b))IR";
230 
231     std::string replacement = R"IR(
232     graph(%a, %c):
233       %d, %b = db::fused(%a, %c)
234       return (%d, %b))IR";
235 
236     SubgraphRewriter rewriter;
237     rewriter.RegisterRewritePattern(pattern, replacement);
238 
239     auto g = graph->copy();
240     rewriter.runOnGraph(g);
241     // We should not perform the replacement on the given graph due to data
242     // dependency constraints: the output %b is used in %e, which precedes one
243     // def of the input %c.
244     FileCheck().check_not("db::fused")->run(*g);
245   }
246 }
247 
TEST(SubgraphRewriterTest,OutputType)248 TEST(SubgraphRewriterTest, OutputType) {
249   std::string pattern = R"IR(
250 graph(%a, %b):
251   %c = c::ccc(%a, %b)
252   return (%c))IR";
253   Graph pattern_graph;
254   std::unordered_map<std::string, Value*> vmap;
255 
256   parseIR(pattern, &pattern_graph, vmap);
257 
258   auto b_is_constant = [](const Match& match,
259                           const std::unordered_map<std::string, Value*>& vmap) {
260     const auto& match_vmap = match.values_map;
261     auto b_node = match_vmap.at(vmap.at("b"))->node();
262     return b_node->kind() == prim::Constant;
263   };
264 
265   std::string replacement = R"IR(
266 graph(%a, %b):
267   %d = d::ddd(%a, %b)
268   return (%d))IR";
269 
270   SubgraphRewriter rewriter;
271   rewriter.RegisterRewritePattern(pattern, replacement);
272   {
273     auto graph = std::make_shared<Graph>();
274 
275     parseIR(
276         R"IR(
277   graph(%0):
278     %a : Float(10, 20) = a::aaa(%0)
279     %b : int = prim::Constant[value=1]()
280     %c : Float(10, 20) = c::ccc(%a, %b)
281     return (%c))IR",
282         graph.get());
283 
284     // output has shape info.
285     rewriter.runOnGraph(graph, b_is_constant);
286     FileCheck()
287         .check("Float(10, 20) = d::ddd")
288         ->check_not("c::ccc")
289         ->run(*graph);
290   }
291   {
292     auto graph = std::make_shared<Graph>();
293 
294     parseIR(
295         R"IR(
296   graph(%0):
297     %a = a::aaa(%0)
298     %b : int = prim::Constant[value=1]()
299     %c = c::ccc(%a, %b)
300     return (%c))IR",
301         graph.get());
302 
303     // output has not shape info.
304     rewriter.runOnGraph(graph, b_is_constant);
305     FileCheck().check("Tensor = d::ddd")->check_not("c::ccc")->run(*graph);
306   }
307 }
308 
309 } // namespace jit
310 } // namespace torch
311