1 #include <torch/csrc/jit/ir/ir.h> 2 3 namespace torch::jit { 4 5 // This class facilitates depth-first iteration over all nodes in a graph. 6 class DepthFirstGraphNodeIterator { 7 Node* current_; 8 9 public: 10 // Constructor. DepthFirstGraphNodeIterator(std::shared_ptr<Graph> & graph)11 explicit DepthFirstGraphNodeIterator(std::shared_ptr<Graph>& graph) 12 : current_(*(graph->block()->nodes().begin())) {} 13 14 // Moves up and to the next node (may move up recursively). move_up()15 void move_up() { 16 if (current_ == nullptr) { 17 return; 18 } 19 // Basically we start from the child block (which is current_) 20 // and we try to find the block that owns it. Now we need to check 21 // if that block is the graph root block, or if it is an If/Loop/etc 22 // block. 23 // 24 // If it's the graph root block we can stop because there is no "up" 25 // but if it is a node (e.g. If/Loop/etc) we need to apply logic 26 // based on where we are coming from to move to the next block. 27 // This might mean that we need to traverse up again (e.g. if we've 28 // reached the end of the else clause in an if block we need to go) 29 // up to the parent block that contains the if. 30 // 31 // Similarly if we've reached the end of the parent block containing 32 // the else clause we might need to go up again so this is a recursive 33 // function. 34 // 35 // BlockNode (if/loop/with) 36 // | 37 // [Block1] ... [Block2] 38 // | 39 // [ Node1, Node2, Node3, FromNode] 40 // 41 auto parent_block = current_->owningBlock(); 42 TORCH_INTERNAL_ASSERT(parent_block, "Every node must be owned by a block"); 43 44 // Get the node that owns the parent block. This node has to be an if, 45 // loop, or with. 46 auto parent_node = parent_block->owningNode(); 47 if (parent_node == nullptr) { 48 // If there's no node that owns this current block then we're at the 49 // top of the graph and since we're trying to move up we have reached 50 // the end of the traversal. 51 current_ = nullptr; 52 return; 53 } 54 55 // Check the type of node this root is. 56 if (parent_node->kind() == prim::If) { 57 // Need to check if we came from the `then` branch or the `else` branch. 58 auto* then_block = parent_node->blocks().at(0); 59 auto* else_block = parent_node->blocks().at(1); 60 61 if (parent_block == else_block) { 62 // If else block then we move to the next node in the parent block. 63 current_ = parent_node->next(); 64 if (current_->kind() == prim::Return) { 65 move_up(); 66 } 67 } else { 68 // If then block then move to the else block if it is not empty. 69 TORCH_INTERNAL_ASSERT(parent_block == then_block); 70 bool else_block_empty = 71 else_block->nodes().begin() == else_block->nodes().end(); 72 73 if (!else_block_empty) { 74 current_ = *(else_block->nodes().begin()); 75 } else { 76 // Since it's empty we move to the next node. 77 current_ = parent_node->next(); 78 if (current_->kind() == prim::Return) { 79 move_up(); 80 } 81 } 82 } 83 } else if ( 84 parent_node->kind() == prim::Loop || 85 parent_node->kind() == prim::With) { 86 current_ = parent_node->next(); 87 if (current_->kind() == prim::Return) { 88 move_up(); 89 } 90 } else { 91 TORCH_INTERNAL_ASSERT( 92 false, "Only if/loop/with nodes should have child blocks"); 93 } 94 } 95 96 // Moves to the next adjacent node or up in to the parent if that is not 97 // possible. move_next()98 void move_next() { 99 if (current_ == nullptr) { 100 return; 101 } 102 103 // Increment to the next node in the current block. 104 current_ = current_->next(); 105 106 // Check if we're at the end of the block. If so we need 107 // to move upwards (if it makes sense to). 108 if (current_->kind() == prim::Return) { 109 move_up(); 110 } 111 } 112 113 // Moves to the next node in the graph into children if it can. move_into()114 void move_into() { 115 if (current_ == nullptr) { 116 return; 117 } 118 119 // Check if we're currently on a node that contains sub-nodes. 120 if (current_->kind() == prim::If || current_->kind() == prim::Loop || 121 current_->kind() == prim::With) { 122 auto* first_block = current_->blocks().at(0); 123 current_ = first_block->param_node(); 124 // Move next will move up and out of the current node if the block is 125 // empty. `move_up` which is called by `move_next` will handle the 126 // difference between If, Loop, and With blocks appropriately. 127 move_next(); 128 } else { 129 move_next(); 130 } 131 } 132 133 // Get the next Node in the graph. \returns nullptr if there are no nodes 134 // left. next()135 Node* next() { 136 auto result = current_; 137 138 // Try move into the existing node to set the next node to be returned. 139 // This will move to the next node if not possible, or move upwards and 140 // to the next. 141 move_into(); 142 143 return result; 144 } 145 }; 146 147 } // namespace torch::jit 148