1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
17
18 #include <algorithm>
19 #include <queue>
20 #include <set>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
31 #include "tensorflow/core/grappler/costs/graph_memory.h"
32 #include "tensorflow/core/grappler/costs/graph_properties.h"
33 #include "tensorflow/core/grappler/costs/utils.h"
34 #include "tensorflow/core/grappler/graph_topology_view.h"
35 #include "tensorflow/core/grappler/grappler_item.h"
36 #include "tensorflow/core/grappler/mutable_graph_view.h"
37 #include "tensorflow/core/grappler/op_types.h"
38 #include "tensorflow/core/grappler/optimizers/static_schedule.h"
39 #include "tensorflow/core/grappler/utils.h"
40 #include "tensorflow/core/grappler/utils/topological_sort.h"
41 #include "tensorflow/core/grappler/utils/traversal.h"
42 #include "tensorflow/core/lib/math/math_util.h"
43 #include "tensorflow/core/lib/strings/str_util.h"
44 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
45 #include "tensorflow/core/util/device_name_utils.h"
46
47 namespace tensorflow {
48 namespace grappler {
49
50 namespace {
51
52 // Prefix added to nodes which are recomputed.
53 const char* kRecomputedNodePrefix = "Recomputed";
54 const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger";
55 // Attribute which may be added to nodes to manually allow them to be
56 // recomputed.
57 const char* kRecomputeHint = "_recompute_hint";
58
59 // Ops which we wouldn't mind recomputing to save memory.
60 // TODO(allenl): Replace this list with a cost model.
GetCheapToRecomputeOps()61 std::unordered_set<string> GetCheapToRecomputeOps() {
62 std::unordered_set<string> cheap_ops = {"Add",
63 "AddN",
64 "BiasAdd",
65 "Cast",
66 "Fill",
67 "FloorDiv",
68 "FloorMod",
69 "FusedBatchNorm",
70 "LeakyRelu",
71 "Mul",
72 "Neg",
73 "RealDiv",
74 "Reciprocal",
75 "Relu",
76 "Relu6",
77 "Reshape",
78 "Rsqrt",
79 "Sigmoid",
80 "Sqrt",
81 "Square",
82 "SquaredDifference",
83 "Sub",
84 "Tile",
85 "Transpose"};
86 return cheap_ops;
87 }
88
89 // Find recomputable ops which feed into target nodes.
FindCandidateRecomputeNodes(const NodeMap & node_map,const GraphDef * graph,const std::function<bool (const NodeDef &)> & is_candidate,const std::function<bool (const NodeDef &)> & is_target)90 std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
91 const NodeMap& node_map, const GraphDef* graph,
92 const std::function<bool(const NodeDef&)>& is_candidate,
93 const std::function<bool(const NodeDef&)>& is_target) {
94 std::unordered_set<const NodeDef*> candidate_recompute_nodes;
95 for (const auto& node : graph->node()) {
96 if (!is_candidate(node)) {
97 continue;
98 }
99 bool has_target_output = false;
100 for (const NodeDef* output : node_map.GetOutputs(node.name())) {
101 // It only makes sense to recompute this if it feeds into a target
102 // node. We expand this to dependencies in GetOpGroupsToRecompute.
103 if (is_target(*output)) {
104 has_target_output = true;
105 break;
106 }
107 }
108 if (!has_target_output) {
109 continue;
110 }
111 bool has_target_input = false;
112 for (const string& input_name : node.input()) {
113 // Don't recompute nodes which depend on target nodes.
114 const NodeDef* input_node = node_map.GetNode(input_name);
115 if (is_target(*input_node)) {
116 has_target_input = true;
117 break;
118 }
119 }
120 if (has_target_input) {
121 continue;
122 }
123 candidate_recompute_nodes.insert(&node);
124 }
125 return candidate_recompute_nodes;
126 }
127
connected_subgraph(const NodeMap & node_map,bool collect_inputs,bool collect_outputs,const std::function<bool (const NodeDef &)> & is_candidate,std::unordered_set<const NodeDef * > * expanded_nodes)128 void connected_subgraph(const NodeMap& node_map, bool collect_inputs,
129 bool collect_outputs,
130 const std::function<bool(const NodeDef&)>& is_candidate,
131 std::unordered_set<const NodeDef*>* expanded_nodes) {
132 std::queue<const NodeDef*> to_visit;
133 for (const NodeDef* starting_node : *expanded_nodes) {
134 to_visit.push(starting_node);
135 }
136 expanded_nodes->clear();
137 while (!to_visit.empty()) {
138 const NodeDef* current_node = to_visit.front();
139 to_visit.pop();
140 if (!expanded_nodes->insert(current_node).second) {
141 // We already visited this node
142 continue;
143 }
144 if (collect_inputs) {
145 // Add inputs and outputs to this subgraph if they are candidates
146 for (const string& input_name_raw : current_node->input()) {
147 const NodeDef* input_node = node_map.GetNode(input_name_raw);
148 if (expanded_nodes->count(input_node) == 0 &&
149 is_candidate(*input_node)) {
150 to_visit.push(input_node);
151 }
152 }
153 }
154 if (collect_outputs) {
155 for (const NodeDef* output : node_map.GetOutputs(current_node->name())) {
156 if (expanded_nodes->count(output) == 0 && is_candidate(*output)) {
157 to_visit.push(output);
158 }
159 }
160 }
161 }
162 }
163
164 struct RecomputedSubGraph {
165 std::unordered_set<const NodeDef*> recomputed_source_nodes;
166 std::unordered_set<NodeDef*> target_nodes;
167 };
168
169 // Find groups of ops to recompute together based on `should_recompute`.
GetOpGroupsToRecompute(const GraphDef * graph,const NodeMap & node_map,const std::function<bool (const NodeDef &)> & should_recompute,const std::function<bool (const NodeDef &)> & is_target)170 std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
171 const GraphDef* graph, const NodeMap& node_map,
172 const std::function<bool(const NodeDef&)>& should_recompute,
173 const std::function<bool(const NodeDef&)>& is_target) {
174 std::unordered_set<const NodeDef*> visited_nodes;
175 std::vector<RecomputedSubGraph> subgraphs_to_recompute;
176 std::unordered_set<const NodeDef*> candidate_recompute_nodes =
177 FindCandidateRecomputeNodes(node_map, graph, should_recompute, is_target);
178 for (const NodeDef* recompute_node : candidate_recompute_nodes) {
179 if (visited_nodes.count(recompute_node) > 0) {
180 continue;
181 }
182 RecomputedSubGraph current_recomputation;
183 // Build out recomputation groups by expanding to inexpensive-to-recompute
184 // nodes which do not feed target nodes. The goal is to capture some
185 // intermediate activations within this graph.
186 std::unordered_set<const NodeDef*> unpruned_recompute_nodes;
187 unpruned_recompute_nodes.insert(recompute_node);
188 connected_subgraph(node_map,
189 true, // Collect inputs
190 true, // Collect outputs
191 should_recompute, &unpruned_recompute_nodes);
192 visited_nodes.insert(unpruned_recompute_nodes.begin(),
193 unpruned_recompute_nodes.end());
194 for (const NodeDef* unpruned_recompute_node : unpruned_recompute_nodes) {
195 bool inserted_feed = false;
196 for (NodeDef* output :
197 node_map.GetOutputs(unpruned_recompute_node->name())) {
198 if (is_target(*output)) {
199 current_recomputation.target_nodes.insert(output);
200 if (!inserted_feed) {
201 // Keep track of nodes which feed directly into a target node. These
202 // and nodes which feed into them will define the recomputed
203 // subgraph.
204 current_recomputation.recomputed_source_nodes.insert(
205 unpruned_recompute_node);
206 inserted_feed = true;
207 }
208 }
209 }
210 }
211 // Recompute only nodes which eventually feed into a target node.
212 connected_subgraph(
213 node_map,
214 true, // Collect inputs
215 false, // Collect outputs
216 [&unpruned_recompute_nodes](const NodeDef& node) {
217 return unpruned_recompute_nodes.count(&node) != 0;
218 },
219 ¤t_recomputation.recomputed_source_nodes);
220 if (current_recomputation.target_nodes.empty()) {
221 continue;
222 }
223 subgraphs_to_recompute.push_back(current_recomputation);
224 }
225 return subgraphs_to_recompute;
226 }
227
228 // Computes the maximum topological numbers of (1) target node components
229 // (gradient nodes being fed by the recomputation), and (2) child recompute node
230 // components for each recomputed node. We will not attach any control
231 // dependencies to a recomputation unless they have component numbers greater
232 // than this value (to prevent cycles).
GetMaxDownstreamComponents(const std::unordered_set<const NodeDef * > & recomputed_source_nodes,const std::unordered_set<NodeDef * > & target_nodes,const NodeMap & node_map,const std::unordered_map<const NodeDef *,int> & components)233 std::unordered_map<const NodeDef*, int> GetMaxDownstreamComponents(
234 const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
235 const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
236 const std::unordered_map<const NodeDef*, int>& components) {
237 std::unordered_map<const NodeDef*, int> recomputed_node_components;
238 // Start by setting component numbers to the maximum among target nodes.
239 for (const NodeDef* original_recompute_node : recomputed_source_nodes) {
240 int max_target_component = -1;
241 for (NodeDef* output :
242 node_map.GetOutputs(original_recompute_node->name())) {
243 if (target_nodes.count(output) != 0) {
244 int current_target_component = components.find(output)->second;
245 if (current_target_component > max_target_component) {
246 max_target_component = current_target_component;
247 }
248 }
249 }
250 if (max_target_component > -1) {
251 recomputed_node_components[original_recompute_node] =
252 max_target_component;
253 }
254 }
255 // Sort recomputed nodes topologically (based on the original graph) so we can
256 // efficiently assign to each node the maximum of its recomputed child
257 // components and its own targets.
258 std::vector<const NodeDef*> recomputed_source_nodes_topological(
259 recomputed_source_nodes.begin(), recomputed_source_nodes.end());
260 std::sort(recomputed_source_nodes_topological.begin(),
261 recomputed_source_nodes_topological.end(),
262 [&components](const NodeDef* first, const NodeDef* second) {
263 return components.find(first)->second <
264 components.find(second)->second;
265 });
266 for (const NodeDef* original_recompute_node :
267 recomputed_source_nodes_topological) {
268 int max_component;
269 auto recomputed_component_iterator =
270 recomputed_node_components.find(original_recompute_node);
271 if (recomputed_component_iterator != recomputed_node_components.end()) {
272 max_component = recomputed_component_iterator->second;
273 } else {
274 max_component = -1;
275 }
276 for (NodeDef* output :
277 node_map.GetOutputs(original_recompute_node->name())) {
278 if (recomputed_source_nodes.count(output) == 0) {
279 continue;
280 }
281 auto child_component_iterator = recomputed_node_components.find(output);
282 CHECK(child_component_iterator != recomputed_node_components.end());
283 int child_component = child_component_iterator->second;
284 if (child_component > max_component) {
285 max_component = child_component;
286 }
287 }
288 CHECK_GE(max_component, 0);
289 recomputed_node_components[original_recompute_node] = max_component;
290 }
291 return recomputed_node_components;
292 }
293
294 // Modifies `graph`, adding trigger nodes and returning a mapping from
295 // `recomputed_source_nodes` to trigger nodes which will not create loops in the
296 // graph (using the component numberings in `components` and
297 // `recomputed_node_max_feed_components`). The copied nodes (not the nodes in
298 // recomputed_source_nodes, which are the originals) eventually get these
299 // control dependencies.
300 std::unordered_map<const NodeDef*, const NodeDef*>
AddRecomputeControlDependencyNodes(const std::unordered_set<const NodeDef * > & recomputed_source_nodes,const std::unordered_set<NodeDef * > & target_nodes,const NodeMap & node_map,const std::unordered_map<const NodeDef *,int> & components,const std::unordered_map<const NodeDef *,int> & recomputed_node_max_feed_components,GraphDef * graph)301 AddRecomputeControlDependencyNodes(
302 const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
303 const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
304 const std::unordered_map<const NodeDef*, int>& components,
305 const std::unordered_map<const NodeDef*, int>&
306 recomputed_node_max_feed_components,
307 GraphDef* graph) {
308 // Sort recomputed nodes based on max downstream components.
309 std::vector<const NodeDef*> recomputed_source_nodes_topological(
310 recomputed_source_nodes.begin(), recomputed_source_nodes.end());
311 std::sort(recomputed_source_nodes_topological.begin(),
312 recomputed_source_nodes_topological.end(),
313 [&recomputed_node_max_feed_components](const NodeDef* first,
314 const NodeDef* second) {
315 int first_component =
316 recomputed_node_max_feed_components.find(first)->second;
317 int second_component =
318 recomputed_node_max_feed_components.find(second)->second;
319 return first_component > second_component
320 // Ensure a consistent ordering. This is necessary because
321 // we're working not with node component numbers (which are
322 // unique) but with the maximum across nodes they feed into
323 // (very much not unique).
324 || (first_component == second_component &&
325 first->name() > second->name());
326 });
327 // Create merged control dependency nodes by sorting target inputs
328 // topologically and zipper merging with the sorted recomputed nodes.
329 std::vector<const NodeDef*> target_inputs_topological;
330 for (const NodeDef* target_node : target_nodes) {
331 for (const string& target_input_name_raw : target_node->input()) {
332 const NodeDef* target_input = node_map.GetNode(target_input_name_raw);
333 // If this node has already had one of its inputs recomputed during this
334 // rewriting pass, we ignore that recomputed node here (it will not be in
335 // the NodeMap).
336 if (target_input == nullptr ||
337 recomputed_source_nodes.count(target_input) != 0 ||
338 components.find(target_node)->second ==
339 components.find(target_input)->second) {
340 continue;
341 }
342 target_inputs_topological.push_back(target_input);
343 }
344 }
345 std::sort(target_inputs_topological.begin(), target_inputs_topological.end(),
346 [&components](const NodeDef* first, const NodeDef* second) {
347 return components.find(first)->second >
348 components.find(second)->second;
349 });
350 auto target_input_iterator = target_inputs_topological.begin();
351 NodeDef* current_trigger_node = nullptr;
352 std::unordered_map<const NodeDef*, const NodeDef*> triggers;
353 for (const NodeDef* original_recomputed_node :
354 recomputed_source_nodes_topological) {
355 NodeDef* new_trigger_node = graph->add_node();
356 new_trigger_node->set_name(AddPrefixToNodeName(
357 original_recomputed_node->name(), kRecomputeTriggerNodePrefix));
358 new_trigger_node->set_op("NoOp");
359 new_trigger_node->set_device(original_recomputed_node->device());
360 if (current_trigger_node != nullptr) {
361 *new_trigger_node->add_input() =
362 strings::StrCat("^", current_trigger_node->name());
363 }
364 current_trigger_node = new_trigger_node;
365 triggers[original_recomputed_node] = current_trigger_node;
366 for (;
367 target_input_iterator != target_inputs_topological.end() &&
368 components.find(*target_input_iterator)->second >
369 recomputed_node_max_feed_components.find(original_recomputed_node)
370 ->second;
371 ++target_input_iterator) {
372 *current_trigger_node->add_input() =
373 strings::StrCat("^", (*target_input_iterator)->name());
374 VLOG(2) << " Recomputation trigger " << current_trigger_node->name()
375 << " depends on " << (*target_input_iterator)->name();
376 }
377 }
378 return triggers;
379 }
380
RecomputedOrOriginalNodeName(const std::unordered_set<string> & recomputed_node_names,const string & original_node_name)381 string RecomputedOrOriginalNodeName(
382 const std::unordered_set<string>& recomputed_node_names,
383 const string& original_node_name) {
384 if (recomputed_node_names.find(original_node_name) ==
385 recomputed_node_names.end()) {
386 return original_node_name;
387 } else {
388 return AddPrefixToNodeName(original_node_name, kRecomputedNodePrefix);
389 }
390 }
391
392 // Helper function to recompute a sub-graph (recomputed_source_nodes). Edges
393 // from recomputed_source_nodes to target_nodes are changed to start from the
394 // recomputed nodes.
RecomputeSubgraph(const std::unordered_set<const NodeDef * > & recomputed_source_nodes,const std::unordered_set<NodeDef * > & target_nodes,const NodeMap & node_map,const std::unordered_map<const NodeDef *,int> & components,GraphDef * graph)395 void RecomputeSubgraph(
396 const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
397 const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
398 const std::unordered_map<const NodeDef*, int>& components,
399 GraphDef* graph) {
400 std::unordered_set<string> recomputed_node_names;
401 VLOG(1) << "Recomputing a " << recomputed_source_nodes.size()
402 << " node subgraph";
403 std::unordered_map<const NodeDef*, int> recomputed_node_components =
404 GetMaxDownstreamComponents(recomputed_source_nodes, target_nodes,
405 node_map, components);
406 for (const NodeDef* original_node : recomputed_source_nodes) {
407 VLOG(2) << " " << original_node->name();
408 recomputed_node_names.insert(original_node->name());
409 }
410 std::unordered_map<const NodeDef*, const NodeDef*> triggers =
411 AddRecomputeControlDependencyNodes(recomputed_source_nodes, target_nodes,
412 node_map, components,
413 recomputed_node_components, graph);
414 // Create the recomputed sub-graph
415 for (const NodeDef* original_node : recomputed_source_nodes) {
416 NodeDef* copied_node = graph->add_node();
417 copied_node->set_name(
418 AddPrefixToNodeName(original_node->name(), kRecomputedNodePrefix));
419 copied_node->set_op(original_node->op());
420 *copied_node->mutable_attr() = original_node->attr();
421 copied_node->set_device(original_node->device());
422 for (const string& original_input_name : original_node->input()) {
423 // Set inputs which are internal to the copied subgraph to their copied
424 // versions.
425 *copied_node->add_input() = RecomputedOrOriginalNodeName(
426 recomputed_node_names, original_input_name);
427 }
428 // Each recomputed node gets a control dependency to prevent it from being
429 // recomputed immediately.
430 *copied_node->add_input() =
431 strings::StrCat("^", triggers[original_node]->name());
432 }
433 // Set the inputs of nodes in the target subgraph to the recomputed nodes
434 // where applicable.
435 for (NodeDef* target_node : target_nodes) {
436 for (string& target_input_name : *target_node->mutable_input()) {
437 target_input_name = RecomputedOrOriginalNodeName(recomputed_node_names,
438 target_input_name);
439 }
440 }
441 }
442
RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,const string & recomputation_targets_name_scope,GraphDef * graph,const GrapplerItem & item)443 void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
444 const string& recomputation_targets_name_scope,
445 GraphDef* graph, const GrapplerItem& item) {
446 // The topological numberings and NodeMap will be stale as soon as we start
447 // modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
448 // looks up nodes which were in the original graph, and preserves the graph
449 // topology it's interested in.
450 // We don't use the results of this topological sort until later, but this
451 // call invalidates all NodeDef pointers, so it needs to be done before we
452 // start collecting those.
453 TF_CHECK_OK(TopologicalSort(graph));
454 NodeMap node_map(graph);
455 std::vector<RecomputedSubGraph> recomputed_subgraphs;
456 // Do not recompute nodes which are fed, since the recomputed node would not
457 // take on the fed value (i.e. gradients would be incorrect).
458 std::unordered_set<string> feeds;
459 for (const auto& feed : item.feed) {
460 feeds.insert(NodeName(feed.first));
461 }
462 std::function<bool(const NodeDef&)> is_target =
463 [&recomputation_targets_name_scope](const NodeDef& node) {
464 // Nodes whose inputs we may want to recompute. This matches node names
465 // that contain recomputation_targets_name_scope as a name scope,
466 // meaning it either begins with or contains the name scope.
467 // Defaults to "gradients/" which will match any node names that begins
468 // with "gradients/" or contains "/gradients/".
469 return absl::StartsWith(node.name(),
470 recomputation_targets_name_scope) ||
471 static_cast<int>(node.name().find(
472 "/" + recomputation_targets_name_scope)) != -1;
473 };
474
475 if (optimization_level == RewriterConfig::RECOMPUTATION_HEURISTICS ||
476 optimization_level == RewriterConfig::HEURISTICS) {
477 // TODO(allenl): Handle ResNet-like architectures better. Right now all of
478 // the cheap forward ops get grouped into a single subgraph which must
479 // execute before gradients start executing (unless layers are manually
480 // separated by identity ops).
481 std::unordered_set<string> cheap_to_recompute_ops =
482 GetCheapToRecomputeOps();
483 recomputed_subgraphs = GetOpGroupsToRecompute(
484 graph, node_map,
485 [&cheap_to_recompute_ops, &feeds, &is_target](const NodeDef& node) {
486 return !is_target(node) && feeds.count(node.name()) == 0 &&
487 (cheap_to_recompute_ops.count(node.op()) > 0 ||
488 node.attr().count(kRecomputeHint) > 0);
489 },
490 is_target);
491 } else if (optimization_level == RewriterConfig::MANUAL) {
492 recomputed_subgraphs = GetOpGroupsToRecompute(
493 graph, node_map,
494 [&feeds, &is_target](const NodeDef& node) {
495 return !is_target(node) && feeds.count(node.name()) == 0 &&
496 node.attr().count(kRecomputeHint) > 0;
497 },
498 is_target);
499 }
500 if (!recomputed_subgraphs.empty()) {
501 std::unordered_map<const NodeDef*, int> topological_numbering;
502 for (int node_number = 0; node_number < graph->node().size();
503 ++node_number) {
504 topological_numbering[graph->mutable_node(node_number)] =
505 graph->node().size() - node_number - 1;
506 }
507 // Duplicate the indicated sub-graphs and set up control dependencies
508 for (const RecomputedSubGraph& subgraph : recomputed_subgraphs) {
509 RecomputeSubgraph(subgraph.recomputed_source_nodes, subgraph.target_nodes,
510 node_map, topological_numbering, graph);
511 }
512 }
513 }
514
SchedulingPass(Cluster * cluster,std::unique_ptr<GraphMemory> * memory_ptr,GrapplerItem * item)515 bool SchedulingPass(Cluster* cluster, std::unique_ptr<GraphMemory>* memory_ptr,
516 GrapplerItem* item) {
517 // Look for AddN nodes (and equivalent) and record input names.
518 MutableGraphView view(&item->graph);
519
520 std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
521 for (NodeDef& node : *item->graph.mutable_node()) {
522 if (!IsAddN(node) && node.op() != "AccumulateNV2") {
523 continue;
524 }
525 // There is nothing to gain by optimizing nodes with 2 or fewer inputs.
526 if (view.NumFanins(node, false) <= 2) {
527 continue;
528 }
529 for (const auto& input : view.GetFanins(node, false)) {
530 if (input.node->device() == node.device()) {
531 string tensor_name =
532 strings::StrCat(input.node->name(), ":", input.port_id);
533 addn_list[tensor_name].insert(&node);
534 }
535 }
536 }
537
538 if (addn_list.empty()) {
539 return false;
540 }
541
542 if ((*memory_ptr) == nullptr) {
543 memory_ptr->reset(new GraphMemory(*item));
544 Status s = (*memory_ptr)->InferStatically(cluster->GetDevices());
545 if (!s.ok()) {
546 memory_ptr->reset();
547 VLOG(1) << "Failed to infer memory usage: " << s.error_message();
548 return false;
549 }
550 }
551 const GraphMemory& memory = **memory_ptr;
552
553 std::unordered_set<NodeDef*> addn_to_rewrite;
554 for (const auto& device : cluster->GetDevices()) {
555 const string& name = device.first;
556 const DeviceProperties& prop = device.second;
557 if (prop.memory_size() <= 0) {
558 VLOG(1) << "Available memory unknown for device " << name;
559 continue;
560 }
561 const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
562
563 if (mem_usage.used_memory <= prop.memory_size() * 0.8) {
564 continue;
565 }
566
567 for (const auto& live : mem_usage.live_tensors) {
568 string tensor_name = strings::StrCat(live.node, ":", live.output_id);
569 auto it = addn_list.find(tensor_name);
570 if (it != addn_list.end()) {
571 addn_to_rewrite.insert(it->second.begin(), it->second.end());
572 }
573 }
574 }
575
576 if (addn_to_rewrite.empty()) {
577 return false;
578 }
579 GraphProperties properties(*item);
580 Status s = properties.InferStatically(/*assume_valid_feeds=*/false,
581 /*aggressive_shape_inference=*/false,
582 /*include_tensor_values=*/false);
583 if (!s.ok()) {
584 VLOG(1) << "Failed to infer shapes: " << s.error_message();
585 return false;
586 }
587
588 // It's ok to use immutable GraphTopologyView here, because we do not destroy
589 // any of the nodes in the underlying graph, we only add new nodes.
590 GraphTopologyView graph_topology;
591 Status initialized_topology = graph_topology.InitializeFromGraph(item->graph);
592 if (!initialized_topology.ok()) {
593 VLOG(1) << "Failed to initialize graph topology view: "
594 << initialized_topology.error_message();
595 return false;
596 }
597
598 bool updated_graph = false;
599 // Rewrite the AddN.
600 for (NodeDef* node : addn_to_rewrite) {
601 if (!properties.HasOutputProperties(node->name())) {
602 VLOG(1) << "Missing properties for " << node->name();
603 continue;
604 }
605 const TensorShapeProto& shape =
606 properties.GetOutputProperties(node->name())[0].shape();
607 PartialTensorShape shp(shape);
608 if (!shp.IsFullyDefined()) {
609 VLOG(1) << "Shape not fully known for " << node->name();
610 continue;
611 }
612 DataType dtype = node->attr().at("T").type();
613 if (dtype != DT_HALF && dtype != DT_FLOAT && dtype != DT_DOUBLE &&
614 dtype != DT_INT64) { // Only GPU-supported TemporaryVariable types.
615 VLOG(1) << "Unsupported dtype for " << node->name();
616 continue;
617 }
618
619 // Compute a topological ordering for the node fanin.
620 std::unordered_map<const NodeDef*, int> topo_order;
621 DfsTraversal(graph_topology, {node}, TraversalDirection::kFollowInputs,
622 DfsCallbacks::PostOrder([&topo_order](const NodeDef* n) {
623 int topo_index = static_cast<int>(topo_order.size());
624 topo_order[n] = topo_index;
625 }));
626
627 std::vector<int> input_topo_index;
628
629 for (int i = 0; i < node->input_size(); ++i) {
630 const string& input = node->input(i);
631 const string node_name = NodeName(input);
632 const NodeDef* node = view.GetNode(node_name);
633 input_topo_index.push_back(topo_order.at(node));
634 }
635 int min_input_topo_index = INT_MAX;
636 int min_input_id = -1;
637 for (int i = 0; i < node->input_size(); ++i) {
638 if (IsControlInput(node->input(i))) {
639 // control inputs are always last.
640 break;
641 }
642 const int current = input_topo_index[i];
643 if (current < min_input_topo_index) {
644 min_input_topo_index = current;
645 min_input_id = i;
646 }
647 }
648 CHECK_LE(0, min_input_id);
649 std::vector<string> pre_ctrl_deps;
650 std::vector<string> post_ctrl_deps;
651 for (int i = node->input_size() - 1; i >= 0; --i) {
652 if (!IsControlInput(node->input(i))) {
653 // control inputs are always last.
654 break;
655 }
656 if (input_topo_index[i] < min_input_topo_index) {
657 // These control dependencies can be executed before the node.
658 pre_ctrl_deps.push_back(node->input(i));
659 } else {
660 // These control dependencies should be executed after the node.
661 post_ctrl_deps.push_back(node->input(i));
662 }
663 }
664
665 const string& device = node->device();
666 const string tmp_var_name = strings::StrCat(node->name(), "/tmp_var");
667 if (view.GetNode(tmp_var_name) != nullptr) {
668 VLOG(1) << "Temporary variable already exists " << tmp_var_name;
669 return false;
670 }
671
672 // Create the temporary variable that will hold intermediate results
673 NodeDef* tmp_var = item->graph.add_node();
674 tmp_var->set_name(tmp_var_name);
675 tmp_var->set_op("TemporaryVariable");
676 tmp_var->set_device(device);
677 (*tmp_var->mutable_attr())["dtype"].set_type(dtype);
678 *(*tmp_var->mutable_attr())["shape"].mutable_shape() = shape;
679 (*tmp_var->mutable_attr())["var_name"].set_s(tmp_var->name());
680
681 for (const string& ctrl_dep : pre_ctrl_deps) {
682 *tmp_var->add_input() = ctrl_dep;
683 }
684 *tmp_var->add_input() =
685 AsControlDependency(NodeName(node->input(min_input_id)));
686
687 // Initialize it to zero
688 NodeDef* zeros = item->graph.add_node();
689 zeros->set_name(strings::StrCat(node->name(), "/tmp_var_zeros"));
690 zeros->set_op("ZerosLike");
691 zeros->set_device(device);
692 (*zeros->mutable_attr())["T"].set_type(dtype);
693 *zeros->add_input() = node->input(min_input_id);
694
695 NodeDef* initialize = item->graph.add_node();
696 initialize->set_name(strings::StrCat(node->name(), "/tmp_var_initializer"));
697 initialize->set_op("Assign");
698 initialize->set_device(device);
699 (*initialize->mutable_attr())["T"].set_type(dtype);
700 (*initialize->mutable_attr())["use_locking"].set_b(false);
701 (*initialize->mutable_attr())["validate_shape"].set_b(false);
702 *initialize->add_input() = tmp_var->name();
703 *initialize->add_input() = zeros->name();
704
705 // Add the assignadd nodes
706 std::vector<NodeDef*> accumulates;
707 for (int i = 0; i < node->input_size(); ++i) {
708 const string& input = node->input(i);
709 if (!IsControlInput(input)) {
710 NodeDef* accumulate = item->graph.add_node();
711 accumulate->set_name(
712 strings::StrCat(node->name(), "/tmp_var_accum_", i));
713 accumulate->set_op("AssignAdd");
714 accumulate->set_device(device);
715 (*accumulate->mutable_attr())["T"].set_type(dtype);
716 (*accumulate->mutable_attr())["use_locking"].set_b(true);
717 *accumulate->add_input() = initialize->name();
718 *accumulate->add_input() = input;
719 accumulates.push_back(accumulate);
720 }
721 }
722
723 // Rewrite the AddN node as a DestroyTemporaryVariable ops
724 node->set_op("DestroyTemporaryVariable");
725 node->clear_input();
726 EraseRegularNodeAttributes(node);
727 (*node->mutable_attr())["T"].set_type(dtype);
728 (*node->mutable_attr())["var_name"].set_s(tmp_var->name());
729 *node->add_input() = initialize->name();
730 for (const NodeDef* accum : accumulates) {
731 *node->add_input() = AsControlDependency(accum->name());
732 }
733 for (const string& ctrl_dep : post_ctrl_deps) {
734 *node->add_input() = ctrl_dep;
735 }
736
737 updated_graph = true;
738 }
739
740 return updated_graph;
741 }
742
BuildSwapPair(NodeDef * node,int input_to_swap,const std::unordered_map<string,const NodeDef * > & name_map,GraphDef * graph,std::pair<NodeDef *,NodeDef * > * swap_pair)743 Status BuildSwapPair(NodeDef* node, int input_to_swap,
744 const std::unordered_map<string, const NodeDef*>& name_map,
745 GraphDef* graph,
746 std::pair<NodeDef*, NodeDef*>* swap_pair) {
747 string task, device;
748 if (!DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) ||
749 !absl::StrContains(device, DEVICE_GPU)) {
750 return errors::InvalidArgument("Can't swap input ", input_to_swap,
751 " of node ", node->name(),
752 " since it is not on GPU");
753 }
754 const OpDef* op_def;
755 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
756 DataType input_type;
757 TF_RETURN_IF_ERROR(
758 InputTypeForNode(*node, *op_def, input_to_swap, &input_type));
759 if (IsRefType(input_type)) {
760 return errors::InvalidArgument("Can't swap input ", input_to_swap,
761 " of node ", node->name(),
762 " since it expects a reference");
763 }
764
765 string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
766 string swap_out_name = strings::StrCat("swap_out_", tensor_to_swap);
767 string swap_in_name = strings::StrCat("swap_in_", tensor_to_swap);
768 if (name_map.find(swap_out_name) != name_map.end() ||
769 name_map.find(swap_in_name) != name_map.end()) {
770 return errors::InvalidArgument("Input ", input_to_swap, " of node ",
771 node->name(), " is already swapped");
772 }
773
774 // Force the tensor to be copied to cpu.
775 NodeDef* swap_out_node = graph->add_node();
776 swap_out_node->set_name(swap_out_name);
777 swap_out_node->set_op("_CopyFromGpuToHost");
778
779 // Force the tensor to be restored to the device.
780 NodeDef* swap_in_node = graph->add_node();
781 swap_in_node->set_name(swap_in_name);
782 swap_in_node->set_op("_CopyFromHostToGpu");
783 *swap_in_node->add_input() = swap_out_node->name();
784
785 // Colocate the swap_out_ and swap_in_ nodes with the node itself.
786 swap_out_node->set_device(node->device());
787 swap_in_node->set_device(node->device());
788 string coloc_group = strings::StrCat("loc@", tensor_to_swap);
789 (*swap_out_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
790 (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
791 (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
792
793 (*swap_in_node->mutable_attr())["T"].set_type(input_type);
794 (*swap_out_node->mutable_attr())["T"].set_type(input_type);
795 *swap_pair = std::make_pair(swap_out_node, swap_in_node);
796
797 return OkStatus();
798 }
799
800 struct SwapInfo {
801 std::vector<int> inputs_to_swap;
802 Costs::NanoSeconds time_to_swap = 0;
803 };
804
FindSwapInTrigger(const NodeDef * node,const SwapInfo & swap_info,const std::unordered_map<string,const NodeDef * > & name_map,const std::unordered_map<const NodeDef *,Costs::NanoSeconds> & execution_times)805 static const NodeDef* FindSwapInTrigger(
806 const NodeDef* node, const SwapInfo& swap_info,
807 const std::unordered_map<string, const NodeDef*>& name_map,
808 const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
809 execution_times) {
810 // max_trigger_time stores the time before which the swap operation needs to
811 // be started in order to load the data back onto the accelerator without
812 // delaying the downstream computation.
813 Costs::NanoSeconds max_trigger_time(0);
814 std::set<string> possible_inputs;
815 for (int i = 0; i < node->input_size(); ++i) {
816 const string input_node_name = NodeName(node->input(i));
817 auto it1 = name_map.find(input_node_name);
818 if (it1 == name_map.end()) {
819 return nullptr;
820 }
821 const NodeDef* input_node = it1->second;
822
823 auto it2 = execution_times.find(input_node);
824 if (it2 == execution_times.end()) {
825 return nullptr;
826 }
827 max_trigger_time = std::max(max_trigger_time, it2->second);
828 possible_inputs.insert(input_node_name);
829 }
830
831 for (const int i : swap_info.inputs_to_swap) {
832 const string input_node_name = NodeName(node->input(i));
833 possible_inputs.erase(input_node_name);
834 }
835 if (possible_inputs.empty()) {
836 return nullptr;
837 }
838
839 max_trigger_time -= swap_info.time_to_swap;
840
841 std::map<Costs::NanoSeconds, const NodeDef*> candidates;
842 std::set<string> already_processed;
843
844 while (!possible_inputs.empty()) {
845 const string input_node_name = *possible_inputs.begin();
846 possible_inputs.erase(possible_inputs.begin());
847 already_processed.insert(input_node_name);
848 auto it1 = name_map.find(input_node_name);
849 if (it1 == name_map.end()) {
850 return nullptr;
851 }
852 const NodeDef* input_node = it1->second;
853 // Don't jump over frames, since adding a control dependency from one frame
854 // to the next isn't supported. Don't go through branches, since we don't
855 // know whether they'll be executed or not.
856 if (ModifiesFrameInfo(*input_node) || IsSwitch(*input_node) ||
857 IsMerge(*input_node)) {
858 continue;
859 }
860 auto it2 = execution_times.find(input_node);
861 if (it2 == execution_times.end()) {
862 return nullptr;
863 }
864 if (it2->second < max_trigger_time) {
865 candidates[it2->second] = input_node;
866 } else {
867 for (const string& fanin : input_node->input()) {
868 string name = NodeName(fanin);
869 if (already_processed.find(name) == already_processed.end()) {
870 possible_inputs.insert(name);
871 }
872 }
873 }
874 }
875
876 // Select the candidate that will execute last, since we want to swap the data
877 // back at the last minute while still allowing enough time for data to be
878 // swapped back timely to feed the downstream nodes.
879 if (!candidates.empty()) {
880 return candidates.rbegin()->second;
881 }
882 return nullptr;
883 }
884
IsSwappable(const MutableGraphView & graph,MutableGraphView::OutputPort output)885 static bool IsSwappable(const MutableGraphView& graph,
886 MutableGraphView::OutputPort output) {
887 const NodeDef& node = *output.node;
888 // There is no point in swapping out persistent tensors, since the tensor will
889 // continue to use memory.
890 if (IsPersistent(node)) {
891 return false;
892 }
893
894 const OpDef* op_def;
895 if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
896 return false;
897 }
898 DataType dtype;
899 if (!OutputTypeForNode(node, *op_def, output.port_id, &dtype).ok()) {
900 return false;
901 }
902 // References can only refer to persistent memory: therefore the node isn't
903 // swappable.
904 if (IsRefType(dtype)) {
905 return false;
906 }
907
908 if (output.node->op() == "Identity" || output.node->op() == "Reshape") {
909 // If placed on the same device, these nodes are just forwarding references
910 // to their input. Therefore they are swappable iff their fanin is swappable
911 // or it resides on a different device.
912 MutableGraphView::InputPort input;
913 input.node = output.node;
914 input.port_id = 0;
915 MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
916 if (fanin.node->device() == node.device()) {
917 return IsSwappable(graph, fanin);
918 }
919 }
920 return true;
921 }
922
FindSwapOutTrigger(const NodeDef * node,int input_id,const MutableGraphView & view,const std::unordered_map<const NodeDef *,Costs::NanoSeconds> & execution_times)923 static NodeDef* FindSwapOutTrigger(
924 const NodeDef* node, int input_id, const MutableGraphView& view,
925 const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
926 execution_times) {
927 // Find the output port that generated the tensor to swap.
928 MutableGraphView::InputPort swap;
929 swap.node = const_cast<NodeDef*>(node);
930 swap.port_id = input_id;
931 MutableGraphView::OutputPort generator = view.GetRegularFanin(swap);
932 if (!generator.node) {
933 return nullptr;
934 }
935
936 const absl::flat_hash_set<MutableGraphView::InputPort>& fanout =
937 view.GetFanout(generator);
938 NodeDef* trigger = nullptr;
939 Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity());
940
941 for (const auto& port : fanout) {
942 if (port.node == node) {
943 continue;
944 }
945 auto it = execution_times.find(port.node);
946 if (it != execution_times.end() && it->second < earliest_fanout) {
947 earliest_fanout = it->second;
948 trigger = port.node;
949 }
950 }
951
952 return trigger;
953 }
954
IsSwappable(MutableGraphView::InputPort input)955 static bool IsSwappable(MutableGraphView::InputPort input) {
956 const NodeDef& node = *input.node;
957
958 const OpDef* op_def;
959 if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
960 return false;
961 }
962
963 DataType dtype;
964 if (!InputTypeForNode(node, *op_def, input.port_id, &dtype).ok()) {
965 return false;
966 }
967
968 return !IsRefType(dtype);
969 }
970
971 struct MemInfo {
972 MutableGraphView::OutputPort port;
973 int64_t memory_used;
974 std::vector<MutableGraphView::InputPort> uses_left;
975 double fitness;
976
operator <tensorflow::grappler::__anondeac83630111::MemInfo977 bool operator<(const MemInfo& other) const { return fitness < other.fitness; }
978 };
979
IdentifySwappingCandidates(Cluster * cluster,GrapplerItem * item,std::unique_ptr<GraphMemory> * memory_ptr,std::unordered_set<string> * skip_list,std::unordered_map<NodeDef *,SwapInfo> * nodes_to_swap)980 static bool IdentifySwappingCandidates(
981 Cluster* cluster, GrapplerItem* item,
982 std::unique_ptr<GraphMemory>* memory_ptr,
983 std::unordered_set<string>* skip_list,
984 std::unordered_map<NodeDef*, SwapInfo>* nodes_to_swap) {
985 if ((*memory_ptr) == nullptr) {
986 memory_ptr->reset(new GraphMemory(*item));
987 Status s = (*memory_ptr)->InferStatically(cluster->GetDevices());
988 if (!s.ok()) {
989 memory_ptr->reset();
990 VLOG(1) << "Failed to infer memory usage: " << s.error_message();
991 return false;
992 }
993 }
994 const GraphMemory& memory = **memory_ptr;
995
996 bool updated_graph = false;
997 for (const auto& device : cluster->GetDevices()) {
998 const string& name = device.first;
999 const DeviceProperties& prop = device.second;
1000 if (prop.type() != "GPU") {
1001 continue;
1002 }
1003 if (prop.memory_size() <= 0) {
1004 VLOG(1) << "Peak memory usage unknown for device " << name;
1005 continue;
1006 }
1007 const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
1008
1009 if (mem_usage.used_memory <= prop.memory_size()) {
1010 continue;
1011 }
1012 int64_t required_savings = mem_usage.used_memory - prop.memory_size();
1013
1014 std::unordered_map<string, Costs::NanoSeconds> op_completion_times;
1015 {
1016 VirtualCluster vcluster(cluster->GetDevices());
1017 if (!vcluster.Provision().ok()) {
1018 return false;
1019 }
1020 if (!vcluster.Initialize(*item).ok()) {
1021 return false;
1022 }
1023 RunMetadata metadata;
1024 Status s = vcluster.Run(item->graph, item->feed, item->fetch, &metadata);
1025 if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
1026 return false;
1027 }
1028
1029 for (const auto& dev_stats : metadata.step_stats().dev_stats()) {
1030 for (const auto& node_stats : dev_stats.node_stats()) {
1031 Costs::NanoSeconds exec_time =
1032 Costs::NanoSeconds(1) +
1033 Costs::MicroSeconds(node_stats.all_start_micros() +
1034 node_stats.op_end_rel_micros());
1035 op_completion_times.emplace(node_stats.node_name(), exec_time);
1036 }
1037 }
1038 }
1039
1040 Costs::Duration peak_time = -1;
1041 for (const auto& live_tensor : mem_usage.live_tensors) {
1042 if (live_tensor.allocation_time > peak_time) {
1043 peak_time = live_tensor.allocation_time;
1044 }
1045 }
1046
1047 std::vector<MemInfo> mem_state;
1048
1049 MutableGraphView graph(&item->graph);
1050 for (const auto& live_tensor : mem_usage.live_tensors) {
1051 if (live_tensor.memory_used <= 1024) {
1052 // Don't bother with small tensors.
1053 continue;
1054 }
1055 if (live_tensor.deallocation_time - live_tensor.allocation_time <=
1056 Costs::Duration(1e6)) {
1057 // Not enough time to swap.
1058 VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node;
1059 continue;
1060 }
1061
1062 if (skip_list->find(live_tensor.node) != skip_list->end()) {
1063 continue;
1064 }
1065 MutableGraphView::OutputPort port =
1066 graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
1067 if (!IsSwappable(graph, port)) {
1068 continue;
1069 }
1070 MemInfo mem_info;
1071 mem_info.port = port;
1072 mem_info.memory_used = live_tensor.memory_used;
1073 Costs::Duration allocation_time = live_tensor.allocation_time;
1074 Costs::Duration earliest_use(Costs::Duration::infinity());
1075 bool valid = true;
1076 for (MutableGraphView::InputPort input : graph.GetFanout(port)) {
1077 // Get execution time.
1078 auto it = op_completion_times.find(input.node->name());
1079 if (it == op_completion_times.end()) {
1080 valid = false;
1081 break;
1082 }
1083 if (it->second <= peak_time) {
1084 continue;
1085 }
1086
1087 if (skip_list->find(input.node->name()) != skip_list->end()) {
1088 valid = false;
1089 break;
1090 }
1091 string input_name =
1092 strings::StrCat(input.node->name(), ":", input.port_id);
1093 if (skip_list->find(input_name) != skip_list->end()) {
1094 valid = false;
1095 break;
1096 }
1097 if (!IsSwappable(input)) {
1098 valid = false;
1099 break;
1100 }
1101
1102 // Set earliest use time that's after peak.
1103 mem_info.uses_left.emplace_back(input);
1104 earliest_use = std::min(earliest_use, it->second);
1105 }
1106 if (valid && !mem_info.uses_left.empty()) {
1107 // Compute the fitness: we need the tensor to be generated way away of
1108 // the time of peak memory usage (to ensure there is enough time to swap
1109 // it out). We also need to ensure it's used way after the peak time, to
1110 // ensure that swapping the tensor back in won't recreate the memory
1111 // bottleneck. Last but not least, we want the tensor to have as few
1112 // remaining uses as possible.
1113 //
1114 // Note that we must perform the arithmetic inexactly as "double", since
1115 // the values do not fit into any integral type.
1116 mem_info.fitness =
1117 MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) /
1118 MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
1119 MathUtil::IPow<double>((allocation_time - peak_time).count(), 2);
1120 mem_info.fitness = -mem_info.fitness;
1121 mem_state.push_back(mem_info);
1122 }
1123 }
1124
1125 // Sort by fitness
1126 std::sort(mem_state.begin(), mem_state.end());
1127
1128 for (const MemInfo& mem_info : mem_state) {
1129 for (const MutableGraphView::InputPort fanout_to_swap :
1130 mem_info.uses_left) {
1131 VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":"
1132 << fanout_to_swap.port_id << " of tensor "
1133 << mem_info.port.node->name() << ":" << mem_info.port.port_id
1134 << " of size " << mem_info.memory_used;
1135
1136 (*nodes_to_swap)[fanout_to_swap.node].inputs_to_swap.push_back(
1137 fanout_to_swap.port_id);
1138 }
1139 required_savings -= mem_info.memory_used;
1140 updated_graph = true;
1141 if (required_savings < 0) {
1142 break;
1143 }
1144 }
1145 }
1146 return updated_graph;
1147 }
1148
SwappingPass(RewriterConfig::MemOptType optimization_level,Cluster * cluster,std::unique_ptr<GraphMemory> * memory,GrapplerItem * item,std::unordered_set<string> * skip_list)1149 bool SwappingPass(RewriterConfig::MemOptType optimization_level,
1150 Cluster* cluster, std::unique_ptr<GraphMemory>* memory,
1151 GrapplerItem* item, std::unordered_set<string>* skip_list) {
1152 std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
1153 if (optimization_level == RewriterConfig::DEFAULT_MEM_OPT ||
1154 optimization_level == RewriterConfig::SWAPPING_HEURISTICS ||
1155 optimization_level == RewriterConfig::HEURISTICS) {
1156 // Use heuristics to figure out what needs to be swapped;
1157 IdentifySwappingCandidates(cluster, item, memory, skip_list,
1158 &nodes_to_swap);
1159 }
1160 // Look for manual annotations in the graph.
1161 for (auto& node : *item->graph.mutable_node()) {
1162 if (node.attr().count("_swap_to_host") != 0) {
1163 SwapInfo& swap_info = nodes_to_swap[&node];
1164 const AttrValue& val = node.attr().at("_swap_to_host");
1165 if (val.has_list()) {
1166 for (int64_t input_id : val.list().i()) {
1167 swap_info.inputs_to_swap.push_back(input_id);
1168 }
1169 } else {
1170 int64_t input_id = val.i();
1171 swap_info.inputs_to_swap.push_back(input_id);
1172 }
1173 }
1174 }
1175 if (nodes_to_swap.empty()) {
1176 // Nothing to do.
1177 return false;
1178 }
1179
1180 // Estimate the size of the data to swap for each node.
1181 GraphProperties properties(*item);
1182 if (!properties
1183 .InferStatically(/*assume_valid_feeds=*/true,
1184 /*aggressive_shape_inference=*/false,
1185 /*include_tensor_values=*/false)
1186 .ok()) {
1187 return false;
1188 }
1189 for (auto& swap : nodes_to_swap) {
1190 const NodeDef* node = swap.first;
1191 const std::vector<OpInfo::TensorProperties>& props =
1192 properties.GetInputProperties(node->name());
1193 SwapInfo& swap_info = swap.second;
1194 int64_t bytes_to_swap = 0;
1195 for (int64_t input_id : swap_info.inputs_to_swap) {
1196 const OpInfo::TensorProperties& t = props[input_id];
1197 bytes_to_swap += CalculateTensorSize(t);
1198 }
1199 // Let's assume we're going to swap over PCIe running at 16 GBps.
1200 swap_info.time_to_swap = bytes_to_swap / 16;
1201 }
1202
1203 std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
1204 if (!EstimateEarliestExecutionTimes(*item, cluster, &execution_times).ok()) {
1205 return false;
1206 }
1207
1208 std::unordered_map<string, const NodeDef*> name_map;
1209 for (const auto& node : item->graph.node()) {
1210 name_map[node.name()] = &node;
1211 }
1212 MutableGraphView view(&item->graph);
1213
1214 bool updated_graph = false;
1215
1216 for (auto& swap : nodes_to_swap) {
1217 NodeDef* node = swap.first;
1218 const SwapInfo& swap_info = swap.second;
1219 if (skip_list->find(node->name()) != skip_list->end()) {
1220 continue;
1221 }
1222
1223 // Make sure the tensor isn't swapped back in right away: look for node that
1224 // will execute just before we need to swap the data back, and add a control
1225 // dependency from that node to the swap node.
1226 const NodeDef* in_trigger =
1227 FindSwapInTrigger(node, swap_info, name_map, execution_times);
1228 // If we failed, don't attempt to reprocess this node in a subsequent pass.
1229 if (!in_trigger) {
1230 skip_list->insert(node->name());
1231 continue;
1232 }
1233
1234 // Swap all the tensors that are marked with the 'swap_to_host' attribute.
1235 for (int input_id : swap_info.inputs_to_swap) {
1236 string input_name = strings::StrCat(node->name(), ":", input_id);
1237 if (skip_list->find(input_name) != skip_list->end()) {
1238 continue;
1239 } else {
1240 // Don't attempt to reprocess this input in a subsequent pass.
1241 skip_list->insert(input_name);
1242 }
1243
1244 // Make sure the tensor is swapped out quickly: look for node that
1245 // will execute just after the tensor is generated and add a control
1246 // dependency from the swap out node to that node.
1247 NodeDef* out_trigger =
1248 FindSwapOutTrigger(node, input_id, view, execution_times);
1249 if (!out_trigger) {
1250 continue;
1251 }
1252
1253 std::pair<NodeDef*, NodeDef*> swap_nodes;
1254 if (!BuildSwapPair(node, input_id, name_map, &item->graph, &swap_nodes)
1255 .ok()) {
1256 continue;
1257 }
1258 *swap_nodes.first->add_input() = node->input(input_id);
1259 *node->mutable_input(input_id) = swap_nodes.second->name();
1260
1261 // Add the control dependencies needed to delay the execution of the swap.
1262 out_trigger->add_input(strings::StrCat("^", swap_nodes.first->name()));
1263 swap_nodes.second->add_input(strings::StrCat("^", in_trigger->name()));
1264
1265 // Make sure we won't try to swap the swap nodes in subsequent passes.
1266 skip_list->insert(swap_nodes.first->name());
1267 skip_list->insert(swap_nodes.second->name());
1268 }
1269 }
1270 return updated_graph;
1271 }
1272
CrossesTaskOrCpuGpuBoundary(const NodeDef & node1,const NodeDef & node2)1273 bool CrossesTaskOrCpuGpuBoundary(const NodeDef& node1, const NodeDef& node2) {
1274 string task1;
1275 string device1;
1276 DeviceNameUtils::SplitDeviceName(node1.device(), &task1, &device1);
1277 string task2;
1278 string device2;
1279 DeviceNameUtils::SplitDeviceName(node2.device(), &task2, &device2);
1280 return task1 != task2 ||
1281 (absl::StrContains(device1, DEVICE_CPU) &&
1282 absl::StrContains(device2, DEVICE_GPU)) ||
1283 (absl::StrContains(device1, DEVICE_GPU) &&
1284 absl::StrContains(device2, DEVICE_CPU));
1285 }
1286
RelaxAssignNodes(const std::set<int> & nodes_to_relax,GraphDef * optimized_graph)1287 void RelaxAssignNodes(const std::set<int>& nodes_to_relax,
1288 GraphDef* optimized_graph) {
1289 for (int idx : nodes_to_relax) {
1290 // Set an attribute telling AssignOp to ignore allocator constraints.
1291 NodeDef* assign_node = optimized_graph->mutable_node(idx);
1292 (*assign_node->mutable_attr())["_grappler_relax_allocator_constraints"]
1293 .set_b(true);
1294 }
1295 }
1296
1297 // TODO(rmlarsen): Add distributed TF test.
FindAssignNodesToRelax(const GraphDef & graph,std::set<int> * nodes_to_relax)1298 Status FindAssignNodesToRelax(const GraphDef& graph,
1299 std::set<int>* nodes_to_relax) {
1300 std::unordered_set<string> devices;
1301 std::vector<int> assign_nodes;
1302 bool found_send = false;
1303 for (int i = 0; i < graph.node_size(); ++i) {
1304 const NodeDef& node = graph.node(i);
1305 devices.insert(node.device());
1306 if (IsAssign(node)) {
1307 assign_nodes.push_back(i);
1308 }
1309 if (IsSend(node)) {
1310 found_send = true;
1311 break;
1312 }
1313 }
1314 if (!found_send && devices.size() == 1) {
1315 nodes_to_relax->insert(assign_nodes.begin(), assign_nodes.end());
1316 return OkStatus();
1317 }
1318
1319 GraphTopologyView graph_view;
1320 TF_RETURN_IF_ERROR(
1321 graph_view.InitializeFromGraph(graph, /*ignore_control_edges=*/true));
1322 std::unordered_set<const NodeDef*> optimized_nodes;
1323
1324 for (int i : assign_nodes) {
1325 const NodeDef& assign_node = graph.node(i);
1326
1327 if (optimized_nodes.find(&assign_node) == optimized_nodes.end()) {
1328 std::vector<const NodeDef*> assign_nodes_in_fanout;
1329 optimized_nodes.insert(&assign_node);
1330 assign_nodes_in_fanout.push_back(&assign_node);
1331
1332 std::vector<const NodeDef*> transitive_fanout;
1333 // Find the nodes in transitive fanout. If a node is known to never
1334 // forward its inputs, we can skip its fanout.
1335 DfsTraversal(graph_view, {graph_view.GetNode(i)},
1336 TraversalDirection::kFollowOutputs,
1337 DfsPredicates::Advance([&](const NodeDef* node) {
1338 return !NeverForwardsInputs(*node);
1339 }),
1340 DfsCallbacks::PreOrder([&](const NodeDef* node) {
1341 transitive_fanout.push_back(node);
1342 }));
1343
1344 bool relax_constraint = true;
1345 // If all nodes in the transitive fanout are on the same device as the
1346 // assign node, there is no need to allocate the output in pinned memory.
1347 for (const NodeDef* fanout_node : transitive_fanout) {
1348 if (relax_constraint &&
1349 (IsSend(*fanout_node) ||
1350 CrossesTaskOrCpuGpuBoundary(*fanout_node, assign_node))) {
1351 relax_constraint = false;
1352 break;
1353 }
1354 if (optimized_nodes.find(fanout_node) == optimized_nodes.end() &&
1355 IsAssign(*fanout_node)) {
1356 assign_nodes_in_fanout.push_back(fanout_node);
1357 }
1358 }
1359
1360 if (relax_constraint) {
1361 for (const NodeDef* assign_node_in_fanout : assign_nodes_in_fanout) {
1362 // If all devices match in fanout of node(i) then, by transitivity,
1363 // they must also match in the fanout of other assign nodes
1364 // in the fanout of node(i), so we can process them here,
1365 // and save computing their transitive fanout later.
1366 optimized_nodes.insert(assign_node_in_fanout);
1367
1368 // Set an attribute telling AssignOp to ignore allocator constraints.
1369 const absl::optional<int> assign_node_idx =
1370 graph_view.GetNodeIndex(*assign_node_in_fanout);
1371 nodes_to_relax->insert(assign_node_idx.value());
1372 }
1373 }
1374 }
1375 }
1376 return OkStatus();
1377 }
1378
1379 } // namespace
1380
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)1381 Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
1382 GraphDef* optimized_graph) {
1383 std::set<int> nodes_to_relax;
1384 TF_RETURN_IF_ERROR(FindAssignNodesToRelax(item.graph, &nodes_to_relax));
1385
1386 bool run_recomputation_pass =
1387 (optimization_level_ == RewriterConfig::RECOMPUTATION_HEURISTICS ||
1388 optimization_level_ == RewriterConfig::HEURISTICS ||
1389 optimization_level_ == RewriterConfig::MANUAL);
1390 if (!run_recomputation_pass && nodes_to_relax.empty() && item.fetch.empty()) {
1391 return errors::Aborted("Nothing to do.");
1392 }
1393
1394 GrapplerItem optimized_item(item);
1395 RelaxAssignNodes(nodes_to_relax, &optimized_item.graph);
1396
1397 if (run_recomputation_pass) {
1398 RecomputationRewritingPass(optimization_level_,
1399 recomputation_targets_name_scope_,
1400 &optimized_item.graph, item);
1401 }
1402
1403 std::unordered_set<string> skip_list;
1404 // Bound the number of rewrite passes to avoid long processing times on graphs
1405 // that simply won't fit in memory.
1406 // SchedulingPass() and SwappingPass() rely on defined fetches in order to
1407 // infer the memory usage, so skip optimization if there are no fetches.
1408 std::unique_ptr<GraphMemory> memory;
1409 if (!item.fetch.empty() && cluster != nullptr) {
1410 bool updated_graph = true;
1411 for (int i = 0; i < 25 && updated_graph; ++i) {
1412 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
1413 updated_graph = false;
1414 if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
1415 optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
1416 optimization_level_ == RewriterConfig::HEURISTICS) &&
1417 cluster != nullptr) {
1418 if (SchedulingPass(cluster, &memory, &optimized_item)) {
1419 // Reset the inferred memory usage since the graph changed.
1420 memory.reset();
1421 updated_graph = true;
1422 }
1423 }
1424
1425 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
1426 if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
1427 optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS ||
1428 optimization_level_ == RewriterConfig::HEURISTICS ||
1429 optimization_level_ == RewriterConfig::MANUAL) &&
1430 cluster != nullptr) {
1431 if (SwappingPass(optimization_level_, cluster, &memory, &optimized_item,
1432 &skip_list)) {
1433 // Reset the inferred memory usage since the graph changed.
1434 memory.reset();
1435 updated_graph = true;
1436 }
1437 }
1438 }
1439 }
1440
1441 optimized_graph->Swap(&optimized_item.graph);
1442 return OkStatus();
1443 }
1444
1445 } // end namespace grappler
1446 } // end namespace tensorflow
1447