xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/data/auto_shard.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/auto_shard.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/dataset.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/metrics.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/grappler/clusters/cluster.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/mutable_graph_view.h"
33 #include "tensorflow/core/grappler/op_types.h"
34 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
35 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
36 #include "tensorflow/core/grappler/utils/functions.h"
37 #include "tensorflow/core/kernels/data/shard_dataset_op.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/platform/errors.h"
40 
41 namespace tensorflow {
42 namespace grappler {
43 namespace {
44 
45 using tensorflow::data::AutoShardPolicy;
46 
47 constexpr char kAssertCardinalityDatasetOpName[] = "AssertCardinalityDataset";
48 constexpr char kBatchDatasetOpName[] = "BatchDataset";
49 constexpr char kBatchDatasetV2OpName[] = "BatchDatasetV2";
50 constexpr char kMapAndBatchDatasetOpName[] = "MapAndBatchDataset";
51 constexpr char kMapDatasetOpName[] = "MapDataset";
52 constexpr char kShardDatasetOpName[] = "ShardDataset";
53 constexpr char kShuffleDatasetOpName[] = "ShuffleDataset";
54 constexpr char kShuffleDatasetV2OpName[] = "ShuffleDatasetV2";
55 constexpr char kShuffleDatasetV3OpName[] = "ShuffleDatasetV3";
56 constexpr char kParallelBatchDatasetOpName[] = "ParallelBatchDataset";
57 constexpr char kPrefetchDatasetOpName[] = "PrefetchDataset";
58 constexpr char kFinalizeDatasetOpName[] = "FinalizeDataset";
59 constexpr char kOptionsDatasetOpName[] = "OptionsDataset";
60 constexpr char kRebatchDatasetOpName[] = "RebatchDataset";
61 constexpr char kRebatchDatasetV2OpName[] = "RebatchDatasetV2";
62 constexpr char kTensorDatasetOpName[] = "TensorDataset";
63 constexpr char kTensorSliceDatasetOpName[] = "TensorSliceDataset";
64 constexpr char kPlaceholderOpName[] = "Placeholder";
65 constexpr char kConstOpName[] = "Const";
66 
67 constexpr char kNumWorkersAttrName[] = "num_workers";
68 constexpr char kNumReplicasAttrName[] = "num_replicas";
69 constexpr char kIndexAttrName[] = "index";
70 constexpr char kAutoShardPolicyAttrName[] = "auto_shard_policy";
71 constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
72 constexpr char kOutputShapes[] = "output_shapes";
73 constexpr char kOutputTypes[] = "output_types";
74 
75 // clang-format off
76 constexpr std::array<const char*, 5> kReaderDatasetOps = {
77     "FixedLengthRecordDataset",
78     "RecordIODataset",
79     "SSTableDataset",
80     "TextLineDataset",
81     "TFRecordDataset"
82 };
83 
84 constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
85     "ConcatenateDataset",
86     "ZipDataset"
87 };
88 
89 constexpr std::array<const char*, 31> kPassThroughOps = {
90     "_Retval",
91     "AssertNextDataset",
92     "BatchDataset",
93     "CacheDataset",
94     "ExperimentalMapAndBatchDataset",
95     "ExperimentalParseExampleDataset",
96     "ExperimentalRebatchDataset",
97     "FilterDataset",
98     "FinalizeDataset",
99     "Identity",
100     "MapAndBatchDataset",
101     "MapDataset",
102     "MaxIntraOpParallelismDataset",
103     "ModelDataset",
104     "OptimizeDataset",
105     "OptionsDataset",
106     "PaddedBatchDataset",
107     "ParallelBatchDataset",
108     "ParallelMapDataset",
109     "ParseExampleDataset",
110     "PrefetchDataset",
111     "PrivateThreadPoolDataset",
112     "ReduceDataset",
113     "RebatchDataset",
114     "RepeatDataset",
115     "ShardDataset",
116     "ShuffleAndRepeatDataset",
117     "ShuffleDataset",
118     "SkipDataset",
119     "TakeDataset",
120     "WindowDataset",
121 };
122 
123 // TODO(frankchn): Process functions within kFuncDatasetOps as well.
124 constexpr std::array<const char*, 5> kFuncDatasetOps = {
125     "ExperimentalParallelInterleaveDataset",
126     "FlatMapDataset",
127     "InterleaveDataset",
128     "LegacyParallelInterleaveDataset",
129     "ParallelInterleaveDataset",
130 };
131 
132 constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {
133     "GeneratorDataset",
134     "RangeDataset",
135     "SparseTensorsSliceDataset",
136     "TensorDataset",
137     "TensorSliceDataset",
138 };
139 
140 // The semantics of these ops are not affected by the change of the batch
141 // size. There are three categories:
142 //   1. The op doesn't change the elements of the dataset, e.g. CacheDataset and
143 //   all ops that sets options.
144 //   2. The op is dataset-element-wise transformation which is orthogonoal to
145 //   the batch size, e.g. ParseExampleDataset.
146 //   3. RebatchDataset. This is a special case. RebatchDataset is added by
147 //   tf.distribute at the end of the input pipeline and will be specially
148 //   handled.
149 constexpr std::array<const char*, 20> kBatchSizeOrthogonalDatasetOps = {
150     "AssertCardinalityDataset",
151     "AssertNextDataset",
152     "BytesProducedStatsDataset",
153     "CacheDataset",
154     "FinalizeDataset",
155     "Identity",
156     "LatencyStatsDataset",
157     "MaxIntraOpParallelismDataset",
158     "ModelDataset",
159     "NonSerializableDataset",
160     "OptimizeDataset",
161     "OptionsDataset",
162     "ParseExampleDataset",
163     "PrefetchDataset",
164     "PrivateThreadPoolDataset",
165     "RebatchDataset",
166     "RepeatDataset",
167     "SetStatsAggregatorDataset",
168     "SleepDataset",
169     "ThreadPoolDataset",
170 };
171 
172 constexpr std::array<const char*, 3> kBatchDatasetOps = {
173     kBatchDatasetOpName,
174     kMapAndBatchDatasetOpName,
175     kParallelBatchDatasetOpName,
176 };
177 
178 // clang-format on
179 
180 Status OptimizeGraph(const GrapplerItem& item, int64_t num_workers,
181                      int64_t index, AutoShardPolicy policy,
182                      int64_t num_replicas, GraphDef* output,
183                      AutoShardPolicy* policy_applied);
184 
185 template <std::size_t SIZE>
IsDatasetNodeOfType(const NodeDef & node,const std::array<const char *,SIZE> & arr)186 bool IsDatasetNodeOfType(const NodeDef& node,
187                          const std::array<const char*, SIZE>& arr) {
188   for (const auto& dataset_op_name : arr) {
189     if (tensorflow::data::MatchesAnyVersion(/*op_prefix=*/dataset_op_name,
190                                             /*op_to_match=*/node.op())) {
191       return true;
192     }
193   }
194   return false;
195 }
196 
197 // Adds a ShardDataset node before `add_before`.
AddShardNode(MutableGraphView * graph,const NodeDef & add_before,int64_t num_workers,int64_t index)198 Status AddShardNode(MutableGraphView* graph, const NodeDef& add_before,
199                     int64_t num_workers, int64_t index) {
200   NodeDef new_node;
201   new_node.set_op(kShardDatasetOpName);
202   graph_utils::SetUniqueGraphNodeName(kShardDatasetOpName, graph->graph(),
203                                       &new_node);
204 
205   // Construct argument nodes
206   NodeDef* num_shards_node =
207       graph_utils::AddScalarConstNode<int64_t>(num_workers, graph);
208   NodeDef* index_node = graph_utils::AddScalarConstNode<int64_t>(index, graph);
209 
210   // Add inputs to new node
211   new_node.add_input(add_before.input(0));
212   new_node.add_input(num_shards_node->name());
213   new_node.add_input(index_node->name());
214 
215   // Ensure that each shard will have at least one element.
216   (*(new_node.mutable_attr()))[data::ShardDatasetOp::kRequireNonEmpty].set_b(
217       true);
218 
219   // Add shapes and other attributes
220   NodeDef* add_after = graph->GetNode(add_before.input(0));
221 
222   if (absl::StrContains(add_after->op(), "Dataset")) {
223     // We still may or may not have the right attributes because Datasets like
224     // TFRecordDataset doesn't have a output type or shape, and by default we
225     // set them to DT_STRING and an unknown shape.
226     if (add_after->attr().count(kOutputShapes) > 0) {
227       graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
228     } else {
229       tensorflow::TensorShapeProto* shape =
230           (*(new_node.mutable_attr()))[kOutputShapes]
231               .mutable_list()
232               ->add_shape();
233       shape->set_unknown_rank(true);
234     }
235 
236     if (add_after->attr().count(kOutputTypes) > 0) {
237       graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
238     } else if (add_after->attr().count("Toutput_types") > 0) {
239       (*(new_node.mutable_attr()))[kOutputTypes] =
240           add_after->attr().at("Toutput_types");
241     } else {
242       (*(new_node.mutable_attr()))[kOutputTypes].mutable_list()->add_type(
243           tensorflow::DataType::DT_STRING);
244     }
245   } else {
246     // TODO(frankchn): Make this work for datasets where input(0) is a Const,
247     // and we need to shard the Const.
248     // This is probably not a dataset, so we bail because we can't infer the
249     // output types and shape.
250     return errors::NotFound(
251         "Unable to shard this input. You may need to wrap the inputs to your "
252         "reader dataset in a TensorSliceDataset. Input node is ",
253         add_after->DebugString());
254   }
255 
256   // Add new node into graph and update edges
257   NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
258   TF_RETURN_IF_ERROR(
259       graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
260 
261   return OkStatus();
262 }
263 
AddShuffleDataset(MutableGraphView * graph,const NodeDef & add_before,const string & buffer_size_node,const string & seed_node,const string & seed2_node,bool reshuffle_each_iteration)264 Status AddShuffleDataset(MutableGraphView* graph, const NodeDef& add_before,
265                          const string& buffer_size_node,
266                          const string& seed_node, const string& seed2_node,
267                          bool reshuffle_each_iteration) {
268   NodeDef* add_after = graph->GetNode(add_before.input(0));
269   NodeDef new_node;
270   new_node.set_op(kShuffleDatasetOpName);
271   graph_utils::SetUniqueGraphNodeName(kShuffleDatasetOpName, graph->graph(),
272                                       &new_node);
273 
274   new_node.add_input(add_before.input(0));
275   new_node.add_input(buffer_size_node);
276   new_node.add_input(seed_node);
277   new_node.add_input(seed2_node);
278 
279   graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
280   graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
281 
282   AttrValue reshuffle_attr;
283   reshuffle_attr.set_b(reshuffle_each_iteration);
284   (*new_node.mutable_attr())[kReshuffleEachIteration] = reshuffle_attr;
285 
286   NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
287 
288   TF_RETURN_IF_ERROR(
289       graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
290   return OkStatus();
291 }
292 
AddShuffleDatasetV2(MutableGraphView * graph,const NodeDef & add_before,const string & buffer_size_node,const string & seed_generator_node)293 Status AddShuffleDatasetV2(MutableGraphView* graph, const NodeDef& add_before,
294                            const string& buffer_size_node,
295                            const string& seed_generator_node) {
296   NodeDef* add_after = graph->GetNode(add_before.input(0));
297   NodeDef new_node;
298   new_node.set_op(kShuffleDatasetV2OpName);
299   graph_utils::SetUniqueGraphNodeName(kShuffleDatasetV2OpName, graph->graph(),
300                                       &new_node);
301 
302   new_node.add_input(add_before.input(0));
303   new_node.add_input(buffer_size_node);
304   new_node.add_input(seed_generator_node);
305 
306   graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
307   graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
308 
309   NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
310 
311   TF_RETURN_IF_ERROR(
312       graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
313   return OkStatus();
314 }
315 
AddShuffleDatasetV3(MutableGraphView * graph,const NodeDef & add_before,const string & buffer_size_node,const string & seed_node,const string & seed2_node,const string & seed_generator_node,bool reshuffle_each_iteration)316 Status AddShuffleDatasetV3(MutableGraphView* graph, const NodeDef& add_before,
317                            const string& buffer_size_node,
318                            const string& seed_node, const string& seed2_node,
319                            const string& seed_generator_node,
320                            bool reshuffle_each_iteration) {
321   NodeDef* add_after = graph->GetNode(add_before.input(0));
322   NodeDef new_node;
323   new_node.set_op(kShuffleDatasetV3OpName);
324   graph_utils::SetUniqueGraphNodeName(kShuffleDatasetV3OpName, graph->graph(),
325                                       &new_node);
326 
327   new_node.add_input(add_before.input(0));
328   new_node.add_input(buffer_size_node);
329   new_node.add_input(seed_node);
330   new_node.add_input(seed2_node);
331   new_node.add_input(seed_generator_node);
332 
333   graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
334   graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
335 
336   AttrValue reshuffle_attr;
337   reshuffle_attr.set_b(reshuffle_each_iteration);
338   (*new_node.mutable_attr())[kReshuffleEachIteration] = reshuffle_attr;
339 
340   NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
341 
342   TF_RETURN_IF_ERROR(
343       graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
344   return OkStatus();
345 }
346 
ReaderOpInFunction(const NodeDef & node,const FunctionLibraryDefinition & flib)347 bool ReaderOpInFunction(const NodeDef& node,
348                         const FunctionLibraryDefinition& flib) {
349   const FunctionDef* func = flib.Find(node.attr().at("f").func().name());
350   for (int i = 0; i < func->node_def_size(); i++) {
351     NodeDef node_in_func = func->node_def(i);
352     if (IsDatasetNodeOfType(node_in_func, kReaderDatasetOps) &&
353         node_in_func.input_size() > 0 &&
354         absl::StartsWith(node_in_func.input(0), "args_0")) {
355       return true;
356     }
357     if (IsDatasetNodeOfType(func->node_def(i), kFuncDatasetOps) &&
358         ReaderOpInFunction(func->node_def(i), flib)) {
359       return true;
360     }
361   }
362   return false;
363 }
364 
RemoveShuffleDataset(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,string * op_name,string * buffer_size_node,string * seed_node,string * seed2_node,bool * reshuffle_each_iteration)365 Status RemoveShuffleDataset(MutableGraphView* graph, const NodeDef& node,
366                             absl::flat_hash_set<string>* nodes_to_delete,
367                             string* op_name, string* buffer_size_node,
368                             string* seed_node, string* seed2_node,
369                             bool* reshuffle_each_iteration) {
370   if (node.op() == kShuffleDatasetOpName) {
371     *op_name = node.op();
372     *buffer_size_node = node.input(1);
373     *seed_node = node.input(2);
374     *seed2_node = node.input(3);
375     *reshuffle_each_iteration = node.attr().at(kReshuffleEachIteration).b();
376     TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
377     nodes_to_delete->insert(node.name());
378   }
379 
380   for (const auto& fanin : graph->GetFanins(node, true)) {
381     TF_RETURN_IF_ERROR(RemoveShuffleDataset(
382         graph, *fanin.node, nodes_to_delete, op_name, buffer_size_node,
383         seed_node, seed2_node, reshuffle_each_iteration));
384   }
385 
386   // TODO(frankchn): Traverse functions too.
387   return OkStatus();
388 }
389 
RemoveShuffleDatasetV2(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,string * op_name,string * buffer_size_node,string * seed_generator_node)390 Status RemoveShuffleDatasetV2(MutableGraphView* graph, const NodeDef& node,
391                               absl::flat_hash_set<string>* nodes_to_delete,
392                               string* op_name, string* buffer_size_node,
393                               string* seed_generator_node) {
394   if (node.op() == kShuffleDatasetV2OpName) {
395     *op_name = node.op();
396     *buffer_size_node = node.input(1);
397     *seed_generator_node = node.input(2);
398     TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
399     nodes_to_delete->insert(node.name());
400   }
401 
402   for (const auto& fanin : graph->GetFanins(node, true)) {
403     TF_RETURN_IF_ERROR(
404         RemoveShuffleDatasetV2(graph, *fanin.node, nodes_to_delete, op_name,
405                                buffer_size_node, seed_generator_node));
406   }
407 
408   // TODO(frankchn): Traverse functions too.
409   return OkStatus();
410 }
411 
RemoveShuffleDatasetV3(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,string * op_name,string * buffer_size_node,string * seed_node,string * seed2_node,string * seed_generator_node,bool * reshuffle_each_iteration)412 Status RemoveShuffleDatasetV3(MutableGraphView* graph, const NodeDef& node,
413                               absl::flat_hash_set<string>* nodes_to_delete,
414                               string* op_name, string* buffer_size_node,
415                               string* seed_node, string* seed2_node,
416                               string* seed_generator_node,
417                               bool* reshuffle_each_iteration) {
418   if (node.op() == kShuffleDatasetV3OpName) {
419     *op_name = node.op();
420     *buffer_size_node = node.input(1);
421     *seed_node = node.input(2);
422     *seed2_node = node.input(3);
423     *seed_generator_node = node.input(4);
424     *reshuffle_each_iteration = node.attr().at(kReshuffleEachIteration).b();
425     TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
426     nodes_to_delete->insert(node.name());
427   }
428 
429   for (const auto& fanin : graph->GetFanins(node, true)) {
430     TF_RETURN_IF_ERROR(RemoveShuffleDatasetV3(
431         graph, *fanin.node, nodes_to_delete, op_name, buffer_size_node,
432         seed_node, seed2_node, seed_generator_node, reshuffle_each_iteration));
433   }
434 
435   // TODO(frankchn): Traverse functions too.
436   return OkStatus();
437 }
438 
ProcessDatasetSourceNode(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,int64_t num_workers,int64_t index)439 Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node,
440                                 absl::flat_hash_set<string>* nodes_to_delete,
441                                 int64_t num_workers, int64_t index) {
442   string shuffle_op_name = "";
443   string buffer_size_node = "";
444   string seed_node = "";
445   string seed2_node = "";
446   string seed_generator_node = "";
447   bool reshuffle_each_iteration;
448 
449   TF_RETURN_IF_ERROR(AddShardNode(graph, node, num_workers, index));
450   TF_RETURN_IF_ERROR(RemoveShuffleDataset(
451       graph, node, nodes_to_delete, &shuffle_op_name, &buffer_size_node,
452       &seed_node, &seed2_node, &reshuffle_each_iteration));
453   if (shuffle_op_name.empty()) {
454     TF_RETURN_IF_ERROR(
455         RemoveShuffleDatasetV2(graph, node, nodes_to_delete, &shuffle_op_name,
456                                &buffer_size_node, &seed_generator_node));
457   }
458   if (shuffle_op_name.empty()) {
459     TF_RETURN_IF_ERROR(RemoveShuffleDatasetV3(
460         graph, node, nodes_to_delete, &shuffle_op_name, &buffer_size_node,
461         &seed_node, &seed2_node, &seed_generator_node,
462         &reshuffle_each_iteration));
463   }
464 
465   if (shuffle_op_name == kShuffleDatasetOpName) {
466     TF_RETURN_IF_ERROR(AddShuffleDataset(graph, node, buffer_size_node,
467                                          seed_node, seed2_node,
468                                          reshuffle_each_iteration));
469   } else if (shuffle_op_name == kShuffleDatasetV2OpName) {
470     TF_RETURN_IF_ERROR(AddShuffleDatasetV2(graph, node, buffer_size_node,
471                                            seed_generator_node));
472   } else if (shuffle_op_name == kShuffleDatasetV3OpName) {
473     TF_RETURN_IF_ERROR(AddShuffleDatasetV3(
474         graph, node, buffer_size_node, seed_node, seed2_node,
475         seed_generator_node, reshuffle_each_iteration));
476   }
477 
478   return OkStatus();
479 }
480 
FindFuncAndTensorSliceDataset(const NodeDef * node,int64_t num_workers,int64_t index,FunctionLibraryDefinition * flib,MutableGraphView * graph,absl::flat_hash_set<string> * nodes_to_delete)481 const NodeDef* FindFuncAndTensorSliceDataset(
482     const NodeDef* node, int64_t num_workers, int64_t index,
483     FunctionLibraryDefinition* flib, MutableGraphView* graph,
484     absl::flat_hash_set<string>* nodes_to_delete) {
485   if (IsDatasetNodeOfType(*node, kFuncDatasetOps)) {
486     const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
487     if (input_node->op() == kTensorSliceDatasetOpName ||
488         input_node->op() == kTensorDatasetOpName) {
489       const NodeDef* next_input_node =
490           graph_utils::GetInputNode(*input_node, *graph, 0);
491       if (next_input_node->op() == kPlaceholderOpName) {
492         return node;
493       }
494     }
495   }
496 
497   if (!IsDatasetNodeOfType(*node, kPassThroughOps)) {
498     return nullptr;
499   }
500 
501   // Sometimes there are other nodes between the last InterleaveDataset and the
502   // second to last FlatMapDataset, so we need to skip over those.
503   const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
504   return FindFuncAndTensorSliceDataset(input_node, num_workers, index, flib,
505                                        graph, nodes_to_delete);
506 }
507 
508 enum class DropRemainderValue { kUnknown, kTrue, kFalse };
509 
GetDropRemainder(const MutableGraphView & graph,const NodeDef & batch_node)510 DropRemainderValue GetDropRemainder(const MutableGraphView& graph,
511                                     const NodeDef& batch_node) {
512   const NodeDef* drop_remainder = nullptr;
513   if (batch_node.op() == kBatchDatasetOpName ||
514       batch_node.op() == kBatchDatasetV2OpName) {
515     drop_remainder = graph.GetNode(batch_node.input(2));
516   } else if (batch_node.op() == kParallelBatchDatasetOpName) {
517     drop_remainder = graph.GetNode(batch_node.input(3));
518   } else if (batch_node.op() == kMapAndBatchDatasetOpName) {
519     int drop_remainder_index =
520         3 + batch_node.attr().at("Targuments").list().shape_size();
521     if (drop_remainder_index >= batch_node.input_size()) {
522       LOG(ERROR) << "Fail to find the drop_remainder of op: "
523                  << batch_node.DebugString();
524       return DropRemainderValue::kUnknown;
525     }
526     drop_remainder = graph.GetNode(batch_node.input(drop_remainder_index));
527   } else {
528     LOG(ERROR) << "Expect a batch node but get " << batch_node.DebugString();
529     return DropRemainderValue::kUnknown;
530   }
531   if (!IsConstant(*drop_remainder)) {
532     return DropRemainderValue::kUnknown;
533   }
534   bool drop_remainder_value;
535   if (!GetNodeAttr(*drop_remainder, "value", &drop_remainder_value).ok()) {
536     return DropRemainderValue::kUnknown;
537   }
538   return drop_remainder_value ? DropRemainderValue::kTrue
539                               : DropRemainderValue::kFalse;
540 }
541 
RecursivelyHandleOp(const NodeDef & node,int64_t num_workers,int64_t index,FunctionLibraryDefinition * flib,MutableGraphView * graph,absl::flat_hash_set<string> * nodes_to_delete)542 Status RecursivelyHandleOp(const NodeDef& node, int64_t num_workers,
543                            int64_t index, FunctionLibraryDefinition* flib,
544                            MutableGraphView* graph,
545                            absl::flat_hash_set<string>* nodes_to_delete) {
546   if (node.op() == kAssertCardinalityDatasetOpName) {
547     LOG(WARNING) << "The `assert_cardinality` transformation is currently not "
548                     "handled by the auto-shard rewrite and will be removed.";
549     nodes_to_delete->insert(node.name());
550     TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
551     const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
552     return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
553                                nodes_to_delete);
554   }
555 
556   if (IsDatasetNodeOfType(node, kUnshardableSourceDatasetOps)) {
557     return errors::NotFound("Found an unshardable source dataset: ",
558                             node.DebugString());
559   }
560 
561   if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) {
562     for (int i = 0; i < node.input_size(); ++i) {
563       const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, i);
564       TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers, index,
565                                              flib, graph, nodes_to_delete));
566     }
567     return OkStatus();
568   }
569 
570   // This handles the case for the following subgraph:
571   //   Placeholder -> TensorSliceDataset -> FlatMapDataset -x->
572   //   (other preprocessing datasets) -> InterleaveDataset
573   // and then inserting the shard node immediately after the FlatMapDataset.
574   //
575   // This is used for some training pipelines where a dataset is created with
576   // the following code:
577   //
578   // def make_dataset_pipeline():
579   //   file_globs = [...]
580   //   datasets = []
581   //   for file_glob in file_globs:
582   //     datasets.append(Dataset.list_files(file_glob).map(TFRecordReader))
583   //   dataset = Dataset.from_tensor_slices(datasets)
584   //   dataset = dataset.flat_map(lambda x: x)
585   //   dataset = ...  # additional preprocessing
586   //   dataset = dataset.interleave(lambda x: x, cycle_length=...)
587   //   return dataset
588   if (IsDatasetNodeOfType(node, kFuncDatasetOps)) {
589     const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
590     const NodeDef* flat_map_node = FindFuncAndTensorSliceDataset(
591         input_node, num_workers, index, flib, graph, nodes_to_delete);
592 
593     if (flat_map_node != nullptr) {
594       auto fanouts = graph->GetFanouts(*flat_map_node, false);
595       // FlatMapDataset should only be the input to one other dataset.
596       if (fanouts.size() == 1) {
597         return ProcessDatasetSourceNode(graph, *fanouts.begin()->node,
598                                         nodes_to_delete, num_workers, index);
599       }
600     }
601   }
602 
603   // This handles the case where a reader Dataset is contained within a
604   // FuncDataset (e.g. FlatMap, ParallelInterleave, etc...). For example:
605   //
606   // dataset = Dataset.list_files("/path/to/data")
607   // dataset = dataset.flat_map(core_readers.TFRecordDataset)
608   //
609   // where the list of files is passed in one-by-one as an argument to the
610   // function in flat_map.
611   if (IsDatasetNodeOfType(node, kFuncDatasetOps) &&
612       ReaderOpInFunction(node, *flib)) {
613     return ProcessDatasetSourceNode(graph, node, nodes_to_delete, num_workers,
614                                     index);
615   }
616 
617   if (IsDatasetNodeOfType(node, kReaderDatasetOps)) {
618     // We reached a reader dataset directly and we try to shard input 0.
619     return ProcessDatasetSourceNode(graph, node, nodes_to_delete, num_workers,
620                                     index);
621   }
622 
623   if (!IsDatasetNodeOfType(node, kPassThroughOps)) {
624     return errors::NotFound(
625         "Did not find a shardable source, walked to ",
626         "a node which is not a dataset: ", node.DebugString(),
627         ". Consider either turning off auto-sharding or switching the "
628         "auto_shard_policy to DATA to shard this dataset. You can do this by "
629         "creating a new `tf.data.Options()` object then setting "
630         "`options.experimental_distribute.auto_shard_policy = "
631         "AutoShardPolicy.DATA` before applying the options object to the "
632         "dataset via `dataset.with_options(options)`.");
633   }
634 
635   const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
636   return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
637                              nodes_to_delete);
638 }
639 
640 // Recursively walk the dataset graph from sink to source, searching for
641 // the first (i.e. closest to the sink) occurrence of a ReaderDataset, such as
642 // CSVDataset, TFRecordDataset, etc. We then insert a ShardDataset op before
643 // that nodes input, so that each worker only reads a subset of files.
644 // Additionally, we remove sources of randomness (e.g. ShuffleDataset) that
645 // occur upstream of the ShardDataset transformation to ensure that sharding
646 // returns a sensible result.
ShardByFile(const NodeDef & sink_node,int64_t num_workers,int64_t index,FunctionLibraryDefinition * flib,MutableGraphView * graph)647 Status ShardByFile(const NodeDef& sink_node, int64_t num_workers, int64_t index,
648                    FunctionLibraryDefinition* flib, MutableGraphView* graph) {
649   absl::flat_hash_set<string> nodes_to_delete;
650   TF_RETURN_IF_ERROR(RecursivelyHandleOp(sink_node, num_workers, index, flib,
651                                          graph, &nodes_to_delete));
652   return graph->DeleteNodes(nodes_to_delete);
653 }
654 
RewriteRebatchV2ToV1(const NodeDef & sink_node,int64_t num_replicas,MutableGraphView * graph)655 Status RewriteRebatchV2ToV1(const NodeDef& sink_node, int64_t num_replicas,
656                             MutableGraphView* graph) {
657   // The final node before AutoShardDataset is RebatchDataset.
658   // This is always the case as RebatchDataset and AutoShardDataset are internal
659   // APIs used directly by tf.distribute's input_lib. As such, instead of
660   // walking the entire dataset graph, we can walk up directly from the
661   // sink_node to get the RebatchDataset.
662   NodeDef* input_node = graph_utils::GetInputNode(sink_node, *graph);
663   if (input_node->op() != kRebatchDatasetV2OpName) {
664     return OkStatus();
665   }
666 
667   NodeDef* rebatch_node = input_node;
668   // Update RebatchDatasetV2 in place. Since Rebatch is an internal API, no
669   // other nodes should have it as an input.
670   rebatch_node->set_op(kRebatchDatasetOpName);
671   // Delete the `batch_sizes` and `drop_remainder` input.
672   rebatch_node->mutable_input()->DeleteSubrange(/*start=*/1, /*num=*/2);
673   // Add the `num_replicas` input.
674   if (num_replicas < 1) {
675     return errors::InvalidArgument(
676         "Cannot rewrite RebatchDatasetV2 to legacy RebatchDataset with invalid "
677         "num_replicas argument. `num_replicas` is ",
678         num_replicas, ", but expected to be >= 1.");
679   }
680   auto num_replicas_node = graph_utils::AddScalarConstNode(num_replicas, graph);
681   rebatch_node->add_input(num_replicas_node->name());
682 
683   // Set `use_fallback` attr. This attr is not used anywhere, so its value
684   // does not matter
685   (*rebatch_node->mutable_attr())["use_fallback"].set_b(true);
686 
687   // Update the output_shapes attr to set all its batch dimensions to -1
688   // (unknown).
689   auto* shapes_attr =
690       gtl::FindOrNull(*rebatch_node->mutable_attr(), "output_shapes");
691   if (shapes_attr == nullptr) {
692     return errors::InvalidArgument(
693         "Cannot rewrite RebatchDatasetV2 with missing `output_shapes` attr.");
694   }
695   for (int i = 0; i < shapes_attr->list().shape_size(); ++i) {
696     auto* shape = shapes_attr->mutable_list()->mutable_shape(i);
697     if (shape->unknown_rank()) continue;
698     shape->mutable_dim(0)->set_size(-1);
699   }
700 
701   return OkStatus();
702 }
703 
ShardByData(const NodeDef & sink_node,int64_t num_workers,int64_t index,int64_t num_replicas,MutableGraphView * graph)704 Status ShardByData(const NodeDef& sink_node, int64_t num_workers, int64_t index,
705                    int64_t num_replicas, MutableGraphView* graph) {
706   const NodeDef* shard_before = &sink_node;
707   // We sometimes insert a PrefetchDataset, OptionsDataset, and FinalizeDataset
708   // at the end of the input pipeline before autosharding. When sharding by
709   // data, we should insert the shard before the these datasets so that the
710   // right number of elements is prefetched.
711   NodeDef* input_node = graph_utils::GetInputNode(sink_node, *graph);
712   while (input_node->op() == kPrefetchDatasetOpName ||
713          input_node->op() == kOptionsDatasetOpName ||
714          input_node->op() == kFinalizeDatasetOpName) {
715     shard_before = input_node;
716     input_node = graph_utils::GetInputNode(*input_node, *graph);
717   }
718   // Sharding by data only works with legacy RebatchDataset. As such, we rewrite
719   // all instances of RebatchDatasetV2 to RebatchDataset.
720   TF_RETURN_IF_ERROR(RewriteRebatchV2ToV1(*shard_before, num_replicas, graph));
721   return AddShardNode(graph, *shard_before, num_workers, index);
722 }
723 
724 // Searches the dataset graph replacing any occurrence of `shard(1, 0)` with
725 // `shard(num_workers, index)`.
ShardByHint(const NodeDef & sink_node,int64_t num_workers,int64_t index,int64_t num_replicas,MutableGraphView * graph)726 Status ShardByHint(const NodeDef& sink_node, int64_t num_workers, int64_t index,
727                    int64_t num_replicas, MutableGraphView* graph) {
728   auto get_shard_node = [graph](const NodeDef& node) -> const NodeDef* {
729     if (node.op() != kShardDatasetOpName) return nullptr;
730     auto num_workers_node = graph->GetNode(node.input(1));
731     if (num_workers_node->op() != kConstOpName) return nullptr;
732     if (num_workers_node->attr().at("value").tensor().int64_val(0) !=
733         tensorflow::data::kShardHint)
734       return nullptr;
735     return &node;
736   };
737 
738   auto* num_workers_node =
739       graph_utils::AddScalarConstNode(static_cast<int64_t>(num_workers), graph);
740   auto* worker_index_node =
741       graph_utils::AddScalarConstNode(static_cast<int64_t>(index), graph);
742 
743   for (const NodeDef& node : graph->graph()->node()) {
744     const NodeDef* shard_node = get_shard_node(node);
745     if (!shard_node) continue;
746     auto mutable_node = graph->GetNode(shard_node->name());
747     *mutable_node->mutable_input(1) = num_workers_node->name();
748     *mutable_node->mutable_input(2) = worker_index_node->name();
749     // Ensure that each shard will have at least one element.
750     (*(mutable_node->mutable_attr()))[data::ShardDatasetOp::kRequireNonEmpty]
751         .set_b(true);
752   }
753   return OkStatus();
754 }
755 
ApplyAutoShard(const NodeDef & sink_node,int64_t num_workers,int64_t index,AutoShardPolicy policy,int64_t num_replicas,MutableGraphView * graph,AutoShardPolicy * policy_applied)756 Status ApplyAutoShard(const NodeDef& sink_node, int64_t num_workers,
757                       int64_t index, AutoShardPolicy policy,
758                       int64_t num_replicas, MutableGraphView* graph,
759                       AutoShardPolicy* policy_applied) {
760   *policy_applied = policy;
761   FunctionLibraryDefinition flib(OpRegistry::Global(),
762                                  graph->graph()->library());
763   switch (policy) {
764     case AutoShardPolicy::OFF:
765       return OkStatus();
766     case AutoShardPolicy::FILE:
767       return ShardByFile(sink_node, num_workers, index, &flib, graph);
768     case AutoShardPolicy::DATA:
769       return ShardByData(sink_node, num_workers, index, num_replicas, graph);
770     case AutoShardPolicy::HINT:
771       return ShardByHint(sink_node, num_workers, index, num_replicas, graph);
772     case AutoShardPolicy::AUTO:
773     default:
774       Status s = ShardByFile(sink_node, num_workers, index, &flib, graph);
775       if (errors::IsNotFound(s)) {
776         LOG(WARNING) << "AUTO sharding policy will apply DATA sharding policy "
777                         "as it failed to apply FILE sharding policy because of "
778                         "the following reason: "
779                      << s.error_message();
780         *policy_applied = AutoShardPolicy::DATA;
781         return ShardByData(sink_node, num_workers, index, num_replicas, graph);
782       }
783       *policy_applied = AutoShardPolicy::FILE;
784       return s;
785   }
786 }
787 
OptimizeGraph(const GrapplerItem & item,int64_t num_workers,int64_t index,AutoShardPolicy policy,int64_t num_replicas,GraphDef * output)788 Status OptimizeGraph(const GrapplerItem& item, int64_t num_workers,
789                      int64_t index, AutoShardPolicy policy,
790                      int64_t num_replicas, GraphDef* output) {
791   *output = item.graph;
792   MutableGraphView graph(output);
793   NodeDef* sink_node;
794   TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node));
795 
796   // id for telemetry purpose. item.id is always the same so we use the address
797   // of the output as id.
798   string id = strings::StrCat(reinterpret_cast<uint64>(output));
799   // Only record metrics on the first shard to avoid duplication.
800   if (index == 0) {
801     std::vector<std::string> ineligible_reason;
802     bool is_eligible = internal::IsEligibleRewriteBatchSize(*sink_node, graph,
803                                                             &ineligible_reason);
804     metrics::RecordTFDataAutoShardRewriteBatchSize(is_eligible,
805                                                    ineligible_reason);
806   }
807 
808   AutoShardPolicy policy_applied = policy;
809   if (policy != AutoShardPolicy::OFF &&
810       !(policy == AutoShardPolicy::FILE && num_workers == 1 && index == 0)) {
811     TF_RETURN_IF_ERROR(ApplyAutoShard(*sink_node, num_workers, index, policy,
812                                       num_replicas, &graph, &policy_applied));
813   }
814   // Only record metrics on the first shard to avoid duplication.
815   if (index == 0) {
816     metrics::RecordTFDataAutoShard(id, policy_applied, num_workers,
817                                    num_replicas);
818   }
819   return OkStatus();
820 }
821 
822 }  // anonymous namespace
823 
824 namespace internal {
IsEligibleRewriteBatchSize(const NodeDef & sink_node,const MutableGraphView & graph,std::vector<std::string> * ineligible_reason)825 bool IsEligibleRewriteBatchSize(const NodeDef& sink_node,
826                                 const MutableGraphView& graph,
827                                 std::vector<std::string>* ineligible_reason) {
828   ineligible_reason->clear();
829   NodeDef* input_node = graph_utils::GetInputNode(sink_node, graph);
830   // We always traverse the graph until we arrive at a batch node to collect all
831   // ineligible reasons;
832   while (input_node != nullptr) {
833     // 1. Skip RebatchDataset and the MapDataset immediately before it. That map
834     // is added by tf.data Python code.
835     if (input_node->op() == kRebatchDatasetOpName ||
836         input_node->op() == kRebatchDatasetV2OpName) {
837       input_node = graph_utils::GetInputNode(*input_node, graph);
838       if (input_node == nullptr || input_node->op() != kMapDatasetOpName) {
839         ineligible_reason->push_back("BUG_NO_MAP_BEFORE_REBATCH");
840         return false;
841       }
842       input_node = graph_utils::GetInputNode(*input_node, graph);
843       continue;
844     }
845     // 2. If the node is insensitive to the batch size of the input, we continue
846     // looking at the input dataset of the node.
847     if (IsDatasetNodeOfType(*input_node, kBatchSizeOrthogonalDatasetOps)) {
848       input_node = graph_utils::GetInputNode(*input_node, graph);
849       continue;
850     }
851     // 3. We arrive at a batch node. Examine its drop_remainder input and
852     // cardinality to determine eligibility.
853     if (IsDatasetNodeOfType(*input_node, kBatchDatasetOps)) {
854       DropRemainderValue drop_remainder = GetDropRemainder(graph, *input_node);
855       int64_t cardinality = data::kUnknownCardinality;
856       bool cardinality_available = true;
857       AttrSlice attrs(*input_node);
858       if (!TryGetNodeAttr(attrs, data::kCardinalityAttrForRewrite,
859                           &cardinality)) {
860         cardinality_available = false;
861       }
862 
863       if (drop_remainder == DropRemainderValue::kFalse ||
864           (cardinality_available &&
865            cardinality == data::kInfiniteCardinality)) {
866         return ineligible_reason->empty();
867       } else {
868         if (drop_remainder == DropRemainderValue::kUnknown) {
869           ineligible_reason->push_back("BATCH_DROP_REMAINDER_UNKNOWN");
870         }
871         if (!cardinality_available) {
872           ineligible_reason->push_back("BATCH_CARDINALITY_NOT_AVAILABLE");
873         }
874         if (drop_remainder == DropRemainderValue::kTrue &&
875             cardinality_available &&
876             cardinality != data::kInfiniteCardinality) {
877           ineligible_reason->push_back("BATCH_DROP_REMAINDER_NOT_INFINITE");
878         }
879         return false;
880       }
881     }
882     // 4. We encountered other nodes before arriving at a batch node. We don't
883     // know whether this node is sensitive to the batch size or not and we err
884     // on the safe side.
885     ineligible_reason->push_back(
886         strings::StrCat("OP_NOT_SUPPORTED_", input_node->op()));
887     input_node = graph_utils::GetInputNode(*input_node, graph);
888   }
889   // If we don't find a batch node, only records BATCH_NOT_FOUND as the reason.
890   ineligible_reason->clear();
891   ineligible_reason->push_back("BATCH_NOT_FOUND");
892   return false;
893 }
894 }  // namespace internal
895 
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)896 Status AutoShard::Init(
897     const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
898   if (!config) return errors::InvalidArgument("RewriterConfig not found.");
899 
900   if ((config->parameter_map().find(kNumWorkersAttrName) ==
901        config->parameter_map().end())) {
902     return errors::InvalidArgument(kNumWorkersAttrName, " parameter missing.");
903   }
904 
905   if ((config->parameter_map().find(kIndexAttrName) ==
906        config->parameter_map().end())) {
907     return errors::InvalidArgument(kIndexAttrName, " parameter missing.");
908   }
909 
910   num_workers_ = config->parameter_map().at(kNumWorkersAttrName).i();
911   index_ = config->parameter_map().at(kIndexAttrName).i();
912   auto_shard_policy_ =
913       AutoShardPolicy(config->parameter_map().at(kAutoShardPolicyAttrName).i());
914   num_replicas_ = config->parameter_map().at(kNumReplicasAttrName).i();
915 
916   if (auto_shard_policy_ != AutoShardPolicy::OFF &&
917       auto_shard_policy_ != AutoShardPolicy::AUTO &&
918       auto_shard_policy_ != AutoShardPolicy::DATA &&
919       auto_shard_policy_ != AutoShardPolicy::FILE &&
920       auto_shard_policy_ != AutoShardPolicy::HINT) {
921     return errors::InvalidArgument(kAutoShardPolicyAttrName, " is invalid.");
922   }
923 
924   if (num_workers_ < 1) {
925     return errors::InvalidArgument(kNumWorkersAttrName,
926                                    " should be >= 1, currently ", num_workers_);
927   }
928 
929   if (index_ < 0 || index_ >= num_workers_) {
930     return errors::InvalidArgument(kIndexAttrName, " should be >= 0 and < ",
931                                    num_workers_, ", currently ", index_);
932   }
933 
934   if (num_replicas_ < 0) {
935     return errors::InvalidArgument(kNumReplicasAttrName, " should be >= 0");
936   }
937 
938   return OkStatus();
939 }
940 
OptimizeAndCollectStats(Cluster * cluster,const GrapplerItem & item,GraphDef * output,OptimizationStats * stats)941 Status AutoShard::OptimizeAndCollectStats(Cluster* cluster,
942                                           const GrapplerItem& item,
943                                           GraphDef* output,
944                                           OptimizationStats* stats) {
945   *output = item.graph;
946   TF_RETURN_IF_ERROR(OptimizeGraph(item, num_workers_, index_,
947                                    auto_shard_policy_, num_replicas_, output));
948   stats->num_changes++;
949   return OkStatus();
950 }
951 
952 REGISTER_GRAPH_OPTIMIZER_AS(AutoShard, "tf_auto_shard");
953 
954 }  // namespace grappler
955 }  // namespace tensorflow
956