xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/dependency_optimizer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
17 
18 #include <unordered_set>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/grappler/costs/graph_properties.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/grappler/utils/topological_sort.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/util/device_name_utils.h"
36 
37 namespace tensorflow {
38 namespace grappler {
39 
40 namespace {
41 
RemoveControlInput(NodeDef * node,const string & control_input_to_remove,NodeMap * node_map)42 bool RemoveControlInput(NodeDef* node, const string& control_input_to_remove,
43                         NodeMap* node_map) {
44   for (int pos = node->input_size() - 1; pos >= 0; --pos) {
45     const string& input = node->input(pos);
46     if (input[0] != '^') break;
47     if (input == control_input_to_remove) {
48       node->mutable_input()->SwapElements(pos, node->input_size() - 1);
49       node->mutable_input()->RemoveLast();
50       node_map->RemoveOutput(NodeName(input), node->name());
51       return true;
52     }
53   }
54   return false;
55 }
56 
57 }  // namespace
58 
SafeToRemoveIdentity(const NodeDef & node) const59 bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
60   if (!IsIdentity(node) && !IsIdentityN(node)) {
61     return true;
62   }
63 
64   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
65     return false;
66   }
67   if (!fetch_nodes_known_) {
68     // The output values of this node may be needed.
69     return false;
70   }
71 
72   if (node.input_size() < 1) {
73     // Node lacks input, is invalid
74     return false;
75   }
76 
77   const NodeDef* input = node_map_->GetNode(NodeName(node.input(0)));
78   if (input == nullptr) {
79     VLOG(1) << "node = " << node.name() << " input = " << node.input(0);
80     return false;
81   }
82   // Don't remove Identity nodes corresponding to Variable reads or following
83   // Recv.
84   if (IsVariable(*input) || IsRecv(*input)) {
85     return false;
86   }
87   for (const auto& consumer : node_map_->GetOutputs(node.name())) {
88     if (node.input_size() > 1 && (IsRetval(*consumer) || IsMerge(*consumer))) {
89       return false;
90     }
91     if (IsSwitch(*input)) {
92       for (const string& consumer_input : consumer->input()) {
93         if (consumer_input == AsControlDependency(node.name())) {
94           return false;
95         }
96       }
97     }
98   }
99   return true;
100 }
101 
SafeToConvertToNoOp(const NodeDef & node) const102 bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
103   if (HasRegularOutputs(node, *node_map_)) {
104     // The output values of this node may be needed.
105     VLOG(3) << "Not safe to convert '" << node.name()
106             << " to NoOp. Node has outputs.";
107     return false;
108   }
109   if (!fetch_nodes_known_) {
110     VLOG(3) << "Not safe to convert '" << node.name()
111             << " to NoOp. Fetches unknown.";
112     return false;
113   }
114   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
115     VLOG(3) << "Not safe to convert to NoOp: " << node.name()
116             << " is in preserve set.";
117     return false;
118   }
119   if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node)) {
120     VLOG(3) << "Not safe to convert '" << node.name()
121             << " to NoOp. Node modifies frame info.";
122     return false;
123   }
124   // Ops reading variables are marked as stateful, but are safe to remove if
125   // redundant.
126   static const absl::flat_hash_set<string>* gather_ops =
127       new absl::flat_hash_set<string>{"Gather", "GatherV2", "GatherNd",
128                                       "ResourceGather", "ResourceGatherNd"};
129   const bool is_variable_read =
130       IsReadVariableOp(node) || IsReadVariablesOp(node) ||
131       gather_ops->find(node.op()) != gather_ops->end();
132   if (!is_variable_read && !IsFreeOfSideEffect(node)) {
133     VLOG(3) << "Not safe to convert '" << node.name()
134             << " to NoOp. Node has side effect.";
135     return false;
136   }
137   if (node.op().rfind("Submodel", 0) == 0) {
138     return false;
139   }
140   const OpDef* op_def = nullptr;
141   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
142   if (!status.ok() || op_def->output_arg_size() == 0) {
143     return false;
144   }
145   const std::unordered_set<string> do_not_rewrite_ops{
146       "Assert",     "CheckNumerics",         "_Retval",
147       "_Arg",       "_ParallelConcatUpdate", "TPUExecute",
148       "TPUCompile", "ControlTrigger"};
149   if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) {
150     return false;
151   }
152   if (!SafeToRemoveIdentity(node)) {
153     return false;
154   }
155   return true;
156 }
157 
NumEdgesIfBypassed(const NodeDef & node,const std::vector<NodeDef * > & output_nodes) const158 int DependencyOptimizer::NumEdgesIfBypassed(
159     const NodeDef& node, const std::vector<NodeDef*>& output_nodes) const {
160   const bool is_multi_input_identity_n =
161       IsIdentityN(node) && !IsIdentityNSingleInput(node);
162   const int num_outputs = output_nodes.size();
163   const int num_inputs = node.input_size();
164 
165   if (is_multi_input_identity_n) {
166     // multi-input identity_n with input/output control dependencies will likely
167     // increase number of edges after optimization.
168     int num_edges_if_bypassed(0);
169     for (const string& input_node_name : node.input()) {
170       if (IsControlInput(input_node_name)) {
171         num_edges_if_bypassed += num_outputs;
172       } else {
173         ++num_edges_if_bypassed;
174       }
175     }
176 
177     for (auto consumer : output_nodes) {
178       for (int j = 0; j < consumer->input_size(); ++j) {
179         const TensorId consumer_input = ParseTensorName(consumer->input(j));
180         if (consumer_input.node() == node.name()) {
181           if (IsControlInput(consumer_input)) {
182             num_edges_if_bypassed += num_inputs;
183           } else {
184             ++num_edges_if_bypassed;
185           }
186         }
187       }
188     }
189     return num_edges_if_bypassed;
190   } else {
191     return num_inputs * num_outputs;
192   }
193 }
194 
BypassingNodeIsBeneficial(const NodeDef & node,const std::vector<NodeDef * > & input_nodes,const std::vector<NodeDef * > & output_nodes) const195 bool DependencyOptimizer::BypassingNodeIsBeneficial(
196     const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
197     const std::vector<NodeDef*>& output_nodes) const {
198   const bool is_identity = IsIdentity(node) || IsIdentityNSingleInput(node);
199   const bool is_multi_input_identity_n =
200       IsIdentityN(node) && !IsIdentityNSingleInput(node);
201   const int num_outputs = output_nodes.size();
202   const int num_inputs = node.input_size();
203 
204   if (NumEdgesIfBypassed(node, output_nodes) > num_inputs + num_outputs) {
205     return false;
206   }
207 
208   // Make sure that we don't increase the number of edges that cross
209   // device boundaries.
210   if ((num_inputs == 1 && num_outputs > 1 &&
211        input_nodes[0]->device() != node.device()) ||
212       (num_inputs > 1 && num_outputs == 1 &&
213        output_nodes[0]->device() != node.device())) {
214     return false;
215   }
216 
217   // TODO(rmlarsen): Not all device crossings are equally expensive.
218   // Assign a cost to each based on device affinity and compute a
219   // cost before and after.
220   const string& node_dev = node.device();
221   int num_cross_in = 0;
222   for (NodeDef* input_node : input_nodes) {
223     num_cross_in += static_cast<int>(input_node->device() != node_dev);
224   }
225   int num_cross_out = 0;
226   for (NodeDef* output_node : output_nodes) {
227     num_cross_out += static_cast<int>(output_node->device() != node_dev);
228   }
229 
230   // Make sure we do not increase the number of device crossings.
231   const int num_cross_before = num_cross_in + num_cross_out;
232   int num_cross_after = 0;
233   for (NodeDef* input_node : input_nodes) {
234     for (NodeDef* output_node : output_nodes) {
235       num_cross_after +=
236           static_cast<int>(input_node->device() != output_node->device());
237     }
238   }
239   if (num_cross_after > num_cross_before) {
240     return false;
241   }
242 
243   if ((is_identity || is_multi_input_identity_n) && num_cross_in > 0 &&
244       num_cross_out > 0 && num_cross_after > 0) {
245     // This identity node follows a device crossing, so it might be
246     // following a _Recv node after partitioning. Do not remove such nodes,
247     // unless they only have consumers on the same device as themselves.
248     return false;
249   }
250 
251   return true;
252 }
253 
OptimizeNode(int node_idx,SetVector<int> * nodes_to_simplify,std::set<int> * nodes_to_delete)254 void DependencyOptimizer::OptimizeNode(int node_idx,
255                                        SetVector<int>* nodes_to_simplify,
256                                        std::set<int>* nodes_to_delete) {
257   NodeDef* node = optimized_graph_->mutable_node(node_idx);
258   const bool is_noop = IsNoOp(*node);
259   const bool is_identity = IsIdentity(*node) || IsIdentityNSingleInput(*node);
260   const bool is_multi_input_identity =
261       IsIdentityN(*node) && !IsIdentityNSingleInput(*node);
262   const string node_name = node->name();
263   // Constant nodes with no input control dependency are always executed early,
264   // so we can prune all their output control dependencies.
265   if (IsConstant(*node) && node->input_size() == 0) {
266     const auto output_nodes = node_map_->GetOutputs(node_name);
267     for (NodeDef* fanout : output_nodes) {
268       bool optimize_fanout = false;
269       bool data_connection = false;
270       for (int i = fanout->input_size() - 1; i >= 0; --i) {
271         const TensorId input_tensor = ParseTensorName(fanout->input(i));
272         if (input_tensor.node() == node_name) {
273           if (input_tensor.index() < 0) {
274             fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1);
275             fanout->mutable_input()->RemoveLast();
276             optimize_fanout = true;
277           } else {
278             data_connection = true;
279           }
280         }
281       }
282       if (optimize_fanout) {
283         nodes_to_simplify->PushBack(node_to_idx_[fanout]);
284         if (!data_connection) {
285           node_map_->RemoveOutput(node_name, fanout->name());
286         }
287       }
288     }
289     if (node_map_->GetOutputs(node_name).empty() && fetch_nodes_known_ &&
290         nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
291       // Mark the node for deletion.
292       nodes_to_delete->insert(node_to_idx_[node]);
293     }
294     return;
295   }
296 
297   // Change ops that only have control dependencies as outputs to NoOps.
298   if (!is_noop && SafeToConvertToNoOp(*node)) {
299     VLOG(2) << "***** Replacing  " << node_name << " (" << node->op()
300             << ") with NoOp.";
301     // The outputs of this node are not consumed. Replace its inputs with
302     // control dependencies and replace the op itself with the NoOp op.
303     std::unordered_set<string> ctrl_inputs;
304     int pos = 0;
305     while (pos < node->input_size()) {
306       const string old_input = node->input(pos);
307       if (IsControlInput(old_input)) {
308         if (!ctrl_inputs.insert(old_input).second) {
309           // We found a duplicate control input. Remove it.
310           node->mutable_input()->SwapElements(pos, node->input_size() - 1);
311           node->mutable_input()->RemoveLast();
312         } else {
313           ++pos;
314         }
315         continue;
316       }
317       // Replace a normal input with a control input.
318       const string ctrl_input = ConstantFolding::AddControlDependency(
319           old_input, optimized_graph_, node_map_.get());
320       ctrl_inputs.insert(ctrl_input);
321       node->set_input(pos, ctrl_input);
322       node_map_->UpdateInput(node_name, old_input, ctrl_input);
323       const NodeDef* old_input_node = node_map_->GetNode(old_input);
324       nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
325       ++pos;
326     }
327     node->set_op("NoOp");
328     EraseRegularNodeAttributes(node);
329     DedupControlInputs(node);
330     nodes_to_simplify->PushBack(node_to_idx_[node]);
331     return;
332   }
333 
334   // Remove NoOp nodes if the product of their fan-in and fan-out is less than
335   // or equal to the sum of the fan-in and fan-out. The non-trivial rewrites
336   // take the following form:
337   //
338   // Case a)
339   //    x --^> +------+                x --^> +---+
340   //    y --^> | NoOp | --^> a   ==>   y --^> | a |
341   //    ...    |      |                  ...  |   |
342   //    z --^> +------+                z --^> +---+
343   //
344   // Case b)
345   //           +------+ --^> a         +---+ --^> a
346   //    x --^> | NoOp | --^> b  ==>    | x | --^> b
347   //           |      | ...            |   | ...
348   //           +------+ --^> c         +---+ --^> c
349   // Case c)
350   //           +------+                x ---^> a
351   //    x --^> | NoOp | --^> a  ==>      \/
352   //    y --^> |      | --^> b           /\
353   //           +------+                y ---^> b
354   //
355   // We only apply this optimization if we don't increase the number of control
356   // edges across device boundaries, e.g. in cases a) and b) if NoOp and
357   // a and x, respectively, are on the same device. Control edges across device
358   // boundaries require inter-device communication (Send/Recv pairs to be
359   // inserted in the graph), which is very costly.
360   //
361   // We also remove identity nodes, subject to the same constraints on number of
362   // resulting control edges and device boundary crossings:
363   //
364   // Case a)
365   //          +----------+ ---> a       +---+ ---> a
366   //    x --> | Identity | --^> b  ==>  | x | --^> b
367   //          |          | ...          |   | ...
368   //          +----------+ --^> c       +---+ --^> c
369   //
370   // Case b)
371   //    x ---> +----------+ ---> a      x ---> +---+
372   //    y --^> | Identity |        ==>  y --^> | a |
373   //    ...    |          |               ...  |   |
374   //    z --^> +----------+             z --^> +---+
375   //
376   // Case c)
377   //           +----------+             x ---> +---+
378   //    x ---> | Identity | ---> a ==>   \--^> | a |
379   //    y --^> |          | --^> b       /\    +---+
380   //           +----------+             y --^> b
381 
382   if (is_noop || ((is_identity || is_multi_input_identity) &&
383                   SafeToRemoveIdentity(*node))) {
384     const int num_inputs = node->input_size();
385     std::vector<NodeDef*> input_nodes;
386     for (int i = 0; i < num_inputs; ++i) {
387       NodeDef* input_node = node_map_->GetNode(node->input(i));
388       if (input_node == nullptr) {
389         LOG(ERROR) << "Invalid input " << node->input(i);
390         return;
391       }
392       input_nodes.push_back(input_node);
393     }
394     const auto& output_node_set = node_map_->GetOutputs(node_name);
395     const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
396                                              output_node_set.end());
397 
398     if (!BypassingNodeIsBeneficial(*node, input_nodes, output_nodes)) {
399       return;
400     }
401 
402     VLOG(2) << "***** Rerouting input around\n" << node->DebugString();
403     // Now remove the node and re-wire its inputs to its outputs.
404     for (auto consumer : output_nodes) {
405       bool updated_consumer = false;
406       VLOG(2) << "consumer before:\n" << consumer->DebugString();
407       // Remove dependency on node from consumer.
408       for (int i = 0; i < num_inputs; ++i) {
409         const NodeDef* input = input_nodes[i];
410         // Forward dependency from input to consumer if it doesn't already
411         // depend on it.
412         if ((is_identity && i == 0) ||
413             (is_multi_input_identity && !IsControlInput(node->input(i)))) {
414           // Replace regular input from Identity node.
415           string new_input;
416           const string& input_to_forward = node->input(i);
417           CHECK(!IsControlInput(input_to_forward));
418           for (int j = 0; j < consumer->input_size(); ++j) {
419             const TensorId old_input = ParseTensorName(consumer->input(j));
420             if (old_input.node() == node_name) {
421               if (old_input.index() == i) {
422                 // Regular input
423                 new_input = input_to_forward;
424                 node_map_->UpdateInput(consumer->name(),
425                                        string(old_input.node()), new_input);
426                 consumer->set_input(j, new_input);
427               } else if (old_input.index() == -1) {
428                 // Control dependency
429                 new_input = AsControlDependency(NodeName(input_to_forward));
430                 node_map_->UpdateInput(consumer->name(),
431                                        string(old_input.node()), new_input);
432                 consumer->set_input(j, new_input);
433               }
434             }
435           }
436           updated_consumer = true;
437         } else {
438           // Forward dependency from input to consumer if it doesn't already
439           // depend on it.
440           if (node_map_->GetOutputs(input->name()).count(consumer) == 0) {
441             consumer->add_input(AsControlDependency(input->name()));
442             node_map_->AddOutput(input->name(), consumer->name());
443             nodes_to_simplify->PushBack(node_to_idx_[input]);
444             updated_consumer = true;
445           }
446         }
447       }
448       updated_consumer |= RemoveControlInput(
449           consumer, AsControlDependency(node_name), node_map_.get());
450       if (updated_consumer) {
451         nodes_to_simplify->PushBack(node_to_idx_[consumer]);
452       }
453       VLOG(2) << "consumer after:\n" << consumer->DebugString();
454     }
455     node_map_->RemoveOutputs(node_name);
456     if (fetch_nodes_known_ &&
457         nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
458       // Mark the node for deletion.
459       nodes_to_delete->insert(node_idx);
460 
461       // Disconnect the node from its inputs to enable further optimizations.
462       node_map_->RemoveInputs(node_name);
463       node->clear_input();
464     }
465   }
466 }
467 
CleanControlInputs()468 void DependencyOptimizer::CleanControlInputs() {
469   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
470     DedupControlInputs(optimized_graph_->mutable_node(i));
471   }
472 }
473 
OptimizeDependencies()474 Status DependencyOptimizer::OptimizeDependencies() {
475   SetVector<int> nodes_to_simplify;
476   std::set<int> nodes_to_delete;
477   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
478     const NodeDef& node = optimized_graph_->node(i);
479     if (IsNoOp(node) || IsIdentity(node) || IsIdentityN(node) ||
480         IsConstant(node) || SafeToConvertToNoOp(node)) {
481       nodes_to_simplify.PushBack(i);
482     }
483   }
484   while (!nodes_to_simplify.Empty()) {
485     int node_to_simplify = nodes_to_simplify.PopBack();
486     // Discard nodes that were marked for deletion already.
487     while (nodes_to_delete.find(node_to_simplify) != nodes_to_delete.end()) {
488       node_to_simplify = nodes_to_simplify.PopBack();
489     }
490     OptimizeNode(node_to_simplify, &nodes_to_simplify, &nodes_to_delete);
491   }
492 
493   if (fetch_nodes_known_) {
494     VLOG(1) << "Deleted " << nodes_to_delete.size() << " out of "
495             << optimized_graph_->node_size() << " nodes.";
496     EraseNodesFromGraph(nodes_to_delete, optimized_graph_);
497     node_map_.reset(new NodeMap(optimized_graph_));
498     BuildNodeToIdx();
499   }
500   return OkStatus();
501 }
502 
503 namespace {
504 
505 enum DistanceFromSource : uint8 { ZERO = 0, ONE = 1, TWO_OR_GREATER = 2 };
506 
LongestPathsLowerBounds(int source,const std::pair<int,int> & target_range,const std::vector<std::vector<int>> & outputs,std::vector<DistanceFromSource> * longest_distance)507 void LongestPathsLowerBounds(
508     int source, const std::pair<int, int>& target_range,
509     const std::vector<std::vector<int>>& outputs,
510     std::vector<DistanceFromSource>* longest_distance) {
511   std::deque<int> queue;
512   queue.emplace_front(source);
513   while (!queue.empty()) {
514     int node = queue.front();
515     queue.pop_front();
516     for (int fanout : outputs[node]) {
517       // 1) Only nodes in the target range can be on paths from source to one of
518       //    its control outputs.
519       // 2) Since we only need a lower bound on the longest distance, we can
520       //    skip nodes for which we have already proven have a path of
521       //    length > 1 from the source.
522       if (fanout >= target_range.first && fanout <= target_range.second &&
523           (*longest_distance)[fanout] != TWO_OR_GREATER) {
524         (*longest_distance)[fanout] =
525             (*longest_distance)[fanout] == ZERO ? ONE : TWO_OR_GREATER;
526         queue.emplace_front(fanout);
527       }
528     }
529   }
530 }
531 
532 }  // namespace
533 
TransitiveReduction()534 Status DependencyOptimizer::TransitiveReduction() {
535   // PRECONDITION: optimized_graph_ must be sorted topologically.
536   const int num_nodes = optimized_graph_->node_size();
537   // Set up a compressed version of the graph to save a constant factor in the
538   // expensive algorithm below. Also cache the set of control outputs and the
539   // highest index of a target of any control output from each node.
540   int num_controls = 0;
541   std::vector<std::vector<int>> outputs(num_nodes);
542   std::vector<gtl::InlinedVector<std::pair<int, int>, 2>> control_outputs(
543       num_nodes);
544   // target_range[i] contains the range of node indices for which to compute
545   // longest paths starting from node i.
546   std::vector<std::pair<int, int>> target_range(num_nodes, {num_nodes, -1});
547   for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
548     const NodeDef& node = optimized_graph_->node(node_idx);
549     if (ModifiesFrameInfo(node) || !HasOpDef(node)) {
550       // Ignore function nodes and nodes that modify frame info.
551       continue;
552     }
553     for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) {
554       const string& input = node.input(input_slot);
555       const NodeDef* input_node = node_map_->GetNode(input);
556       if (ModifiesFrameInfo(*input_node) || IsMerge(*input_node)) {
557         // Ignore edges from nodes that modify frame info and from Merge nodes,
558         // because we cannot know which of it's input paths executes.
559         continue;
560       }
561       const int input_node_idx = node_to_idx_[input_node];
562       outputs[input_node_idx].push_back(node_idx);
563       target_range[input_node_idx].first =
564           std::min(target_range[input_node_idx].first, node_idx);
565       if (IsControlInput(input)) {
566         ++num_controls;
567         control_outputs[input_node_idx].emplace_back(node_idx, input_slot);
568         target_range[input_node_idx].second =
569             std::max(target_range[input_node_idx].second, node_idx);
570       }
571     }
572   }
573 
574   // Run the longest path in DAG algorithm for each source node that has control
575   // outputs. If, for any target node of a control output, there exists a path
576   // of length > 1, we can drop that control dependency.
577   int num_controls_removed = 0;
578   std::vector<DistanceFromSource> longest_distance(num_nodes);
579   // Map from target_index -> set of (input_slot, source_index), representing
580   // the control edges to remove. We sort them in reverse order by input slot,
581   // such that when we swap them out so we don't clobber the
582   // node(target).input() repeated field.
583   typedef std::pair<int, int> InputSlotAndSource;
584   absl::flat_hash_map<
585       int, std::set<InputSlotAndSource, std::greater<InputSlotAndSource>>>
586       control_edges_to_remove;
587   for (int source = 0; source < num_nodes; ++source) {
588     if (target_range[source].first >= target_range[source].second ||
589         target_range[source].second <= source) {
590       continue;
591     }
592     // Compute the set of nodes in the transitive fanout of source with
593     // topological sort index in [target_range.first : target_range.second]]
594     // to which there exists a path of length 2 or more from source.
595     std::fill(longest_distance.begin() + target_range[source].first,
596               longest_distance.begin() + target_range[source].second + 1, ZERO);
597     LongestPathsLowerBounds(source, target_range[source], outputs,
598                             &longest_distance);
599 
600     // If the longest path from source to target of a control dependency is
601     // longer than 1, there exists an alternate path, and we can eliminate the
602     // redundant direct control dependency.
603     for (const auto& control_output : control_outputs[source]) {
604       const int target = control_output.first;
605       if (longest_distance[target] == TWO_OR_GREATER) {
606         const int input_slot = control_output.second;
607         control_edges_to_remove[target].emplace(input_slot, source);
608       }
609     }
610   }
611   for (const auto& it : control_edges_to_remove) {
612     const int target = it.first;
613     NodeDef* target_node = optimized_graph_->mutable_node(target);
614     for (const InputSlotAndSource& slot_and_source : it.second) {
615       const int input_slot = slot_and_source.first;
616       const int source = slot_and_source.second;
617       const NodeDef& source_node = optimized_graph_->node(source);
618       CHECK_LT(input_slot, target_node->input_size());
619       target_node->mutable_input()->SwapElements(input_slot,
620                                                  target_node->input_size() - 1);
621       node_map_->RemoveOutput(source_node.name(), target_node->name());
622       target_node->mutable_input()->RemoveLast();
623       ++num_controls_removed;
624     }
625   }
626   VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls
627           << " control dependencies";
628   return OkStatus();
629 }
630 
BuildNodeToIdx()631 void DependencyOptimizer::BuildNodeToIdx() {
632   // Set up &node -> index map.
633   node_to_idx_.clear();
634   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
635     const NodeDef& node = optimized_graph_->node(i);
636     node_to_idx_[&node] = i;
637   }
638 }
639 
640 // Suppose there are cross-device control inputs to node C from multiple nodes
641 // that are located on another device, e.g., we have control edges:
642 // A->C, B->C
643 // where A and B are on device X and C is on device Y.
644 // We can reduce cross-device communication by introducing an intermediate
645 // NoOp node C' on device X and rewriting the control edges to:
646 // A->C', B->C', C' -> C
GroupCrossDeviceControlEdges(bool host_granularity)647 void DependencyOptimizer::GroupCrossDeviceControlEdges(bool host_granularity) {
648   VLOG(1)
649       << "DependencyOptimizer::GroupCrossDeviceControlEdges host_granularity="
650       << host_granularity;
651   const int num_nodes = optimized_graph_->node_size();
652   for (int i = 0; i < num_nodes; ++i) {
653     NodeDef* node = optimized_graph_->mutable_node(i);
654     if (node->device().empty()) continue;
655     string rest, node_device = node->device();
656     if (host_granularity) {
657       DeviceNameUtils::SplitDeviceName(node->device(), &node_device, &rest);
658     }
659 
660     // Creates new noop nodes for devices on which multiple control inputs are
661     // located.
662 
663     // Map keyed by device name to the newly introduced Noop node for that
664     // device. A nullptr value means that we have only seen a single node on
665     // that device.
666     std::map<string, NodeDef*> noops;
667     int num_noops = 0;
668     for (int j = 0; j < node->input_size(); ++j) {
669       if (IsControlInput(node->input(j))) {
670         const NodeDef* input = node_map_->GetNode(node->input(j));
671         if (input == nullptr || input->device().empty()) continue;
672         string input_device = input->device();
673         if (host_granularity) {
674           DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
675                                            &rest);
676         }
677         if (input_device != node_device) {
678           VLOG(2) << "Cross-device " << node->name() << " " << input->device()
679                   << " -> " << node->device();
680           auto emplace_result = noops.emplace(input_device, nullptr);
681           if (!emplace_result.second &&
682               emplace_result.first->second == nullptr) {
683             VLOG(2) << "Duplicate input device from " << node->name();
684             // This is the second cross-device control input from the same
685             // device. Creates an intermediate noop node on that device.
686             string group_name;
687             NodeDef* noop;
688             // Creates a fresh node name; there may be conflicting names from
689             // a previous iteration of the optimizer.
690             do {
691               group_name = AddPrefixToNodeName(
692                   node->name(),
693                   strings::StrCat("GroupCrossDeviceControlEdges_", num_noops));
694               noop = node_map_->GetNode(group_name);
695               ++num_noops;
696             } while (noop != nullptr);
697             noop = optimized_graph_->add_node();
698             noop->set_name(group_name);
699             noop->set_device(input->device());
700             noop->set_op("NoOp");
701             node_map_->AddNode(noop->name(), noop);
702             emplace_result.first->second = noop;
703             VLOG(1) << "GroupCrossDeviceControlEdges: Added "
704                     << SummarizeNodeDef(*noop);
705           }
706         }
707       }
708     }
709 
710     // Reroute existing control edges to go via the newly introduced NoOp nodes.
711     int pos = 0;
712     while (pos < node->input_size()) {
713       const string& input_name = node->input(pos);
714       if (IsControlInput(input_name)) {
715         NodeDef* input = node_map_->GetNode(input_name);
716         if (input == nullptr) {
717           ++pos;
718         } else {
719           string input_device = input->device();
720           if (host_granularity) {
721             DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
722                                              &rest);
723           }
724           auto it = noops.find(input_device);
725           if (it == noops.end() || it->second == nullptr) {
726             ++pos;
727           } else {
728             VLOG(2) << "Rewriting input from " << input_name;
729             node->mutable_input()->SwapElements(pos, node->input_size() - 1);
730             node->mutable_input()->RemoveLast();
731             it->second->add_input(AsControlDependency(*input));
732             node_map_->UpdateOutput(input_name, node->name(),
733                                     it->second->name());
734           }
735         }
736       } else {
737         ++pos;
738       }
739     }
740     for (const auto& entry : noops) {
741       if (entry.second) {
742         node->add_input(AsControlDependency(*entry.second));
743         node_map_->AddOutput(entry.second->name(), node->name());
744       }
745     }
746   }
747 }
748 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)749 Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
750                                      GraphDef* optimized_graph) {
751   optimized_graph_ = optimized_graph;
752   *optimized_graph_ = item.graph;
753   nodes_to_preserve_ = item.NodesToPreserve();
754   fetch_nodes_known_ = !item.fetch.empty();
755   CleanControlInputs();
756 
757   const int num_iterations = 2;
758   for (int iteration = 0; iteration < num_iterations; ++iteration) {
759     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
760     Status topo_sort_status;
761     // Perform topological sort to prepare the graph for transitive reduction.
762     topo_sort_status = TopologicalSort(optimized_graph_);
763     // Set up index-based graph datastructures to speed up analysis steps below.
764     node_map_.reset(new NodeMap(optimized_graph_));
765     BuildNodeToIdx();
766 
767     if (topo_sort_status.ok()) {
768       // Remove redundant control dependencies.
769       TF_RETURN_IF_ERROR(TransitiveReduction());
770     } else {
771       LOG(ERROR) << "Iteration = " << iteration
772                  << ", topological sort failed with message: "
773                  << topo_sort_status.error_message();
774     }
775     // Turn nodes with only control outputs into NoOps, prune NoOp and Identity
776     // nodes.
777     TF_RETURN_IF_ERROR(OptimizeDependencies());
778 
779     // Dedup control inputs.
780     CleanControlInputs();
781 
782     // Merge multiple control edges from the same device.
783     GroupCrossDeviceControlEdges(/*host_granularity=*/false);
784 
785     // Merge control edges from the same host to reduce RPC traffic.
786     GroupCrossDeviceControlEdges(/*host_granularity=*/true);
787   }
788 
789   return OkStatus();
790 }
791 
792 }  // end namespace grappler
793 }  // end namespace tensorflow
794