xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/memory_optimizer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
17 
18 #include <algorithm>
19 #include <queue>
20 #include <set>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
31 #include "tensorflow/core/grappler/costs/graph_memory.h"
32 #include "tensorflow/core/grappler/costs/graph_properties.h"
33 #include "tensorflow/core/grappler/costs/utils.h"
34 #include "tensorflow/core/grappler/graph_topology_view.h"
35 #include "tensorflow/core/grappler/grappler_item.h"
36 #include "tensorflow/core/grappler/mutable_graph_view.h"
37 #include "tensorflow/core/grappler/op_types.h"
38 #include "tensorflow/core/grappler/optimizers/static_schedule.h"
39 #include "tensorflow/core/grappler/utils.h"
40 #include "tensorflow/core/grappler/utils/topological_sort.h"
41 #include "tensorflow/core/grappler/utils/traversal.h"
42 #include "tensorflow/core/lib/math/math_util.h"
43 #include "tensorflow/core/lib/strings/str_util.h"
44 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
45 #include "tensorflow/core/util/device_name_utils.h"
46 
47 namespace tensorflow {
48 namespace grappler {
49 
50 namespace {
51 
52 // Prefix added to nodes which are recomputed.
53 const char* kRecomputedNodePrefix = "Recomputed";
54 const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger";
55 // Attribute which may be added to nodes to manually allow them to be
56 // recomputed.
57 const char* kRecomputeHint = "_recompute_hint";
58 
59 // Ops which we wouldn't mind recomputing to save memory.
60 // TODO(allenl): Replace this list with a cost model.
GetCheapToRecomputeOps()61 std::unordered_set<string> GetCheapToRecomputeOps() {
62   std::unordered_set<string> cheap_ops = {"Add",
63                                           "AddN",
64                                           "BiasAdd",
65                                           "Cast",
66                                           "Fill",
67                                           "FloorDiv",
68                                           "FloorMod",
69                                           "FusedBatchNorm",
70                                           "LeakyRelu",
71                                           "Mul",
72                                           "Neg",
73                                           "RealDiv",
74                                           "Reciprocal",
75                                           "Relu",
76                                           "Relu6",
77                                           "Reshape",
78                                           "Rsqrt",
79                                           "Sigmoid",
80                                           "Sqrt",
81                                           "Square",
82                                           "SquaredDifference",
83                                           "Sub",
84                                           "Tile",
85                                           "Transpose"};
86   return cheap_ops;
87 }
88 
89 // Find recomputable ops which feed into target nodes.
FindCandidateRecomputeNodes(const NodeMap & node_map,const GraphDef * graph,const std::function<bool (const NodeDef &)> & is_candidate,const std::function<bool (const NodeDef &)> & is_target)90 std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
91     const NodeMap& node_map, const GraphDef* graph,
92     const std::function<bool(const NodeDef&)>& is_candidate,
93     const std::function<bool(const NodeDef&)>& is_target) {
94   std::unordered_set<const NodeDef*> candidate_recompute_nodes;
95   for (const auto& node : graph->node()) {
96     if (!is_candidate(node)) {
97       continue;
98     }
99     bool has_target_output = false;
100     for (const NodeDef* output : node_map.GetOutputs(node.name())) {
101       // It only makes sense to recompute this if it feeds into a target
102       // node. We expand this to dependencies in GetOpGroupsToRecompute.
103       if (is_target(*output)) {
104         has_target_output = true;
105         break;
106       }
107     }
108     if (!has_target_output) {
109       continue;
110     }
111     bool has_target_input = false;
112     for (const string& input_name : node.input()) {
113       // Don't recompute nodes which depend on target nodes.
114       const NodeDef* input_node = node_map.GetNode(input_name);
115       if (is_target(*input_node)) {
116         has_target_input = true;
117         break;
118       }
119     }
120     if (has_target_input) {
121       continue;
122     }
123     candidate_recompute_nodes.insert(&node);
124   }
125   return candidate_recompute_nodes;
126 }
127 
connected_subgraph(const NodeMap & node_map,bool collect_inputs,bool collect_outputs,const std::function<bool (const NodeDef &)> & is_candidate,std::unordered_set<const NodeDef * > * expanded_nodes)128 void connected_subgraph(const NodeMap& node_map, bool collect_inputs,
129                         bool collect_outputs,
130                         const std::function<bool(const NodeDef&)>& is_candidate,
131                         std::unordered_set<const NodeDef*>* expanded_nodes) {
132   std::queue<const NodeDef*> to_visit;
133   for (const NodeDef* starting_node : *expanded_nodes) {
134     to_visit.push(starting_node);
135   }
136   expanded_nodes->clear();
137   while (!to_visit.empty()) {
138     const NodeDef* current_node = to_visit.front();
139     to_visit.pop();
140     if (!expanded_nodes->insert(current_node).second) {
141       // We already visited this node
142       continue;
143     }
144     if (collect_inputs) {
145       // Add inputs and outputs to this subgraph if they are candidates
146       for (const string& input_name_raw : current_node->input()) {
147         const NodeDef* input_node = node_map.GetNode(input_name_raw);
148         if (expanded_nodes->count(input_node) == 0 &&
149             is_candidate(*input_node)) {
150           to_visit.push(input_node);
151         }
152       }
153     }
154     if (collect_outputs) {
155       for (const NodeDef* output : node_map.GetOutputs(current_node->name())) {
156         if (expanded_nodes->count(output) == 0 && is_candidate(*output)) {
157           to_visit.push(output);
158         }
159       }
160     }
161   }
162 }
163 
164 struct RecomputedSubGraph {
165   std::unordered_set<const NodeDef*> recomputed_source_nodes;
166   std::unordered_set<NodeDef*> target_nodes;
167 };
168 
169 // Find groups of ops to recompute together based on `should_recompute`.
GetOpGroupsToRecompute(const GraphDef * graph,const NodeMap & node_map,const std::function<bool (const NodeDef &)> & should_recompute,const std::function<bool (const NodeDef &)> & is_target)170 std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
171     const GraphDef* graph, const NodeMap& node_map,
172     const std::function<bool(const NodeDef&)>& should_recompute,
173     const std::function<bool(const NodeDef&)>& is_target) {
174   std::unordered_set<const NodeDef*> visited_nodes;
175   std::vector<RecomputedSubGraph> subgraphs_to_recompute;
176   std::unordered_set<const NodeDef*> candidate_recompute_nodes =
177       FindCandidateRecomputeNodes(node_map, graph, should_recompute, is_target);
178   for (const NodeDef* recompute_node : candidate_recompute_nodes) {
179     if (visited_nodes.count(recompute_node) > 0) {
180       continue;
181     }
182     RecomputedSubGraph current_recomputation;
183     // Build out recomputation groups by expanding to inexpensive-to-recompute
184     // nodes which do not feed target nodes. The goal is to capture some
185     // intermediate activations within this graph.
186     std::unordered_set<const NodeDef*> unpruned_recompute_nodes;
187     unpruned_recompute_nodes.insert(recompute_node);
188     connected_subgraph(node_map,
189                        true,  // Collect inputs
190                        true,  // Collect outputs
191                        should_recompute, &unpruned_recompute_nodes);
192     visited_nodes.insert(unpruned_recompute_nodes.begin(),
193                          unpruned_recompute_nodes.end());
194     for (const NodeDef* unpruned_recompute_node : unpruned_recompute_nodes) {
195       bool inserted_feed = false;
196       for (NodeDef* output :
197            node_map.GetOutputs(unpruned_recompute_node->name())) {
198         if (is_target(*output)) {
199           current_recomputation.target_nodes.insert(output);
200           if (!inserted_feed) {
201             // Keep track of nodes which feed directly into a target node. These
202             // and nodes which feed into them will define the recomputed
203             // subgraph.
204             current_recomputation.recomputed_source_nodes.insert(
205                 unpruned_recompute_node);
206             inserted_feed = true;
207           }
208         }
209       }
210     }
211     // Recompute only nodes which eventually feed into a target node.
212     connected_subgraph(
213         node_map,
214         true,   // Collect inputs
215         false,  // Collect outputs
216         [&unpruned_recompute_nodes](const NodeDef& node) {
217           return unpruned_recompute_nodes.count(&node) != 0;
218         },
219         &current_recomputation.recomputed_source_nodes);
220     if (current_recomputation.target_nodes.empty()) {
221       continue;
222     }
223     subgraphs_to_recompute.push_back(current_recomputation);
224   }
225   return subgraphs_to_recompute;
226 }
227 
228 // Computes the maximum topological numbers of (1) target node components
229 // (gradient nodes being fed by the recomputation), and (2) child recompute node
230 // components for each recomputed node. We will not attach any control
231 // dependencies to a recomputation unless they have component numbers greater
232 // than this value (to prevent cycles).
GetMaxDownstreamComponents(const std::unordered_set<const NodeDef * > & recomputed_source_nodes,const std::unordered_set<NodeDef * > & target_nodes,const NodeMap & node_map,const std::unordered_map<const NodeDef *,int> & components)233 std::unordered_map<const NodeDef*, int> GetMaxDownstreamComponents(
234     const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
235     const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
236     const std::unordered_map<const NodeDef*, int>& components) {
237   std::unordered_map<const NodeDef*, int> recomputed_node_components;
238   // Start by setting component numbers to the maximum among target nodes.
239   for (const NodeDef* original_recompute_node : recomputed_source_nodes) {
240     int max_target_component = -1;
241     for (NodeDef* output :
242          node_map.GetOutputs(original_recompute_node->name())) {
243       if (target_nodes.count(output) != 0) {
244         int current_target_component = components.find(output)->second;
245         if (current_target_component > max_target_component) {
246           max_target_component = current_target_component;
247         }
248       }
249     }
250     if (max_target_component > -1) {
251       recomputed_node_components[original_recompute_node] =
252           max_target_component;
253     }
254   }
255   // Sort recomputed nodes topologically (based on the original graph) so we can
256   // efficiently assign to each node the maximum of its recomputed child
257   // components and its own targets.
258   std::vector<const NodeDef*> recomputed_source_nodes_topological(
259       recomputed_source_nodes.begin(), recomputed_source_nodes.end());
260   std::sort(recomputed_source_nodes_topological.begin(),
261             recomputed_source_nodes_topological.end(),
262             [&components](const NodeDef* first, const NodeDef* second) {
263               return components.find(first)->second <
264                      components.find(second)->second;
265             });
266   for (const NodeDef* original_recompute_node :
267        recomputed_source_nodes_topological) {
268     int max_component;
269     auto recomputed_component_iterator =
270         recomputed_node_components.find(original_recompute_node);
271     if (recomputed_component_iterator != recomputed_node_components.end()) {
272       max_component = recomputed_component_iterator->second;
273     } else {
274       max_component = -1;
275     }
276     for (NodeDef* output :
277          node_map.GetOutputs(original_recompute_node->name())) {
278       if (recomputed_source_nodes.count(output) == 0) {
279         continue;
280       }
281       auto child_component_iterator = recomputed_node_components.find(output);
282       CHECK(child_component_iterator != recomputed_node_components.end());
283       int child_component = child_component_iterator->second;
284       if (child_component > max_component) {
285         max_component = child_component;
286       }
287     }
288     CHECK_GE(max_component, 0);
289     recomputed_node_components[original_recompute_node] = max_component;
290   }
291   return recomputed_node_components;
292 }
293 
294 // Modifies `graph`, adding trigger nodes and returning a mapping from
295 // `recomputed_source_nodes` to trigger nodes which will not create loops in the
296 // graph (using the component numberings in `components` and
297 // `recomputed_node_max_feed_components`). The copied nodes (not the nodes in
298 // recomputed_source_nodes, which are the originals) eventually get these
299 // control dependencies.
300 std::unordered_map<const NodeDef*, const NodeDef*>
AddRecomputeControlDependencyNodes(const std::unordered_set<const NodeDef * > & recomputed_source_nodes,const std::unordered_set<NodeDef * > & target_nodes,const NodeMap & node_map,const std::unordered_map<const NodeDef *,int> & components,const std::unordered_map<const NodeDef *,int> & recomputed_node_max_feed_components,GraphDef * graph)301 AddRecomputeControlDependencyNodes(
302     const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
303     const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
304     const std::unordered_map<const NodeDef*, int>& components,
305     const std::unordered_map<const NodeDef*, int>&
306         recomputed_node_max_feed_components,
307     GraphDef* graph) {
308   // Sort recomputed nodes based on max downstream components.
309   std::vector<const NodeDef*> recomputed_source_nodes_topological(
310       recomputed_source_nodes.begin(), recomputed_source_nodes.end());
311   std::sort(recomputed_source_nodes_topological.begin(),
312             recomputed_source_nodes_topological.end(),
313             [&recomputed_node_max_feed_components](const NodeDef* first,
314                                                    const NodeDef* second) {
315               int first_component =
316                   recomputed_node_max_feed_components.find(first)->second;
317               int second_component =
318                   recomputed_node_max_feed_components.find(second)->second;
319               return first_component > second_component
320                      // Ensure a consistent ordering. This is necessary because
321                      // we're working not with node component numbers (which are
322                      // unique) but with the maximum across nodes they feed into
323                      // (very much not unique).
324                      || (first_component == second_component &&
325                          first->name() > second->name());
326             });
327   // Create merged control dependency nodes by sorting target inputs
328   // topologically and zipper merging with the sorted recomputed nodes.
329   std::vector<const NodeDef*> target_inputs_topological;
330   for (const NodeDef* target_node : target_nodes) {
331     for (const string& target_input_name_raw : target_node->input()) {
332       const NodeDef* target_input = node_map.GetNode(target_input_name_raw);
333       // If this node has already had one of its inputs recomputed during this
334       // rewriting pass, we ignore that recomputed node here (it will not be in
335       // the NodeMap).
336       if (target_input == nullptr ||
337           recomputed_source_nodes.count(target_input) != 0 ||
338           components.find(target_node)->second ==
339               components.find(target_input)->second) {
340         continue;
341       }
342       target_inputs_topological.push_back(target_input);
343     }
344   }
345   std::sort(target_inputs_topological.begin(), target_inputs_topological.end(),
346             [&components](const NodeDef* first, const NodeDef* second) {
347               return components.find(first)->second >
348                      components.find(second)->second;
349             });
350   auto target_input_iterator = target_inputs_topological.begin();
351   NodeDef* current_trigger_node = nullptr;
352   std::unordered_map<const NodeDef*, const NodeDef*> triggers;
353   for (const NodeDef* original_recomputed_node :
354        recomputed_source_nodes_topological) {
355     NodeDef* new_trigger_node = graph->add_node();
356     new_trigger_node->set_name(AddPrefixToNodeName(
357         original_recomputed_node->name(), kRecomputeTriggerNodePrefix));
358     new_trigger_node->set_op("NoOp");
359     new_trigger_node->set_device(original_recomputed_node->device());
360     if (current_trigger_node != nullptr) {
361       *new_trigger_node->add_input() =
362           strings::StrCat("^", current_trigger_node->name());
363     }
364     current_trigger_node = new_trigger_node;
365     triggers[original_recomputed_node] = current_trigger_node;
366     for (;
367          target_input_iterator != target_inputs_topological.end() &&
368          components.find(*target_input_iterator)->second >
369              recomputed_node_max_feed_components.find(original_recomputed_node)
370                  ->second;
371          ++target_input_iterator) {
372       *current_trigger_node->add_input() =
373           strings::StrCat("^", (*target_input_iterator)->name());
374       VLOG(2) << "  Recomputation trigger " << current_trigger_node->name()
375               << " depends on " << (*target_input_iterator)->name();
376     }
377   }
378   return triggers;
379 }
380 
RecomputedOrOriginalNodeName(const std::unordered_set<string> & recomputed_node_names,const string & original_node_name)381 string RecomputedOrOriginalNodeName(
382     const std::unordered_set<string>& recomputed_node_names,
383     const string& original_node_name) {
384   if (recomputed_node_names.find(original_node_name) ==
385       recomputed_node_names.end()) {
386     return original_node_name;
387   } else {
388     return AddPrefixToNodeName(original_node_name, kRecomputedNodePrefix);
389   }
390 }
391 
392 // Helper function to recompute a sub-graph (recomputed_source_nodes). Edges
393 // from recomputed_source_nodes to target_nodes are changed to start from the
394 // recomputed nodes.
RecomputeSubgraph(const std::unordered_set<const NodeDef * > & recomputed_source_nodes,const std::unordered_set<NodeDef * > & target_nodes,const NodeMap & node_map,const std::unordered_map<const NodeDef *,int> & components,GraphDef * graph)395 void RecomputeSubgraph(
396     const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
397     const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
398     const std::unordered_map<const NodeDef*, int>& components,
399     GraphDef* graph) {
400   std::unordered_set<string> recomputed_node_names;
401   VLOG(1) << "Recomputing a " << recomputed_source_nodes.size()
402           << " node subgraph";
403   std::unordered_map<const NodeDef*, int> recomputed_node_components =
404       GetMaxDownstreamComponents(recomputed_source_nodes, target_nodes,
405                                  node_map, components);
406   for (const NodeDef* original_node : recomputed_source_nodes) {
407     VLOG(2) << "  " << original_node->name();
408     recomputed_node_names.insert(original_node->name());
409   }
410   std::unordered_map<const NodeDef*, const NodeDef*> triggers =
411       AddRecomputeControlDependencyNodes(recomputed_source_nodes, target_nodes,
412                                          node_map, components,
413                                          recomputed_node_components, graph);
414   // Create the recomputed sub-graph
415   for (const NodeDef* original_node : recomputed_source_nodes) {
416     NodeDef* copied_node = graph->add_node();
417     copied_node->set_name(
418         AddPrefixToNodeName(original_node->name(), kRecomputedNodePrefix));
419     copied_node->set_op(original_node->op());
420     *copied_node->mutable_attr() = original_node->attr();
421     copied_node->set_device(original_node->device());
422     for (const string& original_input_name : original_node->input()) {
423       // Set inputs which are internal to the copied subgraph to their copied
424       // versions.
425       *copied_node->add_input() = RecomputedOrOriginalNodeName(
426           recomputed_node_names, original_input_name);
427     }
428     // Each recomputed node gets a control dependency to prevent it from being
429     // recomputed immediately.
430     *copied_node->add_input() =
431         strings::StrCat("^", triggers[original_node]->name());
432   }
433   // Set the inputs of nodes in the target subgraph to the recomputed nodes
434   // where applicable.
435   for (NodeDef* target_node : target_nodes) {
436     for (string& target_input_name : *target_node->mutable_input()) {
437       target_input_name = RecomputedOrOriginalNodeName(recomputed_node_names,
438                                                        target_input_name);
439     }
440   }
441 }
442 
RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,const string & recomputation_targets_name_scope,GraphDef * graph,const GrapplerItem & item)443 void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
444                                 const string& recomputation_targets_name_scope,
445                                 GraphDef* graph, const GrapplerItem& item) {
446   // The topological numberings and NodeMap will be stale as soon as we start
447   // modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
448   // looks up nodes which were in the original graph, and preserves the graph
449   // topology it's interested in.
450   // We don't use the results of this topological sort until later, but this
451   // call invalidates all NodeDef pointers, so it needs to be done before we
452   // start collecting those.
453   TF_CHECK_OK(TopologicalSort(graph));
454   NodeMap node_map(graph);
455   std::vector<RecomputedSubGraph> recomputed_subgraphs;
456   // Do not recompute nodes which are fed, since the recomputed node would not
457   // take on the fed value (i.e. gradients would be incorrect).
458   std::unordered_set<string> feeds;
459   for (const auto& feed : item.feed) {
460     feeds.insert(NodeName(feed.first));
461   }
462   std::function<bool(const NodeDef&)> is_target =
463       [&recomputation_targets_name_scope](const NodeDef& node) {
464         // Nodes whose inputs we may want to recompute. This matches node names
465         // that contain recomputation_targets_name_scope as a name scope,
466         // meaning it either begins with or contains the name scope.
467         // Defaults to "gradients/" which will match any node names that begins
468         // with "gradients/" or contains "/gradients/".
469         return absl::StartsWith(node.name(),
470                                 recomputation_targets_name_scope) ||
471                static_cast<int>(node.name().find(
472                    "/" + recomputation_targets_name_scope)) != -1;
473       };
474 
475   if (optimization_level == RewriterConfig::RECOMPUTATION_HEURISTICS ||
476       optimization_level == RewriterConfig::HEURISTICS) {
477     // TODO(allenl): Handle ResNet-like architectures better. Right now all of
478     // the cheap forward ops get grouped into a single subgraph which must
479     // execute before gradients start executing (unless layers are manually
480     // separated by identity ops).
481     std::unordered_set<string> cheap_to_recompute_ops =
482         GetCheapToRecomputeOps();
483     recomputed_subgraphs = GetOpGroupsToRecompute(
484         graph, node_map,
485         [&cheap_to_recompute_ops, &feeds, &is_target](const NodeDef& node) {
486           return !is_target(node) && feeds.count(node.name()) == 0 &&
487                  (cheap_to_recompute_ops.count(node.op()) > 0 ||
488                   node.attr().count(kRecomputeHint) > 0);
489         },
490         is_target);
491   } else if (optimization_level == RewriterConfig::MANUAL) {
492     recomputed_subgraphs = GetOpGroupsToRecompute(
493         graph, node_map,
494         [&feeds, &is_target](const NodeDef& node) {
495           return !is_target(node) && feeds.count(node.name()) == 0 &&
496                  node.attr().count(kRecomputeHint) > 0;
497         },
498         is_target);
499   }
500   if (!recomputed_subgraphs.empty()) {
501     std::unordered_map<const NodeDef*, int> topological_numbering;
502     for (int node_number = 0; node_number < graph->node().size();
503          ++node_number) {
504       topological_numbering[graph->mutable_node(node_number)] =
505           graph->node().size() - node_number - 1;
506     }
507     // Duplicate the indicated sub-graphs and set up control dependencies
508     for (const RecomputedSubGraph& subgraph : recomputed_subgraphs) {
509       RecomputeSubgraph(subgraph.recomputed_source_nodes, subgraph.target_nodes,
510                         node_map, topological_numbering, graph);
511     }
512   }
513 }
514 
SchedulingPass(Cluster * cluster,std::unique_ptr<GraphMemory> * memory_ptr,GrapplerItem * item)515 bool SchedulingPass(Cluster* cluster, std::unique_ptr<GraphMemory>* memory_ptr,
516                     GrapplerItem* item) {
517   // Look for AddN nodes (and equivalent) and record input names.
518   MutableGraphView view(&item->graph);
519 
520   std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
521   for (NodeDef& node : *item->graph.mutable_node()) {
522     if (!IsAddN(node) && node.op() != "AccumulateNV2") {
523       continue;
524     }
525     // There is nothing to gain by optimizing nodes with 2 or fewer inputs.
526     if (view.NumFanins(node, false) <= 2) {
527       continue;
528     }
529     for (const auto& input : view.GetFanins(node, false)) {
530       if (input.node->device() == node.device()) {
531         string tensor_name =
532             strings::StrCat(input.node->name(), ":", input.port_id);
533         addn_list[tensor_name].insert(&node);
534       }
535     }
536   }
537 
538   if (addn_list.empty()) {
539     return false;
540   }
541 
542   if ((*memory_ptr) == nullptr) {
543     memory_ptr->reset(new GraphMemory(*item));
544     Status s = (*memory_ptr)->InferStatically(cluster->GetDevices());
545     if (!s.ok()) {
546       memory_ptr->reset();
547       VLOG(1) << "Failed to infer memory usage: " << s.error_message();
548       return false;
549     }
550   }
551   const GraphMemory& memory = **memory_ptr;
552 
553   std::unordered_set<NodeDef*> addn_to_rewrite;
554   for (const auto& device : cluster->GetDevices()) {
555     const string& name = device.first;
556     const DeviceProperties& prop = device.second;
557     if (prop.memory_size() <= 0) {
558       VLOG(1) << "Available memory unknown for device " << name;
559       continue;
560     }
561     const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
562 
563     if (mem_usage.used_memory <= prop.memory_size() * 0.8) {
564       continue;
565     }
566 
567     for (const auto& live : mem_usage.live_tensors) {
568       string tensor_name = strings::StrCat(live.node, ":", live.output_id);
569       auto it = addn_list.find(tensor_name);
570       if (it != addn_list.end()) {
571         addn_to_rewrite.insert(it->second.begin(), it->second.end());
572       }
573     }
574   }
575 
576   if (addn_to_rewrite.empty()) {
577     return false;
578   }
579   GraphProperties properties(*item);
580   Status s = properties.InferStatically(/*assume_valid_feeds=*/false,
581                                         /*aggressive_shape_inference=*/false,
582                                         /*include_tensor_values=*/false);
583   if (!s.ok()) {
584     VLOG(1) << "Failed to infer shapes: " << s.error_message();
585     return false;
586   }
587 
588   // It's ok to use immutable GraphTopologyView here, because we do not destroy
589   // any of the nodes in the underlying graph, we only add new nodes.
590   GraphTopologyView graph_topology;
591   Status initialized_topology = graph_topology.InitializeFromGraph(item->graph);
592   if (!initialized_topology.ok()) {
593     VLOG(1) << "Failed to initialize graph topology view: "
594             << initialized_topology.error_message();
595     return false;
596   }
597 
598   bool updated_graph = false;
599   // Rewrite the AddN.
600   for (NodeDef* node : addn_to_rewrite) {
601     if (!properties.HasOutputProperties(node->name())) {
602       VLOG(1) << "Missing properties for " << node->name();
603       continue;
604     }
605     const TensorShapeProto& shape =
606         properties.GetOutputProperties(node->name())[0].shape();
607     PartialTensorShape shp(shape);
608     if (!shp.IsFullyDefined()) {
609       VLOG(1) << "Shape not fully known for " << node->name();
610       continue;
611     }
612     DataType dtype = node->attr().at("T").type();
613     if (dtype != DT_HALF && dtype != DT_FLOAT && dtype != DT_DOUBLE &&
614         dtype != DT_INT64) {  // Only GPU-supported TemporaryVariable types.
615       VLOG(1) << "Unsupported dtype for " << node->name();
616       continue;
617     }
618 
619     // Compute a topological ordering for the node fanin.
620     std::unordered_map<const NodeDef*, int> topo_order;
621     DfsTraversal(graph_topology, {node}, TraversalDirection::kFollowInputs,
622                  DfsCallbacks::PostOrder([&topo_order](const NodeDef* n) {
623                    int topo_index = static_cast<int>(topo_order.size());
624                    topo_order[n] = topo_index;
625                  }));
626 
627     std::vector<int> input_topo_index;
628 
629     for (int i = 0; i < node->input_size(); ++i) {
630       const string& input = node->input(i);
631       const string node_name = NodeName(input);
632       const NodeDef* node = view.GetNode(node_name);
633       input_topo_index.push_back(topo_order.at(node));
634     }
635     int min_input_topo_index = INT_MAX;
636     int min_input_id = -1;
637     for (int i = 0; i < node->input_size(); ++i) {
638       if (IsControlInput(node->input(i))) {
639         // control inputs are always last.
640         break;
641       }
642       const int current = input_topo_index[i];
643       if (current < min_input_topo_index) {
644         min_input_topo_index = current;
645         min_input_id = i;
646       }
647     }
648     CHECK_LE(0, min_input_id);
649     std::vector<string> pre_ctrl_deps;
650     std::vector<string> post_ctrl_deps;
651     for (int i = node->input_size() - 1; i >= 0; --i) {
652       if (!IsControlInput(node->input(i))) {
653         // control inputs are always last.
654         break;
655       }
656       if (input_topo_index[i] < min_input_topo_index) {
657         // These control dependencies can be executed before the node.
658         pre_ctrl_deps.push_back(node->input(i));
659       } else {
660         // These control dependencies should be executed after the node.
661         post_ctrl_deps.push_back(node->input(i));
662       }
663     }
664 
665     const string& device = node->device();
666     const string tmp_var_name = strings::StrCat(node->name(), "/tmp_var");
667     if (view.GetNode(tmp_var_name) != nullptr) {
668       VLOG(1) << "Temporary variable already exists " << tmp_var_name;
669       return false;
670     }
671 
672     // Create the temporary variable that will hold intermediate results
673     NodeDef* tmp_var = item->graph.add_node();
674     tmp_var->set_name(tmp_var_name);
675     tmp_var->set_op("TemporaryVariable");
676     tmp_var->set_device(device);
677     (*tmp_var->mutable_attr())["dtype"].set_type(dtype);
678     *(*tmp_var->mutable_attr())["shape"].mutable_shape() = shape;
679     (*tmp_var->mutable_attr())["var_name"].set_s(tmp_var->name());
680 
681     for (const string& ctrl_dep : pre_ctrl_deps) {
682       *tmp_var->add_input() = ctrl_dep;
683     }
684     *tmp_var->add_input() =
685         AsControlDependency(NodeName(node->input(min_input_id)));
686 
687     // Initialize it to zero
688     NodeDef* zeros = item->graph.add_node();
689     zeros->set_name(strings::StrCat(node->name(), "/tmp_var_zeros"));
690     zeros->set_op("ZerosLike");
691     zeros->set_device(device);
692     (*zeros->mutable_attr())["T"].set_type(dtype);
693     *zeros->add_input() = node->input(min_input_id);
694 
695     NodeDef* initialize = item->graph.add_node();
696     initialize->set_name(strings::StrCat(node->name(), "/tmp_var_initializer"));
697     initialize->set_op("Assign");
698     initialize->set_device(device);
699     (*initialize->mutable_attr())["T"].set_type(dtype);
700     (*initialize->mutable_attr())["use_locking"].set_b(false);
701     (*initialize->mutable_attr())["validate_shape"].set_b(false);
702     *initialize->add_input() = tmp_var->name();
703     *initialize->add_input() = zeros->name();
704 
705     // Add the assignadd nodes
706     std::vector<NodeDef*> accumulates;
707     for (int i = 0; i < node->input_size(); ++i) {
708       const string& input = node->input(i);
709       if (!IsControlInput(input)) {
710         NodeDef* accumulate = item->graph.add_node();
711         accumulate->set_name(
712             strings::StrCat(node->name(), "/tmp_var_accum_", i));
713         accumulate->set_op("AssignAdd");
714         accumulate->set_device(device);
715         (*accumulate->mutable_attr())["T"].set_type(dtype);
716         (*accumulate->mutable_attr())["use_locking"].set_b(true);
717         *accumulate->add_input() = initialize->name();
718         *accumulate->add_input() = input;
719         accumulates.push_back(accumulate);
720       }
721     }
722 
723     // Rewrite the AddN node as a DestroyTemporaryVariable ops
724     node->set_op("DestroyTemporaryVariable");
725     node->clear_input();
726     EraseRegularNodeAttributes(node);
727     (*node->mutable_attr())["T"].set_type(dtype);
728     (*node->mutable_attr())["var_name"].set_s(tmp_var->name());
729     *node->add_input() = initialize->name();
730     for (const NodeDef* accum : accumulates) {
731       *node->add_input() = AsControlDependency(accum->name());
732     }
733     for (const string& ctrl_dep : post_ctrl_deps) {
734       *node->add_input() = ctrl_dep;
735     }
736 
737     updated_graph = true;
738   }
739 
740   return updated_graph;
741 }
742 
BuildSwapPair(NodeDef * node,int input_to_swap,const std::unordered_map<string,const NodeDef * > & name_map,GraphDef * graph,std::pair<NodeDef *,NodeDef * > * swap_pair)743 Status BuildSwapPair(NodeDef* node, int input_to_swap,
744                      const std::unordered_map<string, const NodeDef*>& name_map,
745                      GraphDef* graph,
746                      std::pair<NodeDef*, NodeDef*>* swap_pair) {
747   string task, device;
748   if (!DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) ||
749       !absl::StrContains(device, DEVICE_GPU)) {
750     return errors::InvalidArgument("Can't swap input ", input_to_swap,
751                                    " of node ", node->name(),
752                                    " since it is not on GPU");
753   }
754   const OpDef* op_def;
755   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
756   DataType input_type;
757   TF_RETURN_IF_ERROR(
758       InputTypeForNode(*node, *op_def, input_to_swap, &input_type));
759   if (IsRefType(input_type)) {
760     return errors::InvalidArgument("Can't swap input ", input_to_swap,
761                                    " of node ", node->name(),
762                                    " since it expects a reference");
763   }
764 
765   string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
766   string swap_out_name = strings::StrCat("swap_out_", tensor_to_swap);
767   string swap_in_name = strings::StrCat("swap_in_", tensor_to_swap);
768   if (name_map.find(swap_out_name) != name_map.end() ||
769       name_map.find(swap_in_name) != name_map.end()) {
770     return errors::InvalidArgument("Input ", input_to_swap, " of node ",
771                                    node->name(), " is already swapped");
772   }
773 
774   // Force the tensor to be copied to cpu.
775   NodeDef* swap_out_node = graph->add_node();
776   swap_out_node->set_name(swap_out_name);
777   swap_out_node->set_op("_CopyFromGpuToHost");
778 
779   // Force the tensor to be restored to the device.
780   NodeDef* swap_in_node = graph->add_node();
781   swap_in_node->set_name(swap_in_name);
782   swap_in_node->set_op("_CopyFromHostToGpu");
783   *swap_in_node->add_input() = swap_out_node->name();
784 
785   // Colocate the swap_out_ and swap_in_ nodes with the node itself.
786   swap_out_node->set_device(node->device());
787   swap_in_node->set_device(node->device());
788   string coloc_group = strings::StrCat("loc@", tensor_to_swap);
789   (*swap_out_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
790   (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
791   (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
792 
793   (*swap_in_node->mutable_attr())["T"].set_type(input_type);
794   (*swap_out_node->mutable_attr())["T"].set_type(input_type);
795   *swap_pair = std::make_pair(swap_out_node, swap_in_node);
796 
797   return OkStatus();
798 }
799 
800 struct SwapInfo {
801   std::vector<int> inputs_to_swap;
802   Costs::NanoSeconds time_to_swap = 0;
803 };
804 
FindSwapInTrigger(const NodeDef * node,const SwapInfo & swap_info,const std::unordered_map<string,const NodeDef * > & name_map,const std::unordered_map<const NodeDef *,Costs::NanoSeconds> & execution_times)805 static const NodeDef* FindSwapInTrigger(
806     const NodeDef* node, const SwapInfo& swap_info,
807     const std::unordered_map<string, const NodeDef*>& name_map,
808     const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
809         execution_times) {
810   // max_trigger_time stores the time before which the swap operation needs to
811   // be started in order to load the data back onto the accelerator without
812   // delaying the downstream computation.
813   Costs::NanoSeconds max_trigger_time(0);
814   std::set<string> possible_inputs;
815   for (int i = 0; i < node->input_size(); ++i) {
816     const string input_node_name = NodeName(node->input(i));
817     auto it1 = name_map.find(input_node_name);
818     if (it1 == name_map.end()) {
819       return nullptr;
820     }
821     const NodeDef* input_node = it1->second;
822 
823     auto it2 = execution_times.find(input_node);
824     if (it2 == execution_times.end()) {
825       return nullptr;
826     }
827     max_trigger_time = std::max(max_trigger_time, it2->second);
828     possible_inputs.insert(input_node_name);
829   }
830 
831   for (const int i : swap_info.inputs_to_swap) {
832     const string input_node_name = NodeName(node->input(i));
833     possible_inputs.erase(input_node_name);
834   }
835   if (possible_inputs.empty()) {
836     return nullptr;
837   }
838 
839   max_trigger_time -= swap_info.time_to_swap;
840 
841   std::map<Costs::NanoSeconds, const NodeDef*> candidates;
842   std::set<string> already_processed;
843 
844   while (!possible_inputs.empty()) {
845     const string input_node_name = *possible_inputs.begin();
846     possible_inputs.erase(possible_inputs.begin());
847     already_processed.insert(input_node_name);
848     auto it1 = name_map.find(input_node_name);
849     if (it1 == name_map.end()) {
850       return nullptr;
851     }
852     const NodeDef* input_node = it1->second;
853     // Don't jump over frames, since adding a control dependency from one frame
854     // to the next isn't supported. Don't go through branches, since we don't
855     // know whether they'll be executed or not.
856     if (ModifiesFrameInfo(*input_node) || IsSwitch(*input_node) ||
857         IsMerge(*input_node)) {
858       continue;
859     }
860     auto it2 = execution_times.find(input_node);
861     if (it2 == execution_times.end()) {
862       return nullptr;
863     }
864     if (it2->second < max_trigger_time) {
865       candidates[it2->second] = input_node;
866     } else {
867       for (const string& fanin : input_node->input()) {
868         string name = NodeName(fanin);
869         if (already_processed.find(name) == already_processed.end()) {
870           possible_inputs.insert(name);
871         }
872       }
873     }
874   }
875 
876   // Select the candidate that will execute last, since we want to swap the data
877   // back at the last minute while still allowing enough time for data to be
878   // swapped back timely to feed the downstream nodes.
879   if (!candidates.empty()) {
880     return candidates.rbegin()->second;
881   }
882   return nullptr;
883 }
884 
IsSwappable(const MutableGraphView & graph,MutableGraphView::OutputPort output)885 static bool IsSwappable(const MutableGraphView& graph,
886                         MutableGraphView::OutputPort output) {
887   const NodeDef& node = *output.node;
888   // There is no point in swapping out persistent tensors, since the tensor will
889   // continue to use memory.
890   if (IsPersistent(node)) {
891     return false;
892   }
893 
894   const OpDef* op_def;
895   if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
896     return false;
897   }
898   DataType dtype;
899   if (!OutputTypeForNode(node, *op_def, output.port_id, &dtype).ok()) {
900     return false;
901   }
902   // References can only refer to persistent memory: therefore the node isn't
903   // swappable.
904   if (IsRefType(dtype)) {
905     return false;
906   }
907 
908   if (output.node->op() == "Identity" || output.node->op() == "Reshape") {
909     // If placed on the same device, these nodes are just forwarding references
910     // to their input. Therefore they are swappable iff their fanin is swappable
911     // or it resides on a different device.
912     MutableGraphView::InputPort input;
913     input.node = output.node;
914     input.port_id = 0;
915     MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
916     if (fanin.node->device() == node.device()) {
917       return IsSwappable(graph, fanin);
918     }
919   }
920   return true;
921 }
922 
FindSwapOutTrigger(const NodeDef * node,int input_id,const MutableGraphView & view,const std::unordered_map<const NodeDef *,Costs::NanoSeconds> & execution_times)923 static NodeDef* FindSwapOutTrigger(
924     const NodeDef* node, int input_id, const MutableGraphView& view,
925     const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
926         execution_times) {
927   // Find the output port that generated the tensor to swap.
928   MutableGraphView::InputPort swap;
929   swap.node = const_cast<NodeDef*>(node);
930   swap.port_id = input_id;
931   MutableGraphView::OutputPort generator = view.GetRegularFanin(swap);
932   if (!generator.node) {
933     return nullptr;
934   }
935 
936   const absl::flat_hash_set<MutableGraphView::InputPort>& fanout =
937       view.GetFanout(generator);
938   NodeDef* trigger = nullptr;
939   Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity());
940 
941   for (const auto& port : fanout) {
942     if (port.node == node) {
943       continue;
944     }
945     auto it = execution_times.find(port.node);
946     if (it != execution_times.end() && it->second < earliest_fanout) {
947       earliest_fanout = it->second;
948       trigger = port.node;
949     }
950   }
951 
952   return trigger;
953 }
954 
IsSwappable(MutableGraphView::InputPort input)955 static bool IsSwappable(MutableGraphView::InputPort input) {
956   const NodeDef& node = *input.node;
957 
958   const OpDef* op_def;
959   if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
960     return false;
961   }
962 
963   DataType dtype;
964   if (!InputTypeForNode(node, *op_def, input.port_id, &dtype).ok()) {
965     return false;
966   }
967 
968   return !IsRefType(dtype);
969 }
970 
971 struct MemInfo {
972   MutableGraphView::OutputPort port;
973   int64_t memory_used;
974   std::vector<MutableGraphView::InputPort> uses_left;
975   double fitness;
976 
operator <tensorflow::grappler::__anondeac83630111::MemInfo977   bool operator<(const MemInfo& other) const { return fitness < other.fitness; }
978 };
979 
IdentifySwappingCandidates(Cluster * cluster,GrapplerItem * item,std::unique_ptr<GraphMemory> * memory_ptr,std::unordered_set<string> * skip_list,std::unordered_map<NodeDef *,SwapInfo> * nodes_to_swap)980 static bool IdentifySwappingCandidates(
981     Cluster* cluster, GrapplerItem* item,
982     std::unique_ptr<GraphMemory>* memory_ptr,
983     std::unordered_set<string>* skip_list,
984     std::unordered_map<NodeDef*, SwapInfo>* nodes_to_swap) {
985   if ((*memory_ptr) == nullptr) {
986     memory_ptr->reset(new GraphMemory(*item));
987     Status s = (*memory_ptr)->InferStatically(cluster->GetDevices());
988     if (!s.ok()) {
989       memory_ptr->reset();
990       VLOG(1) << "Failed to infer memory usage: " << s.error_message();
991       return false;
992     }
993   }
994   const GraphMemory& memory = **memory_ptr;
995 
996   bool updated_graph = false;
997   for (const auto& device : cluster->GetDevices()) {
998     const string& name = device.first;
999     const DeviceProperties& prop = device.second;
1000     if (prop.type() != "GPU") {
1001       continue;
1002     }
1003     if (prop.memory_size() <= 0) {
1004       VLOG(1) << "Peak memory usage unknown for device " << name;
1005       continue;
1006     }
1007     const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
1008 
1009     if (mem_usage.used_memory <= prop.memory_size()) {
1010       continue;
1011     }
1012     int64_t required_savings = mem_usage.used_memory - prop.memory_size();
1013 
1014     std::unordered_map<string, Costs::NanoSeconds> op_completion_times;
1015     {
1016       VirtualCluster vcluster(cluster->GetDevices());
1017       if (!vcluster.Provision().ok()) {
1018         return false;
1019       }
1020       if (!vcluster.Initialize(*item).ok()) {
1021         return false;
1022       }
1023       RunMetadata metadata;
1024       Status s = vcluster.Run(item->graph, item->feed, item->fetch, &metadata);
1025       if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
1026         return false;
1027       }
1028 
1029       for (const auto& dev_stats : metadata.step_stats().dev_stats()) {
1030         for (const auto& node_stats : dev_stats.node_stats()) {
1031           Costs::NanoSeconds exec_time =
1032               Costs::NanoSeconds(1) +
1033               Costs::MicroSeconds(node_stats.all_start_micros() +
1034                                   node_stats.op_end_rel_micros());
1035           op_completion_times.emplace(node_stats.node_name(), exec_time);
1036         }
1037       }
1038     }
1039 
1040     Costs::Duration peak_time = -1;
1041     for (const auto& live_tensor : mem_usage.live_tensors) {
1042       if (live_tensor.allocation_time > peak_time) {
1043         peak_time = live_tensor.allocation_time;
1044       }
1045     }
1046 
1047     std::vector<MemInfo> mem_state;
1048 
1049     MutableGraphView graph(&item->graph);
1050     for (const auto& live_tensor : mem_usage.live_tensors) {
1051       if (live_tensor.memory_used <= 1024) {
1052         // Don't bother with small tensors.
1053         continue;
1054       }
1055       if (live_tensor.deallocation_time - live_tensor.allocation_time <=
1056           Costs::Duration(1e6)) {
1057         // Not enough time to swap.
1058         VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node;
1059         continue;
1060       }
1061 
1062       if (skip_list->find(live_tensor.node) != skip_list->end()) {
1063         continue;
1064       }
1065       MutableGraphView::OutputPort port =
1066           graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
1067       if (!IsSwappable(graph, port)) {
1068         continue;
1069       }
1070       MemInfo mem_info;
1071       mem_info.port = port;
1072       mem_info.memory_used = live_tensor.memory_used;
1073       Costs::Duration allocation_time = live_tensor.allocation_time;
1074       Costs::Duration earliest_use(Costs::Duration::infinity());
1075       bool valid = true;
1076       for (MutableGraphView::InputPort input : graph.GetFanout(port)) {
1077         // Get execution time.
1078         auto it = op_completion_times.find(input.node->name());
1079         if (it == op_completion_times.end()) {
1080           valid = false;
1081           break;
1082         }
1083         if (it->second <= peak_time) {
1084           continue;
1085         }
1086 
1087         if (skip_list->find(input.node->name()) != skip_list->end()) {
1088           valid = false;
1089           break;
1090         }
1091         string input_name =
1092             strings::StrCat(input.node->name(), ":", input.port_id);
1093         if (skip_list->find(input_name) != skip_list->end()) {
1094           valid = false;
1095           break;
1096         }
1097         if (!IsSwappable(input)) {
1098           valid = false;
1099           break;
1100         }
1101 
1102         // Set earliest use time that's after peak.
1103         mem_info.uses_left.emplace_back(input);
1104         earliest_use = std::min(earliest_use, it->second);
1105       }
1106       if (valid && !mem_info.uses_left.empty()) {
1107         // Compute the fitness: we need the tensor to be generated way away of
1108         // the time of peak memory usage (to ensure there is enough time to swap
1109         // it out). We also need to ensure it's used way after the peak time, to
1110         // ensure that swapping the tensor back in won't recreate the memory
1111         // bottleneck. Last but not least, we want the tensor to have as few
1112         // remaining uses as possible.
1113         //
1114         // Note that we must perform the arithmetic inexactly as "double", since
1115         // the values do not fit into any integral type.
1116         mem_info.fitness =
1117             MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) /
1118                 MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
1119             MathUtil::IPow<double>((allocation_time - peak_time).count(), 2);
1120         mem_info.fitness = -mem_info.fitness;
1121         mem_state.push_back(mem_info);
1122       }
1123     }
1124 
1125     // Sort by fitness
1126     std::sort(mem_state.begin(), mem_state.end());
1127 
1128     for (const MemInfo& mem_info : mem_state) {
1129       for (const MutableGraphView::InputPort fanout_to_swap :
1130            mem_info.uses_left) {
1131         VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":"
1132                 << fanout_to_swap.port_id << " of tensor "
1133                 << mem_info.port.node->name() << ":" << mem_info.port.port_id
1134                 << " of size " << mem_info.memory_used;
1135 
1136         (*nodes_to_swap)[fanout_to_swap.node].inputs_to_swap.push_back(
1137             fanout_to_swap.port_id);
1138       }
1139       required_savings -= mem_info.memory_used;
1140       updated_graph = true;
1141       if (required_savings < 0) {
1142         break;
1143       }
1144     }
1145   }
1146   return updated_graph;
1147 }
1148 
SwappingPass(RewriterConfig::MemOptType optimization_level,Cluster * cluster,std::unique_ptr<GraphMemory> * memory,GrapplerItem * item,std::unordered_set<string> * skip_list)1149 bool SwappingPass(RewriterConfig::MemOptType optimization_level,
1150                   Cluster* cluster, std::unique_ptr<GraphMemory>* memory,
1151                   GrapplerItem* item, std::unordered_set<string>* skip_list) {
1152   std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
1153   if (optimization_level == RewriterConfig::DEFAULT_MEM_OPT ||
1154       optimization_level == RewriterConfig::SWAPPING_HEURISTICS ||
1155       optimization_level == RewriterConfig::HEURISTICS) {
1156     // Use heuristics to figure out what needs to be swapped;
1157     IdentifySwappingCandidates(cluster, item, memory, skip_list,
1158                                &nodes_to_swap);
1159   }
1160   // Look for manual annotations in the graph.
1161   for (auto& node : *item->graph.mutable_node()) {
1162     if (node.attr().count("_swap_to_host") != 0) {
1163       SwapInfo& swap_info = nodes_to_swap[&node];
1164       const AttrValue& val = node.attr().at("_swap_to_host");
1165       if (val.has_list()) {
1166         for (int64_t input_id : val.list().i()) {
1167           swap_info.inputs_to_swap.push_back(input_id);
1168         }
1169       } else {
1170         int64_t input_id = val.i();
1171         swap_info.inputs_to_swap.push_back(input_id);
1172       }
1173     }
1174   }
1175   if (nodes_to_swap.empty()) {
1176     // Nothing to do.
1177     return false;
1178   }
1179 
1180   // Estimate the size of the data to swap for each node.
1181   GraphProperties properties(*item);
1182   if (!properties
1183            .InferStatically(/*assume_valid_feeds=*/true,
1184                             /*aggressive_shape_inference=*/false,
1185                             /*include_tensor_values=*/false)
1186            .ok()) {
1187     return false;
1188   }
1189   for (auto& swap : nodes_to_swap) {
1190     const NodeDef* node = swap.first;
1191     const std::vector<OpInfo::TensorProperties>& props =
1192         properties.GetInputProperties(node->name());
1193     SwapInfo& swap_info = swap.second;
1194     int64_t bytes_to_swap = 0;
1195     for (int64_t input_id : swap_info.inputs_to_swap) {
1196       const OpInfo::TensorProperties& t = props[input_id];
1197       bytes_to_swap += CalculateTensorSize(t);
1198     }
1199     // Let's assume we're going to swap over PCIe running at 16 GBps.
1200     swap_info.time_to_swap = bytes_to_swap / 16;
1201   }
1202 
1203   std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
1204   if (!EstimateEarliestExecutionTimes(*item, cluster, &execution_times).ok()) {
1205     return false;
1206   }
1207 
1208   std::unordered_map<string, const NodeDef*> name_map;
1209   for (const auto& node : item->graph.node()) {
1210     name_map[node.name()] = &node;
1211   }
1212   MutableGraphView view(&item->graph);
1213 
1214   bool updated_graph = false;
1215 
1216   for (auto& swap : nodes_to_swap) {
1217     NodeDef* node = swap.first;
1218     const SwapInfo& swap_info = swap.second;
1219     if (skip_list->find(node->name()) != skip_list->end()) {
1220       continue;
1221     }
1222 
1223     // Make sure the tensor isn't swapped back in right away: look for node that
1224     // will execute just before we need to swap the data back, and add a control
1225     // dependency from that node to the swap node.
1226     const NodeDef* in_trigger =
1227         FindSwapInTrigger(node, swap_info, name_map, execution_times);
1228     // If we failed, don't attempt to reprocess this node in a subsequent pass.
1229     if (!in_trigger) {
1230       skip_list->insert(node->name());
1231       continue;
1232     }
1233 
1234     // Swap all the tensors that are marked with the 'swap_to_host' attribute.
1235     for (int input_id : swap_info.inputs_to_swap) {
1236       string input_name = strings::StrCat(node->name(), ":", input_id);
1237       if (skip_list->find(input_name) != skip_list->end()) {
1238         continue;
1239       } else {
1240         // Don't attempt to reprocess this input in a subsequent pass.
1241         skip_list->insert(input_name);
1242       }
1243 
1244       // Make sure the tensor is swapped out quickly: look for node that
1245       // will execute just after the tensor is generated and add a control
1246       // dependency from the swap out node to that node.
1247       NodeDef* out_trigger =
1248           FindSwapOutTrigger(node, input_id, view, execution_times);
1249       if (!out_trigger) {
1250         continue;
1251       }
1252 
1253       std::pair<NodeDef*, NodeDef*> swap_nodes;
1254       if (!BuildSwapPair(node, input_id, name_map, &item->graph, &swap_nodes)
1255                .ok()) {
1256         continue;
1257       }
1258       *swap_nodes.first->add_input() = node->input(input_id);
1259       *node->mutable_input(input_id) = swap_nodes.second->name();
1260 
1261       // Add the control dependencies needed to delay the execution of the swap.
1262       out_trigger->add_input(strings::StrCat("^", swap_nodes.first->name()));
1263       swap_nodes.second->add_input(strings::StrCat("^", in_trigger->name()));
1264 
1265       // Make sure we won't try to swap the swap nodes in subsequent passes.
1266       skip_list->insert(swap_nodes.first->name());
1267       skip_list->insert(swap_nodes.second->name());
1268     }
1269   }
1270   return updated_graph;
1271 }
1272 
CrossesTaskOrCpuGpuBoundary(const NodeDef & node1,const NodeDef & node2)1273 bool CrossesTaskOrCpuGpuBoundary(const NodeDef& node1, const NodeDef& node2) {
1274   string task1;
1275   string device1;
1276   DeviceNameUtils::SplitDeviceName(node1.device(), &task1, &device1);
1277   string task2;
1278   string device2;
1279   DeviceNameUtils::SplitDeviceName(node2.device(), &task2, &device2);
1280   return task1 != task2 ||
1281          (absl::StrContains(device1, DEVICE_CPU) &&
1282           absl::StrContains(device2, DEVICE_GPU)) ||
1283          (absl::StrContains(device1, DEVICE_GPU) &&
1284           absl::StrContains(device2, DEVICE_CPU));
1285 }
1286 
RelaxAssignNodes(const std::set<int> & nodes_to_relax,GraphDef * optimized_graph)1287 void RelaxAssignNodes(const std::set<int>& nodes_to_relax,
1288                       GraphDef* optimized_graph) {
1289   for (int idx : nodes_to_relax) {
1290     // Set an attribute telling AssignOp to ignore allocator constraints.
1291     NodeDef* assign_node = optimized_graph->mutable_node(idx);
1292     (*assign_node->mutable_attr())["_grappler_relax_allocator_constraints"]
1293         .set_b(true);
1294   }
1295 }
1296 
1297 // TODO(rmlarsen): Add distributed TF test.
FindAssignNodesToRelax(const GraphDef & graph,std::set<int> * nodes_to_relax)1298 Status FindAssignNodesToRelax(const GraphDef& graph,
1299                               std::set<int>* nodes_to_relax) {
1300   std::unordered_set<string> devices;
1301   std::vector<int> assign_nodes;
1302   bool found_send = false;
1303   for (int i = 0; i < graph.node_size(); ++i) {
1304     const NodeDef& node = graph.node(i);
1305     devices.insert(node.device());
1306     if (IsAssign(node)) {
1307       assign_nodes.push_back(i);
1308     }
1309     if (IsSend(node)) {
1310       found_send = true;
1311       break;
1312     }
1313   }
1314   if (!found_send && devices.size() == 1) {
1315     nodes_to_relax->insert(assign_nodes.begin(), assign_nodes.end());
1316     return OkStatus();
1317   }
1318 
1319   GraphTopologyView graph_view;
1320   TF_RETURN_IF_ERROR(
1321       graph_view.InitializeFromGraph(graph, /*ignore_control_edges=*/true));
1322   std::unordered_set<const NodeDef*> optimized_nodes;
1323 
1324   for (int i : assign_nodes) {
1325     const NodeDef& assign_node = graph.node(i);
1326 
1327     if (optimized_nodes.find(&assign_node) == optimized_nodes.end()) {
1328       std::vector<const NodeDef*> assign_nodes_in_fanout;
1329       optimized_nodes.insert(&assign_node);
1330       assign_nodes_in_fanout.push_back(&assign_node);
1331 
1332       std::vector<const NodeDef*> transitive_fanout;
1333       // Find the nodes in transitive fanout. If a node is known to never
1334       // forward its inputs, we can skip its fanout.
1335       DfsTraversal(graph_view, {graph_view.GetNode(i)},
1336                    TraversalDirection::kFollowOutputs,
1337                    DfsPredicates::Advance([&](const NodeDef* node) {
1338                      return !NeverForwardsInputs(*node);
1339                    }),
1340                    DfsCallbacks::PreOrder([&](const NodeDef* node) {
1341                      transitive_fanout.push_back(node);
1342                    }));
1343 
1344       bool relax_constraint = true;
1345       // If all nodes in the transitive fanout are on the same device as the
1346       // assign node, there is no need to allocate the output in pinned memory.
1347       for (const NodeDef* fanout_node : transitive_fanout) {
1348         if (relax_constraint &&
1349             (IsSend(*fanout_node) ||
1350              CrossesTaskOrCpuGpuBoundary(*fanout_node, assign_node))) {
1351           relax_constraint = false;
1352           break;
1353         }
1354         if (optimized_nodes.find(fanout_node) == optimized_nodes.end() &&
1355             IsAssign(*fanout_node)) {
1356           assign_nodes_in_fanout.push_back(fanout_node);
1357         }
1358       }
1359 
1360       if (relax_constraint) {
1361         for (const NodeDef* assign_node_in_fanout : assign_nodes_in_fanout) {
1362           // If all devices match in fanout of node(i) then, by transitivity,
1363           // they must also match in the fanout of other assign nodes
1364           // in the fanout of node(i), so we can process them here,
1365           // and save computing their transitive fanout later.
1366           optimized_nodes.insert(assign_node_in_fanout);
1367 
1368           // Set an attribute telling AssignOp to ignore allocator constraints.
1369           const absl::optional<int> assign_node_idx =
1370               graph_view.GetNodeIndex(*assign_node_in_fanout);
1371           nodes_to_relax->insert(assign_node_idx.value());
1372         }
1373       }
1374     }
1375   }
1376   return OkStatus();
1377 }
1378 
1379 }  // namespace
1380 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)1381 Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
1382                                  GraphDef* optimized_graph) {
1383   std::set<int> nodes_to_relax;
1384   TF_RETURN_IF_ERROR(FindAssignNodesToRelax(item.graph, &nodes_to_relax));
1385 
1386   bool run_recomputation_pass =
1387       (optimization_level_ == RewriterConfig::RECOMPUTATION_HEURISTICS ||
1388        optimization_level_ == RewriterConfig::HEURISTICS ||
1389        optimization_level_ == RewriterConfig::MANUAL);
1390   if (!run_recomputation_pass && nodes_to_relax.empty() && item.fetch.empty()) {
1391     return errors::Aborted("Nothing to do.");
1392   }
1393 
1394   GrapplerItem optimized_item(item);
1395   RelaxAssignNodes(nodes_to_relax, &optimized_item.graph);
1396 
1397   if (run_recomputation_pass) {
1398     RecomputationRewritingPass(optimization_level_,
1399                                recomputation_targets_name_scope_,
1400                                &optimized_item.graph, item);
1401   }
1402 
1403   std::unordered_set<string> skip_list;
1404   // Bound the number of rewrite passes to avoid long processing times on graphs
1405   // that simply won't fit in memory.
1406   // SchedulingPass() and SwappingPass() rely on defined fetches in order to
1407   // infer the memory usage, so skip optimization if there are no fetches.
1408   std::unique_ptr<GraphMemory> memory;
1409   if (!item.fetch.empty() && cluster != nullptr) {
1410     bool updated_graph = true;
1411     for (int i = 0; i < 25 && updated_graph; ++i) {
1412       GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
1413       updated_graph = false;
1414       if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
1415            optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
1416            optimization_level_ == RewriterConfig::HEURISTICS) &&
1417           cluster != nullptr) {
1418         if (SchedulingPass(cluster, &memory, &optimized_item)) {
1419           // Reset the inferred memory usage since the graph changed.
1420           memory.reset();
1421           updated_graph = true;
1422         }
1423       }
1424 
1425       GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
1426       if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
1427            optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS ||
1428            optimization_level_ == RewriterConfig::HEURISTICS ||
1429            optimization_level_ == RewriterConfig::MANUAL) &&
1430           cluster != nullptr) {
1431         if (SwappingPass(optimization_level_, cluster, &memory, &optimized_item,
1432                          &skip_list)) {
1433           // Reset the inferred memory usage since the graph changed.
1434           memory.reset();
1435           updated_graph = true;
1436         }
1437       }
1438     }
1439   }
1440 
1441   optimized_graph->Swap(&optimized_item.graph);
1442   return OkStatus();
1443 }
1444 
1445 }  // end namespace grappler
1446 }  // end namespace tensorflow
1447