1 #include <torch/csrc/jit/runtime/interpreter/preprocess_graph.h>
2
3 #include <torch/csrc/jit/frontend/schema_matching.h>
4 #include <torch/csrc/jit/runtime/interpreter/can_emit_inline.h>
5
6 namespace torch::jit::interpreter {
7
8 namespace {
9
10 // Insert explicit prim::MethodCall nodes after prim::Enter nodes
11 // to actually call __enter__ on the object. All prim::Enter does
12 // is push the object onto the stack of currently entered objects.
13 // This is necessary because emitting two instructions for a
14 // prim::Enter nodes (one ENTER to push onto the entered objects
15 // stack and one CALL to call __enter__) does not work; the
16 // accounting that determines when to move a value out of a register
17 // is based on the number of uses it has in the IR.
insertEnterMethodCalls(Graph & g)18 void insertEnterMethodCalls(Graph& g) {
19 std::vector<Block*> block_queue;
20 std::vector<Node*> enter_nodes;
21 block_queue.emplace_back(g.block());
22
23 // Traverse the graph while drilling down into blocks belonging to
24 // a node and add all encountered prim::Enter nodes to enter_nodes.
25 while (!block_queue.empty()) {
26 Block* block = block_queue.back();
27 block_queue.pop_back();
28
29 for (auto node : block->nodes()) {
30 if (node->kind() == prim::Enter) {
31 enter_nodes.emplace_back(node);
32 continue;
33 }
34
35 for (auto& node_block : node->blocks()) {
36 block_queue.emplace_back(node_block);
37 }
38 }
39 }
40
41 // For each prim::Enter, emit a prim::MethodCall after it that actually
42 // calls __enter__ on the object.
43 for (auto& enter : enter_nodes) {
44 auto cls = enter->input(0)->type()->expect<ClassType>();
45
46 MatchedSchema enter_matched_schema = matchSchema(
47 cls->findMethod("__enter__")->getSchema(),
48 enter->input(0)->node()->sourceRange(),
49 g,
50 {enter->input(0)},
51 {});
52
53 Node* call = g.insertMethodCall("__enter__", enter_matched_schema)->node();
54 call->moveAfter(enter);
55 enter->replaceAllUsesWith(call);
56 }
57 }
58
59 // insert Drop nodes to kill references for anything unused:
60 // this can happen in a few places, e.g. when a node returns
61 // many values but only one is used
62 // a, b = foo()
63 // return a
dropUnused(Block * b)64 void dropUnused(Block* b) {
65 auto createDropIfUnused = [&](ArrayRef<Value*> values) -> Node* {
66 std::vector<Value*> to_drop;
67 for (auto v : values) {
68 if (v->uses().empty() && v->node()->kind() != prim::Constant) {
69 to_drop.push_back(v);
70 }
71 }
72 if (to_drop.empty()) {
73 return nullptr;
74 }
75 return b->owningGraph()->create(prim::Drop, to_drop, 0);
76 };
77
78 if (auto d = createDropIfUnused(b->inputs())) {
79 b->prependNode(d);
80 }
81 for (auto n : b->nodes()) {
82 if (auto d = createDropIfUnused(n->outputs())) {
83 d->insertAfter(n);
84 }
85 for (auto block : n->blocks()) {
86 dropUnused(block);
87 }
88 }
89 }
90
91 // ensure every value has a final use in the same block where it is defined.
92 // This already true for most nodes. The exceptions are:
93 // 1. A value that is unused.
94 // 2. A value whose last use is nested in some control flow.
95 // For (1) we simply add a prim::Drop node that uses the value right after
96 // it is defined. For (2), we insert a prim::Drop right after the control
97 // flow node where the last use occurs
insertLastUses(Graph & g)98 void insertLastUses(Graph& g) {
99 // struct to share common data structures
100 struct InsertLastUses {
101 Graph& graph;
102 // have we seen this value, yet, if not, it is the last use of the value
103 std::unordered_set<Value*> seen;
104
105 // A map from an If or Loop node to the optional Drop block that
106 // occurs directly after it to release any tensors that go out of scope
107 // when the If/Loop exits. These are created and inserted on demand.
108 std::unordered_map<Node*, Node*> drop_for_node;
109
110 explicit InsertLastUses(Graph& g) : graph(g) {
111 scanBlock(graph.block());
112 }
113 void scanBlock(Block* b) {
114 scanNode(b->return_node());
115 for (auto n : b->nodes().reverse()) {
116 scanNode(n);
117 }
118 }
119 void scanNode(Node* n) {
120 for (auto b : n->blocks()) {
121 scanBlock(b);
122 }
123 // scan backwards so if a value is used twice in the list then it is a
124 // move
125 for (size_t i = n->inputs().size(); i > 0; --i) {
126 scanUse(n, i - 1);
127 }
128 }
129 void scanUse(Node* n, size_t i) {
130 auto v = n->inputs()[i];
131 auto inserted = seen.insert(v).second;
132 if (!inserted) {
133 return;
134 }
135
136 // the last use of v may be in a nested block of an If or Loop statement
137 // find the node 'same_depth_node' at the same depth as the definition of
138 // v, and consider that node to be the last use of v. This ensures we do
139 // not delete nodes in nested scopes that may be executed multiple times
140 // and that nodes used on one side of an if
141 // but not the other get deleted regardless of the branch
142 // e.g.
143 // a = 4
144 // while <...>:
145 // y = a + a
146 // drop(a)
147 // In other words, we find the first program point for v that
148 // _reverse_ dominates the definition of v, and add a drop point there.
149 Node* same_depth_node = findOwnerInBlock(n, v->node()->owningBlock());
150 AT_ASSERT(
151 same_depth_node); // failure means v is not in scope for n, use lint!
152
153 // In the case where v and n are in the same block,
154 // we have a legit final use already.
155 if (same_depth_node == n) {
156 return;
157 }
158
159 // in the case where the use is nested in a block
160 // add a Drop node after that block which will drop 'v'.
161 addToDropIfNotExists(
162 findOrCreateDropInstructionForNode(same_depth_node), v);
163 }
164
165 // finds the node in block 'block' that contains in 'n'
166 // or nullptr if no such node exists, e.g.:
167 // n0: a = 4
168 // n1: if <cond>:
169 // n2: b = a + a
170 // findOwnerInBlock(n2, n0.block()) == n1
171 Node* findOwnerInBlock(Node* n, Block* block) {
172 while (n != nullptr && block != n->owningBlock()) {
173 n = n->owningBlock()->owningNode();
174 }
175 return n;
176 }
177
178 Node* findOrCreateDropInstructionForNode(Node* n) {
179 auto it = drop_for_node.find(n);
180 if (it == drop_for_node.end()) {
181 auto drop_node = graph.create(prim::Drop, 0);
182 drop_node->insertAfter(n);
183 it = drop_for_node.emplace(n, drop_node).first;
184 }
185 return it->second;
186 }
187
188 void addToDropIfNotExists(Node* drop, Value* v) {
189 if (v->node()->kind() == prim::Constant) {
190 return;
191 }
192 for (auto i : drop->inputs()) {
193 // we already accounted for this use
194 if (i == v) {
195 return;
196 }
197 }
198 drop->addInput(v);
199 }
200 };
201
202 InsertLastUses ilu(g);
203 }
204
205 } // namespace
206
PreprocessGraph(Graph & g)207 PreprocessGraph::PreprocessGraph(Graph& g) : graph(g.copy()) {
208 insertEnterMethodCalls(*graph);
209 dropUnused(graph->block());
210 // fill in move_flags by scanning blocks;
211 insertLastUses(*graph);
212 can_emit_inline = std::move(CanEmitInline(*graph.get()).can_emit_inline_);
213 }
214 } // namespace torch::jit::interpreter
215