xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/dataset_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/dataset_utils.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/core/common_runtime/function.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/dataset.h"
31 #include "tensorflow/core/framework/function.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op_def_builder.h"
34 #include "tensorflow/core/framework/op_def_util.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/framework/tensor_util.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/graph/graph_def_builder.h"
40 #include "tensorflow/core/lib/core/blocking_counter.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/hash/hash.h"
43 #include "tensorflow/core/lib/strings/proto_serialization.h"
44 #include "tensorflow/core/platform/host_info.h"
45 #include "tensorflow/core/util/determinism.h"
46 #include "tensorflow/core/util/work_sharder.h"
47 
48 namespace tensorflow {
49 namespace data {
50 namespace {
51 
52 constexpr char kOutputSize[] = "output_size";
53 constexpr char kCode[] = "code";
54 constexpr char kMessage[] = "msg";
55 constexpr char kOutput[] = "output";
56 
get_dataset_experiment_registry_lock()57 static mutex* get_dataset_experiment_registry_lock() {
58   static mutex dataset_experiment_registry_lock(LINKER_INITIALIZED);
59   return &dataset_experiment_registry_lock;
60 }
61 
get_dataset_experiments()62 static absl::flat_hash_map<string, int64_t>* get_dataset_experiments() {
63   static absl::flat_hash_map<string, int64_t>* experiments =
64       new absl::flat_hash_map<string, int64_t>;
65   return experiments;
66 }
67 
68 // Use "Opt" suffix so that they are not confused with the enums in Options
69 // proto.
70 constexpr char kMapAndBatchFusionOpt[] = "map_and_batch_fusion";
71 constexpr char kNoopEliminationOpt[] = "noop_elimination";
72 constexpr char kMapParallelizationOpt[] = "map_parallelization";
73 constexpr char kShuffleAndRepeatFusionOpt[] = "shuffle_and_repeat_fusion";
74 constexpr char kFilterFusionOpt[] = "filter_fusion";
75 constexpr char kMapAndFilterFusionOpt[] = "map_and_filter_fusion";
76 constexpr char kMapFusionOpt[] = "map_fusion";
77 constexpr char kParallelBatchOpt[] = "parallel_batch";
78 constexpr char kAutotuneBufferSizesOpt[] = "autotune_buffer_sizes";
79 constexpr char kDisablePrefetchLegacyAutotuneOpt[] =
80     "disable_prefetch_legacy_autotune";
81 constexpr char kMakeSloppyOpt[] = "make_sloppy";
82 constexpr char kUseChooseFastestOpt[] = "use_choose_fastest";
83 constexpr char kBatchParallelizationOpt[] = "batch_parallelization";
84 constexpr char kEnableGradientDescentOpt[] = "enable_gradient_descent";
85 constexpr char kInjectPrefetchOpt[] = "inject_prefetch";
86 constexpr char kAutotuneOpt[] = "autotune";
87 constexpr char kSlackOpt[] = "slack";
88 constexpr char kSlackPeriodOpt[] = "slack_period";
89 constexpr char kMakeDeterministicOpt[] = "make_deterministic";
90 constexpr char kFilterParallelizationOpt[] = "filter_parallelization";
91 
DefaultOptimizationGraphRewrites(const Options & options,absl::flat_hash_set<tstring> * optimization_enabled,absl::flat_hash_set<tstring> * optimization_disabled,absl::flat_hash_set<tstring> * optimization_default)92 void DefaultOptimizationGraphRewrites(
93     const Options& options, absl::flat_hash_set<tstring>* optimization_enabled,
94     absl::flat_hash_set<tstring>* optimization_disabled,
95     absl::flat_hash_set<tstring>* optimization_default) {
96   const auto& optimization_options = options.optimization_options();
97   if (optimization_options.optional_apply_default_optimizations_case() !=
98           OptimizationOptions::kApplyDefaultOptimizations ||
99       optimization_options.apply_default_optimizations()) {
100     if (optimization_options.optional_map_and_batch_fusion_case() !=
101         OptimizationOptions::kMapAndBatchFusion) {
102       optimization_default->insert(kMapAndBatchFusionOpt);
103     }
104     if (optimization_options.optional_noop_elimination_case() !=
105         OptimizationOptions::kNoopElimination) {
106       optimization_default->insert(kNoopEliminationOpt);
107     }
108     if (optimization_options.optional_map_parallelization_case() !=
109         OptimizationOptions::kMapParallelization) {
110       optimization_default->insert(kMapParallelizationOpt);
111     }
112     if (optimization_options.optional_shuffle_and_repeat_fusion_case() !=
113         OptimizationOptions::kShuffleAndRepeatFusion) {
114       optimization_default->insert(kShuffleAndRepeatFusionOpt);
115     }
116     if (optimization_options.optional_parallel_batch_case() !=
117         OptimizationOptions::kParallelBatch) {
118       optimization_default->insert(kParallelBatchOpt);
119     }
120   }
121   if (OpDeterminismRequired()) {
122     optimization_enabled->insert(kMakeDeterministicOpt);
123   }
124   if (optimization_options.optional_filter_fusion_case() ==
125       OptimizationOptions::kFilterFusion) {
126     if (optimization_options.filter_fusion()) {
127       optimization_enabled->insert(kFilterFusionOpt);
128     } else {
129       optimization_disabled->insert(kFilterFusionOpt);
130     }
131   }
132   if (optimization_options.optional_map_and_batch_fusion_case() ==
133       OptimizationOptions::kMapAndBatchFusion) {
134     if (optimization_options.map_and_batch_fusion()) {
135       optimization_enabled->insert(kMapAndBatchFusionOpt);
136     } else {
137       optimization_disabled->insert(kMapAndBatchFusionOpt);
138     }
139   }
140   if (optimization_options.optional_map_and_filter_fusion_case() ==
141       OptimizationOptions::kMapAndFilterFusion) {
142     if (optimization_options.map_and_filter_fusion()) {
143       optimization_enabled->insert(kMapAndFilterFusionOpt);
144     } else {
145       optimization_disabled->insert(kMapAndFilterFusionOpt);
146     }
147   }
148   if (optimization_options.optional_map_parallelization_case() ==
149       OptimizationOptions::kMapParallelization) {
150     if (optimization_options.map_parallelization()) {
151       optimization_enabled->insert(kMapParallelizationOpt);
152     } else {
153       optimization_disabled->insert(kMapParallelizationOpt);
154     }
155   }
156   if (optimization_options.optional_filter_parallelization_case() ==
157       OptimizationOptions::kFilterParallelization) {
158     if (optimization_options.filter_parallelization()) {
159       optimization_enabled->insert(kFilterParallelizationOpt);
160     } else {
161       optimization_disabled->insert(kFilterParallelizationOpt);
162     }
163   }
164   if (optimization_options.optional_map_fusion_case() ==
165       OptimizationOptions::kMapFusion) {
166     if (optimization_options.map_fusion()) {
167       optimization_enabled->insert(kMapFusionOpt);
168     } else {
169       optimization_disabled->insert(kMapFusionOpt);
170     }
171   }
172   if (optimization_options.optional_noop_elimination_case() ==
173       OptimizationOptions::kNoopElimination) {
174     if (optimization_options.noop_elimination()) {
175       optimization_enabled->insert(kNoopEliminationOpt);
176     } else {
177       optimization_disabled->insert(kNoopEliminationOpt);
178     }
179   }
180   if (optimization_options.optional_parallel_batch_case() ==
181       OptimizationOptions::kParallelBatch) {
182     if (optimization_options.parallel_batch()) {
183       optimization_enabled->insert(kParallelBatchOpt);
184     } else {
185       optimization_disabled->insert(kParallelBatchOpt);
186     }
187   }
188   if (optimization_options.optional_shuffle_and_repeat_fusion_case() ==
189       OptimizationOptions::kShuffleAndRepeatFusion) {
190     if (optimization_options.shuffle_and_repeat_fusion()) {
191       optimization_enabled->insert(kShuffleAndRepeatFusionOpt);
192     } else {
193       optimization_disabled->insert(kShuffleAndRepeatFusionOpt);
194     }
195   }
196   if (optimization_options.optional_inject_prefetch_case() ==
197       OptimizationOptions::kInjectPrefetch) {
198     if (optimization_options.inject_prefetch()) {
199       optimization_enabled->insert(kInjectPrefetchOpt);
200     } else {
201       optimization_disabled->insert(kInjectPrefetchOpt);
202     }
203   }
204 }
205 
206 // Returns whether an op has been allowlisted as stateless. Uses a heuristic to
207 // allowlist source dataset ops which have been marked stateful due to
208 // b/65524810. Also looks up the `op_def->name` in the global
209 // `AllowlistedStatefulOpRegistry`.
IsOpAllowlisted(const OpDef * op_def)210 bool IsOpAllowlisted(const OpDef* op_def) {
211   return (op_def->output_arg_size() == 1 &&
212           op_def->output_arg(0).type() == DT_VARIANT &&
213           (absl::EndsWith(op_def->name(), "Dataset") ||
214            absl::EndsWith(op_def->name(), "DatasetV2"))) ||
215          AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name());
216 }
217 
218 }  // namespace
219 
MaybeOverrideSeeds(std::pair<int64_t,int64_t> seeds)220 std::pair<int64_t, int64_t> MaybeOverrideSeeds(
221     std::pair<int64_t, int64_t> seeds) {
222   if (seeds.first == 0 && seeds.second == 0) {
223     return {random::New64(), random::New64()};
224   }
225   return seeds;
226 }
227 
VerifyTypeMatch(const DataType & expected,const DataType & received,int index)228 Status VerifyTypeMatch(const DataType& expected, const DataType& received,
229                        int index) {
230   if (expected != received) {
231     return errors::InvalidArgument("Data type mismatch at component ", index,
232                                    ": expected ", DataTypeString(expected),
233                                    " but got ", DataTypeString(received), ".");
234   }
235   return OkStatus();
236 }
237 
VerifyTypesMatch(const DataTypeVector & expected,const DataTypeVector & received)238 Status VerifyTypesMatch(const DataTypeVector& expected,
239                         const DataTypeVector& received) {
240   if (expected.size() != received.size()) {
241     return errors::InvalidArgument(
242         "Number of components does not match: expected ", expected.size(),
243         " types but got ", received.size(), ".");
244   }
245   for (size_t i = 0; i < expected.size(); ++i) {
246     TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i));
247   }
248   return OkStatus();
249 }
250 
VerifyTypesMatch(const DataTypeVector & expected,const std::vector<Tensor> & received)251 Status VerifyTypesMatch(const DataTypeVector& expected,
252                         const std::vector<Tensor>& received) {
253   if (expected.size() != received.size()) {
254     return errors::InvalidArgument(
255         "Number of components does not match: expected ", expected.size(),
256         " types but got ", received.size(), ".");
257   }
258   for (size_t i = 0; i < expected.size(); ++i) {
259     TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i));
260   }
261   return OkStatus();
262 }
263 
VerifyShapeCompatible(const PartialTensorShape & expected,const PartialTensorShape & received,int index)264 Status VerifyShapeCompatible(const PartialTensorShape& expected,
265                              const PartialTensorShape& received, int index) {
266   if (!expected.IsCompatibleWith(received)) {
267     return errors::InvalidArgument("Incompatible shapes at component ", index,
268                                    ": expected ", expected.DebugString(),
269                                    " but got ", received.DebugString(), ".");
270   }
271   return OkStatus();
272 }
273 
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<PartialTensorShape> & received)274 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
275                               const std::vector<PartialTensorShape>& received) {
276   if (expected.size() != received.size()) {
277     return errors::InvalidArgument(
278         "Number of components does not match: expected ", expected.size(),
279         " shapes but got ", received.size(), ".");
280   }
281   for (size_t i = 0; i < expected.size(); ++i) {
282     TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i));
283   }
284 
285   return OkStatus();
286 }
287 
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<Tensor> & received)288 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
289                               const std::vector<Tensor>& received) {
290   if (expected.size() != received.size()) {
291     return errors::InvalidArgument(
292         "Number of components does not match: expected ", expected.size(),
293         " shapes but got ", received.size(), ".");
294   }
295   for (size_t i = 0; i < expected.size(); ++i) {
296     TF_RETURN_IF_ERROR(
297         VerifyShapeCompatible(expected[i], received[i].shape(), i));
298   }
299 
300   return OkStatus();
301 }
302 
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionLibraryDefinition & to_add)303 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
304                             const FunctionLibraryDefinition& to_add) {
305   for (const auto& fn : to_add.ListFunctionNames()) {
306     if (auto found = base->Find(fn)) {
307       if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
308         return errors::InvalidArgument("Cannot add function '", fn,
309                                        "' because a different function with "
310                                        "the same signature already exists.");
311       }
312       TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
313     }
314   }
315   return base->AddLibrary(to_add);
316 }
317 
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionDefLibrary & to_add)318 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
319                             const FunctionDefLibrary& to_add) {
320   for (const auto& fd : to_add.function()) {
321     if (auto found = base->Find(fd.signature().name())) {
322       if (!OpDefEqual(found->signature(), fd.signature())) {
323         return errors::InvalidArgument("Cannot add function '",
324                                        fd.signature().name(),
325                                        "' because a different function with "
326                                        "the same signature already exists.");
327       }
328       TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
329     }
330   }
331   return base->AddLibrary(to_add);
332 }
333 
IsFunctionStateful(const FunctionLibraryDefinition & library,const FunctionDef & function_def)334 Status IsFunctionStateful(const FunctionLibraryDefinition& library,
335                           const FunctionDef& function_def) {
336   if (!function_def.signature().is_stateful()) {
337     return OkStatus();
338   }
339 
340   for (const NodeDef& node_def : function_def.node_def()) {
341     TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def));
342   }
343   return OkStatus();
344 }
345 
IsNodeStateful(const FunctionLibraryDefinition & library,const NodeDef & node)346 Status IsNodeStateful(const FunctionLibraryDefinition& library,
347                       const NodeDef& node) {
348   const OpDef* op_def;
349 
350   // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
351   // `LookUpOpDef` errors here.
352   if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
353       IsOpAllowlisted(op_def) || !op_def->is_stateful() ||
354       op_def->name() == "Assert") {
355     return OkStatus();
356   }
357 
358   if (op_def->name() == "If") {
359     const FunctionDef* then_func =
360         library.Find(node.attr().at("then_branch").func().name());
361     const FunctionDef* else_func =
362         library.Find(node.attr().at("else_branch").func().name());
363     if (then_func != nullptr) {
364       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func));
365     }
366     if (else_func != nullptr) {
367       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func));
368     }
369     return OkStatus();
370   }
371 
372   if (op_def->name() == "While") {
373     const FunctionDef* cond_func =
374         library.Find(node.attr().at("cond").func().name());
375     const FunctionDef* body_func =
376         library.Find(node.attr().at("body").func().name());
377     if (cond_func != nullptr) {
378       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func));
379     }
380     if (body_func != nullptr) {
381       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func));
382     }
383     return OkStatus();
384   }
385 
386   return errors::FailedPrecondition(op_def->name(), " is stateful.");
387 }
388 
RunnerWithMaxParallelism(std::function<void (std::function<void ()>)> runner,int max_parallelism)389 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
390     std::function<void(std::function<void()>)> runner, int max_parallelism) {
391   return std::bind(
392       [max_parallelism](
393           // Note: `runner` is a const reference to avoid copying it.
394           const std::function<void(std::function<void()>)>& runner,
395           std::function<void()> fn) {
396         std::function<void()> scoped_fn = std::bind(
397             [max_parallelism](const std::function<void()>& fn) {
398               ScopedPerThreadMaxParallelism scope(max_parallelism);
399               fn();
400             },
401             std::move(fn));
402         runner(std::move(scoped_fn));
403       },
404       std::move(runner), std::placeholders::_1);
405 }
406 
FromString(const std::string & s,DeterminismPolicy * out)407 Status DeterminismPolicy::FromString(const std::string& s,
408                                      DeterminismPolicy* out) {
409   DeterminismPolicy::Type type;
410   if (s == DeterminismPolicy::kDeterministic) {
411     type = DeterminismPolicy::Type::kDeterministic;
412   } else if (s == DeterminismPolicy::kNondeterministic) {
413     type = DeterminismPolicy::Type::kNondeterministic;
414   } else if (s == DeterminismPolicy::kDefault) {
415     type = DeterminismPolicy::Type::kDefault;
416   } else {
417     return errors::InvalidArgument("Unrecognized determinism policy: ", s);
418   }
419   *out = DeterminismPolicy(type);
420   return OkStatus();
421 }
422 
DeterminismPolicy(bool is_deterministic)423 DeterminismPolicy::DeterminismPolicy(bool is_deterministic) {
424   if (is_deterministic) {
425     determinism_ = DeterminismPolicy::Type::kDeterministic;
426   } else {
427     determinism_ = DeterminismPolicy::Type::kNondeterministic;
428   }
429 }
430 
String() const431 std::string DeterminismPolicy::String() const {
432   switch (determinism_) {
433     case DeterminismPolicy::Type::kDeterministic:
434       return DeterminismPolicy::kDeterministic;
435     case DeterminismPolicy::Type::kNondeterministic:
436       return DeterminismPolicy::kNondeterministic;
437     case DeterminismPolicy::Type::kDefault:
438       return DeterminismPolicy::kDefault;
439     default:
440       LOG(ERROR) << "Unrecognized determinism value";
441       return "Unrecognized";
442   }
443 }
444 
MatchesAnyVersion(StringPiece op_prefix,StringPiece op_to_match)445 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match) {
446   if (!absl::StartsWith(op_to_match, op_prefix)) {
447     return false;
448   }
449   if (op_to_match.length() == op_prefix.length()) {
450     return true;
451   }
452   size_t index = op_to_match.length() - 1;
453   while (isdigit(op_to_match[index])) {
454     index--;
455   }
456   return (op_to_match[index] == 'V') && (op_prefix.length() == index);
457 }
458 
GetExperiments()459 absl::flat_hash_set<string> GetExperiments() {
460   return GetExperiments(port::JobName(),
461                         [](const tstring& str) { return Hash64(str); });
462 }
463 
GetExperiments(const string & job_name,std::function<uint64 (const string &)> hash_func)464 absl::flat_hash_set<string> GetExperiments(
465     const string& job_name, std::function<uint64(const string&)> hash_func) {
466   absl::flat_hash_set<string> experiments;
467   if (job_name.empty()) {
468     return experiments;
469   }
470 
471   // Parse the opt-in and opt-out settings.
472   const char* opt_ins_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_IN");
473   const char* opt_outs_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT");
474   string opt_ins_raw;
475   if (opt_ins_raw_cs != nullptr) {
476     opt_ins_raw = string(opt_ins_raw_cs);
477   }
478   string opt_outs_raw;
479   if (opt_outs_raw_cs != nullptr) {
480     opt_outs_raw = string(opt_outs_raw_cs);
481   }
482 
483   // Identify opted out experiments.
484   absl::flat_hash_map<string, int64_t> live_experiments =
485       DatasetExperimentRegistry::Experiments();
486   absl::flat_hash_set<string> opt_outs;
487   if (opt_outs_raw == "all") {
488     for (const auto& pair : live_experiments) {
489       opt_outs.insert(pair.first);
490     }
491   } else {
492     for (const auto& experiment :
493          str_util::Split(opt_outs_raw, ',', str_util::SkipEmpty())) {
494       opt_outs.insert(experiment);
495     }
496   }
497 
498   // Include opted in experiments unless they are opted out.
499   if (opt_ins_raw == "all") {
500     for (const auto& pair : live_experiments) {
501       auto experiment = pair.first;
502       if (!opt_outs.contains(experiment)) {
503         experiments.insert(experiment);
504       }
505     }
506   } else {
507     for (const auto& experiment :
508          str_util::Split(opt_ins_raw, ',', str_util::SkipEmpty())) {
509       if (!opt_outs.contains(experiment)) {
510         experiments.insert(experiment);
511       }
512     }
513   }
514 
515   if (opt_outs_raw == "all_except_opt_in") {
516     return experiments;
517   }
518   // Stochastically include live experiments unless they are opted out.
519   for (const auto& pair : live_experiments) {
520     auto& experiment = pair.first;
521     if ((hash_func(strings::StrCat(job_name, experiment)) % 100 <
522          pair.second) &&
523         !opt_outs.contains(experiment)) {
524       experiments.insert(experiment);
525     }
526   }
527 
528   return experiments;
529 }
530 
LogAndRecordExperiments(const absl::flat_hash_set<string> & experiments)531 void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments) {
532   if (!experiments.empty()) {
533     constexpr float TEN_MINUTES = 60.0 * 10.0;
534     LOG_EVERY_N_SEC(INFO, TEN_MINUTES)
535         << "The input pipeline is subject to the following tf.data experiments:"
536         << " " << absl::StrJoin(experiments, ", ") << ". "
537         << "See `go/tf-data-experiments` for more details.";
538   }
539   for (auto& experiment : experiments) {
540     metrics::RecordTFDataExperiment(experiment);
541   }
542 }
543 
GetOptimizations(const Options & options,absl::flat_hash_set<tstring> * optimizations_enabled,absl::flat_hash_set<tstring> * optimizations_disabled,absl::flat_hash_set<tstring> * optimizations_default)544 void GetOptimizations(const Options& options,
545                       absl::flat_hash_set<tstring>* optimizations_enabled,
546                       absl::flat_hash_set<tstring>* optimizations_disabled,
547                       absl::flat_hash_set<tstring>* optimizations_default) {
548   DefaultOptimizationGraphRewrites(options, optimizations_enabled,
549                                    optimizations_disabled,
550                                    optimizations_default);
551   if (!OpDeterminismRequired() &&
552       options.optional_deterministic_case() == Options::kDeterministic &&
553       !options.deterministic()) {
554     optimizations_enabled->insert(kMakeSloppyOpt);
555   }
556   if (options.optional_slack_case() == Options::kSlack) {
557     if (options.slack()) {
558       optimizations_enabled->insert(kSlackOpt);
559     } else {
560       optimizations_disabled->insert(kSlackOpt);
561     }
562   }
563 }
564 
MaybeCopySubSlice(const Tensor & tensor,int64 index)565 Tensor MaybeCopySubSlice(const Tensor& tensor, int64 index) {
566   Tensor slice = tensor.SubSlice(index);
567   if (slice.IsAligned()) {
568     return slice;
569   } else {
570     return tensorflow::tensor::DeepCopy(slice);
571   }
572 }
573 
StripDevicePlacement(FunctionDefLibrary * library)574 void StripDevicePlacement(FunctionDefLibrary* library) {
575   for (auto& function : (*library->mutable_function())) {
576     for (auto& node : (*function.mutable_node_def())) {
577       if (!node.device().empty()) {
578         *node.mutable_device() = "";
579       }
580     }
581   }
582 }
583 
CopyPartialBatch(int64_t num_elements,const Tensor & value,Tensor * output)584 Status CopyPartialBatch(int64_t num_elements, const Tensor& value,
585                         Tensor* output) {
586   switch (value.dtype()) {
587 #define HANDLE_TYPE(type)                                         \
588   case DataTypeToEnum<type>::value: {                             \
589     auto output_t = output->flat_outer_dims<type>();              \
590     auto value_t = value.flat_outer_dims<type>();                 \
591     for (size_t i = 0; i < num_elements; i++) {                   \
592       output_t.template chip<0>(i) = value_t.template chip<0>(i); \
593     }                                                             \
594     return OkStatus();                                            \
595   }
596     TF_CALL_DATASET_TYPES(HANDLE_TYPE);
597 #undef HANDLE_TYPE
598     default:
599       return errors::InvalidArgument("Unsupported data type: ",
600                                      DataTypeString(value.dtype()));
601   }
602   return OkStatus();
603 }
604 
ReadBatch(IteratorContext * ctx,IteratorStateReader * reader,int64_t batch_size,const string & iterator_prefix,const string & batch_prefix,std::vector<Tensor> * batch)605 Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader,
606                  int64_t batch_size, const string& iterator_prefix,
607                  const string& batch_prefix, std::vector<Tensor>* batch) {
608   int64_t output_size;
609   TF_RETURN_IF_ERROR(reader->ReadScalar(
610       FullName(iterator_prefix,
611                strings::StrCat(batch_prefix, "_", kOutputSize)),
612       &output_size));
613   batch->reserve(output_size);
614   for (int i = 0; i < output_size; i++) {
615     Tensor t;
616     TF_RETURN_IF_ERROR(
617         reader->ReadTensor(ctx->flr(), FullName(iterator_prefix, batch_prefix),
618                            strings::StrCat(kOutput, "_", i), &t));
619     // If the batch was not full, we may have stored only the relevant slice.
620     // Since tensors in `BatchResult.output` are expected to have the leading
621     // dimension of size batch_size, we build a larger tensor and copy the slice
622     // read from the checkpoint into it.
623     if (t.dim_size(0) < batch_size) {
624       TensorShape component_shape(t.shape());
625       component_shape.set_dim(0, batch_size);
626       AllocatorAttributes attr;
627       attr.set_gpu_compatible(true);
628       Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape);
629       TF_RETURN_IF_ERROR(CopyPartialBatch(t.dim_size(0), t, &new_t));
630       batch->emplace_back(std::move(new_t));
631     } else {
632       batch->emplace_back(std::move(t));
633     }
634   }
635   return OkStatus();
636 }
637 
WriteBatch(int64_t batch_size,int64_t num_elements,const string & iterator_prefix,const string & batch_prefix,IteratorStateWriter * writer,std::vector<Tensor> * batch)638 Status WriteBatch(int64_t batch_size, int64_t num_elements,
639                   const string& iterator_prefix, const string& batch_prefix,
640                   IteratorStateWriter* writer, std::vector<Tensor>* batch) {
641   TF_RETURN_IF_ERROR(writer->WriteScalar(
642       FullName(iterator_prefix,
643                strings::StrCat(batch_prefix, "_", kOutputSize)),
644       batch->size()));
645   for (int i = 0; i < batch->size(); i++) {
646     // If the batch is not full, we only store the first `num_elements` values.
647     // The rest of the batch tensor is *uninitialized* and accessing that will
648     // raise msan errors.
649     if (num_elements < batch_size) {
650       TF_RETURN_IF_ERROR(
651           writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
652                               strings::StrCat(kOutput, "_", i),
653                               (*batch)[i].Slice(0, num_elements)));
654     } else {
655       TF_RETURN_IF_ERROR(
656           writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
657                               strings::StrCat(kOutput, "_", i), (*batch)[i]));
658     }
659   }
660   return OkStatus();
661 }
662 
ReadStatus(const string & iterator_prefix,const string & prefix,IteratorStateReader * reader,Status * status)663 Status ReadStatus(const string& iterator_prefix, const string& prefix,
664                   IteratorStateReader* reader, Status* status) {
665   int64_t code_int;
666   TF_RETURN_IF_ERROR(reader->ReadScalar(
667       FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
668       &code_int));
669   error::Code code = static_cast<error::Code>(code_int);
670 
671   if (code != error::Code::OK) {
672     tstring error_message;
673     TF_RETURN_IF_ERROR(reader->ReadScalar(
674         FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
675         &error_message));
676     *status = Status(code, error_message);
677   } else {
678     *status = OkStatus();
679   }
680   return OkStatus();
681 }
682 
WriteStatus(const string & iterator_prefix,const string & prefix,const Status & status,IteratorStateWriter * writer)683 Status WriteStatus(const string& iterator_prefix, const string& prefix,
684                    const Status& status, IteratorStateWriter* writer) {
685   TF_RETURN_IF_ERROR(writer->WriteScalar(
686       FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
687       static_cast<int64_t>(status.code())));
688   if (!status.ok()) {
689     TF_RETURN_IF_ERROR(writer->WriteScalar(
690         FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
691         status.error_message()));
692   }
693   return OkStatus();
694 }
695 
ProcessBatch(int64_t batch_size,int64_t num_elements,bool drop_remainder,const Status & status,IteratorContext * ctx,std::vector<Tensor> * output,bool * end_of_sequence,std::vector<Tensor> * batch)696 Status ProcessBatch(int64_t batch_size, int64_t num_elements,
697                     bool drop_remainder, const Status& status,
698                     IteratorContext* ctx, std::vector<Tensor>* output,
699                     bool* end_of_sequence, std::vector<Tensor>* batch) {
700   if (num_elements == 0) {
701     if (status.ok() || errors::IsOutOfRange(status)) {
702       *end_of_sequence = true;
703       return OkStatus();
704     } else {
705       *end_of_sequence = false;
706       return status;
707     }
708   }
709   if (!status.ok() && !errors::IsOutOfRange(status)) {
710     *end_of_sequence = false;
711     return status;
712   }
713   if (num_elements < batch_size) {
714     if (drop_remainder) {
715       *end_of_sequence = true;
716       return OkStatus();
717     }
718     for (size_t i = 0; i < batch->size(); ++i) {
719       TensorShape component_shape((*batch)[i].shape());
720       component_shape.set_dim(0, num_elements);
721       AllocatorAttributes attr;
722       attr.set_gpu_compatible(true);
723       output->emplace_back(ctx->allocator(attr), (*batch)[i].dtype(),
724                            component_shape);
725       if (!output->back().IsInitialized()) {
726         return errors::ResourceExhausted(
727             "Failed to allocate memory for the batch of component ", i);
728       }
729       TF_RETURN_IF_ERROR(
730           CopyPartialBatch(num_elements, (*batch)[i], &output->back()));
731     }
732   } else {
733     *output = std::move(*batch);
734   }
735   *end_of_sequence = false;
736   return OkStatus();
737 }
738 
CopyBatch(CopyBatchParams params,const std::vector<std::vector<Tensor>> & batch_elements,bool parallel_copy,std::function<Status ()> allocation_callback,std::vector<Tensor> * out_tensors)739 Status CopyBatch(CopyBatchParams params,
740                  const std::vector<std::vector<Tensor>>& batch_elements,
741                  bool parallel_copy,
742                  std::function<Status()> allocation_callback,
743                  std::vector<Tensor>* out_tensors) {
744   const size_t num_tuple_components = batch_elements.at(0).size();
745   out_tensors->reserve(num_tuple_components);
746   const int64_t num_batch_elements = batch_elements.size();
747   for (size_t component_index = 0; component_index < num_tuple_components;
748        ++component_index) {
749     const Tensor& first_element = batch_elements.at(0)[component_index];
750     TensorShape first_element_shape(first_element.shape());
751     TensorShape batch_component_shape({num_batch_elements});
752     batch_component_shape.AppendShape(first_element_shape);
753     out_tensors->emplace_back(params.allocator, first_element.dtype(),
754                               batch_component_shape);
755     if (!out_tensors->back().IsInitialized()) {
756       return errors::ResourceExhausted(
757           "Failed to allocate memory for the batch of component ",
758           component_index);
759     }
760   }
761   if (allocation_callback) {
762     TF_RETURN_IF_ERROR(allocation_callback());
763   }
764   for (size_t component_index = 0; component_index < num_tuple_components;
765        ++component_index) {
766     Tensor& batch_component = out_tensors->at(component_index);
767     const Tensor& first_element = batch_elements.at(0)[component_index];
768     TensorShape first_element_shape(first_element.shape());
769     // Build the output tuple component by copying one slice from each input
770     // element in the batch.
771     auto copy_element_fn = [component_index, &batch_elements, &batch_component,
772                             &first_element_shape](int index) {
773       if (batch_elements.at(index)[component_index].shape() !=
774           first_element_shape) {
775         return errors::InvalidArgument(
776             "Cannot batch tensors with different shapes in component ",
777             component_index, ". First element had shape ",
778             first_element_shape.DebugString(), " and element ", index,
779             " had shape ",
780             batch_elements.at(index)[component_index].shape().DebugString(),
781             ".");
782       }
783       return batch_util::CopyElementToSlice(
784           std::move(batch_elements.at(index)[component_index]),
785           &batch_component, index);
786     };
787     if (parallel_copy && first_element.AllocatedBytes() > (1 << 15)) {
788       Status status;
789       mutex status_mu;
790       BlockingCounter counter(num_batch_elements);
791       const auto num_threads = params.runner_threadpool_size;
792       const auto slice_size = num_batch_elements / num_threads;
793       int64_t offset = 0;
794       for (size_t i = 0; i < num_threads; ++i) {
795         int64_t length = slice_size;
796         // When the number of threads does not divide the number of elements
797         // evenly, the size of some slices is incremented to guarantee their
798         // sizes add up to the total number of elements.
799         if (i < num_batch_elements % num_threads) ++length;
800         (*params.runner)([offset, length, &status, &status_mu, &counter,
801                           &copy_element_fn]() {
802           for (size_t j = offset; j < offset + length; ++j) {
803             {
804               Status s = copy_element_fn(j);
805               mutex_lock l(status_mu);
806               status.Update(s);
807             }
808             counter.DecrementCount();
809           }
810         });
811         offset += length;
812       }
813       counter.Wait();
814       TF_RETURN_IF_ERROR(status);
815     } else {
816       for (size_t i = 0; i < num_batch_elements; ++i) {
817         TF_RETURN_IF_ERROR(copy_element_fn(i));
818       }
819     }
820   }
821   return OkStatus();
822 }
823 
CreateGraphRewriteConfigs(const Options & options)824 absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options) {
825   absl::flat_hash_set<tstring> configs;
826   const auto& autotune_options = options.autotune_options();
827   std::vector<tstring> autotune_only_optimizations = {
828       kAutotuneBufferSizesOpt,
829       kBatchParallelizationOpt,
830       kDisablePrefetchLegacyAutotuneOpt,
831       kEnableGradientDescentOpt,
832       kFilterParallelizationOpt,
833       kMapParallelizationOpt,
834       kInjectPrefetchOpt};
835 
836   if (autotune_options.optional_enabled_case() == AutotuneOptions::kEnabled &&
837       !autotune_options.enabled()) {
838     for (const auto& optimization : autotune_only_optimizations) {
839       configs.insert(
840           absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":false"));
841     }
842   } else {
843     for (const auto& optimization : autotune_only_optimizations) {
844       configs.insert(
845           absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":true"));
846     }
847   }
848   if (options.slack()) {
849     int num_devices = 1;
850     if (options.distribute_options().optional_num_devices_case() ==
851         DistributeOptions::kNumDevices) {
852       num_devices = options.distribute_options().num_devices();
853     }
854     configs.insert(
855         absl::StrCat(kSlackOpt, ":", kSlackPeriodOpt, ":", num_devices));
856   }
857   return configs;
858 }
859 
ShouldConfigureMaxIntraOpParallelism(const Options & options)860 bool ShouldConfigureMaxIntraOpParallelism(const Options& options) {
861   return options.threading_options().optional_max_intra_op_parallelism_case() ==
862          ThreadingOptions::kMaxIntraOpParallelism;
863 }
864 
ShouldUsePrivateThreadPool(const Options & options)865 bool ShouldUsePrivateThreadPool(const Options& options) {
866   return options.threading_options().optional_private_threadpool_size_case() ==
867          ThreadingOptions::kPrivateThreadpoolSize;
868 }
869 
ShouldUseAutotuning(const Options & options)870 bool ShouldUseAutotuning(const Options& options) {
871   return options.autotune_options().optional_enabled_case() !=
872              AutotuneOptions::kEnabled ||
873          options.autotune_options().enabled();
874 }
875 
ShouldApplyOptimizations(const Options & options,const absl::flat_hash_set<tstring> & optimizations_enabled,const absl::flat_hash_set<tstring> & optimizations_default)876 bool ShouldApplyOptimizations(
877     const Options& options,
878     const absl::flat_hash_set<tstring>& optimizations_enabled,
879     const absl::flat_hash_set<tstring>& optimizations_default) {
880   return (options.optimization_options()
881                   .optional_apply_default_optimizations_case() !=
882               OptimizationOptions::kApplyDefaultOptimizations ||
883           options.optimization_options().apply_default_optimizations() ||
884           !optimizations_enabled.empty() || !optimizations_default.empty());
885 }
886 
GetAutotuneDefaultParallelism(IteratorContext * ctx)887 int64 GetAutotuneDefaultParallelism(IteratorContext* ctx) {
888   return std::min(kAutotuneDefaultParallelism, ctx->runner_threadpool_size());
889 }
890 
891 // static
Register(const string & experiment,int64_t rollout_pct)892 void DatasetExperimentRegistry::Register(const string& experiment,
893                                          int64_t rollout_pct) {
894   mutex_lock l(*get_dataset_experiment_registry_lock());
895   get_dataset_experiments()->insert(std::make_pair(experiment, rollout_pct));
896 }
897 
898 // static
Experiments()899 absl::flat_hash_map<string, int64_t> DatasetExperimentRegistry::Experiments() {
900   mutex_lock l(*get_dataset_experiment_registry_lock());
901   return *get_dataset_experiments();
902 }
903 
904 namespace {
905 
906 REGISTER_DATASET_EXPERIMENT("allow_small_function_optimizations", 0);
907 REGISTER_DATASET_EXPERIMENT("autotune_buffer_optimization", 0);
908 REGISTER_DATASET_EXPERIMENT(kFilterParallelizationOpt, 0);
909 REGISTER_DATASET_EXPERIMENT("inject_prefetch", 100);
910 REGISTER_DATASET_EXPERIMENT("min_outer_interleave_parallelism", 0);
911 REGISTER_DATASET_EXPERIMENT("reduce_interleave_prefetch", 0);
912 REGISTER_DATASET_EXPERIMENT("serialize_input_cycle_length", 0);
913 REGISTER_DATASET_EXPERIMENT("stage_based_autotune", 0);
914 }  // namespace
915 }  // namespace data
916 }  // namespace tensorflow
917