1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/auto_shard.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/dataset.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/metrics.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/grappler/clusters/cluster.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/mutable_graph_view.h"
33 #include "tensorflow/core/grappler/op_types.h"
34 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
35 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
36 #include "tensorflow/core/grappler/utils/functions.h"
37 #include "tensorflow/core/kernels/data/shard_dataset_op.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/platform/errors.h"
40
41 namespace tensorflow {
42 namespace grappler {
43 namespace {
44
45 using tensorflow::data::AutoShardPolicy;
46
47 constexpr char kAssertCardinalityDatasetOpName[] = "AssertCardinalityDataset";
48 constexpr char kBatchDatasetOpName[] = "BatchDataset";
49 constexpr char kBatchDatasetV2OpName[] = "BatchDatasetV2";
50 constexpr char kMapAndBatchDatasetOpName[] = "MapAndBatchDataset";
51 constexpr char kMapDatasetOpName[] = "MapDataset";
52 constexpr char kShardDatasetOpName[] = "ShardDataset";
53 constexpr char kShuffleDatasetOpName[] = "ShuffleDataset";
54 constexpr char kShuffleDatasetV2OpName[] = "ShuffleDatasetV2";
55 constexpr char kShuffleDatasetV3OpName[] = "ShuffleDatasetV3";
56 constexpr char kParallelBatchDatasetOpName[] = "ParallelBatchDataset";
57 constexpr char kPrefetchDatasetOpName[] = "PrefetchDataset";
58 constexpr char kFinalizeDatasetOpName[] = "FinalizeDataset";
59 constexpr char kOptionsDatasetOpName[] = "OptionsDataset";
60 constexpr char kRebatchDatasetOpName[] = "RebatchDataset";
61 constexpr char kRebatchDatasetV2OpName[] = "RebatchDatasetV2";
62 constexpr char kTensorDatasetOpName[] = "TensorDataset";
63 constexpr char kTensorSliceDatasetOpName[] = "TensorSliceDataset";
64 constexpr char kPlaceholderOpName[] = "Placeholder";
65 constexpr char kConstOpName[] = "Const";
66
67 constexpr char kNumWorkersAttrName[] = "num_workers";
68 constexpr char kNumReplicasAttrName[] = "num_replicas";
69 constexpr char kIndexAttrName[] = "index";
70 constexpr char kAutoShardPolicyAttrName[] = "auto_shard_policy";
71 constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
72 constexpr char kOutputShapes[] = "output_shapes";
73 constexpr char kOutputTypes[] = "output_types";
74
75 // clang-format off
76 constexpr std::array<const char*, 5> kReaderDatasetOps = {
77 "FixedLengthRecordDataset",
78 "RecordIODataset",
79 "SSTableDataset",
80 "TextLineDataset",
81 "TFRecordDataset"
82 };
83
84 constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
85 "ConcatenateDataset",
86 "ZipDataset"
87 };
88
89 constexpr std::array<const char*, 31> kPassThroughOps = {
90 "_Retval",
91 "AssertNextDataset",
92 "BatchDataset",
93 "CacheDataset",
94 "ExperimentalMapAndBatchDataset",
95 "ExperimentalParseExampleDataset",
96 "ExperimentalRebatchDataset",
97 "FilterDataset",
98 "FinalizeDataset",
99 "Identity",
100 "MapAndBatchDataset",
101 "MapDataset",
102 "MaxIntraOpParallelismDataset",
103 "ModelDataset",
104 "OptimizeDataset",
105 "OptionsDataset",
106 "PaddedBatchDataset",
107 "ParallelBatchDataset",
108 "ParallelMapDataset",
109 "ParseExampleDataset",
110 "PrefetchDataset",
111 "PrivateThreadPoolDataset",
112 "ReduceDataset",
113 "RebatchDataset",
114 "RepeatDataset",
115 "ShardDataset",
116 "ShuffleAndRepeatDataset",
117 "ShuffleDataset",
118 "SkipDataset",
119 "TakeDataset",
120 "WindowDataset",
121 };
122
123 // TODO(frankchn): Process functions within kFuncDatasetOps as well.
124 constexpr std::array<const char*, 5> kFuncDatasetOps = {
125 "ExperimentalParallelInterleaveDataset",
126 "FlatMapDataset",
127 "InterleaveDataset",
128 "LegacyParallelInterleaveDataset",
129 "ParallelInterleaveDataset",
130 };
131
132 constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {
133 "GeneratorDataset",
134 "RangeDataset",
135 "SparseTensorsSliceDataset",
136 "TensorDataset",
137 "TensorSliceDataset",
138 };
139
140 // The semantics of these ops are not affected by the change of the batch
141 // size. There are three categories:
142 // 1. The op doesn't change the elements of the dataset, e.g. CacheDataset and
143 // all ops that sets options.
144 // 2. The op is dataset-element-wise transformation which is orthogonoal to
145 // the batch size, e.g. ParseExampleDataset.
146 // 3. RebatchDataset. This is a special case. RebatchDataset is added by
147 // tf.distribute at the end of the input pipeline and will be specially
148 // handled.
149 constexpr std::array<const char*, 20> kBatchSizeOrthogonalDatasetOps = {
150 "AssertCardinalityDataset",
151 "AssertNextDataset",
152 "BytesProducedStatsDataset",
153 "CacheDataset",
154 "FinalizeDataset",
155 "Identity",
156 "LatencyStatsDataset",
157 "MaxIntraOpParallelismDataset",
158 "ModelDataset",
159 "NonSerializableDataset",
160 "OptimizeDataset",
161 "OptionsDataset",
162 "ParseExampleDataset",
163 "PrefetchDataset",
164 "PrivateThreadPoolDataset",
165 "RebatchDataset",
166 "RepeatDataset",
167 "SetStatsAggregatorDataset",
168 "SleepDataset",
169 "ThreadPoolDataset",
170 };
171
172 constexpr std::array<const char*, 3> kBatchDatasetOps = {
173 kBatchDatasetOpName,
174 kMapAndBatchDatasetOpName,
175 kParallelBatchDatasetOpName,
176 };
177
178 // clang-format on
179
180 Status OptimizeGraph(const GrapplerItem& item, int64_t num_workers,
181 int64_t index, AutoShardPolicy policy,
182 int64_t num_replicas, GraphDef* output,
183 AutoShardPolicy* policy_applied);
184
185 template <std::size_t SIZE>
IsDatasetNodeOfType(const NodeDef & node,const std::array<const char *,SIZE> & arr)186 bool IsDatasetNodeOfType(const NodeDef& node,
187 const std::array<const char*, SIZE>& arr) {
188 for (const auto& dataset_op_name : arr) {
189 if (tensorflow::data::MatchesAnyVersion(/*op_prefix=*/dataset_op_name,
190 /*op_to_match=*/node.op())) {
191 return true;
192 }
193 }
194 return false;
195 }
196
197 // Adds a ShardDataset node before `add_before`.
AddShardNode(MutableGraphView * graph,const NodeDef & add_before,int64_t num_workers,int64_t index)198 Status AddShardNode(MutableGraphView* graph, const NodeDef& add_before,
199 int64_t num_workers, int64_t index) {
200 NodeDef new_node;
201 new_node.set_op(kShardDatasetOpName);
202 graph_utils::SetUniqueGraphNodeName(kShardDatasetOpName, graph->graph(),
203 &new_node);
204
205 // Construct argument nodes
206 NodeDef* num_shards_node =
207 graph_utils::AddScalarConstNode<int64_t>(num_workers, graph);
208 NodeDef* index_node = graph_utils::AddScalarConstNode<int64_t>(index, graph);
209
210 // Add inputs to new node
211 new_node.add_input(add_before.input(0));
212 new_node.add_input(num_shards_node->name());
213 new_node.add_input(index_node->name());
214
215 // Ensure that each shard will have at least one element.
216 (*(new_node.mutable_attr()))[data::ShardDatasetOp::kRequireNonEmpty].set_b(
217 true);
218
219 // Add shapes and other attributes
220 NodeDef* add_after = graph->GetNode(add_before.input(0));
221
222 if (absl::StrContains(add_after->op(), "Dataset")) {
223 // We still may or may not have the right attributes because Datasets like
224 // TFRecordDataset doesn't have a output type or shape, and by default we
225 // set them to DT_STRING and an unknown shape.
226 if (add_after->attr().count(kOutputShapes) > 0) {
227 graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
228 } else {
229 tensorflow::TensorShapeProto* shape =
230 (*(new_node.mutable_attr()))[kOutputShapes]
231 .mutable_list()
232 ->add_shape();
233 shape->set_unknown_rank(true);
234 }
235
236 if (add_after->attr().count(kOutputTypes) > 0) {
237 graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
238 } else if (add_after->attr().count("Toutput_types") > 0) {
239 (*(new_node.mutable_attr()))[kOutputTypes] =
240 add_after->attr().at("Toutput_types");
241 } else {
242 (*(new_node.mutable_attr()))[kOutputTypes].mutable_list()->add_type(
243 tensorflow::DataType::DT_STRING);
244 }
245 } else {
246 // TODO(frankchn): Make this work for datasets where input(0) is a Const,
247 // and we need to shard the Const.
248 // This is probably not a dataset, so we bail because we can't infer the
249 // output types and shape.
250 return errors::NotFound(
251 "Unable to shard this input. You may need to wrap the inputs to your "
252 "reader dataset in a TensorSliceDataset. Input node is ",
253 add_after->DebugString());
254 }
255
256 // Add new node into graph and update edges
257 NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
258 TF_RETURN_IF_ERROR(
259 graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
260
261 return OkStatus();
262 }
263
AddShuffleDataset(MutableGraphView * graph,const NodeDef & add_before,const string & buffer_size_node,const string & seed_node,const string & seed2_node,bool reshuffle_each_iteration)264 Status AddShuffleDataset(MutableGraphView* graph, const NodeDef& add_before,
265 const string& buffer_size_node,
266 const string& seed_node, const string& seed2_node,
267 bool reshuffle_each_iteration) {
268 NodeDef* add_after = graph->GetNode(add_before.input(0));
269 NodeDef new_node;
270 new_node.set_op(kShuffleDatasetOpName);
271 graph_utils::SetUniqueGraphNodeName(kShuffleDatasetOpName, graph->graph(),
272 &new_node);
273
274 new_node.add_input(add_before.input(0));
275 new_node.add_input(buffer_size_node);
276 new_node.add_input(seed_node);
277 new_node.add_input(seed2_node);
278
279 graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
280 graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
281
282 AttrValue reshuffle_attr;
283 reshuffle_attr.set_b(reshuffle_each_iteration);
284 (*new_node.mutable_attr())[kReshuffleEachIteration] = reshuffle_attr;
285
286 NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
287
288 TF_RETURN_IF_ERROR(
289 graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
290 return OkStatus();
291 }
292
AddShuffleDatasetV2(MutableGraphView * graph,const NodeDef & add_before,const string & buffer_size_node,const string & seed_generator_node)293 Status AddShuffleDatasetV2(MutableGraphView* graph, const NodeDef& add_before,
294 const string& buffer_size_node,
295 const string& seed_generator_node) {
296 NodeDef* add_after = graph->GetNode(add_before.input(0));
297 NodeDef new_node;
298 new_node.set_op(kShuffleDatasetV2OpName);
299 graph_utils::SetUniqueGraphNodeName(kShuffleDatasetV2OpName, graph->graph(),
300 &new_node);
301
302 new_node.add_input(add_before.input(0));
303 new_node.add_input(buffer_size_node);
304 new_node.add_input(seed_generator_node);
305
306 graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
307 graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
308
309 NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
310
311 TF_RETURN_IF_ERROR(
312 graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
313 return OkStatus();
314 }
315
AddShuffleDatasetV3(MutableGraphView * graph,const NodeDef & add_before,const string & buffer_size_node,const string & seed_node,const string & seed2_node,const string & seed_generator_node,bool reshuffle_each_iteration)316 Status AddShuffleDatasetV3(MutableGraphView* graph, const NodeDef& add_before,
317 const string& buffer_size_node,
318 const string& seed_node, const string& seed2_node,
319 const string& seed_generator_node,
320 bool reshuffle_each_iteration) {
321 NodeDef* add_after = graph->GetNode(add_before.input(0));
322 NodeDef new_node;
323 new_node.set_op(kShuffleDatasetV3OpName);
324 graph_utils::SetUniqueGraphNodeName(kShuffleDatasetV3OpName, graph->graph(),
325 &new_node);
326
327 new_node.add_input(add_before.input(0));
328 new_node.add_input(buffer_size_node);
329 new_node.add_input(seed_node);
330 new_node.add_input(seed2_node);
331 new_node.add_input(seed_generator_node);
332
333 graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
334 graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
335
336 AttrValue reshuffle_attr;
337 reshuffle_attr.set_b(reshuffle_each_iteration);
338 (*new_node.mutable_attr())[kReshuffleEachIteration] = reshuffle_attr;
339
340 NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
341
342 TF_RETURN_IF_ERROR(
343 graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
344 return OkStatus();
345 }
346
ReaderOpInFunction(const NodeDef & node,const FunctionLibraryDefinition & flib)347 bool ReaderOpInFunction(const NodeDef& node,
348 const FunctionLibraryDefinition& flib) {
349 const FunctionDef* func = flib.Find(node.attr().at("f").func().name());
350 for (int i = 0; i < func->node_def_size(); i++) {
351 NodeDef node_in_func = func->node_def(i);
352 if (IsDatasetNodeOfType(node_in_func, kReaderDatasetOps) &&
353 node_in_func.input_size() > 0 &&
354 absl::StartsWith(node_in_func.input(0), "args_0")) {
355 return true;
356 }
357 if (IsDatasetNodeOfType(func->node_def(i), kFuncDatasetOps) &&
358 ReaderOpInFunction(func->node_def(i), flib)) {
359 return true;
360 }
361 }
362 return false;
363 }
364
RemoveShuffleDataset(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,string * op_name,string * buffer_size_node,string * seed_node,string * seed2_node,bool * reshuffle_each_iteration)365 Status RemoveShuffleDataset(MutableGraphView* graph, const NodeDef& node,
366 absl::flat_hash_set<string>* nodes_to_delete,
367 string* op_name, string* buffer_size_node,
368 string* seed_node, string* seed2_node,
369 bool* reshuffle_each_iteration) {
370 if (node.op() == kShuffleDatasetOpName) {
371 *op_name = node.op();
372 *buffer_size_node = node.input(1);
373 *seed_node = node.input(2);
374 *seed2_node = node.input(3);
375 *reshuffle_each_iteration = node.attr().at(kReshuffleEachIteration).b();
376 TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
377 nodes_to_delete->insert(node.name());
378 }
379
380 for (const auto& fanin : graph->GetFanins(node, true)) {
381 TF_RETURN_IF_ERROR(RemoveShuffleDataset(
382 graph, *fanin.node, nodes_to_delete, op_name, buffer_size_node,
383 seed_node, seed2_node, reshuffle_each_iteration));
384 }
385
386 // TODO(frankchn): Traverse functions too.
387 return OkStatus();
388 }
389
RemoveShuffleDatasetV2(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,string * op_name,string * buffer_size_node,string * seed_generator_node)390 Status RemoveShuffleDatasetV2(MutableGraphView* graph, const NodeDef& node,
391 absl::flat_hash_set<string>* nodes_to_delete,
392 string* op_name, string* buffer_size_node,
393 string* seed_generator_node) {
394 if (node.op() == kShuffleDatasetV2OpName) {
395 *op_name = node.op();
396 *buffer_size_node = node.input(1);
397 *seed_generator_node = node.input(2);
398 TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
399 nodes_to_delete->insert(node.name());
400 }
401
402 for (const auto& fanin : graph->GetFanins(node, true)) {
403 TF_RETURN_IF_ERROR(
404 RemoveShuffleDatasetV2(graph, *fanin.node, nodes_to_delete, op_name,
405 buffer_size_node, seed_generator_node));
406 }
407
408 // TODO(frankchn): Traverse functions too.
409 return OkStatus();
410 }
411
RemoveShuffleDatasetV3(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,string * op_name,string * buffer_size_node,string * seed_node,string * seed2_node,string * seed_generator_node,bool * reshuffle_each_iteration)412 Status RemoveShuffleDatasetV3(MutableGraphView* graph, const NodeDef& node,
413 absl::flat_hash_set<string>* nodes_to_delete,
414 string* op_name, string* buffer_size_node,
415 string* seed_node, string* seed2_node,
416 string* seed_generator_node,
417 bool* reshuffle_each_iteration) {
418 if (node.op() == kShuffleDatasetV3OpName) {
419 *op_name = node.op();
420 *buffer_size_node = node.input(1);
421 *seed_node = node.input(2);
422 *seed2_node = node.input(3);
423 *seed_generator_node = node.input(4);
424 *reshuffle_each_iteration = node.attr().at(kReshuffleEachIteration).b();
425 TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
426 nodes_to_delete->insert(node.name());
427 }
428
429 for (const auto& fanin : graph->GetFanins(node, true)) {
430 TF_RETURN_IF_ERROR(RemoveShuffleDatasetV3(
431 graph, *fanin.node, nodes_to_delete, op_name, buffer_size_node,
432 seed_node, seed2_node, seed_generator_node, reshuffle_each_iteration));
433 }
434
435 // TODO(frankchn): Traverse functions too.
436 return OkStatus();
437 }
438
ProcessDatasetSourceNode(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,int64_t num_workers,int64_t index)439 Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node,
440 absl::flat_hash_set<string>* nodes_to_delete,
441 int64_t num_workers, int64_t index) {
442 string shuffle_op_name = "";
443 string buffer_size_node = "";
444 string seed_node = "";
445 string seed2_node = "";
446 string seed_generator_node = "";
447 bool reshuffle_each_iteration;
448
449 TF_RETURN_IF_ERROR(AddShardNode(graph, node, num_workers, index));
450 TF_RETURN_IF_ERROR(RemoveShuffleDataset(
451 graph, node, nodes_to_delete, &shuffle_op_name, &buffer_size_node,
452 &seed_node, &seed2_node, &reshuffle_each_iteration));
453 if (shuffle_op_name.empty()) {
454 TF_RETURN_IF_ERROR(
455 RemoveShuffleDatasetV2(graph, node, nodes_to_delete, &shuffle_op_name,
456 &buffer_size_node, &seed_generator_node));
457 }
458 if (shuffle_op_name.empty()) {
459 TF_RETURN_IF_ERROR(RemoveShuffleDatasetV3(
460 graph, node, nodes_to_delete, &shuffle_op_name, &buffer_size_node,
461 &seed_node, &seed2_node, &seed_generator_node,
462 &reshuffle_each_iteration));
463 }
464
465 if (shuffle_op_name == kShuffleDatasetOpName) {
466 TF_RETURN_IF_ERROR(AddShuffleDataset(graph, node, buffer_size_node,
467 seed_node, seed2_node,
468 reshuffle_each_iteration));
469 } else if (shuffle_op_name == kShuffleDatasetV2OpName) {
470 TF_RETURN_IF_ERROR(AddShuffleDatasetV2(graph, node, buffer_size_node,
471 seed_generator_node));
472 } else if (shuffle_op_name == kShuffleDatasetV3OpName) {
473 TF_RETURN_IF_ERROR(AddShuffleDatasetV3(
474 graph, node, buffer_size_node, seed_node, seed2_node,
475 seed_generator_node, reshuffle_each_iteration));
476 }
477
478 return OkStatus();
479 }
480
FindFuncAndTensorSliceDataset(const NodeDef * node,int64_t num_workers,int64_t index,FunctionLibraryDefinition * flib,MutableGraphView * graph,absl::flat_hash_set<string> * nodes_to_delete)481 const NodeDef* FindFuncAndTensorSliceDataset(
482 const NodeDef* node, int64_t num_workers, int64_t index,
483 FunctionLibraryDefinition* flib, MutableGraphView* graph,
484 absl::flat_hash_set<string>* nodes_to_delete) {
485 if (IsDatasetNodeOfType(*node, kFuncDatasetOps)) {
486 const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
487 if (input_node->op() == kTensorSliceDatasetOpName ||
488 input_node->op() == kTensorDatasetOpName) {
489 const NodeDef* next_input_node =
490 graph_utils::GetInputNode(*input_node, *graph, 0);
491 if (next_input_node->op() == kPlaceholderOpName) {
492 return node;
493 }
494 }
495 }
496
497 if (!IsDatasetNodeOfType(*node, kPassThroughOps)) {
498 return nullptr;
499 }
500
501 // Sometimes there are other nodes between the last InterleaveDataset and the
502 // second to last FlatMapDataset, so we need to skip over those.
503 const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
504 return FindFuncAndTensorSliceDataset(input_node, num_workers, index, flib,
505 graph, nodes_to_delete);
506 }
507
508 enum class DropRemainderValue { kUnknown, kTrue, kFalse };
509
GetDropRemainder(const MutableGraphView & graph,const NodeDef & batch_node)510 DropRemainderValue GetDropRemainder(const MutableGraphView& graph,
511 const NodeDef& batch_node) {
512 const NodeDef* drop_remainder = nullptr;
513 if (batch_node.op() == kBatchDatasetOpName ||
514 batch_node.op() == kBatchDatasetV2OpName) {
515 drop_remainder = graph.GetNode(batch_node.input(2));
516 } else if (batch_node.op() == kParallelBatchDatasetOpName) {
517 drop_remainder = graph.GetNode(batch_node.input(3));
518 } else if (batch_node.op() == kMapAndBatchDatasetOpName) {
519 int drop_remainder_index =
520 3 + batch_node.attr().at("Targuments").list().shape_size();
521 if (drop_remainder_index >= batch_node.input_size()) {
522 LOG(ERROR) << "Fail to find the drop_remainder of op: "
523 << batch_node.DebugString();
524 return DropRemainderValue::kUnknown;
525 }
526 drop_remainder = graph.GetNode(batch_node.input(drop_remainder_index));
527 } else {
528 LOG(ERROR) << "Expect a batch node but get " << batch_node.DebugString();
529 return DropRemainderValue::kUnknown;
530 }
531 if (!IsConstant(*drop_remainder)) {
532 return DropRemainderValue::kUnknown;
533 }
534 bool drop_remainder_value;
535 if (!GetNodeAttr(*drop_remainder, "value", &drop_remainder_value).ok()) {
536 return DropRemainderValue::kUnknown;
537 }
538 return drop_remainder_value ? DropRemainderValue::kTrue
539 : DropRemainderValue::kFalse;
540 }
541
RecursivelyHandleOp(const NodeDef & node,int64_t num_workers,int64_t index,FunctionLibraryDefinition * flib,MutableGraphView * graph,absl::flat_hash_set<string> * nodes_to_delete)542 Status RecursivelyHandleOp(const NodeDef& node, int64_t num_workers,
543 int64_t index, FunctionLibraryDefinition* flib,
544 MutableGraphView* graph,
545 absl::flat_hash_set<string>* nodes_to_delete) {
546 if (node.op() == kAssertCardinalityDatasetOpName) {
547 LOG(WARNING) << "The `assert_cardinality` transformation is currently not "
548 "handled by the auto-shard rewrite and will be removed.";
549 nodes_to_delete->insert(node.name());
550 TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
551 const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
552 return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
553 nodes_to_delete);
554 }
555
556 if (IsDatasetNodeOfType(node, kUnshardableSourceDatasetOps)) {
557 return errors::NotFound("Found an unshardable source dataset: ",
558 node.DebugString());
559 }
560
561 if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) {
562 for (int i = 0; i < node.input_size(); ++i) {
563 const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, i);
564 TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers, index,
565 flib, graph, nodes_to_delete));
566 }
567 return OkStatus();
568 }
569
570 // This handles the case for the following subgraph:
571 // Placeholder -> TensorSliceDataset -> FlatMapDataset -x->
572 // (other preprocessing datasets) -> InterleaveDataset
573 // and then inserting the shard node immediately after the FlatMapDataset.
574 //
575 // This is used for some training pipelines where a dataset is created with
576 // the following code:
577 //
578 // def make_dataset_pipeline():
579 // file_globs = [...]
580 // datasets = []
581 // for file_glob in file_globs:
582 // datasets.append(Dataset.list_files(file_glob).map(TFRecordReader))
583 // dataset = Dataset.from_tensor_slices(datasets)
584 // dataset = dataset.flat_map(lambda x: x)
585 // dataset = ... # additional preprocessing
586 // dataset = dataset.interleave(lambda x: x, cycle_length=...)
587 // return dataset
588 if (IsDatasetNodeOfType(node, kFuncDatasetOps)) {
589 const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
590 const NodeDef* flat_map_node = FindFuncAndTensorSliceDataset(
591 input_node, num_workers, index, flib, graph, nodes_to_delete);
592
593 if (flat_map_node != nullptr) {
594 auto fanouts = graph->GetFanouts(*flat_map_node, false);
595 // FlatMapDataset should only be the input to one other dataset.
596 if (fanouts.size() == 1) {
597 return ProcessDatasetSourceNode(graph, *fanouts.begin()->node,
598 nodes_to_delete, num_workers, index);
599 }
600 }
601 }
602
603 // This handles the case where a reader Dataset is contained within a
604 // FuncDataset (e.g. FlatMap, ParallelInterleave, etc...). For example:
605 //
606 // dataset = Dataset.list_files("/path/to/data")
607 // dataset = dataset.flat_map(core_readers.TFRecordDataset)
608 //
609 // where the list of files is passed in one-by-one as an argument to the
610 // function in flat_map.
611 if (IsDatasetNodeOfType(node, kFuncDatasetOps) &&
612 ReaderOpInFunction(node, *flib)) {
613 return ProcessDatasetSourceNode(graph, node, nodes_to_delete, num_workers,
614 index);
615 }
616
617 if (IsDatasetNodeOfType(node, kReaderDatasetOps)) {
618 // We reached a reader dataset directly and we try to shard input 0.
619 return ProcessDatasetSourceNode(graph, node, nodes_to_delete, num_workers,
620 index);
621 }
622
623 if (!IsDatasetNodeOfType(node, kPassThroughOps)) {
624 return errors::NotFound(
625 "Did not find a shardable source, walked to ",
626 "a node which is not a dataset: ", node.DebugString(),
627 ". Consider either turning off auto-sharding or switching the "
628 "auto_shard_policy to DATA to shard this dataset. You can do this by "
629 "creating a new `tf.data.Options()` object then setting "
630 "`options.experimental_distribute.auto_shard_policy = "
631 "AutoShardPolicy.DATA` before applying the options object to the "
632 "dataset via `dataset.with_options(options)`.");
633 }
634
635 const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
636 return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
637 nodes_to_delete);
638 }
639
640 // Recursively walk the dataset graph from sink to source, searching for
641 // the first (i.e. closest to the sink) occurrence of a ReaderDataset, such as
642 // CSVDataset, TFRecordDataset, etc. We then insert a ShardDataset op before
643 // that nodes input, so that each worker only reads a subset of files.
644 // Additionally, we remove sources of randomness (e.g. ShuffleDataset) that
645 // occur upstream of the ShardDataset transformation to ensure that sharding
646 // returns a sensible result.
ShardByFile(const NodeDef & sink_node,int64_t num_workers,int64_t index,FunctionLibraryDefinition * flib,MutableGraphView * graph)647 Status ShardByFile(const NodeDef& sink_node, int64_t num_workers, int64_t index,
648 FunctionLibraryDefinition* flib, MutableGraphView* graph) {
649 absl::flat_hash_set<string> nodes_to_delete;
650 TF_RETURN_IF_ERROR(RecursivelyHandleOp(sink_node, num_workers, index, flib,
651 graph, &nodes_to_delete));
652 return graph->DeleteNodes(nodes_to_delete);
653 }
654
RewriteRebatchV2ToV1(const NodeDef & sink_node,int64_t num_replicas,MutableGraphView * graph)655 Status RewriteRebatchV2ToV1(const NodeDef& sink_node, int64_t num_replicas,
656 MutableGraphView* graph) {
657 // The final node before AutoShardDataset is RebatchDataset.
658 // This is always the case as RebatchDataset and AutoShardDataset are internal
659 // APIs used directly by tf.distribute's input_lib. As such, instead of
660 // walking the entire dataset graph, we can walk up directly from the
661 // sink_node to get the RebatchDataset.
662 NodeDef* input_node = graph_utils::GetInputNode(sink_node, *graph);
663 if (input_node->op() != kRebatchDatasetV2OpName) {
664 return OkStatus();
665 }
666
667 NodeDef* rebatch_node = input_node;
668 // Update RebatchDatasetV2 in place. Since Rebatch is an internal API, no
669 // other nodes should have it as an input.
670 rebatch_node->set_op(kRebatchDatasetOpName);
671 // Delete the `batch_sizes` and `drop_remainder` input.
672 rebatch_node->mutable_input()->DeleteSubrange(/*start=*/1, /*num=*/2);
673 // Add the `num_replicas` input.
674 if (num_replicas < 1) {
675 return errors::InvalidArgument(
676 "Cannot rewrite RebatchDatasetV2 to legacy RebatchDataset with invalid "
677 "num_replicas argument. `num_replicas` is ",
678 num_replicas, ", but expected to be >= 1.");
679 }
680 auto num_replicas_node = graph_utils::AddScalarConstNode(num_replicas, graph);
681 rebatch_node->add_input(num_replicas_node->name());
682
683 // Set `use_fallback` attr. This attr is not used anywhere, so its value
684 // does not matter
685 (*rebatch_node->mutable_attr())["use_fallback"].set_b(true);
686
687 // Update the output_shapes attr to set all its batch dimensions to -1
688 // (unknown).
689 auto* shapes_attr =
690 gtl::FindOrNull(*rebatch_node->mutable_attr(), "output_shapes");
691 if (shapes_attr == nullptr) {
692 return errors::InvalidArgument(
693 "Cannot rewrite RebatchDatasetV2 with missing `output_shapes` attr.");
694 }
695 for (int i = 0; i < shapes_attr->list().shape_size(); ++i) {
696 auto* shape = shapes_attr->mutable_list()->mutable_shape(i);
697 if (shape->unknown_rank()) continue;
698 shape->mutable_dim(0)->set_size(-1);
699 }
700
701 return OkStatus();
702 }
703
ShardByData(const NodeDef & sink_node,int64_t num_workers,int64_t index,int64_t num_replicas,MutableGraphView * graph)704 Status ShardByData(const NodeDef& sink_node, int64_t num_workers, int64_t index,
705 int64_t num_replicas, MutableGraphView* graph) {
706 const NodeDef* shard_before = &sink_node;
707 // We sometimes insert a PrefetchDataset, OptionsDataset, and FinalizeDataset
708 // at the end of the input pipeline before autosharding. When sharding by
709 // data, we should insert the shard before the these datasets so that the
710 // right number of elements is prefetched.
711 NodeDef* input_node = graph_utils::GetInputNode(sink_node, *graph);
712 while (input_node->op() == kPrefetchDatasetOpName ||
713 input_node->op() == kOptionsDatasetOpName ||
714 input_node->op() == kFinalizeDatasetOpName) {
715 shard_before = input_node;
716 input_node = graph_utils::GetInputNode(*input_node, *graph);
717 }
718 // Sharding by data only works with legacy RebatchDataset. As such, we rewrite
719 // all instances of RebatchDatasetV2 to RebatchDataset.
720 TF_RETURN_IF_ERROR(RewriteRebatchV2ToV1(*shard_before, num_replicas, graph));
721 return AddShardNode(graph, *shard_before, num_workers, index);
722 }
723
724 // Searches the dataset graph replacing any occurrence of `shard(1, 0)` with
725 // `shard(num_workers, index)`.
ShardByHint(const NodeDef & sink_node,int64_t num_workers,int64_t index,int64_t num_replicas,MutableGraphView * graph)726 Status ShardByHint(const NodeDef& sink_node, int64_t num_workers, int64_t index,
727 int64_t num_replicas, MutableGraphView* graph) {
728 auto get_shard_node = [graph](const NodeDef& node) -> const NodeDef* {
729 if (node.op() != kShardDatasetOpName) return nullptr;
730 auto num_workers_node = graph->GetNode(node.input(1));
731 if (num_workers_node->op() != kConstOpName) return nullptr;
732 if (num_workers_node->attr().at("value").tensor().int64_val(0) !=
733 tensorflow::data::kShardHint)
734 return nullptr;
735 return &node;
736 };
737
738 auto* num_workers_node =
739 graph_utils::AddScalarConstNode(static_cast<int64_t>(num_workers), graph);
740 auto* worker_index_node =
741 graph_utils::AddScalarConstNode(static_cast<int64_t>(index), graph);
742
743 for (const NodeDef& node : graph->graph()->node()) {
744 const NodeDef* shard_node = get_shard_node(node);
745 if (!shard_node) continue;
746 auto mutable_node = graph->GetNode(shard_node->name());
747 *mutable_node->mutable_input(1) = num_workers_node->name();
748 *mutable_node->mutable_input(2) = worker_index_node->name();
749 // Ensure that each shard will have at least one element.
750 (*(mutable_node->mutable_attr()))[data::ShardDatasetOp::kRequireNonEmpty]
751 .set_b(true);
752 }
753 return OkStatus();
754 }
755
ApplyAutoShard(const NodeDef & sink_node,int64_t num_workers,int64_t index,AutoShardPolicy policy,int64_t num_replicas,MutableGraphView * graph,AutoShardPolicy * policy_applied)756 Status ApplyAutoShard(const NodeDef& sink_node, int64_t num_workers,
757 int64_t index, AutoShardPolicy policy,
758 int64_t num_replicas, MutableGraphView* graph,
759 AutoShardPolicy* policy_applied) {
760 *policy_applied = policy;
761 FunctionLibraryDefinition flib(OpRegistry::Global(),
762 graph->graph()->library());
763 switch (policy) {
764 case AutoShardPolicy::OFF:
765 return OkStatus();
766 case AutoShardPolicy::FILE:
767 return ShardByFile(sink_node, num_workers, index, &flib, graph);
768 case AutoShardPolicy::DATA:
769 return ShardByData(sink_node, num_workers, index, num_replicas, graph);
770 case AutoShardPolicy::HINT:
771 return ShardByHint(sink_node, num_workers, index, num_replicas, graph);
772 case AutoShardPolicy::AUTO:
773 default:
774 Status s = ShardByFile(sink_node, num_workers, index, &flib, graph);
775 if (errors::IsNotFound(s)) {
776 LOG(WARNING) << "AUTO sharding policy will apply DATA sharding policy "
777 "as it failed to apply FILE sharding policy because of "
778 "the following reason: "
779 << s.error_message();
780 *policy_applied = AutoShardPolicy::DATA;
781 return ShardByData(sink_node, num_workers, index, num_replicas, graph);
782 }
783 *policy_applied = AutoShardPolicy::FILE;
784 return s;
785 }
786 }
787
OptimizeGraph(const GrapplerItem & item,int64_t num_workers,int64_t index,AutoShardPolicy policy,int64_t num_replicas,GraphDef * output)788 Status OptimizeGraph(const GrapplerItem& item, int64_t num_workers,
789 int64_t index, AutoShardPolicy policy,
790 int64_t num_replicas, GraphDef* output) {
791 *output = item.graph;
792 MutableGraphView graph(output);
793 NodeDef* sink_node;
794 TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node));
795
796 // id for telemetry purpose. item.id is always the same so we use the address
797 // of the output as id.
798 string id = strings::StrCat(reinterpret_cast<uint64>(output));
799 // Only record metrics on the first shard to avoid duplication.
800 if (index == 0) {
801 std::vector<std::string> ineligible_reason;
802 bool is_eligible = internal::IsEligibleRewriteBatchSize(*sink_node, graph,
803 &ineligible_reason);
804 metrics::RecordTFDataAutoShardRewriteBatchSize(is_eligible,
805 ineligible_reason);
806 }
807
808 AutoShardPolicy policy_applied = policy;
809 if (policy != AutoShardPolicy::OFF &&
810 !(policy == AutoShardPolicy::FILE && num_workers == 1 && index == 0)) {
811 TF_RETURN_IF_ERROR(ApplyAutoShard(*sink_node, num_workers, index, policy,
812 num_replicas, &graph, &policy_applied));
813 }
814 // Only record metrics on the first shard to avoid duplication.
815 if (index == 0) {
816 metrics::RecordTFDataAutoShard(id, policy_applied, num_workers,
817 num_replicas);
818 }
819 return OkStatus();
820 }
821
822 } // anonymous namespace
823
824 namespace internal {
IsEligibleRewriteBatchSize(const NodeDef & sink_node,const MutableGraphView & graph,std::vector<std::string> * ineligible_reason)825 bool IsEligibleRewriteBatchSize(const NodeDef& sink_node,
826 const MutableGraphView& graph,
827 std::vector<std::string>* ineligible_reason) {
828 ineligible_reason->clear();
829 NodeDef* input_node = graph_utils::GetInputNode(sink_node, graph);
830 // We always traverse the graph until we arrive at a batch node to collect all
831 // ineligible reasons;
832 while (input_node != nullptr) {
833 // 1. Skip RebatchDataset and the MapDataset immediately before it. That map
834 // is added by tf.data Python code.
835 if (input_node->op() == kRebatchDatasetOpName ||
836 input_node->op() == kRebatchDatasetV2OpName) {
837 input_node = graph_utils::GetInputNode(*input_node, graph);
838 if (input_node == nullptr || input_node->op() != kMapDatasetOpName) {
839 ineligible_reason->push_back("BUG_NO_MAP_BEFORE_REBATCH");
840 return false;
841 }
842 input_node = graph_utils::GetInputNode(*input_node, graph);
843 continue;
844 }
845 // 2. If the node is insensitive to the batch size of the input, we continue
846 // looking at the input dataset of the node.
847 if (IsDatasetNodeOfType(*input_node, kBatchSizeOrthogonalDatasetOps)) {
848 input_node = graph_utils::GetInputNode(*input_node, graph);
849 continue;
850 }
851 // 3. We arrive at a batch node. Examine its drop_remainder input and
852 // cardinality to determine eligibility.
853 if (IsDatasetNodeOfType(*input_node, kBatchDatasetOps)) {
854 DropRemainderValue drop_remainder = GetDropRemainder(graph, *input_node);
855 int64_t cardinality = data::kUnknownCardinality;
856 bool cardinality_available = true;
857 AttrSlice attrs(*input_node);
858 if (!TryGetNodeAttr(attrs, data::kCardinalityAttrForRewrite,
859 &cardinality)) {
860 cardinality_available = false;
861 }
862
863 if (drop_remainder == DropRemainderValue::kFalse ||
864 (cardinality_available &&
865 cardinality == data::kInfiniteCardinality)) {
866 return ineligible_reason->empty();
867 } else {
868 if (drop_remainder == DropRemainderValue::kUnknown) {
869 ineligible_reason->push_back("BATCH_DROP_REMAINDER_UNKNOWN");
870 }
871 if (!cardinality_available) {
872 ineligible_reason->push_back("BATCH_CARDINALITY_NOT_AVAILABLE");
873 }
874 if (drop_remainder == DropRemainderValue::kTrue &&
875 cardinality_available &&
876 cardinality != data::kInfiniteCardinality) {
877 ineligible_reason->push_back("BATCH_DROP_REMAINDER_NOT_INFINITE");
878 }
879 return false;
880 }
881 }
882 // 4. We encountered other nodes before arriving at a batch node. We don't
883 // know whether this node is sensitive to the batch size or not and we err
884 // on the safe side.
885 ineligible_reason->push_back(
886 strings::StrCat("OP_NOT_SUPPORTED_", input_node->op()));
887 input_node = graph_utils::GetInputNode(*input_node, graph);
888 }
889 // If we don't find a batch node, only records BATCH_NOT_FOUND as the reason.
890 ineligible_reason->clear();
891 ineligible_reason->push_back("BATCH_NOT_FOUND");
892 return false;
893 }
894 } // namespace internal
895
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)896 Status AutoShard::Init(
897 const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
898 if (!config) return errors::InvalidArgument("RewriterConfig not found.");
899
900 if ((config->parameter_map().find(kNumWorkersAttrName) ==
901 config->parameter_map().end())) {
902 return errors::InvalidArgument(kNumWorkersAttrName, " parameter missing.");
903 }
904
905 if ((config->parameter_map().find(kIndexAttrName) ==
906 config->parameter_map().end())) {
907 return errors::InvalidArgument(kIndexAttrName, " parameter missing.");
908 }
909
910 num_workers_ = config->parameter_map().at(kNumWorkersAttrName).i();
911 index_ = config->parameter_map().at(kIndexAttrName).i();
912 auto_shard_policy_ =
913 AutoShardPolicy(config->parameter_map().at(kAutoShardPolicyAttrName).i());
914 num_replicas_ = config->parameter_map().at(kNumReplicasAttrName).i();
915
916 if (auto_shard_policy_ != AutoShardPolicy::OFF &&
917 auto_shard_policy_ != AutoShardPolicy::AUTO &&
918 auto_shard_policy_ != AutoShardPolicy::DATA &&
919 auto_shard_policy_ != AutoShardPolicy::FILE &&
920 auto_shard_policy_ != AutoShardPolicy::HINT) {
921 return errors::InvalidArgument(kAutoShardPolicyAttrName, " is invalid.");
922 }
923
924 if (num_workers_ < 1) {
925 return errors::InvalidArgument(kNumWorkersAttrName,
926 " should be >= 1, currently ", num_workers_);
927 }
928
929 if (index_ < 0 || index_ >= num_workers_) {
930 return errors::InvalidArgument(kIndexAttrName, " should be >= 0 and < ",
931 num_workers_, ", currently ", index_);
932 }
933
934 if (num_replicas_ < 0) {
935 return errors::InvalidArgument(kNumReplicasAttrName, " should be >= 0");
936 }
937
938 return OkStatus();
939 }
940
OptimizeAndCollectStats(Cluster * cluster,const GrapplerItem & item,GraphDef * output,OptimizationStats * stats)941 Status AutoShard::OptimizeAndCollectStats(Cluster* cluster,
942 const GrapplerItem& item,
943 GraphDef* output,
944 OptimizationStats* stats) {
945 *output = item.graph;
946 TF_RETURN_IF_ERROR(OptimizeGraph(item, num_workers_, index_,
947 auto_shard_policy_, num_replicas_, output));
948 stats->num_changes++;
949 return OkStatus();
950 }
951
952 REGISTER_GRAPH_OPTIMIZER_AS(AutoShard, "tf_auto_shard");
953
954 } // namespace grappler
955 } // namespace tensorflow
956