xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/graph_iterator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/ir/ir.h>
2 
3 namespace torch::jit {
4 
5 // This class facilitates depth-first iteration over all nodes in a graph.
6 class DepthFirstGraphNodeIterator {
7   Node* current_;
8 
9  public:
10   // Constructor.
DepthFirstGraphNodeIterator(std::shared_ptr<Graph> & graph)11   explicit DepthFirstGraphNodeIterator(std::shared_ptr<Graph>& graph)
12       : current_(*(graph->block()->nodes().begin())) {}
13 
14   // Moves up and to the next node (may move up recursively).
move_up()15   void move_up() {
16     if (current_ == nullptr) {
17       return;
18     }
19     // Basically we start from the child block (which is current_)
20     // and we try to find the block that owns it. Now we need to check
21     // if that block is the graph root block, or if it is an If/Loop/etc
22     // block.
23     //
24     // If it's the graph root block we can stop because there is no "up"
25     // but if it is a node (e.g. If/Loop/etc) we need to apply logic
26     // based on where we are coming from to move to the next block.
27     // This might mean that we need to traverse up again (e.g. if we've
28     // reached the end of the else clause in an if block we need to go)
29     // up to the parent block that contains the if.
30     //
31     // Similarly if we've reached the end of the parent block containing
32     // the else clause we might need to go up again so this is a recursive
33     // function.
34     //
35     //              BlockNode (if/loop/with)
36     //                       |
37     //            [Block1]  ... [Block2]
38     //                |
39     //   [ Node1, Node2, Node3, FromNode]
40     //
41     auto parent_block = current_->owningBlock();
42     TORCH_INTERNAL_ASSERT(parent_block, "Every node must be owned by a block");
43 
44     // Get the node that owns the parent block. This node has to be an if,
45     // loop, or with.
46     auto parent_node = parent_block->owningNode();
47     if (parent_node == nullptr) {
48       // If there's no node that owns this current block then we're at the
49       // top of the graph and since we're trying to move up we have reached
50       // the end of the traversal.
51       current_ = nullptr;
52       return;
53     }
54 
55     // Check the type of node this root is.
56     if (parent_node->kind() == prim::If) {
57       // Need to check if we came from the `then` branch or the `else` branch.
58       auto* then_block = parent_node->blocks().at(0);
59       auto* else_block = parent_node->blocks().at(1);
60 
61       if (parent_block == else_block) {
62         // If else block then we move to the next node in the parent block.
63         current_ = parent_node->next();
64         if (current_->kind() == prim::Return) {
65           move_up();
66         }
67       } else {
68         // If then block then move to the else block if it is not empty.
69         TORCH_INTERNAL_ASSERT(parent_block == then_block);
70         bool else_block_empty =
71             else_block->nodes().begin() == else_block->nodes().end();
72 
73         if (!else_block_empty) {
74           current_ = *(else_block->nodes().begin());
75         } else {
76           // Since it's empty we move to the next node.
77           current_ = parent_node->next();
78           if (current_->kind() == prim::Return) {
79             move_up();
80           }
81         }
82       }
83     } else if (
84         parent_node->kind() == prim::Loop ||
85         parent_node->kind() == prim::With) {
86       current_ = parent_node->next();
87       if (current_->kind() == prim::Return) {
88         move_up();
89       }
90     } else {
91       TORCH_INTERNAL_ASSERT(
92           false, "Only if/loop/with nodes should have child blocks");
93     }
94   }
95 
96   // Moves to the next adjacent node or up in to the parent if that is not
97   // possible.
move_next()98   void move_next() {
99     if (current_ == nullptr) {
100       return;
101     }
102 
103     // Increment to the next node in the current block.
104     current_ = current_->next();
105 
106     // Check if we're at the end of the block. If so we need
107     // to move upwards (if it makes sense to).
108     if (current_->kind() == prim::Return) {
109       move_up();
110     }
111   }
112 
113   // Moves to the next node in the graph into children if it can.
move_into()114   void move_into() {
115     if (current_ == nullptr) {
116       return;
117     }
118 
119     // Check if we're currently on a node that contains sub-nodes.
120     if (current_->kind() == prim::If || current_->kind() == prim::Loop ||
121         current_->kind() == prim::With) {
122       auto* first_block = current_->blocks().at(0);
123       current_ = first_block->param_node();
124       // Move next will move up and out of the current node if the block is
125       // empty. `move_up` which is called by `move_next` will handle the
126       // difference between If, Loop, and With blocks appropriately.
127       move_next();
128     } else {
129       move_next();
130     }
131   }
132 
133   // Get the next Node in the graph. \returns nullptr if there are no nodes
134   // left.
next()135   Node* next() {
136     auto result = current_;
137 
138     // Try move into the existing node to set the next node to be returned.
139     // This will move to the next node if not possible, or move upwards and
140     // to the next.
141     move_into();
142 
143     return result;
144   }
145 };
146 
147 } // namespace torch::jit
148