xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/convert_to_ssa.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/convert_to_ssa.h>
2 #include <torch/csrc/jit/frontend/exit_transforms.h>
3 #include <torch/csrc/jit/frontend/inline_loop_condition.h>
4 #include <torch/csrc/jit/frontend/ir_emitter.h>
5 #include <torch/csrc/jit/frontend/mini_environment.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 
8 namespace torch::jit {
9 
10 // At the beginning of the pass the Graph has already undergone type checking,
11 // and writes or reads to a variable are emitted as Loads and Stores in the
12 // graph.
13 //     a = 1
14 //     print(a)
15 // is represented as:
16 //     %a.1 : int = prim::Constant[value=1]()
17 //     prim::Store[name="a"](%a.1)
18 //     %a : int = prim::Load[name="a"]()
19 //     prim::Print(%a)
20 //
21 // First, this pass recursively adds the Loads & Stores to control flow nodes
22 // Then the graph is converted to SSA form.
23 
24 using ValueEnvironment = MiniEnvironment<Value*>;
25 using TypeEnvironment = MiniEnvironment<TypePtr>;
26 
27 // Adds Loads & Stores to Loops & Ifs
28 struct ControlFlowLoadStores {
addBlockInputtorch::jit::ControlFlowLoadStores29   static void addBlockInput(
30       Block* b,
31       const TypePtr& type,
32       const std::string& name) {
33     auto g = b->owningGraph();
34     g->createStore(name, b->addInput(name)->setType(type))
35         ->insertAfter(b->param_node());
36   }
37 
addBlockOutputtorch::jit::ControlFlowLoadStores38   static void addBlockOutput(
39       Block* exit_block,
40       const TypePtr& type,
41       const std::string& name) {
42     WithInsertPoint insert(exit_block);
43     auto g = exit_block->owningGraph();
44     auto block_exit = g->insertNode(g->createLoad(name, type))->output();
45     exit_block->registerOutput(block_exit);
46   }
47 
addNodeOutputtorch::jit::ControlFlowLoadStores48   static void addNodeOutput(
49       Node* n,
50       const TypePtr& type,
51       const std::string& name) {
52     auto out = n->addOutput()->setType(type);
53     if (meaningfulName(name)) {
54       out->setDebugName(name);
55     }
56     auto g = n->owningGraph();
57     g->createStore(name, out)->insertAfter(n);
58   }
59 
addNodeInputtorch::jit::ControlFlowLoadStores60   static void addNodeInput(
61       Node* n,
62       const TypePtr& type,
63       const std::string& name) {
64     auto g = n->owningGraph();
65     auto inp = g->createLoad(name, type)->insertBefore(n)->output();
66     n->addInput(inp);
67   }
68 
addIfLoadStorestorch::jit::ControlFlowLoadStores69   void addIfLoadStores(Node* n) {
70     auto true_block = n->blocks().at(0);
71     auto false_block = n->blocks().at(1);
72 
73     auto true_vars = addControlFlowLoadStores(true_block);
74     auto false_vars = addControlFlowLoadStores(false_block);
75     std::set<std::string> mutated_variables;
76 
77     for (auto& v : true_vars->definedVariables()) {
78       if (false_vars->findInAnyFrame(v)) {
79         mutated_variables.insert(v);
80       }
81     }
82     for (auto& v : false_vars->definedVariables()) {
83       if (true_vars->findInAnyFrame(v)) {
84         mutated_variables.insert(v);
85       }
86     }
87 
88     // Following the same logic as emitIfElseBlocks in ir_emitter.cpp,
89     // we emit a node output if the variable is defined in each block
90     // and the types of each block can be unified
91     for (const auto& x : mutated_variables) {
92       auto true_type = true_vars->findInAnyFrame(x);
93       auto false_type = false_vars->findInAnyFrame(x);
94       auto unified =
95           unifyTypes(true_type, false_type, /*default_to_union=*/true);
96 
97       addBlockOutput(true_block, true_type, x);
98       addBlockOutput(false_block, false_type, x);
99       addNodeOutput(n, *unified, x);
100     }
101   }
102 
103   // loop_carried_outputs* = Loop(max_trip_count, start_condition,
104   //                              loop_carried_inputs*)
105   //                    block0(loop_counter, loop_carried_block*) {
106   //                       <body>
107   //                       -> (continue_condition, loop_carried_block_outputs*)
108   //                    }
109   // all loop_carried_... lists are the same length and represent the value of
110   // loop-carried variables whose definitions are updated as the loop executes
111   // in a way that ensure single static assignment.
addLoopLoadStorestorch::jit::ControlFlowLoadStores112   void addLoopLoadStores(Node* n) {
113     auto body_block = n->blocks().at(0);
114     auto loop_vars = addControlFlowLoadStores(body_block);
115 
116     for (const auto& name : loop_vars->definedVariables()) {
117       // if the variable local to the loop body, then
118       // we do not need a loop carried variable for it
119       auto parent_type = environment_stack->findInAnyFrame(name);
120       if (!parent_type) {
121         continue;
122       }
123 
124       // since the loop may execute 0 or many times, the output types
125       // of the loop and the input loop carried dependencies are conservatively
126       // the union of the output of the body and the input to the loop
127       auto block_type = loop_vars->findInThisFrame(name);
128       auto unified_type = unifyTypes(parent_type, block_type).value();
129 
130       // Insert a store at the beginning of the loop block, so that all
131       // loads of the variable will use the loop carried value
132       addNodeInput(n, parent_type, name);
133       addBlockInput(body_block, unified_type, name);
134       addBlockOutput(body_block, block_type, name);
135       addNodeOutput(n, unified_type, name);
136     }
137   }
138 
addControlFlowLoadStorestorch::jit::ControlFlowLoadStores139   std::shared_ptr<TypeEnvironment> addControlFlowLoadStores(Block* block) {
140     pushFrame(block);
141     for (Node* n : block->nodes()) {
142       switch (n->kind()) {
143         case prim::If: {
144           addIfLoadStores(n);
145         } break;
146         case prim::Loop: {
147           addLoopLoadStores(n);
148         } break;
149         case prim::Closure: {
150           for (auto b : n->blocks()) {
151             addControlFlowLoadStores(b);
152           }
153         } break;
154         case prim::Store: {
155           environment_stack->setVar(n->s(attr::name), n->input()->type());
156         } break;
157         case prim::ComprehensionScope: {
158           addControlFlowLoadStores(n->blocks().at(0));
159         } break;
160       }
161     }
162     return popFrame();
163   }
164 
pushFrametorch::jit::ControlFlowLoadStores165   void pushFrame(Block* b) {
166     environment_stack = std::make_shared<TypeEnvironment>(b, environment_stack);
167   }
168 
popFrametorch::jit::ControlFlowLoadStores169   std::shared_ptr<TypeEnvironment> popFrame() {
170     auto old_frame = environment_stack;
171     environment_stack = environment_stack->next;
172     return old_frame;
173   }
174 
runtorch::jit::ControlFlowLoadStores175   void run(std::shared_ptr<Graph>& graph) {
176     addControlFlowLoadStores(graph->block());
177   }
178 
179   std::shared_ptr<TypeEnvironment> environment_stack = nullptr;
180 };
181 
182 // Given a graph where 1) outputs have been added to control flow nodes and
183 // 2) loads and stores are represented in the graph, erase the Loads & Stores.
184 struct EraseLoadStores {
eraseBlockLoadStorestorch::jit::EraseLoadStores185   void eraseBlockLoadStores(Block* block) {
186     pushFrame(block);
187     for (auto it = block->nodes().begin(); it != block->nodes().end();) {
188       auto n = *it;
189       it++;
190 
191       switch (n->kind()) {
192         case prim::Store: {
193           environment_stack->setVar(n->s(attr::name), n->input());
194           n->destroy();
195         } break;
196         case prim::Load: {
197           auto name = n->s(attr::name);
198           auto var = environment_stack->findInAnyFrame(name);
199           TORCH_INTERNAL_ASSERT(
200               var, "Typechecking should ensure the variable name is set");
201           n->output()->replaceAllUsesWith(var);
202           n->destroy();
203         } break;
204         case prim::ComprehensionScope: {
205           // writes within a local variable scope do not leak into
206           // the rest of the graph
207           auto body = n->blocks().at(0);
208           eraseBlockLoadStores(body);
209           // inline the local variable scope into the graph
210           for (auto it_cmpr = body->nodes().begin();
211                it_cmpr != body->nodes().end();) {
212             Node* body_node = *it_cmpr;
213             it_cmpr++;
214             body_node->moveBefore(n);
215           }
216           n->destroy();
217         } break;
218         default: {
219           for (auto b : n->blocks()) {
220             eraseBlockLoadStores(b);
221           }
222         } break;
223       }
224     }
225     popFrame();
226   }
227 
pushFrametorch::jit::EraseLoadStores228   void pushFrame(Block* b) {
229     environment_stack =
230         std::make_shared<ValueEnvironment>(b, environment_stack);
231   }
232 
popFrametorch::jit::EraseLoadStores233   std::shared_ptr<ValueEnvironment> popFrame() {
234     auto old_frame = environment_stack;
235     environment_stack = environment_stack->next;
236     return old_frame;
237   }
238 
runtorch::jit::EraseLoadStores239   void run(std::shared_ptr<Graph>& graph) {
240     eraseBlockLoadStores(graph->block());
241   }
242 
243   std::shared_ptr<ValueEnvironment> environment_stack = nullptr;
244 };
245 
246 // This pass transforms Breaks & Continues to be LoopContinuations,
247 // of the form LoopContinuations(%loop_continue_condition, *loop_carried_vars)
248 // Break Statements have the condition set to false, and Continue statements
249 // inline the loop condition as the first input.
250 struct LoopContinuations {
251  public:
runtorch::jit::LoopContinuations252   void run(std::shared_ptr<Graph>& graph) {
253     run(graph->block());
254   }
255 
256  private:
addLoopCarriedOutputstorch::jit::LoopContinuations257   void addLoopCarriedOutputs(Node* n) {
258     auto g = n->owningGraph();
259     WithInsertPoint insert(n);
260     // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
261     auto continuation = curr_loop_->blocks().at(0)->return_node();
262     for (auto out : continuation->inputs()) {
263       auto load_node = out->node();
264       TORCH_INTERNAL_ASSERT(load_node->kind() == prim::Load);
265       auto new_load =
266           g->insertNode(g->createClone(load_node, [](Value* v) { return v; }));
267       n->addInput(new_load->output());
268     }
269   }
270 
assignExitContinuationstorch::jit::LoopContinuations271   void assignExitContinuations(Block* block) {
272     for (auto it = block->nodes().begin(); it != block->nodes().end();) {
273       Node* n = *it;
274       it++;
275       switch (n->kind()) {
276         case prim::If: {
277           assignExitContinuations(n->blocks().at(0));
278           assignExitContinuations(n->blocks().at(1));
279         } break;
280         case prim::Closure: {
281           LoopContinuations closure_block;
282           closure_block.run(n->blocks().at(0));
283         } break;
284         case prim::Loop: {
285           Node* prev_loop = curr_loop_;
286           curr_loop_ = n;
287           assignExitContinuations(n->blocks().at(0));
288           curr_loop_ = prev_loop;
289         } break;
290         case prim::ContinueStmt: {
291           auto loop_continuation =
292               graph_->create(prim::LoopContinuation, 0)->insertAfter(n);
293           auto header_block = loop_continuation->addBlock();
294           // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
295           auto pre_header = curr_loop_->blocks().at(1);
296           header_block->cloneFrom(pre_header, [](Value* v) { return v; });
297           InlineBlockBeforeNode(n, header_block);
298           loop_continuation->addInput(header_block->outputs().at(0));
299           loop_continuation->eraseBlock(0);
300           addLoopCarriedOutputs(loop_continuation);
301           n->destroy();
302         } break;
303         case prim::BreakStmt: {
304           auto loop_exit =
305               graph_->create(prim::LoopContinuation, 0)->insertAfter(n);
306           // first input is the loop continue condition - break sets false
307           loop_exit->addInput(false_val_);
308           addLoopCarriedOutputs(loop_exit);
309           n->destroy();
310         } break;
311       }
312     }
313   }
314 
runtorch::jit::LoopContinuations315   void run(Block* b) {
316     {
317       graph_ = b->owningGraph();
318       WithInsertPoint guard(b->nodes().front());
319       false_val_ = graph_->insertConstant(false);
320     }
321     assignExitContinuations(b);
322   }
323 
324   Graph* graph_ = nullptr;
325   Value* false_val_ = nullptr;
326   Node* curr_loop_ = nullptr;
327 };
328 
329 // Converting to SSA works in multiple parts. First, we add control flow
330 // loads and stores to the graph. Now that control flow outputs are set,
331 // we can set remove Break & Continue to have the correct continuations to the
332 // end of the block (LoopContinuation). Then we inline the loop condition into
333 // the graph. Then, we erase Loads & Stores. Finally, we remove
334 // LoopContinuations from the graph.
ConvertToSSA(std::shared_ptr<Graph> & graph)335 void ConvertToSSA(std::shared_ptr<Graph>& graph) {
336   ControlFlowLoadStores ctrl;
337   ctrl.run(graph);
338   LoopContinuations exit_vars;
339   exit_vars.run(graph);
340   InlineLoopCondition(graph);
341   EraseLoadStores erase_loads_stores;
342   erase_loads_stores.run(graph);
343   TransformExits(graph);
344 }
345 
346 } // namespace torch::jit
347