xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_subgraph_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include "test/cpp/jit/test_utils.h"
4 
5 #include <torch/csrc/jit/testing/file_check.h>
6 #include "torch/csrc/jit/passes/common_subexpression_elimination.h"
7 #include "torch/csrc/jit/passes/utils/subgraph_utils.h"
8 
9 namespace torch {
10 namespace jit {
11 
TEST(SubgraphUtilsTest,Basic)12 TEST(SubgraphUtilsTest, Basic) {
13   auto graph = build_lstm();
14   EliminateCommonSubexpression(graph);
15 
16   std::vector<Node*> originalNodes(
17       graph->nodes().begin(), graph->nodes().end());
18 
19   for (bool reverse_iterate : {true, false}) {
20     // Merge everything into a single subgraph
21     bool first = true;
22     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
23     Node* subgraph;
24     auto it =
25         reverse_iterate ? graph->nodes().rbegin() : graph->nodes().begin();
26     auto end = reverse_iterate ? graph->nodes().rend() : graph->nodes().end();
27     for (; it != end;) {
28       if (first) {
29         subgraph = SubgraphUtils::createSingletonSubgraph(
30             *it, prim::DifferentiableGraph);
31         it = reverse_iterate ? ++subgraph->reverseIterator()
32                              : ++subgraph->iterator();
33         first = false;
34       }
35 
36       SubgraphUtils::mergeNodeIntoSubgraph(*it, subgraph);
37       it = reverse_iterate ? ++subgraph->reverseIterator()
38                            : ++subgraph->iterator();
39     }
40 
41     // Unmerge and compare with original node listing
42     // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
43     SubgraphUtils::unmergeSubgraph(subgraph);
44     EliminateCommonSubexpression(graph);
45 
46     std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end());
47     ASSERT_EQ(originalNodes.size(), newNodes.size());
48   }
49 }
50 
TEST(SubgraphUtilsTest,MergeSubgraphs)51 TEST(SubgraphUtilsTest, MergeSubgraphs) {
52   auto graph = std::make_shared<Graph>();
53   std::unordered_map<std::string, Value*> parse_map;
54   parseIR(
55       R"IR(
56 graph(%a : Tensor, %b : Tensor, %c : Tensor):
57   %x : Tensor = aten::sigmoid(%a)
58   %y : Tensor = aten::mul(%a, %b)
59   %p : Tensor = aten::div(%c, %b)
60   %q1 : Tensor = aten::mul(%p, %a)
61   %q2 : Tensor = aten::tanh(%q1)
62   %q3 : Tensor = aten::tanh(%q2)
63   %q4 : Tensor = aten::tanh(%q3)
64   %q5 : Tensor = aten::hardsigmoid(%q4)
65   return (%x, %y, %q5))IR",
66       &*graph,
67       parse_map);
68 
69   std::vector<Node*> originalNodes(
70       graph->nodes().begin(), graph->nodes().end());
71   for (bool reverse_merge : {true, false}) {
72     // Merge everything into two adjacent subgraphs
73     Node* graph1 = SubgraphUtils::createSingletonSubgraph(
74         *graph->nodes().begin(), prim::DifferentiableGraph);
75     while (true) {
76       Node* next = graph1->next();
77       if (next->kind() == aten::tanh) {
78         break;
79       }
80       SubgraphUtils::mergeNodeIntoSubgraph(next, graph1);
81     }
82     Node* graph2 = SubgraphUtils::createSingletonSubgraph(
83         graph1->next(), prim::DifferentiableGraph);
84     while (graph2->next() != *graph->nodes().end()) {
85       SubgraphUtils::mergeNodeIntoSubgraph(graph2->next(), graph2);
86     }
87     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
88     Node* subgraph;
89     if (reverse_merge) {
90       SubgraphUtils::mergeNodeIntoSubgraph(graph2, graph1);
91       subgraph = graph1;
92     } else {
93       SubgraphUtils::mergeNodeIntoSubgraph(graph1, graph2);
94       subgraph = graph2;
95     }
96     auto run_file_check = [](std::shared_ptr<Graph> graph) {
97       graph->lint();
98       testing::FileCheck()
99           .check("aten::sigmoid")
100           ->check("aten::mul")
101           ->check("aten::div")
102           ->check("aten::mul")
103           ->check_count("aten::tanh", 3)
104           ->check("aten::hardsigmoid")
105           ->run(*graph);
106     };
107     run_file_check(subgraph->g(attr::Subgraph));
108 
109     // Unmerge and compare with original node listing
110     SubgraphUtils::unmergeSubgraph(subgraph);
111     EliminateCommonSubexpression(graph);
112     run_file_check(graph);
113 
114     std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end());
115     ASSERT_EQ(originalNodes.size(), newNodes.size());
116   }
117 }
118 
TEST(SubgraphUtilsTest,GraphName)119 TEST(SubgraphUtilsTest, GraphName) {
120   auto graph = std::make_shared<Graph>();
121 
122   std::unordered_map<std::string, Value*> parse_map;
123   parseIR(
124       R"IR(
125 graph(%a : Tensor, %b : Tensor, %c : Tensor):
126   %x : Tensor = aten::tanh(%a)
127   %y : Tensor = aten::mul(%a, %b)
128   %p : Tensor = aten::div(%c, %b)
129   %q1 : Tensor = aten::mul(%p, %a)
130   %q2 : Tensor = aten::tanh(%q1)
131   %q3 : Tensor = aten::tanh(%q2)
132   %q4 : Tensor = aten::tanh(%q3)
133   %q5 : Tensor = aten::tanh(%q4)
134   return (%x, %y, %q5))IR",
135       &*graph,
136       parse_map);
137   std::string ref_full_name = "graph_tanh_mul_div_mul_tanh_tanh_tanh_tanh";
138   std::string full_name =
139       SubgraphUtils::generateNameForGraph(graph, 80, "graph");
140   ASSERT_EQ(full_name, ref_full_name);
141 
142   std::string truncated_name =
143       SubgraphUtils::generateNameForGraph(graph, 10, "graph");
144 
145   ASSERT_LE(truncated_name.size(), ref_full_name.size());
146 }
147 
148 } // namespace jit
149 } // namespace torch
150