xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/data/make_deterministic.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/make_deterministic.h"
17 
18 #include <algorithm>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/grappler/clusters/cluster.h"
27 #include "tensorflow/core/grappler/grappler_item.h"
28 #include "tensorflow/core/grappler/mutable_graph_view.h"
29 #include "tensorflow/core/grappler/op_types.h"
30 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
31 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
32 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
33 #include "tensorflow/core/grappler/optimizers/data/split_utils.h"
34 #include "tensorflow/core/grappler/utils.h"
35 
36 namespace tensorflow {
37 namespace grappler {
38 namespace {
39 
40 constexpr char kInterleaveOp[] = "InterleaveDataset";
41 constexpr char kParallelInterleaveOp[] = "ParallelInterleaveDataset";
42 constexpr char kLegacyParallelInterleaveOp[] =
43     "LegacyParallelInterleaveDatasetV2";
44 constexpr char kMapOp[] = "MapDataset";
45 constexpr char kParallelMapOp[] = "ParallelMapDataset";
46 constexpr char kParallelMapOpV2[] = "ParallelMapDatasetV2";
47 constexpr char kMapAndBatchOp[] = "MapAndBatchDataset";
48 constexpr char kBatchOp[] = "BatchDataset";
49 constexpr char kBatchV2Op[] = "BatchDatasetV2";
50 constexpr char kParallelBatchOp[] = "ParallelBatchDataset";
51 constexpr char kPrefetchOp[] = "PrefetchDataset";
52 
53 // List of stateful ops which do not introduce nondeterminism when run as part
54 // of a Dataset function, e.g. within an InterleaveDataset's function. These are
55 // stateful dataset ops which do not read or modify TensorFlow state. Stateful
56 // ops not in this list can introduce nondeterminism, either due to the fact
57 // they are run in parallel (e.g. in a MapDataset with num_parallel_calls > 1)
58 // or because they can run asynchronously (e.g. a PrefetchDataset can cause ops
59 // in a MapDataset to run at the same time as ops outside a dataset).
60 //
61 // Ops in this list are allowed to read from files, as we do not make any
62 // guarantees on determinism if files are modified while a dataset is running.
63 // TODO(reedwm): Expand this list.
64 constexpr std::array<const char*, 9> kDeterministicStatefulOps = {
65     "TextLineDataset", "FixedLengthRecordDataset", "TFRecordDataset",
66     "TensorSliceDataset", "RangeDataset", "SSTableDataset", "RecordIODataset",
67     // Because Print and Assert are on this list, the order of Print and Assert
68     // ops may not be deterministic. This is acceptable, as it doesn't affect
69     // model outputs or weights or other numeric values.
70     "Print", "Assert"};
71 
72 // List of stateful ops which do not introduce nondeterminism when run
73 // asynchronously as part of a Dataset function, but may introduce
74 // nondeterminism when run in parallel. All legacy random ops can be put on this
75 // list, since the state in internal to the op itself, and so there is no risk
76 // of ops outside the dataset reading or modifying the state.
77 constexpr std::array<const char*, 13> kDeterministicStatefulOpsWhenAsync = {
78     "RandomUniform",
79     "RandomUniformInt",
80     "RandomStandardNormal",
81     "ParameterizedTruncatedNormal",
82     "TruncatedNormal",
83     "RandomShuffle",
84     "Multinomial",
85     "RandomGamma",
86     "RandomGammaGrad",
87     "RandomPoisson",
88     "RandomCrop",
89     "SampleDistortedBoundingBox",
90     "SampleDistortedBoundingBoxV2"};
91 
IsDeterministicWhenRunInParallel(const std::string & stateful_op)92 bool IsDeterministicWhenRunInParallel(const std::string& stateful_op) {
93   for (auto op_in_array : kDeterministicStatefulOps) {
94     if (data::MatchesAnyVersion(op_in_array, stateful_op)) {
95       return true;
96     }
97   }
98   return false;
99 }
100 
IsDeterministicWhenRunAsynchronously(const std::string & stateful_op)101 bool IsDeterministicWhenRunAsynchronously(const std::string& stateful_op) {
102   for (auto op_in_array : kDeterministicStatefulOps) {
103     if (data::MatchesAnyVersion(op_in_array, stateful_op)) {
104       return true;
105     }
106   }
107   for (auto op_in_array : kDeterministicStatefulOpsWhenAsync) {
108     if (data::MatchesAnyVersion(op_in_array, stateful_op)) {
109       return true;
110     }
111   }
112   return false;
113 }
114 
IsParallelInterleave(const std::string & op)115 bool IsParallelInterleave(const std::string& op) {
116   return data::MatchesAnyVersion(kParallelInterleaveOp, op) ||
117          op == kLegacyParallelInterleaveOp;
118 }
119 
IsParallelMap(const std::string & op)120 bool IsParallelMap(const std::string& op) {
121   return data::MatchesAnyVersion(kParallelMapOp, op);
122 }
123 
IsParallelBatch(const std::string & op)124 bool IsParallelBatch(const std::string& op) {
125   return data::MatchesAnyVersion(kParallelBatchOp, op);
126 }
127 
IsMapAndBatch(const std::string & op)128 bool IsMapAndBatch(const std::string& op) {
129   return data::MatchesAnyVersion(kMapAndBatchOp, op);
130 }
131 
IsPrefetch(const std::string & op)132 bool IsPrefetch(const std::string& op) {
133   return data::MatchesAnyVersion(kPrefetchOp, op);
134 }
135 
136 // Returns whether the op is a dataset op which runs a function multiple times
137 // in parallel.
IntroducesFunctionParallelism(const std::string & op)138 bool IntroducesFunctionParallelism(const std::string& op) {
139   return IsParallelInterleave(op) || IsParallelMap(op) || IsMapAndBatch(op);
140 }
141 
142 // Returns whether the op is a dataset op which can cause functions in the input
143 // pipeline to run asynchronously.
IntroducesAsynchrony(const std::string & op)144 bool IntroducesAsynchrony(const std::string& op) {
145   // Currently, every op that introduces parallelism also introduces
146   // asynchrony.
147   return IntroducesFunctionParallelism(op) || IsPrefetch(op) ||
148          IsParallelBatch(op);
149 }
150 
151 // Returns map from node name to NodeDef in a function.
NameToNode(const FunctionDef & function)152 absl::flat_hash_map<absl::string_view, const NodeDef*> NameToNode(
153     const FunctionDef& function) {
154   absl::flat_hash_map<absl::string_view, const NodeDef*> name_to_node;
155   for (const NodeDef& node : function.node_def()) {
156     name_to_node.insert({node.name(), &node});
157   }
158   return name_to_node;
159 }
160 
GetMutableNode(const string & node_name,MutableGraphView * graph)161 NodeDef* GetMutableNode(const string& node_name, MutableGraphView* graph) {
162   int index = graph_utils::FindGraphNodeWithName(node_name, *graph->graph());
163   DCHECK_NE(index, -1) << "Failed to find node " << node_name
164                        << " in the optimized graph.";
165   return graph->graph()->mutable_node(index);
166 }
167 
168 // Converts a ParallelInterleaveDataset or ParallelMapDataset to the equivalent
169 // non-parallel version, to make it deterministic.
ConvertMapOrInterleave(const string & node_name,MutableGraphView * graph)170 Status ConvertMapOrInterleave(const string& node_name,
171                               MutableGraphView* graph) {
172   NodeDef* node = GetMutableNode(node_name, graph);
173 
174   auto Targuments = node->attr().find("Targuments");
175   if (Targuments == node->attr().end()) {
176     return errors::Internal("Failed to find Targuments attribute for node ",
177                             node_name);
178   }
179 
180   int num_inputs_after_rewrite;
181   if (IsParallelInterleave(node->op())) {
182     node->set_op(kInterleaveOp);
183     num_inputs_after_rewrite = 3 + Targuments->second.list().type_size();
184   } else {
185     DCHECK(IsParallelMap(node->op()));
186     node->set_op(kMapOp);
187     num_inputs_after_rewrite = 1 + Targuments->second.list().type_size();
188   }
189 
190   // ParallelInterleave and ParallelMap ops take in more inputs than the
191   // corresponding non-parallel versions, so turn extra inputs into control
192   // inputs. These extra inputs are for performance and are safe to ignore.
193   int inputs_processed = 0;
194   for (int i = 0; i < node->input_size(); i++) {
195     std::string input = node->input(i);
196     if (IsControlInput(input)) {
197       continue;
198     }
199     if (inputs_processed >= num_inputs_after_rewrite) {
200       node->set_input(i, absl::StrCat("^", input));
201     }
202     inputs_processed++;
203   }
204   if (inputs_processed < num_inputs_after_rewrite) {
205     return errors::Internal("Found only ", inputs_processed, " inputs to node ",
206                             node_name, ", but expected to find at least ",
207                             num_inputs_after_rewrite);
208   }
209 
210   // Remove extra attributes not in Interleave or Map.
211   node->mutable_attr()->erase("deterministic");
212   node->mutable_attr()->erase("sloppy");
213   return OkStatus();
214 }
215 
216 // Returns all transitive dependencies of a set of nodes, including the nodes
217 // themselves.
GetAllTransitiveDependencies(const FunctionDef & function_def,const absl::flat_hash_set<absl::string_view> & nodes)218 absl::flat_hash_set<absl::string_view> GetAllTransitiveDependencies(
219     const FunctionDef& function_def,
220     const absl::flat_hash_set<absl::string_view>& nodes) {
221   std::vector<absl::string_view> nodes_to_process;
222   std::copy(nodes.begin(), nodes.end(), std::back_inserter(nodes_to_process));
223 
224   absl::flat_hash_map<absl::string_view, const NodeDef*> name_to_node =
225       NameToNode(function_def);
226   absl::flat_hash_set<absl::string_view> dependencies;
227   while (!nodes_to_process.empty()) {
228     absl::string_view node_name = nodes_to_process.back();
229     nodes_to_process.pop_back();
230     if (dependencies.contains(node_name)) {
231       continue;
232     }
233     dependencies.insert(node_name);
234     auto iter = name_to_node.find(node_name);
235     if (iter == name_to_node.end()) {
236       // If the node doesn't exist, the function is malformed, so just ignore
237       // the node for now.
238       continue;
239     }
240     for (absl::string_view inp : iter->second->input()) {
241       absl::string_view inp_node = inp.substr(0, inp.find(':'));
242       if (inp_node.at(0) == '^') {
243         inp_node = inp_node.substr(1);
244       }
245       // Input may be an argument instead of a node, so explicitly check if name
246       // is in name_to_node.
247       if (name_to_node.contains(inp_node)) {
248         nodes_to_process.push_back(inp_node);
249       }
250     }
251   }
252   return dependencies;
253 }
254 
255 // Makes a ParallelMapV2 op deterministic by splitting it into separate Map and
256 // ParallelMapV2 ops, or a MapAndBatch op deterministic by splitting it into
257 // separate Map and MapAndBatch ops. All the nondeterministic nodes and their
258 // dependencies are moved to the Map node.
SplitMap(const FunctionLibraryDefinition & library,const string & map_node_name,MutableGraphView * graph,const absl::flat_hash_set<absl::string_view> & nondeterministic_nodes)259 Status SplitMap(
260     const FunctionLibraryDefinition& library, const string& map_node_name,
261     MutableGraphView* graph,
262     const absl::flat_hash_set<absl::string_view>& nondeterministic_nodes) {
263   NodeDef* map_node = GetMutableNode(map_node_name, graph);
264   NameAttrList func = map_node->attr().at("f").func();
265   const FunctionDef* function_def = library.Find(func.name());
266   if (!function_def) {
267     return errors::Internal("Could not look up function ", func.name(),
268                             " in FunctionLibraryDefinition");
269   }
270 
271   absl::flat_hash_set<absl::string_view> nodes_to_move =
272       GetAllTransitiveDependencies(*function_def, nondeterministic_nodes);
273 
274   VLOG(2) << "Will move nodes to nonparallel function: "
275           << absl::StrJoin(nodes_to_move, ", ");
276 
277   int64_t num_captured_arguments =
278       map_node->attr().find("Targuments")->second.list().type_size();
279 
280   TF_ASSIGN_OR_RETURN(
281       split_utils::SplitResults split_results,
282       split_utils::SplitFunction(*function_def, nodes_to_move,
283                                  num_captured_arguments, library));
284 
285   if (split_results.first_function_output_types.empty()) {
286     // Map datasets require there to be at least one output.
287     return errors::Unimplemented(
288         "The case where the first function has no outputs is unimplemented.");
289   }
290 
291   bool is_map_and_batch = map_node->op() == kMapAndBatchOp;
292 
293   NodeDef* first_map_node_ptr;
294   {
295     NodeDef first_map_node;
296     graph_utils::SetUniqueGraphNodeName(
297         strings::StrCat("make_deterministic_sequential_map/", map_node->name()),
298         graph->graph(), &first_map_node);
299     first_map_node.set_op(kMapOp);
300     int num_control_deps = NumControlInputs(*map_node);
301     // ParallelMap and MapAndBatch nodes have "num_extra_inputs" more inputs
302     // than Map. All inputs are copied to the Map node, but the
303     // "num_extra_inputs" inputs are converted to control dependencies.
304     int num_extra_inputs = is_map_and_batch ? 3 : 1;
305     int control_deps_index = map_node->input_size() - num_control_deps;
306     int extra_inputs_index = control_deps_index - num_extra_inputs;
307     for (int i = 0; i < extra_inputs_index; i++) {
308       // Copy inputs that are also inputs to Map
309       DCHECK(!IsControlInput(map_node->input(i)));
310       first_map_node.add_input(map_node->input(i));
311     }
312     for (int i = extra_inputs_index; i < control_deps_index; i++) {
313       // Copy the extra inputs, converting them to control dependencies
314       DCHECK(!IsControlInput(map_node->input(i)));
315       first_map_node.add_input(absl::StrCat("^", map_node->input(i)));
316     }
317     for (int i = control_deps_index; i < map_node->input_size(); i++) {
318       // Copy the control dependencies
319       DCHECK(IsControlInput(map_node->input(i)));
320       first_map_node.add_input(map_node->input(i));
321     }
322 
323     NameAttrList* name_attr_list =
324         (*first_map_node.mutable_attr())["f"].mutable_func();
325     // TODO(reedwm): Set attrs?
326     name_attr_list->set_name(split_results.first_function.signature().name());
327 
328     graph_utils::CopyAttribute("Targuments", *map_node, &first_map_node);
329     for (auto key : {"use_inter_op_parallelism", "preserve_cardinality"}) {
330       if (gtl::FindOrNull(map_node->attr(), key)) {
331         graph_utils::CopyAttribute(key, *map_node, &first_map_node);
332       }
333     }
334     AddNodeAttr("output_types", split_results.first_function_output_types,
335                 &first_map_node);
336     TensorShapeProto unknown_shape;
337     unknown_shape.set_unknown_rank(true);
338     std::vector<TensorShapeProto> output_shapes(
339         split_results.first_function_output_types.size(), unknown_shape);
340     AddNodeAttr("output_shapes", output_shapes, &first_map_node);
341     first_map_node_ptr = graph->AddNode(std::move(first_map_node));
342   }
343 
344   NodeDef* second_map_node_ptr;
345   {
346     NodeDef second_map_node;
347     string node_name =
348         map_node->op() == kMapAndBatchOp ? "map_and_batch" : "parallel_map";
349     graph_utils::SetUniqueGraphNodeName(
350         strings::StrCat("make_deterministic_parallel_", node_name, "/",
351                         map_node->name()),
352         graph->graph(), &second_map_node);
353     second_map_node.set_op(map_node->op());
354     second_map_node.add_input(first_map_node_ptr->name());
355     for (int i = 1; i < map_node->input_size(); i++) {
356       second_map_node.add_input(map_node->input(i));
357     }
358     NameAttrList* name_attr_list =
359         (*second_map_node.mutable_attr())["f"].mutable_func();
360     // TODO(reedwm): Set attrs?
361     name_attr_list->set_name(split_results.second_function.signature().name());
362     graph_utils::CopyAttribute("Targuments", *map_node, &second_map_node);
363     graph_utils::CopyAttribute("output_types", *map_node, &second_map_node);
364     graph_utils::CopyAttribute("output_shapes", *map_node, &second_map_node);
365     if (!is_map_and_batch) {
366       AddNodeAttr("deterministic", "true", &second_map_node);
367     }
368     for (auto key : {"use_inter_op_parallelism", "preserve_cardinality"}) {
369       if (gtl::FindOrNull(map_node->attr(), key)) {
370         graph_utils::CopyAttribute(key, *map_node, &second_map_node);
371       }
372     }
373     second_map_node_ptr = graph->AddNode(std::move(second_map_node));
374   }
375 
376   TF_RETURN_IF_ERROR(
377       graph->UpdateFanouts(map_node->name(), second_map_node_ptr->name()));
378   *graph->graph()->mutable_library()->mutable_function()->Add() =
379       split_results.first_function;
380   *graph->graph()->mutable_library()->mutable_function()->Add() =
381       split_results.second_function;
382   return OkStatus();
383 }
384 
385 // Converts a ParallalBatch dataset to a Batch dataset, to make it
386 // deterministic.
ConvertBatch(const string & node_name,MutableGraphView * graph)387 Status ConvertBatch(const string& node_name, MutableGraphView* graph) {
388   NodeDef* node = GetMutableNode(node_name, graph);
389   node->set_op(kBatchV2Op);
390   std::string num_parallel_calls_input = node->input(2);
391   node->set_input(2, node->input(3));
392   node->set_input(3, absl::StrCat("^", num_parallel_calls_input));
393   node->mutable_attr()->erase("deterministic");
394   return OkStatus();
395 }
396 
397 // Convert a MapAndBatch node to a separate Map node and Batch node, to make it
398 // deterministic. Caller should delete the MapAndBatch node afterwards.
399 // TODO(reedwm): Handle 'metadata' attribute. Currently the Map node and Batch
400 // node will have an empty 'metadata' attribute.
ConvertMapAndBatch(const string & node_name,MutableGraphView * graph)401 Status ConvertMapAndBatch(const string& node_name, MutableGraphView* graph) {
402   int index = graph_utils::FindGraphNodeWithName(node_name, *graph->graph());
403   DCHECK_NE(index, -1) << "Failed to find node " << node_name
404                        << " in the optimized graph.";
405   const NodeDef& orig_node = graph->graph()->node(index);
406 
407   auto Targuments = orig_node.attr().find("Targuments");
408   if (Targuments == orig_node.attr().end()) {
409     return errors::Internal("Failed to find Targuments attribute for node ",
410                             node_name);
411   }
412 
413   // Create map node
414   NodeDef new_map_node;
415   new_map_node.set_op(kMapOp);
416   graph_utils::SetUniqueGraphNodeName(kMapOp, graph->graph(), &new_map_node);
417   int num_map_inputs = 1 + Targuments->second.list().type_size();
418   for (int i = 0; i < num_map_inputs; i++) {
419     new_map_node.add_input(orig_node.input(i));
420   }
421   for (int i = num_map_inputs; i < orig_node.input_size(); i++) {
422     if (IsControlInput(orig_node.input(i))) {
423       new_map_node.add_input(orig_node.input(i));
424     } else {
425       new_map_node.add_input(absl::StrCat("^", orig_node.input(i)));
426     }
427   }
428   for (auto key : {"f", "Targuments", "output_types"}) {
429     graph_utils::CopyAttribute(key, orig_node, &new_map_node);
430   }
431   for (auto key : {"preserve_cardinality"}) {
432     if (gtl::FindOrNull(new_map_node.attr(), key)) {
433       graph_utils::CopyAttribute(key, orig_node, &new_map_node);
434     }
435   }
436   auto orig_output_shapes = orig_node.attr().find("output_shapes");
437   if (orig_output_shapes == orig_node.attr().end()) {
438     return errors::Internal("Failed to find output_shapes attribute for node ",
439                             node_name);
440   }
441 
442   // Set "output_shapes" attr of Map to be "output_shapes" of MapAndBatch with
443   // the leading dimension removed for each shape.
444   AttrValue& map_output_shapes =
445       (*new_map_node.mutable_attr())["output_shapes"];
446   for (const TensorShapeProto& orig_shape :
447        orig_output_shapes->second.list().shape()) {
448     TensorShapeProto* new_shape = map_output_shapes.mutable_list()->add_shape();
449     if (orig_shape.unknown_rank()) {
450       new_shape->set_unknown_rank(true);
451     } else if (orig_shape.dim_size() == 0) {
452       return errors::Internal(
453           "Output shape of MapAndBatch node cannot be scalar");
454     } else {
455       for (int i = 1; i < orig_shape.dim_size(); i++) {
456         *new_shape->add_dim() = orig_shape.dim(i);
457       }
458     }
459   }
460 
461   // Create batch node
462   NodeDef new_batch_node;
463   new_batch_node.set_op(kBatchV2Op);
464   graph_utils::SetUniqueGraphNodeName(kBatchOp, graph->graph(),
465                                       &new_batch_node);
466   new_batch_node.add_input(new_map_node.name());
467   new_batch_node.add_input(orig_node.input(num_map_inputs));  // batch_size
468   new_batch_node.add_input(
469       orig_node.input(num_map_inputs + 2));  // drop_remainder
470   graph_utils::CopyShapesAndTypesAttrs(orig_node, &new_batch_node);
471 
472   graph->AddNode(std::move(new_map_node));
473   NodeDef* graph_batch_node = graph->AddNode(std::move(new_batch_node));
474   TF_RETURN_IF_ERROR(
475       graph->UpdateFanouts(orig_node.name(), graph_batch_node->name()));
476   return OkStatus();
477 }
478 
479 // Change the buffer_size of a Prefetch node to zero, effectively disabling it,
480 // to make it deterministic.
ConvertPrefetch(const string & node_name,MutableGraphView * graph)481 Status ConvertPrefetch(const string& node_name, MutableGraphView* graph) {
482   NodeDef* node = GetMutableNode(node_name, graph);
483   constexpr int buffer_size_index = 1;
484   node->add_input(absl::StrCat("^", node->input(buffer_size_index)));
485   NodeDef* tmp = graph_utils::AddScalarConstNode<int64_t>(0, graph);
486   node->set_input(buffer_size_index, tmp->name());
487   return OkStatus();
488 }
489 
490 // The two ways nondeterminism can occur in an input pipeline when there are
491 // stateful ops.
492 enum class NondeterminismType { PARALLELISM, ASYNCHRONY };
493 
494 // Returns whether the stateful op is deterministic if run in parallel or
495 // asynchronously.
IsDeterministicStatefulOp(NondeterminismType type,const std::string & stateful_op)496 bool IsDeterministicStatefulOp(NondeterminismType type,
497                                const std::string& stateful_op) {
498   return type == NondeterminismType::PARALLELISM
499              ? IsDeterministicWhenRunInParallel(stateful_op)
500              : IsDeterministicWhenRunAsynchronously(stateful_op);
501 }
502 
503 // Defined below. Mutually recursive with FunctionMayIntroduceNondeterminism.
504 bool FunctionNodeMayIntroduceNondeterminism(
505     const FunctionLibraryDefinition& library, const NodeDef& node_def,
506     NondeterminismType nondeterminism_type,
507     absl::flat_hash_set<std::string>* functions_processed);
508 
509 // Returns true if the function may introduce nondeterminism. Depending on
510 // 'nondeterminism_type', either checks if nondeterminism can occur when the
511 // function is run several times in parallel or when run asynchronously.
512 // Recursively checks any function attributes of ops within the function.
513 // "functions_processed" is the list of functions already processed, so that the
514 // same function is not recursively checked twice. If not null, nodes causing
515 // nondeterminism will be added to "nondeterministic_nodes".
FunctionMayIntroduceNondeterminism(const FunctionLibraryDefinition & library,const std::string & function_name,NondeterminismType nondeterminism_type,absl::flat_hash_set<std::string> * functions_processed,absl::flat_hash_set<absl::string_view> * nondeterministic_nodes)516 bool FunctionMayIntroduceNondeterminism(
517     const FunctionLibraryDefinition& library, const std::string& function_name,
518     NondeterminismType nondeterminism_type,
519     absl::flat_hash_set<std::string>* functions_processed,
520     absl::flat_hash_set<absl::string_view>* nondeterministic_nodes) {
521   if (functions_processed->contains(function_name)) {
522     return false;
523   }
524   functions_processed->insert(function_name);
525   const FunctionDef* function_def = library.Find(function_name);
526   if (!function_def) {
527     VLOG(2) << "Could not look up function " << function_name
528             << " in FunctionLibraryDefinition, so rewriting op to be safe";
529     return true;
530   }
531   bool found = false;
532   for (const NodeDef& node_def : function_def->node_def()) {
533     bool nondeterministic = FunctionNodeMayIntroduceNondeterminism(
534         library, node_def, nondeterminism_type, functions_processed);
535     if (nondeterministic) {
536       if (nondeterministic_nodes) {
537         nondeterministic_nodes->insert(node_def.name());
538         found = true;
539       } else {
540         return true;
541       }
542     }
543   }
544   return found;
545 }
546 
FunctionMayIntroduceNondeterminism(const FunctionLibraryDefinition & library,const std::string & function_name,NondeterminismType nondeterminism_type)547 bool FunctionMayIntroduceNondeterminism(
548     const FunctionLibraryDefinition& library, const std::string& function_name,
549     NondeterminismType nondeterminism_type) {
550   absl::flat_hash_set<string> functions_processed;
551   return FunctionMayIntroduceNondeterminism(library, function_name,
552                                             nondeterminism_type,
553                                             &functions_processed, nullptr);
554 }
555 
556 // Returns true if the given NodeDef inside a function may cause nondeterminism.
FunctionNodeMayIntroduceNondeterminism(const FunctionLibraryDefinition & library,const NodeDef & node_def,NondeterminismType nondeterminism_type,absl::flat_hash_set<std::string> * functions_processed)557 bool FunctionNodeMayIntroduceNondeterminism(
558     const FunctionLibraryDefinition& library, const NodeDef& node_def,
559     NondeterminismType nondeterminism_type,
560     absl::flat_hash_set<std::string>* functions_processed) {
561   const OpRegistrationData* op_reg_data = nullptr;
562   Status s = library.LookUp(node_def.op(), &op_reg_data);
563   if (!s.ok()) {
564     VLOG(2) << "Could not look up op " << node_def.op()
565             << " in FunctionLibraryDefinition, so rewriting op to be safe";
566     return true;
567   }
568   bool is_function_op = op_reg_data->is_function_op;
569 
570   bool is_stateful = false;
571   if (!is_function_op) {
572     const OpDef* op_def;
573     s = OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def);
574     if (!s.ok()) {
575       VLOG(2) << "Could not look up op " << node_def.op()
576               << " in OpRegistry, so rewriting op to be safe";
577       return true;
578     }
579     is_stateful = op_def->is_stateful();
580   }
581 
582   // Rewrite nondeterministic stateful ops. Function ops and If/While ops are
583   // skipped, since we instead look at the ops within the function(s).
584   if (is_stateful && !IsStatefulPartitionedCall((node_def)) &&
585       !IsIf(node_def) && !IsWhile(node_def) &&
586       !IsDeterministicStatefulOp(nondeterminism_type, node_def.op())) {
587     VLOG(2) << "Will rewrite due to op: " << node_def.op();
588     return true;
589   }
590 
591   // Recursively check for nondeterminism in all function attributes.
592   std::vector<std::string> attr_func_names;
593   for (const auto& attr : node_def.attr()) {
594     if (attr.second.has_func()) {
595       attr_func_names.push_back(attr.second.func().name());
596     }
597     for (const auto& name_attr_list : attr.second.list().func()) {
598       attr_func_names.push_back(name_attr_list.name());
599     }
600   }
601   if (is_function_op) {
602     attr_func_names.push_back(node_def.op());
603   }
604   for (const std::string& inner_function_name : attr_func_names) {
605     if (FunctionMayIntroduceNondeterminism(library, inner_function_name,
606                                            nondeterminism_type,
607                                            functions_processed, nullptr)) {
608       return true;
609     }
610   }
611   return false;
612 }
613 
614 // Returns true if "node" is a dataset node whose function can introduce
615 // nondeterminism when run asynchronously.
NodeMayIntroduceNondeterminismWhenAsync(const FunctionLibraryDefinition & library,const NodeDef & node)616 bool NodeMayIntroduceNondeterminismWhenAsync(
617     const FunctionLibraryDefinition& library, const NodeDef& node) {
618   const OpDef* op_def;
619   Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
620   if (s.code() == error::NOT_FOUND) {
621     return false;
622   } else if (!s.ok()) {
623     return true;
624   }
625   if (data::DatasetOpKernel::IsDatasetOp(*op_def)) {
626     std::vector<std::string> attr_func_names;
627     for (const auto& attr : node.attr()) {
628       if (attr.second.has_func()) {
629         attr_func_names.push_back(attr.second.func().name());
630       }
631       for (const auto& name_attr_list : attr.second.list().func()) {
632         attr_func_names.push_back(name_attr_list.name());
633       }
634     }
635     for (const std::string& inner_function_name : attr_func_names) {
636       if (FunctionMayIntroduceNondeterminism(library, inner_function_name,
637                                              NondeterminismType::ASYNCHRONY)) {
638         return true;
639       }
640     }
641   }
642   return false;
643 }
644 
645 // Returns true if the graph has any dataset node whose function can introduce
646 // nondeterminism when run asynchronously.
GraphMayHaveAsyncNondeterminism(const FunctionLibraryDefinition & library,const GraphDef & graph)647 bool GraphMayHaveAsyncNondeterminism(const FunctionLibraryDefinition& library,
648                                      const GraphDef& graph) {
649   for (const NodeDef& node : graph.node()) {
650     if (NodeMayIntroduceNondeterminismWhenAsync(library, node)) {
651       return true;
652     }
653   }
654   for (const string& function_name : library.ListFunctionNames()) {
655     const FunctionDef* function_def = library.Find(function_name);
656     CHECK(function_def);  // Crash Ok
657     for (const NodeDef& node : function_def->node_def()) {
658       if (NodeMayIntroduceNondeterminismWhenAsync(library, node)) {
659         return true;
660       }
661     }
662   }
663   return false;
664 }
665 
666 }  // namespace
667 
OptimizeAndCollectStats(Cluster * cluster,const GrapplerItem & item,GraphDef * output,OptimizationStats * stats)668 Status MakeDeterministic::OptimizeAndCollectStats(Cluster* cluster,
669                                                   const GrapplerItem& item,
670                                                   GraphDef* output,
671                                                   OptimizationStats* stats) {
672   *output = item.graph;
673   MutableGraphView graph(output);
674   FunctionLibraryDefinition function_library(OpRegistry::Global(),
675                                              item.graph.library());
676   absl::flat_hash_set<string> nodes_to_delete;
677   bool remove_async_nodes =
678       GraphMayHaveAsyncNondeterminism(function_library, item.graph);
679 
680   for (const NodeDef& node : item.graph.node()) {
681     if (graph_utils::HasSloppyAttr(node.op())) {
682       NodeDef* mutable_node = GetMutableNode(node.name(), &graph);
683       (*mutable_node->mutable_attr())["sloppy"].set_b(false);
684       stats->num_changes++;
685     }
686     if (graph_utils::HasDeterministicAttr(node.op())) {
687       NodeDef* mutable_node = GetMutableNode(node.name(), &graph);
688       (*mutable_node->mutable_attr())["deterministic"].set_s("true");
689       stats->num_changes++;
690     }
691 
692     bool rewrite_due_to_async =
693         IntroducesAsynchrony(node.op()) && remove_async_nodes;
694     absl::flat_hash_set<std::string> functions_processed;
695     absl::flat_hash_set<absl::string_view> nondeterministic_nodes;
696     bool rewrite_due_to_parallelism =
697         IntroducesFunctionParallelism(node.op()) &&
698         FunctionMayIntroduceNondeterminism(
699             function_library, node.attr().at("f").func().name(),
700             NondeterminismType::PARALLELISM, &functions_processed,
701             &nondeterministic_nodes);
702     if (!rewrite_due_to_async && !rewrite_due_to_parallelism) {
703       continue;
704     }
705 
706     VLOG(1) << "Rewriting node " << node.name() << " (" << node.op()
707             << ") because it introduces nondeterminism through "
708             << (rewrite_due_to_async ? "asynchrony" : "parallelism");
709 
710     bool maybe_can_split =
711         !rewrite_due_to_async &&
712         (node.op() == kParallelMapOpV2 || IsMapAndBatch(node.op()));
713     if (maybe_can_split) {
714       Status s = SplitMap(function_library, node.name(), &graph,
715                           nondeterministic_nodes);
716       if (s.ok()) {
717         VLOG(1) << "Split node " << node.name() << " (" << node.op()
718                 << ") into two map nodes: a nonparallel version and a "
719                    "parallel version.";
720         nodes_to_delete.insert(node.name());
721         continue;
722       } else if (s.code() == error::UNIMPLEMENTED) {
723         // If splitting the function is unimplemented, instead convert the node
724         // to a nonparallel version below.
725         VLOG(1) << "Could not move stateful ops to their own function, so will "
726                    "convert node "
727                 << node.name()
728                 << " to a nonparallel version instead. Reason: " << s;
729       } else {
730         return s;
731       }
732     }
733 
734     if (IsPrefetch(node.op())) {
735       TF_RETURN_IF_ERROR(ConvertPrefetch(node.name(), &graph));
736     } else if (IsMapAndBatch(node.op())) {
737       TF_RETURN_IF_ERROR(ConvertMapAndBatch(node.name(), &graph));
738       nodes_to_delete.insert(node.name());
739     } else if (IsParallelBatch(node.op())) {
740       TF_RETURN_IF_ERROR(ConvertBatch(node.name(), &graph));
741     } else {
742       DCHECK(IsParallelInterleave(node.op()) || IsParallelMap(node.op()));
743       TF_RETURN_IF_ERROR(ConvertMapOrInterleave(node.name(), &graph));
744     }
745     stats->num_changes++;
746   }
747 
748   TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
749   return OkStatus();
750 }
751 
752 REGISTER_GRAPH_OPTIMIZER_AS(MakeDeterministic, "make_deterministic");
753 
754 }  // namespace grappler
755 }  // namespace tensorflow
756