xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/interpreter/preprocess_graph.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/interpreter/preprocess_graph.h>
2 
3 #include <torch/csrc/jit/frontend/schema_matching.h>
4 #include <torch/csrc/jit/runtime/interpreter/can_emit_inline.h>
5 
6 namespace torch::jit::interpreter {
7 
8 namespace {
9 
10 // Insert explicit prim::MethodCall nodes after prim::Enter nodes
11 // to actually call __enter__ on the object. All prim::Enter does
12 // is push the object onto the stack of currently entered objects.
13 // This is necessary because emitting two instructions for a
14 // prim::Enter nodes (one ENTER to push onto the entered objects
15 // stack and one CALL to call __enter__) does not work; the
16 // accounting that determines when to move a value out of a register
17 // is based on the number of uses it has in the IR.
insertEnterMethodCalls(Graph & g)18 void insertEnterMethodCalls(Graph& g) {
19   std::vector<Block*> block_queue;
20   std::vector<Node*> enter_nodes;
21   block_queue.emplace_back(g.block());
22 
23   // Traverse the graph while drilling down into blocks belonging to
24   // a node and add all encountered prim::Enter nodes to enter_nodes.
25   while (!block_queue.empty()) {
26     Block* block = block_queue.back();
27     block_queue.pop_back();
28 
29     for (auto node : block->nodes()) {
30       if (node->kind() == prim::Enter) {
31         enter_nodes.emplace_back(node);
32         continue;
33       }
34 
35       for (auto& node_block : node->blocks()) {
36         block_queue.emplace_back(node_block);
37       }
38     }
39   }
40 
41   // For each prim::Enter, emit a prim::MethodCall after it that actually
42   // calls __enter__ on the object.
43   for (auto& enter : enter_nodes) {
44     auto cls = enter->input(0)->type()->expect<ClassType>();
45 
46     MatchedSchema enter_matched_schema = matchSchema(
47         cls->findMethod("__enter__")->getSchema(),
48         enter->input(0)->node()->sourceRange(),
49         g,
50         {enter->input(0)},
51         {});
52 
53     Node* call = g.insertMethodCall("__enter__", enter_matched_schema)->node();
54     call->moveAfter(enter);
55     enter->replaceAllUsesWith(call);
56   }
57 }
58 
59 // insert Drop nodes to kill references for anything unused:
60 // this can happen in a few places, e.g. when a node returns
61 // many values but only one is used
62 // a, b = foo()
63 // return a
dropUnused(Block * b)64 void dropUnused(Block* b) {
65   auto createDropIfUnused = [&](ArrayRef<Value*> values) -> Node* {
66     std::vector<Value*> to_drop;
67     for (auto v : values) {
68       if (v->uses().empty() && v->node()->kind() != prim::Constant) {
69         to_drop.push_back(v);
70       }
71     }
72     if (to_drop.empty()) {
73       return nullptr;
74     }
75     return b->owningGraph()->create(prim::Drop, to_drop, 0);
76   };
77 
78   if (auto d = createDropIfUnused(b->inputs())) {
79     b->prependNode(d);
80   }
81   for (auto n : b->nodes()) {
82     if (auto d = createDropIfUnused(n->outputs())) {
83       d->insertAfter(n);
84     }
85     for (auto block : n->blocks()) {
86       dropUnused(block);
87     }
88   }
89 }
90 
91 // ensure every value has a final use in the same block where it is defined.
92 // This already true for most nodes. The exceptions are:
93 // 1. A value that is unused.
94 // 2. A value whose last use is nested in some control flow.
95 // For (1) we simply add a prim::Drop node that uses the value right after
96 // it is defined. For (2), we insert a prim::Drop right after the control
97 // flow node where the last use occurs
insertLastUses(Graph & g)98 void insertLastUses(Graph& g) {
99   // struct to share common data structures
100   struct InsertLastUses {
101     Graph& graph;
102     // have we seen this value, yet, if not, it is the last use of the value
103     std::unordered_set<Value*> seen;
104 
105     // A map from an If or Loop node to the optional Drop block that
106     // occurs directly after it to release any tensors that go out of scope
107     // when the If/Loop exits. These are created and inserted on demand.
108     std::unordered_map<Node*, Node*> drop_for_node;
109 
110     explicit InsertLastUses(Graph& g) : graph(g) {
111       scanBlock(graph.block());
112     }
113     void scanBlock(Block* b) {
114       scanNode(b->return_node());
115       for (auto n : b->nodes().reverse()) {
116         scanNode(n);
117       }
118     }
119     void scanNode(Node* n) {
120       for (auto b : n->blocks()) {
121         scanBlock(b);
122       }
123       // scan backwards so if a value is used twice in the list then it is a
124       // move
125       for (size_t i = n->inputs().size(); i > 0; --i) {
126         scanUse(n, i - 1);
127       }
128     }
129     void scanUse(Node* n, size_t i) {
130       auto v = n->inputs()[i];
131       auto inserted = seen.insert(v).second;
132       if (!inserted) {
133         return;
134       }
135 
136       // the last use of v may be in a nested block of an If or Loop statement
137       // find the node 'same_depth_node' at the same depth as the definition of
138       // v, and consider that node to be the last use of v. This ensures we do
139       // not delete nodes in nested scopes that may be executed multiple times
140       // and that nodes used on one side of an if
141       // but not the other get deleted regardless of the branch
142       // e.g.
143       // a = 4
144       // while <...>:
145       //   y = a + a
146       // drop(a)
147       // In other words, we find the first program point for v that
148       // _reverse_ dominates the definition of v, and add a drop point there.
149       Node* same_depth_node = findOwnerInBlock(n, v->node()->owningBlock());
150       AT_ASSERT(
151           same_depth_node); // failure means v is not in scope for n, use lint!
152 
153       // In the case where v and n are in the same block,
154       // we have a legit final use already.
155       if (same_depth_node == n) {
156         return;
157       }
158 
159       // in the case where the use is nested in a block
160       // add a Drop node after that block which will drop 'v'.
161       addToDropIfNotExists(
162           findOrCreateDropInstructionForNode(same_depth_node), v);
163     }
164 
165     // finds the node in block 'block' that contains in 'n'
166     // or nullptr if no such node exists, e.g.:
167     // n0: a = 4
168     // n1: if <cond>:
169     // n2:    b = a + a
170     // findOwnerInBlock(n2, n0.block()) == n1
171     Node* findOwnerInBlock(Node* n, Block* block) {
172       while (n != nullptr && block != n->owningBlock()) {
173         n = n->owningBlock()->owningNode();
174       }
175       return n;
176     }
177 
178     Node* findOrCreateDropInstructionForNode(Node* n) {
179       auto it = drop_for_node.find(n);
180       if (it == drop_for_node.end()) {
181         auto drop_node = graph.create(prim::Drop, 0);
182         drop_node->insertAfter(n);
183         it = drop_for_node.emplace(n, drop_node).first;
184       }
185       return it->second;
186     }
187 
188     void addToDropIfNotExists(Node* drop, Value* v) {
189       if (v->node()->kind() == prim::Constant) {
190         return;
191       }
192       for (auto i : drop->inputs()) {
193         // we already accounted for this use
194         if (i == v) {
195           return;
196         }
197       }
198       drop->addInput(v);
199     }
200   };
201 
202   InsertLastUses ilu(g);
203 }
204 
205 } // namespace
206 
PreprocessGraph(Graph & g)207 PreprocessGraph::PreprocessGraph(Graph& g) : graph(g.copy()) {
208   insertEnterMethodCalls(*graph);
209   dropUnused(graph->block());
210   // fill in move_flags by scanning blocks;
211   insertLastUses(*graph);
212   can_emit_inline = std::move(CanEmitInline(*graph.get()).can_emit_inline_);
213 }
214 } // namespace torch::jit::interpreter
215