xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/interpreter/can_emit_inline.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <memory>
4 
5 #include <torch/csrc/jit/ir/ir.h>
6 
7 namespace torch::jit::interpreter {
8 /*
9 This is an optimization that reduces the number of store/load/move nodes needed
10 by recognizing that parts of the graph are simple trees like a*x + b*y. When
11 this happens it is possible to work directly off of the stack by emitting the
12 tree in a depth-first left-to-right manner:
13   load a
14   load x
15   mul # stack now is a*x
16   load b
17   load y
18   mul # stack now is a*x, b*y
19   add
20 
21 can_emit_inline_[node] == true means that this node participates as a non-root
22 member of one of these trees. The code emitter will not emit this node when
23 it is encountered in the node. Instead the node is emitted in a depth first
24 traversal from where it is used in a tree.
25 
26 To participate in a tree a node must have a single use (otherwise it is not
27 tree-like) and output a single value (for simplicity.) If our IR was functional,
28 these would be the only constraints. However, many nodes have side effects, so
29 we must ensure that emitting the nodes in depth first order from the tree's root
30 _does not reorder the emission of the nodes_. To ensure this, we work backward
31 from the root of a potential tree, visiting its inputs in reverse depth first
32 order, while scanning the node list backward (with the block_point node). When
33 these traversal line up we know it is safe to emit the tree in this way. We
34 ignore constant nodes, which do not have side effects.
35 */
36 struct CanEmitInline {
CanEmitInlineCanEmitInline37   explicit CanEmitInline(Graph& graph) {
38     scanBlock(graph.block());
39   }
canInlineCanEmitInline40   bool canInline(Value* v) {
41     return v->node()->kind() != prim::Param &&
42         // without this a BailOut may float downstream past some later
43         // BailOut
44         // and receive a higher jf_index. Then a GUARD instruction
45         // we generated for the floated BailOut will get popped up from the
46         // instruction stack
47         // by the later BailOut in createBailoutBlock and its jf_index
48         // will become invalid.
49         v->node()->kind() != prim::TensorExprGroup &&
50         v->node()->kind() != prim::TensorExprDynamicGroup &&
51         v->node()->kind() != prim::StaticSubgraph &&
52         v->node()->kind() != prim::CudaFusionGroup &&
53         v->node()->kind() != prim::FusionGroup &&
54         v->node()->kind() != prim::BailOut && v->uses().size() == 1 &&
55         v->node()->outputs().size() == 1;
56   }
57 
previousNonConstantCanEmitInline58   Node* previousNonConstant(Node* n) {
59     do {
60       n = n->prev();
61     } while (n->kind() == prim::Constant);
62     return n;
63   }
64 
scanValueCanEmitInline65   Node* scanValue(Node* block_point, Value* v) {
66     // this node is a candidate for inline, if our reverse scan of the
67     // node list lines up with the use of v, we know it will be emitted in
68     // tree order, and we can inlining. Scan continues for further nodes.
69     if (v->node() == block_point && canInline(v)) {
70       // since we inlined this node, we may be able to recursively inline
71       // its inputs, so we continue scanning it
72       block_point = scanNode(v->node());
73       can_emit_inline_[v->node()] = true;
74     }
75     // if it does not line up, we can't inline 'v', and will just generate
76     // a load/move for it. However, other inputs may still appear in tree
77     // order so we continue the scan of the inputs.
78     return block_point;
79   }
80 
scanNodeCanEmitInline81   Node* scanNode(Node* n) {
82     // don't bother to scan nodes we have already determined to be inline
83     if (can_emit_inline_.count(n)) {
84       return nullptr;
85     }
86     for (auto b : n->blocks()) {
87       scanBlock(b);
88     }
89     Node* block_point = previousNonConstant(n);
90     for (auto it = n->inputs().rbegin(), end = n->inputs().rend(); it != end;
91          ++it) {
92       block_point = scanValue(block_point, *it);
93     }
94     return block_point;
95   }
96 
scanBlockCanEmitInline97   void scanBlock(Block* b) {
98     scanNode(b->return_node());
99     for (auto node : b->nodes().reverse()) {
100       scanNode(node);
101     }
102   }
103   std::unordered_map<Node*, bool> can_emit_inline_;
104 };
105 
106 } // namespace torch::jit::interpreter
107