1 #include <torch/csrc/jit/frontend/convert_to_ssa.h>
2 #include <torch/csrc/jit/frontend/exit_transforms.h>
3 #include <torch/csrc/jit/frontend/inline_loop_condition.h>
4 #include <torch/csrc/jit/frontend/ir_emitter.h>
5 #include <torch/csrc/jit/frontend/mini_environment.h>
6 #include <torch/csrc/jit/ir/ir.h>
7
8 namespace torch::jit {
9
10 // At the beginning of the pass the Graph has already undergone type checking,
11 // and writes or reads to a variable are emitted as Loads and Stores in the
12 // graph.
13 // a = 1
14 // print(a)
15 // is represented as:
16 // %a.1 : int = prim::Constant[value=1]()
17 // prim::Store[name="a"](%a.1)
18 // %a : int = prim::Load[name="a"]()
19 // prim::Print(%a)
20 //
21 // First, this pass recursively adds the Loads & Stores to control flow nodes
22 // Then the graph is converted to SSA form.
23
24 using ValueEnvironment = MiniEnvironment<Value*>;
25 using TypeEnvironment = MiniEnvironment<TypePtr>;
26
27 // Adds Loads & Stores to Loops & Ifs
28 struct ControlFlowLoadStores {
addBlockInputtorch::jit::ControlFlowLoadStores29 static void addBlockInput(
30 Block* b,
31 const TypePtr& type,
32 const std::string& name) {
33 auto g = b->owningGraph();
34 g->createStore(name, b->addInput(name)->setType(type))
35 ->insertAfter(b->param_node());
36 }
37
addBlockOutputtorch::jit::ControlFlowLoadStores38 static void addBlockOutput(
39 Block* exit_block,
40 const TypePtr& type,
41 const std::string& name) {
42 WithInsertPoint insert(exit_block);
43 auto g = exit_block->owningGraph();
44 auto block_exit = g->insertNode(g->createLoad(name, type))->output();
45 exit_block->registerOutput(block_exit);
46 }
47
addNodeOutputtorch::jit::ControlFlowLoadStores48 static void addNodeOutput(
49 Node* n,
50 const TypePtr& type,
51 const std::string& name) {
52 auto out = n->addOutput()->setType(type);
53 if (meaningfulName(name)) {
54 out->setDebugName(name);
55 }
56 auto g = n->owningGraph();
57 g->createStore(name, out)->insertAfter(n);
58 }
59
addNodeInputtorch::jit::ControlFlowLoadStores60 static void addNodeInput(
61 Node* n,
62 const TypePtr& type,
63 const std::string& name) {
64 auto g = n->owningGraph();
65 auto inp = g->createLoad(name, type)->insertBefore(n)->output();
66 n->addInput(inp);
67 }
68
addIfLoadStorestorch::jit::ControlFlowLoadStores69 void addIfLoadStores(Node* n) {
70 auto true_block = n->blocks().at(0);
71 auto false_block = n->blocks().at(1);
72
73 auto true_vars = addControlFlowLoadStores(true_block);
74 auto false_vars = addControlFlowLoadStores(false_block);
75 std::set<std::string> mutated_variables;
76
77 for (auto& v : true_vars->definedVariables()) {
78 if (false_vars->findInAnyFrame(v)) {
79 mutated_variables.insert(v);
80 }
81 }
82 for (auto& v : false_vars->definedVariables()) {
83 if (true_vars->findInAnyFrame(v)) {
84 mutated_variables.insert(v);
85 }
86 }
87
88 // Following the same logic as emitIfElseBlocks in ir_emitter.cpp,
89 // we emit a node output if the variable is defined in each block
90 // and the types of each block can be unified
91 for (const auto& x : mutated_variables) {
92 auto true_type = true_vars->findInAnyFrame(x);
93 auto false_type = false_vars->findInAnyFrame(x);
94 auto unified =
95 unifyTypes(true_type, false_type, /*default_to_union=*/true);
96
97 addBlockOutput(true_block, true_type, x);
98 addBlockOutput(false_block, false_type, x);
99 addNodeOutput(n, *unified, x);
100 }
101 }
102
103 // loop_carried_outputs* = Loop(max_trip_count, start_condition,
104 // loop_carried_inputs*)
105 // block0(loop_counter, loop_carried_block*) {
106 // <body>
107 // -> (continue_condition, loop_carried_block_outputs*)
108 // }
109 // all loop_carried_... lists are the same length and represent the value of
110 // loop-carried variables whose definitions are updated as the loop executes
111 // in a way that ensure single static assignment.
addLoopLoadStorestorch::jit::ControlFlowLoadStores112 void addLoopLoadStores(Node* n) {
113 auto body_block = n->blocks().at(0);
114 auto loop_vars = addControlFlowLoadStores(body_block);
115
116 for (const auto& name : loop_vars->definedVariables()) {
117 // if the variable local to the loop body, then
118 // we do not need a loop carried variable for it
119 auto parent_type = environment_stack->findInAnyFrame(name);
120 if (!parent_type) {
121 continue;
122 }
123
124 // since the loop may execute 0 or many times, the output types
125 // of the loop and the input loop carried dependencies are conservatively
126 // the union of the output of the body and the input to the loop
127 auto block_type = loop_vars->findInThisFrame(name);
128 auto unified_type = unifyTypes(parent_type, block_type).value();
129
130 // Insert a store at the beginning of the loop block, so that all
131 // loads of the variable will use the loop carried value
132 addNodeInput(n, parent_type, name);
133 addBlockInput(body_block, unified_type, name);
134 addBlockOutput(body_block, block_type, name);
135 addNodeOutput(n, unified_type, name);
136 }
137 }
138
addControlFlowLoadStorestorch::jit::ControlFlowLoadStores139 std::shared_ptr<TypeEnvironment> addControlFlowLoadStores(Block* block) {
140 pushFrame(block);
141 for (Node* n : block->nodes()) {
142 switch (n->kind()) {
143 case prim::If: {
144 addIfLoadStores(n);
145 } break;
146 case prim::Loop: {
147 addLoopLoadStores(n);
148 } break;
149 case prim::Closure: {
150 for (auto b : n->blocks()) {
151 addControlFlowLoadStores(b);
152 }
153 } break;
154 case prim::Store: {
155 environment_stack->setVar(n->s(attr::name), n->input()->type());
156 } break;
157 case prim::ComprehensionScope: {
158 addControlFlowLoadStores(n->blocks().at(0));
159 } break;
160 }
161 }
162 return popFrame();
163 }
164
pushFrametorch::jit::ControlFlowLoadStores165 void pushFrame(Block* b) {
166 environment_stack = std::make_shared<TypeEnvironment>(b, environment_stack);
167 }
168
popFrametorch::jit::ControlFlowLoadStores169 std::shared_ptr<TypeEnvironment> popFrame() {
170 auto old_frame = environment_stack;
171 environment_stack = environment_stack->next;
172 return old_frame;
173 }
174
runtorch::jit::ControlFlowLoadStores175 void run(std::shared_ptr<Graph>& graph) {
176 addControlFlowLoadStores(graph->block());
177 }
178
179 std::shared_ptr<TypeEnvironment> environment_stack = nullptr;
180 };
181
182 // Given a graph where 1) outputs have been added to control flow nodes and
183 // 2) loads and stores are represented in the graph, erase the Loads & Stores.
184 struct EraseLoadStores {
eraseBlockLoadStorestorch::jit::EraseLoadStores185 void eraseBlockLoadStores(Block* block) {
186 pushFrame(block);
187 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
188 auto n = *it;
189 it++;
190
191 switch (n->kind()) {
192 case prim::Store: {
193 environment_stack->setVar(n->s(attr::name), n->input());
194 n->destroy();
195 } break;
196 case prim::Load: {
197 auto name = n->s(attr::name);
198 auto var = environment_stack->findInAnyFrame(name);
199 TORCH_INTERNAL_ASSERT(
200 var, "Typechecking should ensure the variable name is set");
201 n->output()->replaceAllUsesWith(var);
202 n->destroy();
203 } break;
204 case prim::ComprehensionScope: {
205 // writes within a local variable scope do not leak into
206 // the rest of the graph
207 auto body = n->blocks().at(0);
208 eraseBlockLoadStores(body);
209 // inline the local variable scope into the graph
210 for (auto it_cmpr = body->nodes().begin();
211 it_cmpr != body->nodes().end();) {
212 Node* body_node = *it_cmpr;
213 it_cmpr++;
214 body_node->moveBefore(n);
215 }
216 n->destroy();
217 } break;
218 default: {
219 for (auto b : n->blocks()) {
220 eraseBlockLoadStores(b);
221 }
222 } break;
223 }
224 }
225 popFrame();
226 }
227
pushFrametorch::jit::EraseLoadStores228 void pushFrame(Block* b) {
229 environment_stack =
230 std::make_shared<ValueEnvironment>(b, environment_stack);
231 }
232
popFrametorch::jit::EraseLoadStores233 std::shared_ptr<ValueEnvironment> popFrame() {
234 auto old_frame = environment_stack;
235 environment_stack = environment_stack->next;
236 return old_frame;
237 }
238
runtorch::jit::EraseLoadStores239 void run(std::shared_ptr<Graph>& graph) {
240 eraseBlockLoadStores(graph->block());
241 }
242
243 std::shared_ptr<ValueEnvironment> environment_stack = nullptr;
244 };
245
246 // This pass transforms Breaks & Continues to be LoopContinuations,
247 // of the form LoopContinuations(%loop_continue_condition, *loop_carried_vars)
248 // Break Statements have the condition set to false, and Continue statements
249 // inline the loop condition as the first input.
250 struct LoopContinuations {
251 public:
runtorch::jit::LoopContinuations252 void run(std::shared_ptr<Graph>& graph) {
253 run(graph->block());
254 }
255
256 private:
addLoopCarriedOutputstorch::jit::LoopContinuations257 void addLoopCarriedOutputs(Node* n) {
258 auto g = n->owningGraph();
259 WithInsertPoint insert(n);
260 // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
261 auto continuation = curr_loop_->blocks().at(0)->return_node();
262 for (auto out : continuation->inputs()) {
263 auto load_node = out->node();
264 TORCH_INTERNAL_ASSERT(load_node->kind() == prim::Load);
265 auto new_load =
266 g->insertNode(g->createClone(load_node, [](Value* v) { return v; }));
267 n->addInput(new_load->output());
268 }
269 }
270
assignExitContinuationstorch::jit::LoopContinuations271 void assignExitContinuations(Block* block) {
272 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
273 Node* n = *it;
274 it++;
275 switch (n->kind()) {
276 case prim::If: {
277 assignExitContinuations(n->blocks().at(0));
278 assignExitContinuations(n->blocks().at(1));
279 } break;
280 case prim::Closure: {
281 LoopContinuations closure_block;
282 closure_block.run(n->blocks().at(0));
283 } break;
284 case prim::Loop: {
285 Node* prev_loop = curr_loop_;
286 curr_loop_ = n;
287 assignExitContinuations(n->blocks().at(0));
288 curr_loop_ = prev_loop;
289 } break;
290 case prim::ContinueStmt: {
291 auto loop_continuation =
292 graph_->create(prim::LoopContinuation, 0)->insertAfter(n);
293 auto header_block = loop_continuation->addBlock();
294 // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
295 auto pre_header = curr_loop_->blocks().at(1);
296 header_block->cloneFrom(pre_header, [](Value* v) { return v; });
297 InlineBlockBeforeNode(n, header_block);
298 loop_continuation->addInput(header_block->outputs().at(0));
299 loop_continuation->eraseBlock(0);
300 addLoopCarriedOutputs(loop_continuation);
301 n->destroy();
302 } break;
303 case prim::BreakStmt: {
304 auto loop_exit =
305 graph_->create(prim::LoopContinuation, 0)->insertAfter(n);
306 // first input is the loop continue condition - break sets false
307 loop_exit->addInput(false_val_);
308 addLoopCarriedOutputs(loop_exit);
309 n->destroy();
310 } break;
311 }
312 }
313 }
314
runtorch::jit::LoopContinuations315 void run(Block* b) {
316 {
317 graph_ = b->owningGraph();
318 WithInsertPoint guard(b->nodes().front());
319 false_val_ = graph_->insertConstant(false);
320 }
321 assignExitContinuations(b);
322 }
323
324 Graph* graph_ = nullptr;
325 Value* false_val_ = nullptr;
326 Node* curr_loop_ = nullptr;
327 };
328
329 // Converting to SSA works in multiple parts. First, we add control flow
330 // loads and stores to the graph. Now that control flow outputs are set,
331 // we can set remove Break & Continue to have the correct continuations to the
332 // end of the block (LoopContinuation). Then we inline the loop condition into
333 // the graph. Then, we erase Loads & Stores. Finally, we remove
334 // LoopContinuations from the graph.
ConvertToSSA(std::shared_ptr<Graph> & graph)335 void ConvertToSSA(std::shared_ptr<Graph>& graph) {
336 ControlFlowLoadStores ctrl;
337 ctrl.run(graph);
338 LoopContinuations exit_vars;
339 exit_vars.run(graph);
340 InlineLoopCondition(graph);
341 EraseLoadStores erase_loads_stores;
342 erase_loads_stores.run(graph);
343 TransformExits(graph);
344 }
345
346 } // namespace torch::jit
347