xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/subgraph_rewrite.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
2 
3 #include <torch/csrc/jit/ir/irparser.h>
4 #include <torch/csrc/jit/ir/subgraph_matcher.h>
5 
6 #include <c10/util/irange.h>
7 
8 #include <utility>
9 
10 namespace torch::jit {
11 
12 namespace {
update_source_range_and_cs_ptr(const std::set<const Node * > & input_nodes,const Match & m,std::unordered_map<Node *,Node * > & pattern_node_map)13 void update_source_range_and_cs_ptr(
14     const std::set<const Node*>& input_nodes,
15     const Match& m,
16     std::unordered_map<Node*, Node*>& pattern_node_map) {
17   // pattern_node_map, maps nodes of the replacement graph
18   // to the nodes of the pattern graph.
19   // Now we iterate over each node of the replacement graph
20   // and find the corresponding pattern node in the match.
21   // The matched's node's source range and callstack is then
22   // used to update replacement node's source range and callstack
23   for (auto& it : pattern_node_map) {
24     Node* replacement_node = it.first;
25     Node* pattern_node = it.second;
26     if (!input_nodes.count(pattern_node)) {
27       Node* orig_node = m.nodes_map.at(pattern_node);
28       replacement_node->setSourceRange(orig_node->sourceRange());
29       if (orig_node->callstack()) {
30         replacement_node->setCallStack(orig_node->callstack().value());
31       }
32     }
33   }
34 }
35 } // namespace
36 
RegisterDefaultPatterns()37 void SubgraphRewriter::RegisterDefaultPatterns() {
38   // TODO: Add actual patterns (like Conv-Relu).
39   RegisterRewritePattern(
40       R"IR(
41 graph(%x, %w, %b):
42   %c = aten::conv(%x, %w, %b)
43   %r = aten::relu(%c)
44   return (%r))IR",
45       R"IR(
46 graph(%x, %w, %b):
47   %r = aten::convrelu(%x, %w, %b)
48   return (%r))IR",
49       {{"r", "c"}});
50 }
51 
RegisterRewritePattern(const std::string & pattern,const std::string & replacement,const std::vector<std::pair<std::string,std::string>> & value_name_pairs)52 void SubgraphRewriter::RegisterRewritePattern(
53     const std::string& pattern,
54     const std::string& replacement,
55     const std::vector<std::pair<std::string, std::string>>& value_name_pairs) {
56   std::unordered_map<std::string, std::string> value_name_map(
57       value_name_pairs.begin(), value_name_pairs.end());
58   RewritePatternDescr d = {pattern, replacement, std::move(value_name_map)};
59   patterns_.push_back(std::move(d));
60 }
61 
runOnModule(const Module & module)62 Module SubgraphRewriter::runOnModule(const Module& module) {
63   nodes_to_delete_.clear();
64   for (const auto& m : module.get_methods()) {
65     auto g = toGraphFunction(m.function()).graph();
66     runOnGraph(g);
67   }
68   return module;
69 }
70 
runOnGraph(std::shared_ptr<Graph> & graph,const std::vector<MatchFilter> & filters)71 void SubgraphRewriter::runOnGraph(
72     std::shared_ptr<Graph>& graph,
73     const std::vector<MatchFilter>& filters) {
74   for (const RewritePatternDescr& pattern : patterns_) {
75     rewriteSinglePatternOnGraph(graph, pattern, filters);
76   }
77 }
78 
rewriteSinglePatternOnGraph(std::shared_ptr<Graph> & graph,const RewritePatternDescr & pattern,const std::vector<MatchFilter> & filters)79 void SubgraphRewriter::rewriteSinglePatternOnGraph(
80     std::shared_ptr<Graph>& graph,
81     const RewritePatternDescr& pattern,
82     const std::vector<MatchFilter>& filters) {
83   std::unordered_map<Value*, Value*> rewrite_map;
84   std::vector<Value*> values_to_rewrite;
85 
86   Graph pattern_graph;
87   std::unordered_map<std::string, Value*> vmap;
88   parseIR(pattern.pattern, &pattern_graph, vmap);
89 
90   Graph replacement_graph;
91   std::unordered_map<std::string, Value*> vmap_replacement;
92   parseIR(pattern.replacement, &replacement_graph, vmap_replacement);
93 
94   // First construct map of Node*-to-Node*
95   // This maps Nodes in replacement graph to nodes in pattern graph
96   // given the value_name_map, which maps value names from replacement
97   // pattern to value name in pattern
98   std::unordered_map<Node*, Node*> pattern_node_map;
99   std::set<const Node*> pattern_input_nodes;
100   for (auto& it : vmap_replacement) {
101     const auto& replacement_value_name = it.first;
102     Node* replacement_value_node = it.second->node();
103     if (pattern.value_name_map.count(replacement_value_name)) {
104       const auto& pattern_value_name =
105           pattern.value_name_map.at(replacement_value_name);
106       TORCH_CHECK(
107           vmap.count(pattern_value_name),
108           "Value must be found in the replacement graph.");
109       Node* pattern_value_node = vmap.at(pattern_value_name)->node();
110       pattern_node_map.emplace(replacement_value_node, pattern_value_node);
111     }
112   }
113 
114   const auto& matches = findPatternMatches(pattern_graph, *graph);
115   for (const Match& match : matches) {
116     if (!std::all_of(filters.begin(), filters.end(), [&](const MatchFilter& f) {
117           return f(match, vmap);
118         })) {
119       continue;
120     }
121     // Matches might overlap with each other, in that case some of the nodes in
122     // the current match might have already been used in another folded pattern.
123     // We need to skip such matches.
124     if (overlapsWithPreviousMatches(&match)) {
125       continue;
126     }
127 
128     // Figure out what values we need to use as inputs and outputs for the
129     // replacement subgraph and where the replacement subgraph needs to be
130     // inserted.
131     Node* ins_point = nullptr;
132     std::vector<Value*> inputs, outputs;
133     for (Value* v : pattern_graph.inputs()) {
134       Value* input = match.values_map.at(v);
135       if (!ins_point || ins_point->isBefore(input->node())) {
136         ins_point = input->node();
137       }
138       inputs.push_back(input);
139     }
140     AT_ASSERT(ins_point);
141 
142     // Check that the insertion point we've chosen precedes all the uses of the
143     // outputs - otherwise the replacement is incorrect and we have to skip it.
144     bool ins_point_before_uses = true;
145     for (Value* v : pattern_graph.outputs()) {
146       Value* output = match.values_map.at(v);
147       outputs.push_back(match.values_map.at(v));
148 
149       for (const Use& u : output->uses()) {
150         if (u.user->isBefore(ins_point)) {
151           ins_point_before_uses = false;
152           break;
153         }
154       }
155     }
156 
157     if (!ins_point_before_uses) {
158       continue;
159     }
160 
161     // Before rewriting the graph, update source range and callstack
162     // info of the replacement pattern graph so that the rewritten graph
163     // has the updated info
164     update_source_range_and_cs_ptr(
165         pattern_input_nodes, match, pattern_node_map);
166     // Insert a clone of replacement subgraph.
167     // `inputs` vector holds values that we would use as incoming values to the
168     // new subgraph, and we will get `new_outputs` vector containing values
169     // produced by this new subgraph - we will then rewrite old outputs with the
170     // new ones.
171     WithInsertPoint insert_point(ins_point->next());
172     std::vector<Value*> new_outputs =
173         insertGraph(*graph, replacement_graph, inputs);
174 
175     // Record all planned rewritings
176     AT_ASSERT(outputs.size() == new_outputs.size());
177     for (const auto idx : c10::irange(outputs.size())) {
178       values_to_rewrite.push_back(outputs[idx]);
179       rewrite_map[outputs[idx]] =
180           new_outputs[idx]->setType(outputs[idx]->type());
181     }
182     // Record all planned deletions
183     for (Node* pattern_n : pattern_graph.nodes()) {
184       if (match.nodes_map.count(pattern_n)) {
185         Node* n = match.nodes_map.at(pattern_n);
186         nodes_to_delete_.insert(n);
187       }
188     }
189   }
190 
191   // Perform planned rewritings
192   for (auto v : values_to_rewrite) {
193     v->replaceAllUsesWith(rewrite_map.at(v));
194   }
195 
196   // Perform planned deletions
197   for (auto n : nodes_to_delete_) {
198     n->removeAllInputs();
199   }
200   for (auto n : nodes_to_delete_) {
201     n->destroy();
202   }
203   nodes_to_delete_.clear();
204 }
205 
overlapsWithPreviousMatches(const Match * match)206 bool SubgraphRewriter::overlapsWithPreviousMatches(const Match* match) {
207   for (auto n : match->nodes_map) {
208     if (nodes_to_delete_.count(n.second)) {
209       return true;
210     }
211   }
212   return false;
213 }
214 
PatternBasedRewrite(const Module & module)215 Module PatternBasedRewrite(const Module& module) {
216   // TODO: Deep-copy the module
217   SubgraphRewriter subgraph_rewriter;
218   subgraph_rewriter.RegisterDefaultPatterns();
219   return subgraph_rewriter.runOnModule(module);
220 }
221 
222 } // namespace torch::jit
223