xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_subgraph_matcher.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include "test/cpp/jit/test_utils.h"
4 #include "torch/csrc/jit/ir/subgraph_matcher.h"
5 
6 namespace torch {
7 namespace jit {
8 
TEST(SubgraphMatcherTest,Trivial1)9 TEST(SubgraphMatcherTest, Trivial1) {
10   Graph graph, pattern;
11   parseIR(
12       R"IR(
13 graph(%0):
14   %a = a::aaa(%0)
15   return (%a))IR",
16       &graph);
17   parseIR(
18       R"IR(
19 graph(%0):
20   %x = a::aaa(%0)
21   return (%x))IR",
22       &pattern);
23   AT_ASSERT(!findPatternMatches(pattern, graph).empty());
24 }
25 
TEST(SubgraphMatcherTest,Trivial2)26 TEST(SubgraphMatcherTest, Trivial2) {
27   Graph graph;
28   auto* g_in = graph.addInput();
29   auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1));
30   g_tanh->addInput(g_in);
31   graph.registerOutput(g_tanh->output());
32 
33   Graph pattern;
34   auto* p_in = pattern.addInput();
35   auto* p_tanh =
36       pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1));
37   p_tanh->addInput(p_in);
38   pattern.registerOutput(p_tanh->output());
39 
40   auto matches = findPatternMatches(pattern, graph);
41   AT_ASSERT(matches.size() == 1);
42   for (const Match& m : matches) {
43     AT_ASSERT(m.values_map.at(p_in) == g_in);
44     AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output());
45     AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh);
46   }
47 }
48 
TEST(SubgraphMatcherTest,Trivial3)49 TEST(SubgraphMatcherTest, Trivial3) {
50   Graph graph, pattern;
51   parseIR(
52       R"IR(
53 graph(%0):
54   %a = a::a(%0)
55   %b = a::b(%0)
56   %c = a::c(%a, %b)
57   return (%c))IR",
58       &graph);
59   parseIR(
60       R"IR(
61 graph(%a, %b):
62   %c = a::c(%a, %b)
63   return (%c))IR",
64       &pattern);
65   AT_ASSERT(!findPatternMatches(pattern, graph).empty());
66 }
67 
TEST(SubgraphMatcherTest,Trivial4)68 TEST(SubgraphMatcherTest, Trivial4) {
69   Graph graph;
70   auto* g_in0 = graph.addInput();
71   auto* g_in1 = graph.addInput();
72   auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/1));
73   g_mul->addInput(g_in0);
74   g_mul->addInput(g_in1);
75   graph.registerOutput(g_mul->output());
76 
77   Graph pattern;
78   auto* p_in0 = pattern.addInput();
79   auto* p_in1 = pattern.addInput();
80   auto* p_mul =
81       pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/1));
82   p_mul->addInput(p_in0);
83   p_mul->addInput(p_in1);
84   pattern.registerOutput(p_mul->output());
85 
86   auto matches = findPatternMatches(pattern, graph);
87   AT_ASSERT(matches.size() == 1);
88   for (const Match& m : matches) {
89     AT_ASSERT(m.values_map.at(p_in0) == g_in0);
90     AT_ASSERT(m.values_map.at(p_in1) == g_in1);
91     AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output());
92     AT_ASSERT(m.nodes_map.at(p_mul) == g_mul);
93   }
94 }
95 
TEST(SubgraphMatcherTest,Linear1)96 TEST(SubgraphMatcherTest, Linear1) {
97   Graph graph, pattern;
98   parseIR(
99       R"IR(
100 graph(%0):
101   %a = a::aaa(%0)
102   %b = b::bbb(%a)
103   %c = c::ccc(%b)
104   %d = d::ddd(%c)
105   %a = a::aaa(%0)
106   return (%d))IR",
107       &graph);
108   parseIR(
109       R"IR(
110 graph(%0):
111   %x = b::bbb(%0)
112   %y = c::ccc(%x)
113   return (%y))IR",
114       &pattern);
115   AT_ASSERT(!findPatternMatches(pattern, graph).empty());
116 }
117 
TEST(SubgraphMatcherTest,Linear2)118 TEST(SubgraphMatcherTest, Linear2) {
119   Graph graph;
120   auto* g_in = graph.addInput();
121 
122   auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1));
123   g_tanh->addInput(g_in);
124 
125   auto* g_tanh2 =
126       graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/1));
127   g_tanh2->addInput(g_tanh->output());
128 
129   graph.registerOutput(g_tanh2->output());
130 
131   Graph pattern;
132   auto* p_in = pattern.addInput();
133 
134   auto* p_tanh =
135       pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1));
136   p_tanh->addInput(p_in);
137 
138   auto* p_tanh2 =
139       pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/1));
140   p_tanh2->addInput(p_tanh->output());
141 
142   pattern.registerOutput(p_tanh2->output());
143 
144   auto matches = findPatternMatches(pattern, graph);
145   AT_ASSERT(matches.size() == 1);
146   for (const Match& m : matches) {
147     AT_ASSERT(m.values_map.at(p_in) == g_in);
148     AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output());
149     AT_ASSERT(m.values_map.at(p_tanh2->output()) == g_tanh2->output());
150     AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh);
151     AT_ASSERT(m.nodes_map.at(p_tanh2) == g_tanh2);
152   }
153 }
154 
155 /**
156  * Test diamond pattern:
157  *
158  *     ooo
159  *      |
160  *     aaa
161  *    /   \
162  *  bbb   ccc
163  *     \ /
164  *     ddd
165  *      |
166  *     eee
167  */
TEST(SubgraphMatcherTest,Diamond1)168 TEST(SubgraphMatcherTest, Diamond1) {
169   Graph graph, pattern1, pattern2;
170   parseIR(
171       R"IR(
172 graph(%0):
173   %o = o::ooo(%0)
174   %a = a::aaa(%o)
175   %b = b::bbb(%a)
176   %c = c::ccc(%a)
177   %d = d::ddd(%b, %c)
178   %e = e::eee(%d)
179   return (%e))IR",
180       &graph);
181 
182   parseIR(
183       R"IR(
184 graph(%0):
185   %a = a::aaa(%0)
186   %b = b::bbb(%a)
187   %c = c::ccc(%a)
188   %d = d::ddd(%b, %c)
189   return (%d))IR",
190       &pattern1);
191   AT_ASSERT(!findPatternMatches(pattern1, graph).empty());
192 
193   // Check that order of nodes inside the diamond does not affect the result
194   parseIR(
195       R"IR(
196 graph(%0):
197   %a = a::aaa(%0)
198   %c = c::ccc(%a)
199   %b = b::bbb(%a)
200   %d = d::ddd(%b, %c)
201   return (%d))IR",
202       &pattern2);
203   AT_ASSERT(!findPatternMatches(pattern2, graph).empty());
204 }
205 
206 /**
207  * Test diamond pattern:
208  *
209  *     i0
210  *      |
211  *    chunk
212  *    /   \
213  * os[0] os[1]
214  *     \ /
215  *      *
216  *      |
217  *      o1
218  */
TEST(SubgraphMatcherTest,Diamond2)219 TEST(SubgraphMatcherTest, Diamond2) {
220   Graph graph;
221   auto* g_in = graph.addInput();
222 
223   auto* g_chunk =
224       graph.insertNode(graph.create(prim::ConstantChunk, /*num_outputs =*/2));
225   g_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0);
226   g_chunk->addInput(g_in);
227 
228   auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/1));
229   g_mul->addInput(g_chunk->outputs()[0]);
230   g_mul->addInput(g_chunk->outputs()[1]);
231   graph.registerOutput(g_mul->output());
232 
233   Graph pattern;
234   auto* p_in = pattern.addInput();
235   auto* p_chunk = pattern.insertNode(
236       pattern.create(prim::ConstantChunk, /*num_outputs =*/2));
237   p_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0);
238   p_chunk->addInput(p_in);
239 
240   auto* p_mul =
241       pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/1));
242   p_mul->addInput(p_chunk->outputs()[0]);
243   p_mul->addInput(p_chunk->outputs()[1]);
244   pattern.registerOutput(p_mul->output());
245 
246   auto matches = findPatternMatches(pattern, graph);
247   AT_ASSERT(matches.size() == 1);
248   for (const Match& m : matches) {
249     AT_ASSERT(m.values_map.at(p_in) == g_in);
250     AT_ASSERT(m.values_map.at(p_chunk->outputs()[0]) == g_chunk->outputs()[0]);
251     AT_ASSERT(m.values_map.at(p_chunk->outputs()[1]) == g_chunk->outputs()[1]);
252     AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output());
253     AT_ASSERT(m.nodes_map.at(p_mul) == g_mul);
254   }
255 }
256 
TEST(SubgraphMatcherTest,XPattern)257 TEST(SubgraphMatcherTest, XPattern) {
258   Graph graph, pattern;
259   parseIR(
260       R"IR(
261 graph(%0, %1):
262   %b = b::bbb(%0)
263   %c = c::ccc(%1)
264   %x = x::xxx(%b, %c)
265   %e = e::eee(%x)
266   %f = f::fff(%x)
267   %g = g::ggg(%e, %f)
268   return (%g))IR",
269       &graph);
270   parseIR(
271       R"IR(
272 graph(%0, %1):
273   %b = b::bbb(%0)
274   %c = c::ccc(%1)
275   %x = x::xxx(%b, %c)
276   %e = e::eee(%x)
277   %f = f::fff(%x)
278   %g = g::ggg(%e, %f)
279   return (%g))IR",
280       &pattern);
281   AT_ASSERT(!findPatternMatches(pattern, graph).empty());
282 }
283 
TEST(SubgraphMatcherTest,MultipleMatches)284 TEST(SubgraphMatcherTest, MultipleMatches) {
285   Graph graph, pattern;
286   parseIR(
287       R"IR(
288 graph(%t0):
289   %t1 = a::aaa(%t0)
290   %t2 = a::aaa(%t1)
291   %t3 = a::aaa(%t2)
292   %t4 = a::aaa(%t3)
293   return (%t4))IR",
294       &graph);
295   parseIR(
296       R"IR(
297 graph(%t0):
298   %t1 = a::aaa(%t0)
299   return (%t1))IR",
300       &pattern);
301   auto matches = findPatternMatches(pattern, graph);
302   AT_ASSERT(matches.size() == 4);
303 }
304 
TEST(SubgraphMatcherTest,OverlappingMatches)305 TEST(SubgraphMatcherTest, OverlappingMatches) {
306   Graph graph, pattern;
307   parseIR(
308       R"IR(
309 graph(%t0):
310   %t1 = a::aaa(%t0)
311   %t2 = a::aaa(%t1)
312   %t3 = a::aaa(%t2)
313   %t4 = a::aaa(%t3)
314   return (%t4))IR",
315       &graph);
316   parseIR(
317       R"IR(
318 graph(%t0):
319   %t1 = a::aaa(%t0)
320   %t2 = a::aaa(%t1)
321   return (%t2))IR",
322       &pattern);
323   auto matches = findPatternMatches(pattern, graph);
324   AT_ASSERT(matches.size() == 3);
325 }
326 
TEST(SubgraphMatcherTest,MatchInBasicBlocks1)327 TEST(SubgraphMatcherTest, MatchInBasicBlocks1) {
328   Graph graph;
329   parseIR(
330       R"IR(
331 graph(%a, %b, %c):
332   %d = aten::mul(%a, %b)
333   %x = prim::If(%c)
334     block0():
335       %x1 = aten::mul(%a, %d)
336       -> (%x1)
337     block1():
338       %x2 = aten::mul(%b, %d)
339       -> (%x2)
340   return (%x))IR",
341       &graph);
342 
343   // Ensure the matches don't cross basic block boundaries
344   Graph pattern0;
345   parseIR(
346       R"IR(
347 graph(%x, %y):
348   %z = aten::mul(%x, %y)
349   return (%z))IR",
350       &pattern0);
351   AT_ASSERT(findPatternMatches(pattern0, graph).size() == 3);
352 
353   Graph pattern1;
354   parseIR(
355       R"IR(
356 graph(%x, %y):
357   %z1 = aten::mul(%x, %y)
358   %z2 = aten::mul(%y, %z1)
359   return (%z2))IR",
360       &pattern1);
361   AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0);
362 }
363 
TEST(SubgraphMatcherTest,MatchInBasicBlocks2)364 TEST(SubgraphMatcherTest, MatchInBasicBlocks2) {
365   Graph graph;
366   parseIR(
367       R"IR(
368 graph(%a, %b):
369   %x = my::mul(%a, %b)
370   %y = my::node_with_subblock()
371     block0():
372       %z = my::mul(%b, %x)
373       -> (%z)
374   return (%y))IR",
375       &graph);
376 
377   // Check that we can match both mul ops
378   Graph pattern0;
379   parseIR(
380       R"IR(
381 graph(%x, %y):
382   %z = my::mul(%x, %y)
383   return (%z))IR",
384       &pattern0);
385   AT_ASSERT(findPatternMatches(pattern0, graph).size() == 2);
386 
387   // Ensure the matches don't cross basic block boundaries
388   Graph pattern1;
389   parseIR(
390       R"IR(
391 graph(%x, %y):
392   %u = my::mul(%x, %y)
393   %v = my::mul(%y, %u)
394   return (%v))IR",
395       &pattern1);
396   AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0);
397 }
398 
TEST(SubgraphMatcherTest,MatchesAttributes)399 TEST(SubgraphMatcherTest, MatchesAttributes) {
400   Graph graph;
401   parseIR(
402       R"IR(
403 graph(%0):
404   %a = a::a[isattr=[1,2]](%0)
405   %b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j](%0)
406   %c = a::c[myattr="qqq"](%a, %b)
407   return (%c))IR",
408       &graph);
409 
410   {
411     Graph pattern;
412     parseIR(
413         R"IR(
414 graph(%a, %b):
415   %c = a::c[myattr="qqq"](%a, %b)
416   return (%c))IR",
417         &pattern);
418     AT_ASSERT(!findPatternMatches(pattern, graph).empty());
419   }
420   {
421     Graph pattern;
422     parseIR(
423         R"IR(
424 graph(%a, %b):
425   %c = a::c[myattr="zzz"](%a, %b)
426   return (%c))IR",
427         &pattern);
428     AT_ASSERT(findPatternMatches(pattern, graph).empty());
429   }
430   {
431     Graph pattern;
432     parseIR(
433         R"IR(
434 graph(%0):
435   %b = a::b[extraattr=10](%0)
436   return (%b))IR",
437         &pattern);
438     AT_ASSERT(findPatternMatches(pattern, graph).empty());
439   }
440   {
441     Graph pattern;
442     parseIR(
443         R"IR(
444 graph(%0):
445   %b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j](%0)
446   return (%b))IR",
447         &pattern);
448     AT_ASSERT(!findPatternMatches(pattern, graph).empty());
449   }
450   {
451     Graph pattern;
452     parseIR(
453         R"IR(
454 graph(%0):
455   %b = a::b[intattr=10, floatattr=3.14, complexattr=-3.14j, strattr="rrr"](%0)
456   return (%b))IR",
457         &pattern);
458     AT_ASSERT(findPatternMatches(pattern, graph).empty());
459   }
460   {
461     Graph pattern;
462     parseIR(
463         R"IR(
464 graph(%0):
465   %a = a::a[isattr=[1,2]](%0)
466   return (%a))IR",
467         &pattern);
468     // Lists are not supported yet, thus we shouldn't match for now.
469     AT_ASSERT(findPatternMatches(pattern, graph).empty());
470   }
471   {
472     Graph pattern;
473     parseIR(
474         R"IR(
475 graph(%a, %b):
476   %c = a::c[myattr="q.*"](%a, %b)
477   return (%c))IR",
478         &pattern);
479     AT_ASSERT(!findPatternMatches(pattern, graph).empty());
480   }
481 }
482 
TEST(SubgraphMatcherTest,BadPattern)483 TEST(SubgraphMatcherTest, BadPattern) {
484   Graph graph, pattern1, pattern2;
485   parseIR(
486       R"IR(
487 graph(%x):
488   %y = my::op1(%x)
489   %z = my::op2(%x)
490   return (%y, %z))IR",
491       &graph);
492 
493   parseIR(
494       R"IR(
495 graph(%x):
496   %y = my::node_with_subblock()
497     block0():
498       %z = my::op(%x)
499       -> (%z)
500   return (%y))IR",
501       &pattern1);
502   // No support for patterns with subblocks
503   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
504   ASSERT_ANY_THROW(findPatternMatches(pattern1, graph));
505 
506   parseIR(
507       R"IR(
508 graph(%x):
509   %y = my::op1(%x)
510   %z = my::op2(%x)
511   return (%y, %z))IR",
512       &pattern2);
513   // Not supported multi-output pattern, because not the whole pattern is
514   // covered by a traversal up from the first output (`%z = ...` is not
515   // visited). See the note "Multi-output Patterns" in subgraph_matcher.h.
516   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
517   ASSERT_ANY_THROW(findPatternMatches(pattern2, graph));
518 }
519 
TEST(SubgraphMatcherTest,MultiOutput)520 TEST(SubgraphMatcherTest, MultiOutput) {
521   {
522     Graph graph, pattern;
523     parseIR(
524         R"IR(
525 graph(%0):
526   %a = a::aaa(%0)
527   %b = b::bbb(%a)
528   %c = c::ccc(%a, %b)
529   %x = a::aaa(%c)
530   %y = b::bbb(%x)
531   %z = d::ddd(%x, %y)
532   return (%y))IR",
533         &graph);
534     parseIR(
535         R"IR(
536 graph(%0):
537   %a = a::aaa(%0)
538   %b = b::bbb(%a)
539   return (%b, %a))IR",
540         &pattern);
541     AT_ASSERT(findPatternMatches(pattern, graph).size() == 2);
542   }
543   {
544     Graph graph, pattern;
545     parseIR(
546         R"IR(
547 graph(%0, %1):
548   %a1, %a2 = a::aaa(%0, %1)
549   %b = b::bbb(%a1)
550   %c = c::ccc(%b)
551 
552   %x1, %x2 = a::aaa(%c, %a2)
553   %y = b::bbb(%x1)
554   %z = d::ddd(%y)
555   return (%z))IR",
556         &graph);
557     parseIR(
558         R"IR(
559 graph(%0, %1):
560   %a1, %a2 = a::aaa(%0, %1)
561   %b = b::bbb(%a1)
562   return (%b, %a2))IR",
563         &pattern);
564     AT_ASSERT(findPatternMatches(pattern, graph).size() == 2);
565   }
566 }
567 
568 } // namespace jit
569 } // namespace torch
570