1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/dataset_utils.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <utility>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/core/common_runtime/function.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/dataset.h"
31 #include "tensorflow/core/framework/function.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op_def_builder.h"
34 #include "tensorflow/core/framework/op_def_util.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/framework/tensor_util.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/graph/graph_def_builder.h"
40 #include "tensorflow/core/lib/core/blocking_counter.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/hash/hash.h"
43 #include "tensorflow/core/lib/strings/proto_serialization.h"
44 #include "tensorflow/core/platform/host_info.h"
45 #include "tensorflow/core/util/determinism.h"
46 #include "tensorflow/core/util/work_sharder.h"
47
48 namespace tensorflow {
49 namespace data {
50 namespace {
51
52 constexpr char kOutputSize[] = "output_size";
53 constexpr char kCode[] = "code";
54 constexpr char kMessage[] = "msg";
55 constexpr char kOutput[] = "output";
56
get_dataset_experiment_registry_lock()57 static mutex* get_dataset_experiment_registry_lock() {
58 static mutex dataset_experiment_registry_lock(LINKER_INITIALIZED);
59 return &dataset_experiment_registry_lock;
60 }
61
get_dataset_experiments()62 static absl::flat_hash_map<string, int64_t>* get_dataset_experiments() {
63 static absl::flat_hash_map<string, int64_t>* experiments =
64 new absl::flat_hash_map<string, int64_t>;
65 return experiments;
66 }
67
68 // Use "Opt" suffix so that they are not confused with the enums in Options
69 // proto.
70 constexpr char kMapAndBatchFusionOpt[] = "map_and_batch_fusion";
71 constexpr char kNoopEliminationOpt[] = "noop_elimination";
72 constexpr char kMapParallelizationOpt[] = "map_parallelization";
73 constexpr char kShuffleAndRepeatFusionOpt[] = "shuffle_and_repeat_fusion";
74 constexpr char kFilterFusionOpt[] = "filter_fusion";
75 constexpr char kMapAndFilterFusionOpt[] = "map_and_filter_fusion";
76 constexpr char kMapFusionOpt[] = "map_fusion";
77 constexpr char kParallelBatchOpt[] = "parallel_batch";
78 constexpr char kAutotuneBufferSizesOpt[] = "autotune_buffer_sizes";
79 constexpr char kDisablePrefetchLegacyAutotuneOpt[] =
80 "disable_prefetch_legacy_autotune";
81 constexpr char kMakeSloppyOpt[] = "make_sloppy";
82 constexpr char kUseChooseFastestOpt[] = "use_choose_fastest";
83 constexpr char kBatchParallelizationOpt[] = "batch_parallelization";
84 constexpr char kEnableGradientDescentOpt[] = "enable_gradient_descent";
85 constexpr char kInjectPrefetchOpt[] = "inject_prefetch";
86 constexpr char kAutotuneOpt[] = "autotune";
87 constexpr char kSlackOpt[] = "slack";
88 constexpr char kSlackPeriodOpt[] = "slack_period";
89 constexpr char kMakeDeterministicOpt[] = "make_deterministic";
90 constexpr char kFilterParallelizationOpt[] = "filter_parallelization";
91
DefaultOptimizationGraphRewrites(const Options & options,absl::flat_hash_set<tstring> * optimization_enabled,absl::flat_hash_set<tstring> * optimization_disabled,absl::flat_hash_set<tstring> * optimization_default)92 void DefaultOptimizationGraphRewrites(
93 const Options& options, absl::flat_hash_set<tstring>* optimization_enabled,
94 absl::flat_hash_set<tstring>* optimization_disabled,
95 absl::flat_hash_set<tstring>* optimization_default) {
96 const auto& optimization_options = options.optimization_options();
97 if (optimization_options.optional_apply_default_optimizations_case() !=
98 OptimizationOptions::kApplyDefaultOptimizations ||
99 optimization_options.apply_default_optimizations()) {
100 if (optimization_options.optional_map_and_batch_fusion_case() !=
101 OptimizationOptions::kMapAndBatchFusion) {
102 optimization_default->insert(kMapAndBatchFusionOpt);
103 }
104 if (optimization_options.optional_noop_elimination_case() !=
105 OptimizationOptions::kNoopElimination) {
106 optimization_default->insert(kNoopEliminationOpt);
107 }
108 if (optimization_options.optional_map_parallelization_case() !=
109 OptimizationOptions::kMapParallelization) {
110 optimization_default->insert(kMapParallelizationOpt);
111 }
112 if (optimization_options.optional_shuffle_and_repeat_fusion_case() !=
113 OptimizationOptions::kShuffleAndRepeatFusion) {
114 optimization_default->insert(kShuffleAndRepeatFusionOpt);
115 }
116 if (optimization_options.optional_parallel_batch_case() !=
117 OptimizationOptions::kParallelBatch) {
118 optimization_default->insert(kParallelBatchOpt);
119 }
120 }
121 if (OpDeterminismRequired()) {
122 optimization_enabled->insert(kMakeDeterministicOpt);
123 }
124 if (optimization_options.optional_filter_fusion_case() ==
125 OptimizationOptions::kFilterFusion) {
126 if (optimization_options.filter_fusion()) {
127 optimization_enabled->insert(kFilterFusionOpt);
128 } else {
129 optimization_disabled->insert(kFilterFusionOpt);
130 }
131 }
132 if (optimization_options.optional_map_and_batch_fusion_case() ==
133 OptimizationOptions::kMapAndBatchFusion) {
134 if (optimization_options.map_and_batch_fusion()) {
135 optimization_enabled->insert(kMapAndBatchFusionOpt);
136 } else {
137 optimization_disabled->insert(kMapAndBatchFusionOpt);
138 }
139 }
140 if (optimization_options.optional_map_and_filter_fusion_case() ==
141 OptimizationOptions::kMapAndFilterFusion) {
142 if (optimization_options.map_and_filter_fusion()) {
143 optimization_enabled->insert(kMapAndFilterFusionOpt);
144 } else {
145 optimization_disabled->insert(kMapAndFilterFusionOpt);
146 }
147 }
148 if (optimization_options.optional_map_parallelization_case() ==
149 OptimizationOptions::kMapParallelization) {
150 if (optimization_options.map_parallelization()) {
151 optimization_enabled->insert(kMapParallelizationOpt);
152 } else {
153 optimization_disabled->insert(kMapParallelizationOpt);
154 }
155 }
156 if (optimization_options.optional_filter_parallelization_case() ==
157 OptimizationOptions::kFilterParallelization) {
158 if (optimization_options.filter_parallelization()) {
159 optimization_enabled->insert(kFilterParallelizationOpt);
160 } else {
161 optimization_disabled->insert(kFilterParallelizationOpt);
162 }
163 }
164 if (optimization_options.optional_map_fusion_case() ==
165 OptimizationOptions::kMapFusion) {
166 if (optimization_options.map_fusion()) {
167 optimization_enabled->insert(kMapFusionOpt);
168 } else {
169 optimization_disabled->insert(kMapFusionOpt);
170 }
171 }
172 if (optimization_options.optional_noop_elimination_case() ==
173 OptimizationOptions::kNoopElimination) {
174 if (optimization_options.noop_elimination()) {
175 optimization_enabled->insert(kNoopEliminationOpt);
176 } else {
177 optimization_disabled->insert(kNoopEliminationOpt);
178 }
179 }
180 if (optimization_options.optional_parallel_batch_case() ==
181 OptimizationOptions::kParallelBatch) {
182 if (optimization_options.parallel_batch()) {
183 optimization_enabled->insert(kParallelBatchOpt);
184 } else {
185 optimization_disabled->insert(kParallelBatchOpt);
186 }
187 }
188 if (optimization_options.optional_shuffle_and_repeat_fusion_case() ==
189 OptimizationOptions::kShuffleAndRepeatFusion) {
190 if (optimization_options.shuffle_and_repeat_fusion()) {
191 optimization_enabled->insert(kShuffleAndRepeatFusionOpt);
192 } else {
193 optimization_disabled->insert(kShuffleAndRepeatFusionOpt);
194 }
195 }
196 if (optimization_options.optional_inject_prefetch_case() ==
197 OptimizationOptions::kInjectPrefetch) {
198 if (optimization_options.inject_prefetch()) {
199 optimization_enabled->insert(kInjectPrefetchOpt);
200 } else {
201 optimization_disabled->insert(kInjectPrefetchOpt);
202 }
203 }
204 }
205
206 // Returns whether an op has been allowlisted as stateless. Uses a heuristic to
207 // allowlist source dataset ops which have been marked stateful due to
208 // b/65524810. Also looks up the `op_def->name` in the global
209 // `AllowlistedStatefulOpRegistry`.
IsOpAllowlisted(const OpDef * op_def)210 bool IsOpAllowlisted(const OpDef* op_def) {
211 return (op_def->output_arg_size() == 1 &&
212 op_def->output_arg(0).type() == DT_VARIANT &&
213 (absl::EndsWith(op_def->name(), "Dataset") ||
214 absl::EndsWith(op_def->name(), "DatasetV2"))) ||
215 AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name());
216 }
217
218 } // namespace
219
MaybeOverrideSeeds(std::pair<int64_t,int64_t> seeds)220 std::pair<int64_t, int64_t> MaybeOverrideSeeds(
221 std::pair<int64_t, int64_t> seeds) {
222 if (seeds.first == 0 && seeds.second == 0) {
223 return {random::New64(), random::New64()};
224 }
225 return seeds;
226 }
227
VerifyTypeMatch(const DataType & expected,const DataType & received,int index)228 Status VerifyTypeMatch(const DataType& expected, const DataType& received,
229 int index) {
230 if (expected != received) {
231 return errors::InvalidArgument("Data type mismatch at component ", index,
232 ": expected ", DataTypeString(expected),
233 " but got ", DataTypeString(received), ".");
234 }
235 return OkStatus();
236 }
237
VerifyTypesMatch(const DataTypeVector & expected,const DataTypeVector & received)238 Status VerifyTypesMatch(const DataTypeVector& expected,
239 const DataTypeVector& received) {
240 if (expected.size() != received.size()) {
241 return errors::InvalidArgument(
242 "Number of components does not match: expected ", expected.size(),
243 " types but got ", received.size(), ".");
244 }
245 for (size_t i = 0; i < expected.size(); ++i) {
246 TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i));
247 }
248 return OkStatus();
249 }
250
VerifyTypesMatch(const DataTypeVector & expected,const std::vector<Tensor> & received)251 Status VerifyTypesMatch(const DataTypeVector& expected,
252 const std::vector<Tensor>& received) {
253 if (expected.size() != received.size()) {
254 return errors::InvalidArgument(
255 "Number of components does not match: expected ", expected.size(),
256 " types but got ", received.size(), ".");
257 }
258 for (size_t i = 0; i < expected.size(); ++i) {
259 TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i));
260 }
261 return OkStatus();
262 }
263
VerifyShapeCompatible(const PartialTensorShape & expected,const PartialTensorShape & received,int index)264 Status VerifyShapeCompatible(const PartialTensorShape& expected,
265 const PartialTensorShape& received, int index) {
266 if (!expected.IsCompatibleWith(received)) {
267 return errors::InvalidArgument("Incompatible shapes at component ", index,
268 ": expected ", expected.DebugString(),
269 " but got ", received.DebugString(), ".");
270 }
271 return OkStatus();
272 }
273
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<PartialTensorShape> & received)274 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
275 const std::vector<PartialTensorShape>& received) {
276 if (expected.size() != received.size()) {
277 return errors::InvalidArgument(
278 "Number of components does not match: expected ", expected.size(),
279 " shapes but got ", received.size(), ".");
280 }
281 for (size_t i = 0; i < expected.size(); ++i) {
282 TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i));
283 }
284
285 return OkStatus();
286 }
287
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<Tensor> & received)288 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
289 const std::vector<Tensor>& received) {
290 if (expected.size() != received.size()) {
291 return errors::InvalidArgument(
292 "Number of components does not match: expected ", expected.size(),
293 " shapes but got ", received.size(), ".");
294 }
295 for (size_t i = 0; i < expected.size(); ++i) {
296 TF_RETURN_IF_ERROR(
297 VerifyShapeCompatible(expected[i], received[i].shape(), i));
298 }
299
300 return OkStatus();
301 }
302
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionLibraryDefinition & to_add)303 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
304 const FunctionLibraryDefinition& to_add) {
305 for (const auto& fn : to_add.ListFunctionNames()) {
306 if (auto found = base->Find(fn)) {
307 if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
308 return errors::InvalidArgument("Cannot add function '", fn,
309 "' because a different function with "
310 "the same signature already exists.");
311 }
312 TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
313 }
314 }
315 return base->AddLibrary(to_add);
316 }
317
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionDefLibrary & to_add)318 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
319 const FunctionDefLibrary& to_add) {
320 for (const auto& fd : to_add.function()) {
321 if (auto found = base->Find(fd.signature().name())) {
322 if (!OpDefEqual(found->signature(), fd.signature())) {
323 return errors::InvalidArgument("Cannot add function '",
324 fd.signature().name(),
325 "' because a different function with "
326 "the same signature already exists.");
327 }
328 TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
329 }
330 }
331 return base->AddLibrary(to_add);
332 }
333
IsFunctionStateful(const FunctionLibraryDefinition & library,const FunctionDef & function_def)334 Status IsFunctionStateful(const FunctionLibraryDefinition& library,
335 const FunctionDef& function_def) {
336 if (!function_def.signature().is_stateful()) {
337 return OkStatus();
338 }
339
340 for (const NodeDef& node_def : function_def.node_def()) {
341 TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def));
342 }
343 return OkStatus();
344 }
345
IsNodeStateful(const FunctionLibraryDefinition & library,const NodeDef & node)346 Status IsNodeStateful(const FunctionLibraryDefinition& library,
347 const NodeDef& node) {
348 const OpDef* op_def;
349
350 // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
351 // `LookUpOpDef` errors here.
352 if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
353 IsOpAllowlisted(op_def) || !op_def->is_stateful() ||
354 op_def->name() == "Assert") {
355 return OkStatus();
356 }
357
358 if (op_def->name() == "If") {
359 const FunctionDef* then_func =
360 library.Find(node.attr().at("then_branch").func().name());
361 const FunctionDef* else_func =
362 library.Find(node.attr().at("else_branch").func().name());
363 if (then_func != nullptr) {
364 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func));
365 }
366 if (else_func != nullptr) {
367 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func));
368 }
369 return OkStatus();
370 }
371
372 if (op_def->name() == "While") {
373 const FunctionDef* cond_func =
374 library.Find(node.attr().at("cond").func().name());
375 const FunctionDef* body_func =
376 library.Find(node.attr().at("body").func().name());
377 if (cond_func != nullptr) {
378 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func));
379 }
380 if (body_func != nullptr) {
381 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func));
382 }
383 return OkStatus();
384 }
385
386 return errors::FailedPrecondition(op_def->name(), " is stateful.");
387 }
388
RunnerWithMaxParallelism(std::function<void (std::function<void ()>)> runner,int max_parallelism)389 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
390 std::function<void(std::function<void()>)> runner, int max_parallelism) {
391 return std::bind(
392 [max_parallelism](
393 // Note: `runner` is a const reference to avoid copying it.
394 const std::function<void(std::function<void()>)>& runner,
395 std::function<void()> fn) {
396 std::function<void()> scoped_fn = std::bind(
397 [max_parallelism](const std::function<void()>& fn) {
398 ScopedPerThreadMaxParallelism scope(max_parallelism);
399 fn();
400 },
401 std::move(fn));
402 runner(std::move(scoped_fn));
403 },
404 std::move(runner), std::placeholders::_1);
405 }
406
FromString(const std::string & s,DeterminismPolicy * out)407 Status DeterminismPolicy::FromString(const std::string& s,
408 DeterminismPolicy* out) {
409 DeterminismPolicy::Type type;
410 if (s == DeterminismPolicy::kDeterministic) {
411 type = DeterminismPolicy::Type::kDeterministic;
412 } else if (s == DeterminismPolicy::kNondeterministic) {
413 type = DeterminismPolicy::Type::kNondeterministic;
414 } else if (s == DeterminismPolicy::kDefault) {
415 type = DeterminismPolicy::Type::kDefault;
416 } else {
417 return errors::InvalidArgument("Unrecognized determinism policy: ", s);
418 }
419 *out = DeterminismPolicy(type);
420 return OkStatus();
421 }
422
DeterminismPolicy(bool is_deterministic)423 DeterminismPolicy::DeterminismPolicy(bool is_deterministic) {
424 if (is_deterministic) {
425 determinism_ = DeterminismPolicy::Type::kDeterministic;
426 } else {
427 determinism_ = DeterminismPolicy::Type::kNondeterministic;
428 }
429 }
430
String() const431 std::string DeterminismPolicy::String() const {
432 switch (determinism_) {
433 case DeterminismPolicy::Type::kDeterministic:
434 return DeterminismPolicy::kDeterministic;
435 case DeterminismPolicy::Type::kNondeterministic:
436 return DeterminismPolicy::kNondeterministic;
437 case DeterminismPolicy::Type::kDefault:
438 return DeterminismPolicy::kDefault;
439 default:
440 LOG(ERROR) << "Unrecognized determinism value";
441 return "Unrecognized";
442 }
443 }
444
MatchesAnyVersion(StringPiece op_prefix,StringPiece op_to_match)445 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match) {
446 if (!absl::StartsWith(op_to_match, op_prefix)) {
447 return false;
448 }
449 if (op_to_match.length() == op_prefix.length()) {
450 return true;
451 }
452 size_t index = op_to_match.length() - 1;
453 while (isdigit(op_to_match[index])) {
454 index--;
455 }
456 return (op_to_match[index] == 'V') && (op_prefix.length() == index);
457 }
458
GetExperiments()459 absl::flat_hash_set<string> GetExperiments() {
460 return GetExperiments(port::JobName(),
461 [](const tstring& str) { return Hash64(str); });
462 }
463
GetExperiments(const string & job_name,std::function<uint64 (const string &)> hash_func)464 absl::flat_hash_set<string> GetExperiments(
465 const string& job_name, std::function<uint64(const string&)> hash_func) {
466 absl::flat_hash_set<string> experiments;
467 if (job_name.empty()) {
468 return experiments;
469 }
470
471 // Parse the opt-in and opt-out settings.
472 const char* opt_ins_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_IN");
473 const char* opt_outs_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT");
474 string opt_ins_raw;
475 if (opt_ins_raw_cs != nullptr) {
476 opt_ins_raw = string(opt_ins_raw_cs);
477 }
478 string opt_outs_raw;
479 if (opt_outs_raw_cs != nullptr) {
480 opt_outs_raw = string(opt_outs_raw_cs);
481 }
482
483 // Identify opted out experiments.
484 absl::flat_hash_map<string, int64_t> live_experiments =
485 DatasetExperimentRegistry::Experiments();
486 absl::flat_hash_set<string> opt_outs;
487 if (opt_outs_raw == "all") {
488 for (const auto& pair : live_experiments) {
489 opt_outs.insert(pair.first);
490 }
491 } else {
492 for (const auto& experiment :
493 str_util::Split(opt_outs_raw, ',', str_util::SkipEmpty())) {
494 opt_outs.insert(experiment);
495 }
496 }
497
498 // Include opted in experiments unless they are opted out.
499 if (opt_ins_raw == "all") {
500 for (const auto& pair : live_experiments) {
501 auto experiment = pair.first;
502 if (!opt_outs.contains(experiment)) {
503 experiments.insert(experiment);
504 }
505 }
506 } else {
507 for (const auto& experiment :
508 str_util::Split(opt_ins_raw, ',', str_util::SkipEmpty())) {
509 if (!opt_outs.contains(experiment)) {
510 experiments.insert(experiment);
511 }
512 }
513 }
514
515 if (opt_outs_raw == "all_except_opt_in") {
516 return experiments;
517 }
518 // Stochastically include live experiments unless they are opted out.
519 for (const auto& pair : live_experiments) {
520 auto& experiment = pair.first;
521 if ((hash_func(strings::StrCat(job_name, experiment)) % 100 <
522 pair.second) &&
523 !opt_outs.contains(experiment)) {
524 experiments.insert(experiment);
525 }
526 }
527
528 return experiments;
529 }
530
LogAndRecordExperiments(const absl::flat_hash_set<string> & experiments)531 void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments) {
532 if (!experiments.empty()) {
533 constexpr float TEN_MINUTES = 60.0 * 10.0;
534 LOG_EVERY_N_SEC(INFO, TEN_MINUTES)
535 << "The input pipeline is subject to the following tf.data experiments:"
536 << " " << absl::StrJoin(experiments, ", ") << ". "
537 << "See `go/tf-data-experiments` for more details.";
538 }
539 for (auto& experiment : experiments) {
540 metrics::RecordTFDataExperiment(experiment);
541 }
542 }
543
GetOptimizations(const Options & options,absl::flat_hash_set<tstring> * optimizations_enabled,absl::flat_hash_set<tstring> * optimizations_disabled,absl::flat_hash_set<tstring> * optimizations_default)544 void GetOptimizations(const Options& options,
545 absl::flat_hash_set<tstring>* optimizations_enabled,
546 absl::flat_hash_set<tstring>* optimizations_disabled,
547 absl::flat_hash_set<tstring>* optimizations_default) {
548 DefaultOptimizationGraphRewrites(options, optimizations_enabled,
549 optimizations_disabled,
550 optimizations_default);
551 if (!OpDeterminismRequired() &&
552 options.optional_deterministic_case() == Options::kDeterministic &&
553 !options.deterministic()) {
554 optimizations_enabled->insert(kMakeSloppyOpt);
555 }
556 if (options.optional_slack_case() == Options::kSlack) {
557 if (options.slack()) {
558 optimizations_enabled->insert(kSlackOpt);
559 } else {
560 optimizations_disabled->insert(kSlackOpt);
561 }
562 }
563 }
564
MaybeCopySubSlice(const Tensor & tensor,int64 index)565 Tensor MaybeCopySubSlice(const Tensor& tensor, int64 index) {
566 Tensor slice = tensor.SubSlice(index);
567 if (slice.IsAligned()) {
568 return slice;
569 } else {
570 return tensorflow::tensor::DeepCopy(slice);
571 }
572 }
573
StripDevicePlacement(FunctionDefLibrary * library)574 void StripDevicePlacement(FunctionDefLibrary* library) {
575 for (auto& function : (*library->mutable_function())) {
576 for (auto& node : (*function.mutable_node_def())) {
577 if (!node.device().empty()) {
578 *node.mutable_device() = "";
579 }
580 }
581 }
582 }
583
CopyPartialBatch(int64_t num_elements,const Tensor & value,Tensor * output)584 Status CopyPartialBatch(int64_t num_elements, const Tensor& value,
585 Tensor* output) {
586 switch (value.dtype()) {
587 #define HANDLE_TYPE(type) \
588 case DataTypeToEnum<type>::value: { \
589 auto output_t = output->flat_outer_dims<type>(); \
590 auto value_t = value.flat_outer_dims<type>(); \
591 for (size_t i = 0; i < num_elements; i++) { \
592 output_t.template chip<0>(i) = value_t.template chip<0>(i); \
593 } \
594 return OkStatus(); \
595 }
596 TF_CALL_DATASET_TYPES(HANDLE_TYPE);
597 #undef HANDLE_TYPE
598 default:
599 return errors::InvalidArgument("Unsupported data type: ",
600 DataTypeString(value.dtype()));
601 }
602 return OkStatus();
603 }
604
ReadBatch(IteratorContext * ctx,IteratorStateReader * reader,int64_t batch_size,const string & iterator_prefix,const string & batch_prefix,std::vector<Tensor> * batch)605 Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader,
606 int64_t batch_size, const string& iterator_prefix,
607 const string& batch_prefix, std::vector<Tensor>* batch) {
608 int64_t output_size;
609 TF_RETURN_IF_ERROR(reader->ReadScalar(
610 FullName(iterator_prefix,
611 strings::StrCat(batch_prefix, "_", kOutputSize)),
612 &output_size));
613 batch->reserve(output_size);
614 for (int i = 0; i < output_size; i++) {
615 Tensor t;
616 TF_RETURN_IF_ERROR(
617 reader->ReadTensor(ctx->flr(), FullName(iterator_prefix, batch_prefix),
618 strings::StrCat(kOutput, "_", i), &t));
619 // If the batch was not full, we may have stored only the relevant slice.
620 // Since tensors in `BatchResult.output` are expected to have the leading
621 // dimension of size batch_size, we build a larger tensor and copy the slice
622 // read from the checkpoint into it.
623 if (t.dim_size(0) < batch_size) {
624 TensorShape component_shape(t.shape());
625 component_shape.set_dim(0, batch_size);
626 AllocatorAttributes attr;
627 attr.set_gpu_compatible(true);
628 Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape);
629 TF_RETURN_IF_ERROR(CopyPartialBatch(t.dim_size(0), t, &new_t));
630 batch->emplace_back(std::move(new_t));
631 } else {
632 batch->emplace_back(std::move(t));
633 }
634 }
635 return OkStatus();
636 }
637
WriteBatch(int64_t batch_size,int64_t num_elements,const string & iterator_prefix,const string & batch_prefix,IteratorStateWriter * writer,std::vector<Tensor> * batch)638 Status WriteBatch(int64_t batch_size, int64_t num_elements,
639 const string& iterator_prefix, const string& batch_prefix,
640 IteratorStateWriter* writer, std::vector<Tensor>* batch) {
641 TF_RETURN_IF_ERROR(writer->WriteScalar(
642 FullName(iterator_prefix,
643 strings::StrCat(batch_prefix, "_", kOutputSize)),
644 batch->size()));
645 for (int i = 0; i < batch->size(); i++) {
646 // If the batch is not full, we only store the first `num_elements` values.
647 // The rest of the batch tensor is *uninitialized* and accessing that will
648 // raise msan errors.
649 if (num_elements < batch_size) {
650 TF_RETURN_IF_ERROR(
651 writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
652 strings::StrCat(kOutput, "_", i),
653 (*batch)[i].Slice(0, num_elements)));
654 } else {
655 TF_RETURN_IF_ERROR(
656 writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
657 strings::StrCat(kOutput, "_", i), (*batch)[i]));
658 }
659 }
660 return OkStatus();
661 }
662
ReadStatus(const string & iterator_prefix,const string & prefix,IteratorStateReader * reader,Status * status)663 Status ReadStatus(const string& iterator_prefix, const string& prefix,
664 IteratorStateReader* reader, Status* status) {
665 int64_t code_int;
666 TF_RETURN_IF_ERROR(reader->ReadScalar(
667 FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
668 &code_int));
669 error::Code code = static_cast<error::Code>(code_int);
670
671 if (code != error::Code::OK) {
672 tstring error_message;
673 TF_RETURN_IF_ERROR(reader->ReadScalar(
674 FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
675 &error_message));
676 *status = Status(code, error_message);
677 } else {
678 *status = OkStatus();
679 }
680 return OkStatus();
681 }
682
WriteStatus(const string & iterator_prefix,const string & prefix,const Status & status,IteratorStateWriter * writer)683 Status WriteStatus(const string& iterator_prefix, const string& prefix,
684 const Status& status, IteratorStateWriter* writer) {
685 TF_RETURN_IF_ERROR(writer->WriteScalar(
686 FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
687 static_cast<int64_t>(status.code())));
688 if (!status.ok()) {
689 TF_RETURN_IF_ERROR(writer->WriteScalar(
690 FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
691 status.error_message()));
692 }
693 return OkStatus();
694 }
695
ProcessBatch(int64_t batch_size,int64_t num_elements,bool drop_remainder,const Status & status,IteratorContext * ctx,std::vector<Tensor> * output,bool * end_of_sequence,std::vector<Tensor> * batch)696 Status ProcessBatch(int64_t batch_size, int64_t num_elements,
697 bool drop_remainder, const Status& status,
698 IteratorContext* ctx, std::vector<Tensor>* output,
699 bool* end_of_sequence, std::vector<Tensor>* batch) {
700 if (num_elements == 0) {
701 if (status.ok() || errors::IsOutOfRange(status)) {
702 *end_of_sequence = true;
703 return OkStatus();
704 } else {
705 *end_of_sequence = false;
706 return status;
707 }
708 }
709 if (!status.ok() && !errors::IsOutOfRange(status)) {
710 *end_of_sequence = false;
711 return status;
712 }
713 if (num_elements < batch_size) {
714 if (drop_remainder) {
715 *end_of_sequence = true;
716 return OkStatus();
717 }
718 for (size_t i = 0; i < batch->size(); ++i) {
719 TensorShape component_shape((*batch)[i].shape());
720 component_shape.set_dim(0, num_elements);
721 AllocatorAttributes attr;
722 attr.set_gpu_compatible(true);
723 output->emplace_back(ctx->allocator(attr), (*batch)[i].dtype(),
724 component_shape);
725 if (!output->back().IsInitialized()) {
726 return errors::ResourceExhausted(
727 "Failed to allocate memory for the batch of component ", i);
728 }
729 TF_RETURN_IF_ERROR(
730 CopyPartialBatch(num_elements, (*batch)[i], &output->back()));
731 }
732 } else {
733 *output = std::move(*batch);
734 }
735 *end_of_sequence = false;
736 return OkStatus();
737 }
738
CopyBatch(CopyBatchParams params,const std::vector<std::vector<Tensor>> & batch_elements,bool parallel_copy,std::function<Status ()> allocation_callback,std::vector<Tensor> * out_tensors)739 Status CopyBatch(CopyBatchParams params,
740 const std::vector<std::vector<Tensor>>& batch_elements,
741 bool parallel_copy,
742 std::function<Status()> allocation_callback,
743 std::vector<Tensor>* out_tensors) {
744 const size_t num_tuple_components = batch_elements.at(0).size();
745 out_tensors->reserve(num_tuple_components);
746 const int64_t num_batch_elements = batch_elements.size();
747 for (size_t component_index = 0; component_index < num_tuple_components;
748 ++component_index) {
749 const Tensor& first_element = batch_elements.at(0)[component_index];
750 TensorShape first_element_shape(first_element.shape());
751 TensorShape batch_component_shape({num_batch_elements});
752 batch_component_shape.AppendShape(first_element_shape);
753 out_tensors->emplace_back(params.allocator, first_element.dtype(),
754 batch_component_shape);
755 if (!out_tensors->back().IsInitialized()) {
756 return errors::ResourceExhausted(
757 "Failed to allocate memory for the batch of component ",
758 component_index);
759 }
760 }
761 if (allocation_callback) {
762 TF_RETURN_IF_ERROR(allocation_callback());
763 }
764 for (size_t component_index = 0; component_index < num_tuple_components;
765 ++component_index) {
766 Tensor& batch_component = out_tensors->at(component_index);
767 const Tensor& first_element = batch_elements.at(0)[component_index];
768 TensorShape first_element_shape(first_element.shape());
769 // Build the output tuple component by copying one slice from each input
770 // element in the batch.
771 auto copy_element_fn = [component_index, &batch_elements, &batch_component,
772 &first_element_shape](int index) {
773 if (batch_elements.at(index)[component_index].shape() !=
774 first_element_shape) {
775 return errors::InvalidArgument(
776 "Cannot batch tensors with different shapes in component ",
777 component_index, ". First element had shape ",
778 first_element_shape.DebugString(), " and element ", index,
779 " had shape ",
780 batch_elements.at(index)[component_index].shape().DebugString(),
781 ".");
782 }
783 return batch_util::CopyElementToSlice(
784 std::move(batch_elements.at(index)[component_index]),
785 &batch_component, index);
786 };
787 if (parallel_copy && first_element.AllocatedBytes() > (1 << 15)) {
788 Status status;
789 mutex status_mu;
790 BlockingCounter counter(num_batch_elements);
791 const auto num_threads = params.runner_threadpool_size;
792 const auto slice_size = num_batch_elements / num_threads;
793 int64_t offset = 0;
794 for (size_t i = 0; i < num_threads; ++i) {
795 int64_t length = slice_size;
796 // When the number of threads does not divide the number of elements
797 // evenly, the size of some slices is incremented to guarantee their
798 // sizes add up to the total number of elements.
799 if (i < num_batch_elements % num_threads) ++length;
800 (*params.runner)([offset, length, &status, &status_mu, &counter,
801 ©_element_fn]() {
802 for (size_t j = offset; j < offset + length; ++j) {
803 {
804 Status s = copy_element_fn(j);
805 mutex_lock l(status_mu);
806 status.Update(s);
807 }
808 counter.DecrementCount();
809 }
810 });
811 offset += length;
812 }
813 counter.Wait();
814 TF_RETURN_IF_ERROR(status);
815 } else {
816 for (size_t i = 0; i < num_batch_elements; ++i) {
817 TF_RETURN_IF_ERROR(copy_element_fn(i));
818 }
819 }
820 }
821 return OkStatus();
822 }
823
CreateGraphRewriteConfigs(const Options & options)824 absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options) {
825 absl::flat_hash_set<tstring> configs;
826 const auto& autotune_options = options.autotune_options();
827 std::vector<tstring> autotune_only_optimizations = {
828 kAutotuneBufferSizesOpt,
829 kBatchParallelizationOpt,
830 kDisablePrefetchLegacyAutotuneOpt,
831 kEnableGradientDescentOpt,
832 kFilterParallelizationOpt,
833 kMapParallelizationOpt,
834 kInjectPrefetchOpt};
835
836 if (autotune_options.optional_enabled_case() == AutotuneOptions::kEnabled &&
837 !autotune_options.enabled()) {
838 for (const auto& optimization : autotune_only_optimizations) {
839 configs.insert(
840 absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":false"));
841 }
842 } else {
843 for (const auto& optimization : autotune_only_optimizations) {
844 configs.insert(
845 absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":true"));
846 }
847 }
848 if (options.slack()) {
849 int num_devices = 1;
850 if (options.distribute_options().optional_num_devices_case() ==
851 DistributeOptions::kNumDevices) {
852 num_devices = options.distribute_options().num_devices();
853 }
854 configs.insert(
855 absl::StrCat(kSlackOpt, ":", kSlackPeriodOpt, ":", num_devices));
856 }
857 return configs;
858 }
859
ShouldConfigureMaxIntraOpParallelism(const Options & options)860 bool ShouldConfigureMaxIntraOpParallelism(const Options& options) {
861 return options.threading_options().optional_max_intra_op_parallelism_case() ==
862 ThreadingOptions::kMaxIntraOpParallelism;
863 }
864
ShouldUsePrivateThreadPool(const Options & options)865 bool ShouldUsePrivateThreadPool(const Options& options) {
866 return options.threading_options().optional_private_threadpool_size_case() ==
867 ThreadingOptions::kPrivateThreadpoolSize;
868 }
869
ShouldUseAutotuning(const Options & options)870 bool ShouldUseAutotuning(const Options& options) {
871 return options.autotune_options().optional_enabled_case() !=
872 AutotuneOptions::kEnabled ||
873 options.autotune_options().enabled();
874 }
875
ShouldApplyOptimizations(const Options & options,const absl::flat_hash_set<tstring> & optimizations_enabled,const absl::flat_hash_set<tstring> & optimizations_default)876 bool ShouldApplyOptimizations(
877 const Options& options,
878 const absl::flat_hash_set<tstring>& optimizations_enabled,
879 const absl::flat_hash_set<tstring>& optimizations_default) {
880 return (options.optimization_options()
881 .optional_apply_default_optimizations_case() !=
882 OptimizationOptions::kApplyDefaultOptimizations ||
883 options.optimization_options().apply_default_optimizations() ||
884 !optimizations_enabled.empty() || !optimizations_default.empty());
885 }
886
GetAutotuneDefaultParallelism(IteratorContext * ctx)887 int64 GetAutotuneDefaultParallelism(IteratorContext* ctx) {
888 return std::min(kAutotuneDefaultParallelism, ctx->runner_threadpool_size());
889 }
890
891 // static
Register(const string & experiment,int64_t rollout_pct)892 void DatasetExperimentRegistry::Register(const string& experiment,
893 int64_t rollout_pct) {
894 mutex_lock l(*get_dataset_experiment_registry_lock());
895 get_dataset_experiments()->insert(std::make_pair(experiment, rollout_pct));
896 }
897
898 // static
Experiments()899 absl::flat_hash_map<string, int64_t> DatasetExperimentRegistry::Experiments() {
900 mutex_lock l(*get_dataset_experiment_registry_lock());
901 return *get_dataset_experiments();
902 }
903
904 namespace {
905
906 REGISTER_DATASET_EXPERIMENT("allow_small_function_optimizations", 0);
907 REGISTER_DATASET_EXPERIMENT("autotune_buffer_optimization", 0);
908 REGISTER_DATASET_EXPERIMENT(kFilterParallelizationOpt, 0);
909 REGISTER_DATASET_EXPERIMENT("inject_prefetch", 100);
910 REGISTER_DATASET_EXPERIMENT("min_outer_interleave_parallelism", 0);
911 REGISTER_DATASET_EXPERIMENT("reduce_interleave_prefetch", 0);
912 REGISTER_DATASET_EXPERIMENT("serialize_input_cycle_length", 0);
913 REGISTER_DATASET_EXPERIMENT("stage_based_autotune", 0);
914 } // namespace
915 } // namespace data
916 } // namespace tensorflow
917