xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/utils/topological_sort.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/topological_sort.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <unordered_map>
21 
22 #include "absl/types/span.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/grappler/graph_topology_view.h"
25 #include "tensorflow/core/grappler/graph_view.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/lib/core/status.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 
33 namespace {
34 
MakeEphemeralEdges(const absl::Span<const TopologicalDependency> extra_dependencies)35 std::vector<GraphView::Edge> MakeEphemeralEdges(
36     const absl::Span<const TopologicalDependency> extra_dependencies) {
37   std::vector<GraphView::Edge> ephemeral_edges;
38   ephemeral_edges.reserve(extra_dependencies.size());
39   for (const auto& dep : extra_dependencies) {
40     ephemeral_edges.emplace_back(
41         GraphView::OutputPort(dep.from, Graph::kControlSlot),
42         GraphView::InputPort(dep.to, Graph::kControlSlot));
43   }
44   return ephemeral_edges;
45 }
46 
47 // Kahn's algorithm is implemented.
48 // For details, see https://en.wikipedia.org/wiki/Topological_sorting
ComputeTopologicalOrder(const GraphDef & graph,const absl::Span<const TopologicalDependency> extra_dependencies,std::vector<int> * ready_nodes)49 Status ComputeTopologicalOrder(
50     const GraphDef& graph,
51     const absl::Span<const TopologicalDependency> extra_dependencies,
52     std::vector<int>* ready_nodes) {
53   GraphTopologyView graph_view;
54   TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph(
55       graph, MakeEphemeralEdges(extra_dependencies)));
56 
57   // Keep track of how many inputs are ready for the given node.
58   std::vector<int> num_ready_inputs(graph.node_size(), 0);
59 
60   // We'll push index of ready nodes to this output vector.
61   ready_nodes->reserve(graph.node_size());
62 
63   int front = 0;
64   int back = 0;
65 
66   for (int i = 0; i < graph.node_size(); i++) {
67     if (graph_view.GetFanin(i).empty()) {
68       ready_nodes->push_back(i);
69       back++;
70     }
71     if (IsMerge(graph.node(i))) {
72       for (int input : graph_view.GetFanin(i)) {
73         if (IsNextIteration(graph.node(input))) {
74           num_ready_inputs[i]++;
75         }
76       }
77     }
78   }
79 
80   while (front != back) {
81     int ready_node = (*ready_nodes)[front];
82     for (int fanout : graph_view.GetFanout(ready_node)) {
83       ++num_ready_inputs[fanout];
84       const int max_size = graph_view.GetFanin(fanout).size();
85       if (num_ready_inputs[fanout] == max_size) {
86         ready_nodes->push_back(fanout);
87         ++back;
88       }
89     }
90     ++front;
91   }
92 
93   if (back != graph_view.num_nodes()) {
94     if (VLOG_IS_ON(1)) {
95       VLOG(1) << "The graph couldn't be sorted in topological order. Stalled "
96                  "at node = "
97               << graph.node(back).DebugString();
98       for (int i = 0; i < graph_view.num_nodes(); ++i) {
99         const int max_size = graph_view.GetFanin(i).size();
100         if (num_ready_inputs[i] != max_size) {
101           VLOG(1) << "Node not ready: " << graph.node(i).DebugString();
102         }
103       }
104     }
105     return errors::InvalidArgument(
106         "The graph couldn't be sorted in topological order.");
107   }
108   return OkStatus();
109 }
110 
111 }  // namespace
112 
ComputeTopologicalOrder(const GraphDef & graph,const absl::Span<const TopologicalDependency> extra_dependencies,std::vector<const NodeDef * > * topo_order)113 Status ComputeTopologicalOrder(
114     const GraphDef& graph,
115     const absl::Span<const TopologicalDependency> extra_dependencies,
116     std::vector<const NodeDef*>* topo_order) {
117   std::vector<int> ready_nodes;
118   TF_RETURN_IF_ERROR(
119       ComputeTopologicalOrder(graph, extra_dependencies, &ready_nodes));
120 
121   topo_order->reserve(ready_nodes.size());
122   for (int ready_node_idx : ready_nodes) {
123     topo_order->emplace_back(&graph.node(ready_node_idx));
124   }
125 
126   return OkStatus();
127 }
128 
ComputeTopologicalOrder(const GraphDef & graph,std::vector<const NodeDef * > * topo_order)129 Status ComputeTopologicalOrder(const GraphDef& graph,
130                                std::vector<const NodeDef*>* topo_order) {
131   return ComputeTopologicalOrder(graph, {}, topo_order);
132 }
133 
ReversedTopologicalSort(GraphDef * graph)134 Status ReversedTopologicalSort(GraphDef* graph) {
135   std::vector<int> ready_nodes;
136   TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes));
137   std::reverse(ready_nodes.begin(), ready_nodes.end());
138   PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
139   return OkStatus();
140 }
141 
TopologicalSort(GraphDef * graph)142 Status TopologicalSort(GraphDef* graph) {
143   std::vector<int> ready_nodes;
144   TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes));
145   PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
146   return OkStatus();
147 }
148 
149 }  // namespace grappler
150 }  // namespace tensorflow
151