1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
17
18 #include <unordered_set>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/grappler/costs/graph_properties.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/grappler/utils/topological_sort.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/util/device_name_utils.h"
36
37 namespace tensorflow {
38 namespace grappler {
39
40 namespace {
41
RemoveControlInput(NodeDef * node,const string & control_input_to_remove,NodeMap * node_map)42 bool RemoveControlInput(NodeDef* node, const string& control_input_to_remove,
43 NodeMap* node_map) {
44 for (int pos = node->input_size() - 1; pos >= 0; --pos) {
45 const string& input = node->input(pos);
46 if (input[0] != '^') break;
47 if (input == control_input_to_remove) {
48 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
49 node->mutable_input()->RemoveLast();
50 node_map->RemoveOutput(NodeName(input), node->name());
51 return true;
52 }
53 }
54 return false;
55 }
56
57 } // namespace
58
SafeToRemoveIdentity(const NodeDef & node) const59 bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
60 if (!IsIdentity(node) && !IsIdentityN(node)) {
61 return true;
62 }
63
64 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
65 return false;
66 }
67 if (!fetch_nodes_known_) {
68 // The output values of this node may be needed.
69 return false;
70 }
71
72 if (node.input_size() < 1) {
73 // Node lacks input, is invalid
74 return false;
75 }
76
77 const NodeDef* input = node_map_->GetNode(NodeName(node.input(0)));
78 if (input == nullptr) {
79 VLOG(1) << "node = " << node.name() << " input = " << node.input(0);
80 return false;
81 }
82 // Don't remove Identity nodes corresponding to Variable reads or following
83 // Recv.
84 if (IsVariable(*input) || IsRecv(*input)) {
85 return false;
86 }
87 for (const auto& consumer : node_map_->GetOutputs(node.name())) {
88 if (node.input_size() > 1 && (IsRetval(*consumer) || IsMerge(*consumer))) {
89 return false;
90 }
91 if (IsSwitch(*input)) {
92 for (const string& consumer_input : consumer->input()) {
93 if (consumer_input == AsControlDependency(node.name())) {
94 return false;
95 }
96 }
97 }
98 }
99 return true;
100 }
101
SafeToConvertToNoOp(const NodeDef & node) const102 bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
103 if (HasRegularOutputs(node, *node_map_)) {
104 // The output values of this node may be needed.
105 VLOG(3) << "Not safe to convert '" << node.name()
106 << " to NoOp. Node has outputs.";
107 return false;
108 }
109 if (!fetch_nodes_known_) {
110 VLOG(3) << "Not safe to convert '" << node.name()
111 << " to NoOp. Fetches unknown.";
112 return false;
113 }
114 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
115 VLOG(3) << "Not safe to convert to NoOp: " << node.name()
116 << " is in preserve set.";
117 return false;
118 }
119 if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node)) {
120 VLOG(3) << "Not safe to convert '" << node.name()
121 << " to NoOp. Node modifies frame info.";
122 return false;
123 }
124 // Ops reading variables are marked as stateful, but are safe to remove if
125 // redundant.
126 static const absl::flat_hash_set<string>* gather_ops =
127 new absl::flat_hash_set<string>{"Gather", "GatherV2", "GatherNd",
128 "ResourceGather", "ResourceGatherNd"};
129 const bool is_variable_read =
130 IsReadVariableOp(node) || IsReadVariablesOp(node) ||
131 gather_ops->find(node.op()) != gather_ops->end();
132 if (!is_variable_read && !IsFreeOfSideEffect(node)) {
133 VLOG(3) << "Not safe to convert '" << node.name()
134 << " to NoOp. Node has side effect.";
135 return false;
136 }
137 if (node.op().rfind("Submodel", 0) == 0) {
138 return false;
139 }
140 const OpDef* op_def = nullptr;
141 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
142 if (!status.ok() || op_def->output_arg_size() == 0) {
143 return false;
144 }
145 const std::unordered_set<string> do_not_rewrite_ops{
146 "Assert", "CheckNumerics", "_Retval",
147 "_Arg", "_ParallelConcatUpdate", "TPUExecute",
148 "TPUCompile", "ControlTrigger"};
149 if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) {
150 return false;
151 }
152 if (!SafeToRemoveIdentity(node)) {
153 return false;
154 }
155 return true;
156 }
157
NumEdgesIfBypassed(const NodeDef & node,const std::vector<NodeDef * > & output_nodes) const158 int DependencyOptimizer::NumEdgesIfBypassed(
159 const NodeDef& node, const std::vector<NodeDef*>& output_nodes) const {
160 const bool is_multi_input_identity_n =
161 IsIdentityN(node) && !IsIdentityNSingleInput(node);
162 const int num_outputs = output_nodes.size();
163 const int num_inputs = node.input_size();
164
165 if (is_multi_input_identity_n) {
166 // multi-input identity_n with input/output control dependencies will likely
167 // increase number of edges after optimization.
168 int num_edges_if_bypassed(0);
169 for (const string& input_node_name : node.input()) {
170 if (IsControlInput(input_node_name)) {
171 num_edges_if_bypassed += num_outputs;
172 } else {
173 ++num_edges_if_bypassed;
174 }
175 }
176
177 for (auto consumer : output_nodes) {
178 for (int j = 0; j < consumer->input_size(); ++j) {
179 const TensorId consumer_input = ParseTensorName(consumer->input(j));
180 if (consumer_input.node() == node.name()) {
181 if (IsControlInput(consumer_input)) {
182 num_edges_if_bypassed += num_inputs;
183 } else {
184 ++num_edges_if_bypassed;
185 }
186 }
187 }
188 }
189 return num_edges_if_bypassed;
190 } else {
191 return num_inputs * num_outputs;
192 }
193 }
194
BypassingNodeIsBeneficial(const NodeDef & node,const std::vector<NodeDef * > & input_nodes,const std::vector<NodeDef * > & output_nodes) const195 bool DependencyOptimizer::BypassingNodeIsBeneficial(
196 const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
197 const std::vector<NodeDef*>& output_nodes) const {
198 const bool is_identity = IsIdentity(node) || IsIdentityNSingleInput(node);
199 const bool is_multi_input_identity_n =
200 IsIdentityN(node) && !IsIdentityNSingleInput(node);
201 const int num_outputs = output_nodes.size();
202 const int num_inputs = node.input_size();
203
204 if (NumEdgesIfBypassed(node, output_nodes) > num_inputs + num_outputs) {
205 return false;
206 }
207
208 // Make sure that we don't increase the number of edges that cross
209 // device boundaries.
210 if ((num_inputs == 1 && num_outputs > 1 &&
211 input_nodes[0]->device() != node.device()) ||
212 (num_inputs > 1 && num_outputs == 1 &&
213 output_nodes[0]->device() != node.device())) {
214 return false;
215 }
216
217 // TODO(rmlarsen): Not all device crossings are equally expensive.
218 // Assign a cost to each based on device affinity and compute a
219 // cost before and after.
220 const string& node_dev = node.device();
221 int num_cross_in = 0;
222 for (NodeDef* input_node : input_nodes) {
223 num_cross_in += static_cast<int>(input_node->device() != node_dev);
224 }
225 int num_cross_out = 0;
226 for (NodeDef* output_node : output_nodes) {
227 num_cross_out += static_cast<int>(output_node->device() != node_dev);
228 }
229
230 // Make sure we do not increase the number of device crossings.
231 const int num_cross_before = num_cross_in + num_cross_out;
232 int num_cross_after = 0;
233 for (NodeDef* input_node : input_nodes) {
234 for (NodeDef* output_node : output_nodes) {
235 num_cross_after +=
236 static_cast<int>(input_node->device() != output_node->device());
237 }
238 }
239 if (num_cross_after > num_cross_before) {
240 return false;
241 }
242
243 if ((is_identity || is_multi_input_identity_n) && num_cross_in > 0 &&
244 num_cross_out > 0 && num_cross_after > 0) {
245 // This identity node follows a device crossing, so it might be
246 // following a _Recv node after partitioning. Do not remove such nodes,
247 // unless they only have consumers on the same device as themselves.
248 return false;
249 }
250
251 return true;
252 }
253
OptimizeNode(int node_idx,SetVector<int> * nodes_to_simplify,std::set<int> * nodes_to_delete)254 void DependencyOptimizer::OptimizeNode(int node_idx,
255 SetVector<int>* nodes_to_simplify,
256 std::set<int>* nodes_to_delete) {
257 NodeDef* node = optimized_graph_->mutable_node(node_idx);
258 const bool is_noop = IsNoOp(*node);
259 const bool is_identity = IsIdentity(*node) || IsIdentityNSingleInput(*node);
260 const bool is_multi_input_identity =
261 IsIdentityN(*node) && !IsIdentityNSingleInput(*node);
262 const string node_name = node->name();
263 // Constant nodes with no input control dependency are always executed early,
264 // so we can prune all their output control dependencies.
265 if (IsConstant(*node) && node->input_size() == 0) {
266 const auto output_nodes = node_map_->GetOutputs(node_name);
267 for (NodeDef* fanout : output_nodes) {
268 bool optimize_fanout = false;
269 bool data_connection = false;
270 for (int i = fanout->input_size() - 1; i >= 0; --i) {
271 const TensorId input_tensor = ParseTensorName(fanout->input(i));
272 if (input_tensor.node() == node_name) {
273 if (input_tensor.index() < 0) {
274 fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1);
275 fanout->mutable_input()->RemoveLast();
276 optimize_fanout = true;
277 } else {
278 data_connection = true;
279 }
280 }
281 }
282 if (optimize_fanout) {
283 nodes_to_simplify->PushBack(node_to_idx_[fanout]);
284 if (!data_connection) {
285 node_map_->RemoveOutput(node_name, fanout->name());
286 }
287 }
288 }
289 if (node_map_->GetOutputs(node_name).empty() && fetch_nodes_known_ &&
290 nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
291 // Mark the node for deletion.
292 nodes_to_delete->insert(node_to_idx_[node]);
293 }
294 return;
295 }
296
297 // Change ops that only have control dependencies as outputs to NoOps.
298 if (!is_noop && SafeToConvertToNoOp(*node)) {
299 VLOG(2) << "***** Replacing " << node_name << " (" << node->op()
300 << ") with NoOp.";
301 // The outputs of this node are not consumed. Replace its inputs with
302 // control dependencies and replace the op itself with the NoOp op.
303 std::unordered_set<string> ctrl_inputs;
304 int pos = 0;
305 while (pos < node->input_size()) {
306 const string old_input = node->input(pos);
307 if (IsControlInput(old_input)) {
308 if (!ctrl_inputs.insert(old_input).second) {
309 // We found a duplicate control input. Remove it.
310 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
311 node->mutable_input()->RemoveLast();
312 } else {
313 ++pos;
314 }
315 continue;
316 }
317 // Replace a normal input with a control input.
318 const string ctrl_input = ConstantFolding::AddControlDependency(
319 old_input, optimized_graph_, node_map_.get());
320 ctrl_inputs.insert(ctrl_input);
321 node->set_input(pos, ctrl_input);
322 node_map_->UpdateInput(node_name, old_input, ctrl_input);
323 const NodeDef* old_input_node = node_map_->GetNode(old_input);
324 nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
325 ++pos;
326 }
327 node->set_op("NoOp");
328 EraseRegularNodeAttributes(node);
329 DedupControlInputs(node);
330 nodes_to_simplify->PushBack(node_to_idx_[node]);
331 return;
332 }
333
334 // Remove NoOp nodes if the product of their fan-in and fan-out is less than
335 // or equal to the sum of the fan-in and fan-out. The non-trivial rewrites
336 // take the following form:
337 //
338 // Case a)
339 // x --^> +------+ x --^> +---+
340 // y --^> | NoOp | --^> a ==> y --^> | a |
341 // ... | | ... | |
342 // z --^> +------+ z --^> +---+
343 //
344 // Case b)
345 // +------+ --^> a +---+ --^> a
346 // x --^> | NoOp | --^> b ==> | x | --^> b
347 // | | ... | | ...
348 // +------+ --^> c +---+ --^> c
349 // Case c)
350 // +------+ x ---^> a
351 // x --^> | NoOp | --^> a ==> \/
352 // y --^> | | --^> b /\
353 // +------+ y ---^> b
354 //
355 // We only apply this optimization if we don't increase the number of control
356 // edges across device boundaries, e.g. in cases a) and b) if NoOp and
357 // a and x, respectively, are on the same device. Control edges across device
358 // boundaries require inter-device communication (Send/Recv pairs to be
359 // inserted in the graph), which is very costly.
360 //
361 // We also remove identity nodes, subject to the same constraints on number of
362 // resulting control edges and device boundary crossings:
363 //
364 // Case a)
365 // +----------+ ---> a +---+ ---> a
366 // x --> | Identity | --^> b ==> | x | --^> b
367 // | | ... | | ...
368 // +----------+ --^> c +---+ --^> c
369 //
370 // Case b)
371 // x ---> +----------+ ---> a x ---> +---+
372 // y --^> | Identity | ==> y --^> | a |
373 // ... | | ... | |
374 // z --^> +----------+ z --^> +---+
375 //
376 // Case c)
377 // +----------+ x ---> +---+
378 // x ---> | Identity | ---> a ==> \--^> | a |
379 // y --^> | | --^> b /\ +---+
380 // +----------+ y --^> b
381
382 if (is_noop || ((is_identity || is_multi_input_identity) &&
383 SafeToRemoveIdentity(*node))) {
384 const int num_inputs = node->input_size();
385 std::vector<NodeDef*> input_nodes;
386 for (int i = 0; i < num_inputs; ++i) {
387 NodeDef* input_node = node_map_->GetNode(node->input(i));
388 if (input_node == nullptr) {
389 LOG(ERROR) << "Invalid input " << node->input(i);
390 return;
391 }
392 input_nodes.push_back(input_node);
393 }
394 const auto& output_node_set = node_map_->GetOutputs(node_name);
395 const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
396 output_node_set.end());
397
398 if (!BypassingNodeIsBeneficial(*node, input_nodes, output_nodes)) {
399 return;
400 }
401
402 VLOG(2) << "***** Rerouting input around\n" << node->DebugString();
403 // Now remove the node and re-wire its inputs to its outputs.
404 for (auto consumer : output_nodes) {
405 bool updated_consumer = false;
406 VLOG(2) << "consumer before:\n" << consumer->DebugString();
407 // Remove dependency on node from consumer.
408 for (int i = 0; i < num_inputs; ++i) {
409 const NodeDef* input = input_nodes[i];
410 // Forward dependency from input to consumer if it doesn't already
411 // depend on it.
412 if ((is_identity && i == 0) ||
413 (is_multi_input_identity && !IsControlInput(node->input(i)))) {
414 // Replace regular input from Identity node.
415 string new_input;
416 const string& input_to_forward = node->input(i);
417 CHECK(!IsControlInput(input_to_forward));
418 for (int j = 0; j < consumer->input_size(); ++j) {
419 const TensorId old_input = ParseTensorName(consumer->input(j));
420 if (old_input.node() == node_name) {
421 if (old_input.index() == i) {
422 // Regular input
423 new_input = input_to_forward;
424 node_map_->UpdateInput(consumer->name(),
425 string(old_input.node()), new_input);
426 consumer->set_input(j, new_input);
427 } else if (old_input.index() == -1) {
428 // Control dependency
429 new_input = AsControlDependency(NodeName(input_to_forward));
430 node_map_->UpdateInput(consumer->name(),
431 string(old_input.node()), new_input);
432 consumer->set_input(j, new_input);
433 }
434 }
435 }
436 updated_consumer = true;
437 } else {
438 // Forward dependency from input to consumer if it doesn't already
439 // depend on it.
440 if (node_map_->GetOutputs(input->name()).count(consumer) == 0) {
441 consumer->add_input(AsControlDependency(input->name()));
442 node_map_->AddOutput(input->name(), consumer->name());
443 nodes_to_simplify->PushBack(node_to_idx_[input]);
444 updated_consumer = true;
445 }
446 }
447 }
448 updated_consumer |= RemoveControlInput(
449 consumer, AsControlDependency(node_name), node_map_.get());
450 if (updated_consumer) {
451 nodes_to_simplify->PushBack(node_to_idx_[consumer]);
452 }
453 VLOG(2) << "consumer after:\n" << consumer->DebugString();
454 }
455 node_map_->RemoveOutputs(node_name);
456 if (fetch_nodes_known_ &&
457 nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
458 // Mark the node for deletion.
459 nodes_to_delete->insert(node_idx);
460
461 // Disconnect the node from its inputs to enable further optimizations.
462 node_map_->RemoveInputs(node_name);
463 node->clear_input();
464 }
465 }
466 }
467
CleanControlInputs()468 void DependencyOptimizer::CleanControlInputs() {
469 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
470 DedupControlInputs(optimized_graph_->mutable_node(i));
471 }
472 }
473
OptimizeDependencies()474 Status DependencyOptimizer::OptimizeDependencies() {
475 SetVector<int> nodes_to_simplify;
476 std::set<int> nodes_to_delete;
477 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
478 const NodeDef& node = optimized_graph_->node(i);
479 if (IsNoOp(node) || IsIdentity(node) || IsIdentityN(node) ||
480 IsConstant(node) || SafeToConvertToNoOp(node)) {
481 nodes_to_simplify.PushBack(i);
482 }
483 }
484 while (!nodes_to_simplify.Empty()) {
485 int node_to_simplify = nodes_to_simplify.PopBack();
486 // Discard nodes that were marked for deletion already.
487 while (nodes_to_delete.find(node_to_simplify) != nodes_to_delete.end()) {
488 node_to_simplify = nodes_to_simplify.PopBack();
489 }
490 OptimizeNode(node_to_simplify, &nodes_to_simplify, &nodes_to_delete);
491 }
492
493 if (fetch_nodes_known_) {
494 VLOG(1) << "Deleted " << nodes_to_delete.size() << " out of "
495 << optimized_graph_->node_size() << " nodes.";
496 EraseNodesFromGraph(nodes_to_delete, optimized_graph_);
497 node_map_.reset(new NodeMap(optimized_graph_));
498 BuildNodeToIdx();
499 }
500 return OkStatus();
501 }
502
503 namespace {
504
505 enum DistanceFromSource : uint8 { ZERO = 0, ONE = 1, TWO_OR_GREATER = 2 };
506
LongestPathsLowerBounds(int source,const std::pair<int,int> & target_range,const std::vector<std::vector<int>> & outputs,std::vector<DistanceFromSource> * longest_distance)507 void LongestPathsLowerBounds(
508 int source, const std::pair<int, int>& target_range,
509 const std::vector<std::vector<int>>& outputs,
510 std::vector<DistanceFromSource>* longest_distance) {
511 std::deque<int> queue;
512 queue.emplace_front(source);
513 while (!queue.empty()) {
514 int node = queue.front();
515 queue.pop_front();
516 for (int fanout : outputs[node]) {
517 // 1) Only nodes in the target range can be on paths from source to one of
518 // its control outputs.
519 // 2) Since we only need a lower bound on the longest distance, we can
520 // skip nodes for which we have already proven have a path of
521 // length > 1 from the source.
522 if (fanout >= target_range.first && fanout <= target_range.second &&
523 (*longest_distance)[fanout] != TWO_OR_GREATER) {
524 (*longest_distance)[fanout] =
525 (*longest_distance)[fanout] == ZERO ? ONE : TWO_OR_GREATER;
526 queue.emplace_front(fanout);
527 }
528 }
529 }
530 }
531
532 } // namespace
533
TransitiveReduction()534 Status DependencyOptimizer::TransitiveReduction() {
535 // PRECONDITION: optimized_graph_ must be sorted topologically.
536 const int num_nodes = optimized_graph_->node_size();
537 // Set up a compressed version of the graph to save a constant factor in the
538 // expensive algorithm below. Also cache the set of control outputs and the
539 // highest index of a target of any control output from each node.
540 int num_controls = 0;
541 std::vector<std::vector<int>> outputs(num_nodes);
542 std::vector<gtl::InlinedVector<std::pair<int, int>, 2>> control_outputs(
543 num_nodes);
544 // target_range[i] contains the range of node indices for which to compute
545 // longest paths starting from node i.
546 std::vector<std::pair<int, int>> target_range(num_nodes, {num_nodes, -1});
547 for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
548 const NodeDef& node = optimized_graph_->node(node_idx);
549 if (ModifiesFrameInfo(node) || !HasOpDef(node)) {
550 // Ignore function nodes and nodes that modify frame info.
551 continue;
552 }
553 for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) {
554 const string& input = node.input(input_slot);
555 const NodeDef* input_node = node_map_->GetNode(input);
556 if (ModifiesFrameInfo(*input_node) || IsMerge(*input_node)) {
557 // Ignore edges from nodes that modify frame info and from Merge nodes,
558 // because we cannot know which of it's input paths executes.
559 continue;
560 }
561 const int input_node_idx = node_to_idx_[input_node];
562 outputs[input_node_idx].push_back(node_idx);
563 target_range[input_node_idx].first =
564 std::min(target_range[input_node_idx].first, node_idx);
565 if (IsControlInput(input)) {
566 ++num_controls;
567 control_outputs[input_node_idx].emplace_back(node_idx, input_slot);
568 target_range[input_node_idx].second =
569 std::max(target_range[input_node_idx].second, node_idx);
570 }
571 }
572 }
573
574 // Run the longest path in DAG algorithm for each source node that has control
575 // outputs. If, for any target node of a control output, there exists a path
576 // of length > 1, we can drop that control dependency.
577 int num_controls_removed = 0;
578 std::vector<DistanceFromSource> longest_distance(num_nodes);
579 // Map from target_index -> set of (input_slot, source_index), representing
580 // the control edges to remove. We sort them in reverse order by input slot,
581 // such that when we swap them out so we don't clobber the
582 // node(target).input() repeated field.
583 typedef std::pair<int, int> InputSlotAndSource;
584 absl::flat_hash_map<
585 int, std::set<InputSlotAndSource, std::greater<InputSlotAndSource>>>
586 control_edges_to_remove;
587 for (int source = 0; source < num_nodes; ++source) {
588 if (target_range[source].first >= target_range[source].second ||
589 target_range[source].second <= source) {
590 continue;
591 }
592 // Compute the set of nodes in the transitive fanout of source with
593 // topological sort index in [target_range.first : target_range.second]]
594 // to which there exists a path of length 2 or more from source.
595 std::fill(longest_distance.begin() + target_range[source].first,
596 longest_distance.begin() + target_range[source].second + 1, ZERO);
597 LongestPathsLowerBounds(source, target_range[source], outputs,
598 &longest_distance);
599
600 // If the longest path from source to target of a control dependency is
601 // longer than 1, there exists an alternate path, and we can eliminate the
602 // redundant direct control dependency.
603 for (const auto& control_output : control_outputs[source]) {
604 const int target = control_output.first;
605 if (longest_distance[target] == TWO_OR_GREATER) {
606 const int input_slot = control_output.second;
607 control_edges_to_remove[target].emplace(input_slot, source);
608 }
609 }
610 }
611 for (const auto& it : control_edges_to_remove) {
612 const int target = it.first;
613 NodeDef* target_node = optimized_graph_->mutable_node(target);
614 for (const InputSlotAndSource& slot_and_source : it.second) {
615 const int input_slot = slot_and_source.first;
616 const int source = slot_and_source.second;
617 const NodeDef& source_node = optimized_graph_->node(source);
618 CHECK_LT(input_slot, target_node->input_size());
619 target_node->mutable_input()->SwapElements(input_slot,
620 target_node->input_size() - 1);
621 node_map_->RemoveOutput(source_node.name(), target_node->name());
622 target_node->mutable_input()->RemoveLast();
623 ++num_controls_removed;
624 }
625 }
626 VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls
627 << " control dependencies";
628 return OkStatus();
629 }
630
BuildNodeToIdx()631 void DependencyOptimizer::BuildNodeToIdx() {
632 // Set up &node -> index map.
633 node_to_idx_.clear();
634 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
635 const NodeDef& node = optimized_graph_->node(i);
636 node_to_idx_[&node] = i;
637 }
638 }
639
640 // Suppose there are cross-device control inputs to node C from multiple nodes
641 // that are located on another device, e.g., we have control edges:
642 // A->C, B->C
643 // where A and B are on device X and C is on device Y.
644 // We can reduce cross-device communication by introducing an intermediate
645 // NoOp node C' on device X and rewriting the control edges to:
646 // A->C', B->C', C' -> C
GroupCrossDeviceControlEdges(bool host_granularity)647 void DependencyOptimizer::GroupCrossDeviceControlEdges(bool host_granularity) {
648 VLOG(1)
649 << "DependencyOptimizer::GroupCrossDeviceControlEdges host_granularity="
650 << host_granularity;
651 const int num_nodes = optimized_graph_->node_size();
652 for (int i = 0; i < num_nodes; ++i) {
653 NodeDef* node = optimized_graph_->mutable_node(i);
654 if (node->device().empty()) continue;
655 string rest, node_device = node->device();
656 if (host_granularity) {
657 DeviceNameUtils::SplitDeviceName(node->device(), &node_device, &rest);
658 }
659
660 // Creates new noop nodes for devices on which multiple control inputs are
661 // located.
662
663 // Map keyed by device name to the newly introduced Noop node for that
664 // device. A nullptr value means that we have only seen a single node on
665 // that device.
666 std::map<string, NodeDef*> noops;
667 int num_noops = 0;
668 for (int j = 0; j < node->input_size(); ++j) {
669 if (IsControlInput(node->input(j))) {
670 const NodeDef* input = node_map_->GetNode(node->input(j));
671 if (input == nullptr || input->device().empty()) continue;
672 string input_device = input->device();
673 if (host_granularity) {
674 DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
675 &rest);
676 }
677 if (input_device != node_device) {
678 VLOG(2) << "Cross-device " << node->name() << " " << input->device()
679 << " -> " << node->device();
680 auto emplace_result = noops.emplace(input_device, nullptr);
681 if (!emplace_result.second &&
682 emplace_result.first->second == nullptr) {
683 VLOG(2) << "Duplicate input device from " << node->name();
684 // This is the second cross-device control input from the same
685 // device. Creates an intermediate noop node on that device.
686 string group_name;
687 NodeDef* noop;
688 // Creates a fresh node name; there may be conflicting names from
689 // a previous iteration of the optimizer.
690 do {
691 group_name = AddPrefixToNodeName(
692 node->name(),
693 strings::StrCat("GroupCrossDeviceControlEdges_", num_noops));
694 noop = node_map_->GetNode(group_name);
695 ++num_noops;
696 } while (noop != nullptr);
697 noop = optimized_graph_->add_node();
698 noop->set_name(group_name);
699 noop->set_device(input->device());
700 noop->set_op("NoOp");
701 node_map_->AddNode(noop->name(), noop);
702 emplace_result.first->second = noop;
703 VLOG(1) << "GroupCrossDeviceControlEdges: Added "
704 << SummarizeNodeDef(*noop);
705 }
706 }
707 }
708 }
709
710 // Reroute existing control edges to go via the newly introduced NoOp nodes.
711 int pos = 0;
712 while (pos < node->input_size()) {
713 const string& input_name = node->input(pos);
714 if (IsControlInput(input_name)) {
715 NodeDef* input = node_map_->GetNode(input_name);
716 if (input == nullptr) {
717 ++pos;
718 } else {
719 string input_device = input->device();
720 if (host_granularity) {
721 DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
722 &rest);
723 }
724 auto it = noops.find(input_device);
725 if (it == noops.end() || it->second == nullptr) {
726 ++pos;
727 } else {
728 VLOG(2) << "Rewriting input from " << input_name;
729 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
730 node->mutable_input()->RemoveLast();
731 it->second->add_input(AsControlDependency(*input));
732 node_map_->UpdateOutput(input_name, node->name(),
733 it->second->name());
734 }
735 }
736 } else {
737 ++pos;
738 }
739 }
740 for (const auto& entry : noops) {
741 if (entry.second) {
742 node->add_input(AsControlDependency(*entry.second));
743 node_map_->AddOutput(entry.second->name(), node->name());
744 }
745 }
746 }
747 }
748
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)749 Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
750 GraphDef* optimized_graph) {
751 optimized_graph_ = optimized_graph;
752 *optimized_graph_ = item.graph;
753 nodes_to_preserve_ = item.NodesToPreserve();
754 fetch_nodes_known_ = !item.fetch.empty();
755 CleanControlInputs();
756
757 const int num_iterations = 2;
758 for (int iteration = 0; iteration < num_iterations; ++iteration) {
759 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
760 Status topo_sort_status;
761 // Perform topological sort to prepare the graph for transitive reduction.
762 topo_sort_status = TopologicalSort(optimized_graph_);
763 // Set up index-based graph datastructures to speed up analysis steps below.
764 node_map_.reset(new NodeMap(optimized_graph_));
765 BuildNodeToIdx();
766
767 if (topo_sort_status.ok()) {
768 // Remove redundant control dependencies.
769 TF_RETURN_IF_ERROR(TransitiveReduction());
770 } else {
771 LOG(ERROR) << "Iteration = " << iteration
772 << ", topological sort failed with message: "
773 << topo_sort_status.error_message();
774 }
775 // Turn nodes with only control outputs into NoOps, prune NoOp and Identity
776 // nodes.
777 TF_RETURN_IF_ERROR(OptimizeDependencies());
778
779 // Dedup control inputs.
780 CleanControlInputs();
781
782 // Merge multiple control edges from the same device.
783 GroupCrossDeviceControlEdges(/*host_granularity=*/false);
784
785 // Merge control edges from the same host to reduce RPC traffic.
786 GroupCrossDeviceControlEdges(/*host_granularity=*/true);
787 }
788
789 return OkStatus();
790 }
791
792 } // end namespace grappler
793 } // end namespace tensorflow
794