xref: /aosp_15_r20/external/tensorflow/tensorflow/core/graph/control_flow.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/control_flow.h"
17 
18 #include <deque>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/graph/node_builder.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 
26 namespace tensorflow {
27 namespace {
28 // Information about a loop frame structure.
29 struct Frame {
30   string name;
31 
32   // Pointer to the parent frame. The root frame has a pointer to itself.
33   Frame* parent = nullptr;
34 
35   // The loop condition of the loop. There should be exactly one loop condition
36   // in every loop.
37   const Node* loop_cond = nullptr;
38 };
39 
40 // Verify that the ControlFlowInfo of the graph has valid loop structure.
ValidateControlFlowInfo(const Graph * graph,const std::vector<ControlFlowInfo> & cf_info)41 Status ValidateControlFlowInfo(const Graph* graph,
42                                const std::vector<ControlFlowInfo>& cf_info) {
43   std::unordered_map<string, Frame> frames;
44   for (const Node* node : graph->op_nodes()) {
45     const ControlFlowInfo& cf = cf_info[node->id()];
46     if (!cf.frame || !cf.parent_frame) {
47       // Skip nodes unreachable from the source node. They might be pruned
48       // later.
49       continue;
50     }
51 
52     Frame& frame = frames[cf.frame_name];
53     Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
54     if (frame.parent == nullptr) {
55       frame.parent = parent;
56       frame.name = cf.frame_name;
57     } else if (frame.parent != parent) {
58       return errors::Internal(
59           "Invalid loop structure: Mismatched parent frames for \"",
60           cf.frame_name, "\": \"", parent->name, "\" vs \"", frame.parent->name,
61           "\". The node giving this error: ", FormatNodeForError(*node),
62           ". This is an internal bug, please file a bug report with "
63           "instructions on how to reproduce the error.");
64     }
65     if (IsLoopCond(node)) {
66       // ForwardLoopCounter runs in the same frame as the forward loop and
67       // BackPropLoopCounter runs in the same frame as the backprop loop. They
68       // are the only cases that multiple loops share the same frame.
69       if (frame.loop_cond &&
70           !absl::StrContains(frame.loop_cond->name(), "LoopCounter") &&
71           !absl::StrContains(node->name(), "LoopCounter")) {
72         return errors::InvalidArgument(
73             "Invalid loop structure: Loop \"", cf.frame_name,
74             "\" has more than one LoopCond node: ", FormatNodeForError(*node),
75             " and ", FormatNodeForError(*frame.loop_cond),
76             ". This is an internal bug, please file a bug report with "
77             "instructions on how to reproduce the error.");
78       }
79       frame.loop_cond = node;
80     }
81   }
82   return OkStatus();
83 }
84 }  // namespace
85 
BuildControlFlowInfo(const Graph * g,std::vector<ControlFlowInfo> * info,std::vector<string> * unreachable_nodes)86 Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
87                             std::vector<string>* unreachable_nodes) {
88   info->clear();
89   info->resize(g->num_node_ids());
90 
91   std::vector<const Node*> parent_nodes;
92   parent_nodes.resize(g->num_node_ids());
93 
94   const Node* src_node = g->source_node();
95   ControlFlowInfo& src_info = (*info)[src_node->id()];
96   src_info.frame = src_node;
97   src_info.parent_frame = src_node;
98 
99   string frame_name;
100   std::deque<const Node*> ready;
101   ready.push_back(src_node);
102   while (!ready.empty()) {
103     const Node* curr_node = ready.front();
104     ready.pop_front();
105     const ControlFlowInfo& curr_info = (*info)[curr_node->id()];
106     const Node* frame = curr_info.frame;
107     const Node* parent = curr_info.parent_frame;
108     frame_name = curr_info.frame_name;
109 
110     if (IsExit(curr_node)) {
111       // Exit to the parent frame.
112       const ControlFlowInfo& parent_info = (*info)[parent->id()];
113       frame = parent_info.frame;
114       parent = parent_info.parent_frame;
115       frame_name = parent_info.frame_name;
116     }
117 
118     for (const Edge* out_edge : curr_node->out_edges()) {
119       const Node* out = out_edge->dst();
120       int out_id = out->id();
121       ControlFlowInfo* out_info = &(*info)[out_id];
122       const Node* out_parent = out_info->parent_frame;
123       bool is_visited = (parent_nodes[out_id] != nullptr);
124 
125       // Skip Sink/Source nodes.
126       if (!out->IsOp()) continue;
127 
128       // Add to ready queue if not seen.
129       if (!is_visited) {
130         parent_nodes[out->id()] = curr_node;
131         ready.push_back(out);
132       }
133 
134       // Process the node 'out'.
135       if (IsEnter(out)) {
136         if (is_visited) {
137           const string& parent_frame = (*info)[out_parent->id()].frame_name;
138           if (parent_frame != frame_name) {
139             return errors::InvalidArgument(
140                 FormatNodeForError(*out),
141                 " has inputs from different frames. The input ",
142                 FormatNodeForError(*curr_node), " is in frame '", frame_name,
143                 "'. The input ", FormatNodeForError(*parent_nodes[out->id()]),
144                 " is in frame '", parent_frame, "'.");
145           }
146         } else {
147           out_info->frame = out;
148           out_info->parent_frame = frame;
149           TF_RETURN_IF_ERROR(
150               GetNodeAttr(out->attrs(), "frame_name", &out_info->frame_name));
151           if (out_info->frame_name.empty()) {
152             return errors::InvalidArgument("The Enter ",
153                                            FormatNodeForError(*out),
154                                            " must have a frame name.");
155           }
156         }
157       } else {
158         if (is_visited) {
159           if (out_info->frame_name != frame_name) {
160             return errors::InvalidArgument(
161                 FormatNodeForError(*out),
162                 " has inputs from different frames. The input ",
163                 FormatNodeForError(*curr_node), " is in frame '", frame_name,
164                 "'. The input ", FormatNodeForError(*parent_nodes[out->id()]),
165                 " is in frame '", out_info->frame_name, "'.");
166           }
167         } else {
168           out_info->frame = frame;
169           out_info->parent_frame = parent;
170           out_info->frame_name = frame_name;
171         }
172       }
173     }
174   }
175   if (unreachable_nodes) {
176     for (const Node* node : g->op_nodes()) {
177       if (!parent_nodes[node->id()]) {
178         unreachable_nodes->push_back(node->name());
179       }
180     }
181   }
182   TF_RETURN_IF_ERROR(ValidateControlFlowInfo(g, *info));
183   return OkStatus();
184 }
185 
186 }  // namespace tensorflow
187