1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils/topological_sort.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <unordered_map>
21
22 #include "absl/types/span.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/grappler/graph_topology_view.h"
25 #include "tensorflow/core/grappler/graph_view.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/lib/core/status.h"
29
30 namespace tensorflow {
31 namespace grappler {
32
33 namespace {
34
MakeEphemeralEdges(const absl::Span<const TopologicalDependency> extra_dependencies)35 std::vector<GraphView::Edge> MakeEphemeralEdges(
36 const absl::Span<const TopologicalDependency> extra_dependencies) {
37 std::vector<GraphView::Edge> ephemeral_edges;
38 ephemeral_edges.reserve(extra_dependencies.size());
39 for (const auto& dep : extra_dependencies) {
40 ephemeral_edges.emplace_back(
41 GraphView::OutputPort(dep.from, Graph::kControlSlot),
42 GraphView::InputPort(dep.to, Graph::kControlSlot));
43 }
44 return ephemeral_edges;
45 }
46
47 // Kahn's algorithm is implemented.
48 // For details, see https://en.wikipedia.org/wiki/Topological_sorting
ComputeTopologicalOrder(const GraphDef & graph,const absl::Span<const TopologicalDependency> extra_dependencies,std::vector<int> * ready_nodes)49 Status ComputeTopologicalOrder(
50 const GraphDef& graph,
51 const absl::Span<const TopologicalDependency> extra_dependencies,
52 std::vector<int>* ready_nodes) {
53 GraphTopologyView graph_view;
54 TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph(
55 graph, MakeEphemeralEdges(extra_dependencies)));
56
57 // Keep track of how many inputs are ready for the given node.
58 std::vector<int> num_ready_inputs(graph.node_size(), 0);
59
60 // We'll push index of ready nodes to this output vector.
61 ready_nodes->reserve(graph.node_size());
62
63 int front = 0;
64 int back = 0;
65
66 for (int i = 0; i < graph.node_size(); i++) {
67 if (graph_view.GetFanin(i).empty()) {
68 ready_nodes->push_back(i);
69 back++;
70 }
71 if (IsMerge(graph.node(i))) {
72 for (int input : graph_view.GetFanin(i)) {
73 if (IsNextIteration(graph.node(input))) {
74 num_ready_inputs[i]++;
75 }
76 }
77 }
78 }
79
80 while (front != back) {
81 int ready_node = (*ready_nodes)[front];
82 for (int fanout : graph_view.GetFanout(ready_node)) {
83 ++num_ready_inputs[fanout];
84 const int max_size = graph_view.GetFanin(fanout).size();
85 if (num_ready_inputs[fanout] == max_size) {
86 ready_nodes->push_back(fanout);
87 ++back;
88 }
89 }
90 ++front;
91 }
92
93 if (back != graph_view.num_nodes()) {
94 if (VLOG_IS_ON(1)) {
95 VLOG(1) << "The graph couldn't be sorted in topological order. Stalled "
96 "at node = "
97 << graph.node(back).DebugString();
98 for (int i = 0; i < graph_view.num_nodes(); ++i) {
99 const int max_size = graph_view.GetFanin(i).size();
100 if (num_ready_inputs[i] != max_size) {
101 VLOG(1) << "Node not ready: " << graph.node(i).DebugString();
102 }
103 }
104 }
105 return errors::InvalidArgument(
106 "The graph couldn't be sorted in topological order.");
107 }
108 return OkStatus();
109 }
110
111 } // namespace
112
ComputeTopologicalOrder(const GraphDef & graph,const absl::Span<const TopologicalDependency> extra_dependencies,std::vector<const NodeDef * > * topo_order)113 Status ComputeTopologicalOrder(
114 const GraphDef& graph,
115 const absl::Span<const TopologicalDependency> extra_dependencies,
116 std::vector<const NodeDef*>* topo_order) {
117 std::vector<int> ready_nodes;
118 TF_RETURN_IF_ERROR(
119 ComputeTopologicalOrder(graph, extra_dependencies, &ready_nodes));
120
121 topo_order->reserve(ready_nodes.size());
122 for (int ready_node_idx : ready_nodes) {
123 topo_order->emplace_back(&graph.node(ready_node_idx));
124 }
125
126 return OkStatus();
127 }
128
ComputeTopologicalOrder(const GraphDef & graph,std::vector<const NodeDef * > * topo_order)129 Status ComputeTopologicalOrder(const GraphDef& graph,
130 std::vector<const NodeDef*>* topo_order) {
131 return ComputeTopologicalOrder(graph, {}, topo_order);
132 }
133
ReversedTopologicalSort(GraphDef * graph)134 Status ReversedTopologicalSort(GraphDef* graph) {
135 std::vector<int> ready_nodes;
136 TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes));
137 std::reverse(ready_nodes.begin(), ready_nodes.end());
138 PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
139 return OkStatus();
140 }
141
TopologicalSort(GraphDef * graph)142 Status TopologicalSort(GraphDef* graph) {
143 std::vector<int> ready_nodes;
144 TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes));
145 PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
146 return OkStatus();
147 }
148
149 } // namespace grappler
150 } // namespace tensorflow
151