1 #pragma once 2 3 #include <memory> 4 5 #include <torch/csrc/jit/ir/ir.h> 6 7 namespace torch::jit::interpreter { 8 /* 9 This is an optimization that reduces the number of store/load/move nodes needed 10 by recognizing that parts of the graph are simple trees like a*x + b*y. When 11 this happens it is possible to work directly off of the stack by emitting the 12 tree in a depth-first left-to-right manner: 13 load a 14 load x 15 mul # stack now is a*x 16 load b 17 load y 18 mul # stack now is a*x, b*y 19 add 20 21 can_emit_inline_[node] == true means that this node participates as a non-root 22 member of one of these trees. The code emitter will not emit this node when 23 it is encountered in the node. Instead the node is emitted in a depth first 24 traversal from where it is used in a tree. 25 26 To participate in a tree a node must have a single use (otherwise it is not 27 tree-like) and output a single value (for simplicity.) If our IR was functional, 28 these would be the only constraints. However, many nodes have side effects, so 29 we must ensure that emitting the nodes in depth first order from the tree's root 30 _does not reorder the emission of the nodes_. To ensure this, we work backward 31 from the root of a potential tree, visiting its inputs in reverse depth first 32 order, while scanning the node list backward (with the block_point node). When 33 these traversal line up we know it is safe to emit the tree in this way. We 34 ignore constant nodes, which do not have side effects. 35 */ 36 struct CanEmitInline { CanEmitInlineCanEmitInline37 explicit CanEmitInline(Graph& graph) { 38 scanBlock(graph.block()); 39 } canInlineCanEmitInline40 bool canInline(Value* v) { 41 return v->node()->kind() != prim::Param && 42 // without this a BailOut may float downstream past some later 43 // BailOut 44 // and receive a higher jf_index. Then a GUARD instruction 45 // we generated for the floated BailOut will get popped up from the 46 // instruction stack 47 // by the later BailOut in createBailoutBlock and its jf_index 48 // will become invalid. 49 v->node()->kind() != prim::TensorExprGroup && 50 v->node()->kind() != prim::TensorExprDynamicGroup && 51 v->node()->kind() != prim::StaticSubgraph && 52 v->node()->kind() != prim::CudaFusionGroup && 53 v->node()->kind() != prim::FusionGroup && 54 v->node()->kind() != prim::BailOut && v->uses().size() == 1 && 55 v->node()->outputs().size() == 1; 56 } 57 previousNonConstantCanEmitInline58 Node* previousNonConstant(Node* n) { 59 do { 60 n = n->prev(); 61 } while (n->kind() == prim::Constant); 62 return n; 63 } 64 scanValueCanEmitInline65 Node* scanValue(Node* block_point, Value* v) { 66 // this node is a candidate for inline, if our reverse scan of the 67 // node list lines up with the use of v, we know it will be emitted in 68 // tree order, and we can inlining. Scan continues for further nodes. 69 if (v->node() == block_point && canInline(v)) { 70 // since we inlined this node, we may be able to recursively inline 71 // its inputs, so we continue scanning it 72 block_point = scanNode(v->node()); 73 can_emit_inline_[v->node()] = true; 74 } 75 // if it does not line up, we can't inline 'v', and will just generate 76 // a load/move for it. However, other inputs may still appear in tree 77 // order so we continue the scan of the inputs. 78 return block_point; 79 } 80 scanNodeCanEmitInline81 Node* scanNode(Node* n) { 82 // don't bother to scan nodes we have already determined to be inline 83 if (can_emit_inline_.count(n)) { 84 return nullptr; 85 } 86 for (auto b : n->blocks()) { 87 scanBlock(b); 88 } 89 Node* block_point = previousNonConstant(n); 90 for (auto it = n->inputs().rbegin(), end = n->inputs().rend(); it != end; 91 ++it) { 92 block_point = scanValue(block_point, *it); 93 } 94 return block_point; 95 } 96 scanBlockCanEmitInline97 void scanBlock(Block* b) { 98 scanNode(b->return_node()); 99 for (auto node : b->nodes().reverse()) { 100 scanNode(node); 101 } 102 } 103 std::unordered_map<Node*, bool> can_emit_inline_; 104 }; 105 106 } // namespace torch::jit::interpreter 107