1 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
2
3 #include <torch/csrc/jit/ir/irparser.h>
4 #include <torch/csrc/jit/ir/subgraph_matcher.h>
5
6 #include <c10/util/irange.h>
7
8 #include <utility>
9
10 namespace torch::jit {
11
12 namespace {
update_source_range_and_cs_ptr(const std::set<const Node * > & input_nodes,const Match & m,std::unordered_map<Node *,Node * > & pattern_node_map)13 void update_source_range_and_cs_ptr(
14 const std::set<const Node*>& input_nodes,
15 const Match& m,
16 std::unordered_map<Node*, Node*>& pattern_node_map) {
17 // pattern_node_map, maps nodes of the replacement graph
18 // to the nodes of the pattern graph.
19 // Now we iterate over each node of the replacement graph
20 // and find the corresponding pattern node in the match.
21 // The matched's node's source range and callstack is then
22 // used to update replacement node's source range and callstack
23 for (auto& it : pattern_node_map) {
24 Node* replacement_node = it.first;
25 Node* pattern_node = it.second;
26 if (!input_nodes.count(pattern_node)) {
27 Node* orig_node = m.nodes_map.at(pattern_node);
28 replacement_node->setSourceRange(orig_node->sourceRange());
29 if (orig_node->callstack()) {
30 replacement_node->setCallStack(orig_node->callstack().value());
31 }
32 }
33 }
34 }
35 } // namespace
36
RegisterDefaultPatterns()37 void SubgraphRewriter::RegisterDefaultPatterns() {
38 // TODO: Add actual patterns (like Conv-Relu).
39 RegisterRewritePattern(
40 R"IR(
41 graph(%x, %w, %b):
42 %c = aten::conv(%x, %w, %b)
43 %r = aten::relu(%c)
44 return (%r))IR",
45 R"IR(
46 graph(%x, %w, %b):
47 %r = aten::convrelu(%x, %w, %b)
48 return (%r))IR",
49 {{"r", "c"}});
50 }
51
RegisterRewritePattern(const std::string & pattern,const std::string & replacement,const std::vector<std::pair<std::string,std::string>> & value_name_pairs)52 void SubgraphRewriter::RegisterRewritePattern(
53 const std::string& pattern,
54 const std::string& replacement,
55 const std::vector<std::pair<std::string, std::string>>& value_name_pairs) {
56 std::unordered_map<std::string, std::string> value_name_map(
57 value_name_pairs.begin(), value_name_pairs.end());
58 RewritePatternDescr d = {pattern, replacement, std::move(value_name_map)};
59 patterns_.push_back(std::move(d));
60 }
61
runOnModule(const Module & module)62 Module SubgraphRewriter::runOnModule(const Module& module) {
63 nodes_to_delete_.clear();
64 for (const auto& m : module.get_methods()) {
65 auto g = toGraphFunction(m.function()).graph();
66 runOnGraph(g);
67 }
68 return module;
69 }
70
runOnGraph(std::shared_ptr<Graph> & graph,const std::vector<MatchFilter> & filters)71 void SubgraphRewriter::runOnGraph(
72 std::shared_ptr<Graph>& graph,
73 const std::vector<MatchFilter>& filters) {
74 for (const RewritePatternDescr& pattern : patterns_) {
75 rewriteSinglePatternOnGraph(graph, pattern, filters);
76 }
77 }
78
rewriteSinglePatternOnGraph(std::shared_ptr<Graph> & graph,const RewritePatternDescr & pattern,const std::vector<MatchFilter> & filters)79 void SubgraphRewriter::rewriteSinglePatternOnGraph(
80 std::shared_ptr<Graph>& graph,
81 const RewritePatternDescr& pattern,
82 const std::vector<MatchFilter>& filters) {
83 std::unordered_map<Value*, Value*> rewrite_map;
84 std::vector<Value*> values_to_rewrite;
85
86 Graph pattern_graph;
87 std::unordered_map<std::string, Value*> vmap;
88 parseIR(pattern.pattern, &pattern_graph, vmap);
89
90 Graph replacement_graph;
91 std::unordered_map<std::string, Value*> vmap_replacement;
92 parseIR(pattern.replacement, &replacement_graph, vmap_replacement);
93
94 // First construct map of Node*-to-Node*
95 // This maps Nodes in replacement graph to nodes in pattern graph
96 // given the value_name_map, which maps value names from replacement
97 // pattern to value name in pattern
98 std::unordered_map<Node*, Node*> pattern_node_map;
99 std::set<const Node*> pattern_input_nodes;
100 for (auto& it : vmap_replacement) {
101 const auto& replacement_value_name = it.first;
102 Node* replacement_value_node = it.second->node();
103 if (pattern.value_name_map.count(replacement_value_name)) {
104 const auto& pattern_value_name =
105 pattern.value_name_map.at(replacement_value_name);
106 TORCH_CHECK(
107 vmap.count(pattern_value_name),
108 "Value must be found in the replacement graph.");
109 Node* pattern_value_node = vmap.at(pattern_value_name)->node();
110 pattern_node_map.emplace(replacement_value_node, pattern_value_node);
111 }
112 }
113
114 const auto& matches = findPatternMatches(pattern_graph, *graph);
115 for (const Match& match : matches) {
116 if (!std::all_of(filters.begin(), filters.end(), [&](const MatchFilter& f) {
117 return f(match, vmap);
118 })) {
119 continue;
120 }
121 // Matches might overlap with each other, in that case some of the nodes in
122 // the current match might have already been used in another folded pattern.
123 // We need to skip such matches.
124 if (overlapsWithPreviousMatches(&match)) {
125 continue;
126 }
127
128 // Figure out what values we need to use as inputs and outputs for the
129 // replacement subgraph and where the replacement subgraph needs to be
130 // inserted.
131 Node* ins_point = nullptr;
132 std::vector<Value*> inputs, outputs;
133 for (Value* v : pattern_graph.inputs()) {
134 Value* input = match.values_map.at(v);
135 if (!ins_point || ins_point->isBefore(input->node())) {
136 ins_point = input->node();
137 }
138 inputs.push_back(input);
139 }
140 AT_ASSERT(ins_point);
141
142 // Check that the insertion point we've chosen precedes all the uses of the
143 // outputs - otherwise the replacement is incorrect and we have to skip it.
144 bool ins_point_before_uses = true;
145 for (Value* v : pattern_graph.outputs()) {
146 Value* output = match.values_map.at(v);
147 outputs.push_back(match.values_map.at(v));
148
149 for (const Use& u : output->uses()) {
150 if (u.user->isBefore(ins_point)) {
151 ins_point_before_uses = false;
152 break;
153 }
154 }
155 }
156
157 if (!ins_point_before_uses) {
158 continue;
159 }
160
161 // Before rewriting the graph, update source range and callstack
162 // info of the replacement pattern graph so that the rewritten graph
163 // has the updated info
164 update_source_range_and_cs_ptr(
165 pattern_input_nodes, match, pattern_node_map);
166 // Insert a clone of replacement subgraph.
167 // `inputs` vector holds values that we would use as incoming values to the
168 // new subgraph, and we will get `new_outputs` vector containing values
169 // produced by this new subgraph - we will then rewrite old outputs with the
170 // new ones.
171 WithInsertPoint insert_point(ins_point->next());
172 std::vector<Value*> new_outputs =
173 insertGraph(*graph, replacement_graph, inputs);
174
175 // Record all planned rewritings
176 AT_ASSERT(outputs.size() == new_outputs.size());
177 for (const auto idx : c10::irange(outputs.size())) {
178 values_to_rewrite.push_back(outputs[idx]);
179 rewrite_map[outputs[idx]] =
180 new_outputs[idx]->setType(outputs[idx]->type());
181 }
182 // Record all planned deletions
183 for (Node* pattern_n : pattern_graph.nodes()) {
184 if (match.nodes_map.count(pattern_n)) {
185 Node* n = match.nodes_map.at(pattern_n);
186 nodes_to_delete_.insert(n);
187 }
188 }
189 }
190
191 // Perform planned rewritings
192 for (auto v : values_to_rewrite) {
193 v->replaceAllUsesWith(rewrite_map.at(v));
194 }
195
196 // Perform planned deletions
197 for (auto n : nodes_to_delete_) {
198 n->removeAllInputs();
199 }
200 for (auto n : nodes_to_delete_) {
201 n->destroy();
202 }
203 nodes_to_delete_.clear();
204 }
205
overlapsWithPreviousMatches(const Match * match)206 bool SubgraphRewriter::overlapsWithPreviousMatches(const Match* match) {
207 for (auto n : match->nodes_map) {
208 if (nodes_to_delete_.count(n.second)) {
209 return true;
210 }
211 }
212 return false;
213 }
214
PatternBasedRewrite(const Module & module)215 Module PatternBasedRewrite(const Module& module) {
216 // TODO: Deep-copy the module
217 SubgraphRewriter subgraph_rewriter;
218 subgraph_rewriter.RegisterDefaultPatterns();
219 return subgraph_rewriter.runOnModule(module);
220 }
221
222 } // namespace torch::jit
223