1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/dataset_utils.h"
17
18 #include <functional>
19 #include <string>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/data/dataset_test_base.h"
23 #include "tensorflow/core/data/serialization_utils.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/node_def_builder.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/str_util.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/protobuf/error_codes.pb.h"
34 #include "tensorflow/core/util/determinism_test_util.h"
35 #include "tensorflow/core/util/work_sharder.h"
36
37 namespace tensorflow {
38 namespace data {
39 namespace {
40
TEST(DatasetUtilsTest,MatchesAnyVersion)41 TEST(DatasetUtilsTest, MatchesAnyVersion) {
42 EXPECT_TRUE(MatchesAnyVersion("BatchDataset", "BatchDataset"));
43 EXPECT_TRUE(MatchesAnyVersion("BatchDataset", "BatchDatasetV2"));
44 EXPECT_TRUE(MatchesAnyVersion("BatchDataset", "BatchDatasetV3"));
45 EXPECT_FALSE(MatchesAnyVersion("BatchDataset", "BatchDatasetXV3"));
46 EXPECT_FALSE(MatchesAnyVersion("BatchDataset", "BatchV2Dataset"));
47 EXPECT_FALSE(MatchesAnyVersion("BatchDataset", "PaddedBatchDataset"));
48 }
49
TEST(DatasetUtilsTest,AddToFunctionLibrary)50 TEST(DatasetUtilsTest, AddToFunctionLibrary) {
51 auto make_fn_a = [](const string& fn_name) {
52 return FunctionDefHelper::Create(
53 /*function_name=*/fn_name,
54 /*in_def=*/{"arg: int64"},
55 /*out_def=*/{"ret: int64"},
56 /*attr_def=*/{},
57 /*node_def=*/{{{"node"}, "Identity", {"arg"}, {{"T", DT_INT64}}}},
58 /*ret_def=*/{{"ret", "node:output:0"}});
59 };
60
61 auto make_fn_b = [](const string& fn_name) {
62 return FunctionDefHelper::Create(
63 /*function_name=*/fn_name,
64 /*in_def=*/{"arg: int64"},
65 /*out_def=*/{"ret: int64"},
66 /*attr_def=*/{},
67 /*node_def=*/
68 {{{"node"}, "Identity", {"arg"}, {{"T", DT_INT64}}},
69 {{"node2"}, "Identity", {"node:output:0"}, {{"T", DT_INT64}}}},
70 /*ret_def=*/{{"ret", "node2:output:0"}});
71 };
72
73 FunctionDefLibrary fdef_base;
74 *fdef_base.add_function() = make_fn_a("0");
75 *fdef_base.add_function() = make_fn_a("1");
76 *fdef_base.add_function() = make_fn_a("2");
77
78 FunctionDefLibrary fdef_to_add;
79 *fdef_to_add.add_function() = make_fn_b("0"); // Override
80 *fdef_to_add.add_function() = make_fn_a("1"); // Do nothing
81 *fdef_to_add.add_function() = make_fn_b("3"); // Add new function
82
83 FunctionLibraryDefinition flib_0(OpRegistry::Global(), fdef_base);
84 TF_ASSERT_OK(AddToFunctionLibrary(&flib_0, fdef_to_add));
85
86 FunctionLibraryDefinition flib_1(OpRegistry::Global(), fdef_base);
87 FunctionLibraryDefinition flib_to_add(OpRegistry::Global(), fdef_to_add);
88 TF_ASSERT_OK(AddToFunctionLibrary(&flib_1, flib_to_add));
89
90 for (const auto& flib : {flib_0, flib_1}) {
91 EXPECT_TRUE(FunctionDefsEqual(*flib.Find("0"), make_fn_b("0")));
92 EXPECT_TRUE(FunctionDefsEqual(*flib.Find("1"), make_fn_a("1")));
93 EXPECT_TRUE(FunctionDefsEqual(*flib.Find("2"), make_fn_a("2")));
94 EXPECT_TRUE(FunctionDefsEqual(*flib.Find("3"), make_fn_b("3")));
95 }
96 }
97
TEST(DatasetUtilsTest,AddToFunctionLibraryWithConflictingSignatures)98 TEST(DatasetUtilsTest, AddToFunctionLibraryWithConflictingSignatures) {
99 FunctionDefLibrary fdef_base;
100 *fdef_base.add_function() = FunctionDefHelper::Create(
101 /*function_name=*/"0",
102 /*in_def=*/{"arg: int64"},
103 /*out_def=*/{"ret: int64"},
104 /*attr_def=*/{},
105 /*node_def=*/{},
106 /*ret_def=*/{{"ret", "arg"}});
107
108 FunctionDefLibrary fdef_to_add;
109 *fdef_to_add.add_function() = FunctionDefHelper::Create(
110 /*function_name=*/"0",
111 /*in_def=*/{"arg: int64"},
112 /*out_def=*/{"ret: int64", "ret2: int64"},
113 /*attr_def=*/{},
114 /*node_def=*/{},
115 /*ret_def=*/{{"ret", "arg"}, {"ret2", "arg"}});
116
117 FunctionLibraryDefinition flib_0(OpRegistry::Global(), fdef_base);
118 Status s = AddToFunctionLibrary(&flib_0, fdef_to_add);
119 EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
120 EXPECT_EQ(
121 "Cannot add function '0' because a different function with the same "
122 "signature already exists.",
123 s.error_message());
124
125 FunctionLibraryDefinition flib_1(OpRegistry::Global(), fdef_base);
126 FunctionLibraryDefinition flib_to_add(OpRegistry::Global(), fdef_to_add);
127 s = AddToFunctionLibrary(&flib_1, flib_to_add);
128 EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
129 EXPECT_EQ(
130 "Cannot add function '0' because a different function with the same "
131 "signature already exists.",
132 s.error_message());
133 }
134
TEST(DatasetUtilsTest,StripDevicePlacement)135 TEST(DatasetUtilsTest, StripDevicePlacement) {
136 FunctionDefLibrary flib;
137 *flib.add_function() = FunctionDefHelper::Create(
138 /*function_name=*/"0",
139 /*in_def=*/{"arg: int64"},
140 /*out_def=*/{"ret: int64"},
141 /*attr_def=*/{},
142 /*node_def=*/
143 {{{"node"},
144 "Identity",
145 {"arg"},
146 {{"T", DT_INT64}},
147 /*dep=*/{},
148 /*device=*/"device:CPU:0"}},
149 /*ret_def=*/{{"ret", "arg"}});
150 EXPECT_EQ(flib.function(0).node_def(0).device(), "device:CPU:0");
151 StripDevicePlacement(&flib);
152 EXPECT_EQ(flib.function(0).node_def(0).device(), "");
153 }
154
TEST(DatasetUtilsTest,RunnerWithMaxParallelism)155 TEST(DatasetUtilsTest, RunnerWithMaxParallelism) {
156 auto runner =
157 RunnerWithMaxParallelism([](const std::function<void()> fn) { fn(); }, 2);
158 auto fn = []() { ASSERT_EQ(GetPerThreadMaxParallelism(), 2); };
159 runner(fn);
160 }
161
TEST(DatasetUtilsTest,ParseDeterminismPolicy)162 TEST(DatasetUtilsTest, ParseDeterminismPolicy) {
163 DeterminismPolicy determinism;
164 TF_ASSERT_OK(DeterminismPolicy::FromString("true", &determinism));
165 EXPECT_TRUE(determinism.IsDeterministic());
166 TF_ASSERT_OK(DeterminismPolicy::FromString("false", &determinism));
167 EXPECT_TRUE(determinism.IsNondeterministic());
168 TF_ASSERT_OK(DeterminismPolicy::FromString("default", &determinism));
169 EXPECT_TRUE(determinism.IsDefault());
170 }
171
TEST(DatasetUtilsTest,DeterminismString)172 TEST(DatasetUtilsTest, DeterminismString) {
173 for (auto s : {"true", "false", "default"}) {
174 DeterminismPolicy determinism;
175 TF_ASSERT_OK(DeterminismPolicy::FromString(s, &determinism));
176 EXPECT_TRUE(s == determinism.String());
177 }
178 }
179
TEST(DatasetUtilsTest,BoolConstructor)180 TEST(DatasetUtilsTest, BoolConstructor) {
181 EXPECT_TRUE(DeterminismPolicy(true).IsDeterministic());
182 EXPECT_FALSE(DeterminismPolicy(true).IsNondeterministic());
183 EXPECT_FALSE(DeterminismPolicy(true).IsDefault());
184
185 EXPECT_TRUE(DeterminismPolicy(false).IsNondeterministic());
186 EXPECT_FALSE(DeterminismPolicy(false).IsDeterministic());
187 EXPECT_FALSE(DeterminismPolicy(false).IsDefault());
188 }
189
190 REGISTER_DATASET_EXPERIMENT("test_only_experiment_0", 0);
191 REGISTER_DATASET_EXPERIMENT("test_only_experiment_1", 1);
192 REGISTER_DATASET_EXPERIMENT("test_only_experiment_5", 5);
193 REGISTER_DATASET_EXPERIMENT("test_only_experiment_10", 10);
194 REGISTER_DATASET_EXPERIMENT("test_only_experiment_50", 50);
195 REGISTER_DATASET_EXPERIMENT("test_only_experiment_99", 99);
196 REGISTER_DATASET_EXPERIMENT("test_only_experiment_100", 100);
197
198 struct GetExperimentsHashTestCase {
199 uint64 hash;
200 std::vector<string> expected_in;
201 std::vector<string> expected_out;
202 };
203
204 class GetExperimentsHashTest
205 : public ::testing::TestWithParam<GetExperimentsHashTestCase> {};
206
TEST_P(GetExperimentsHashTest,DatasetUtils)207 TEST_P(GetExperimentsHashTest, DatasetUtils) {
208 const GetExperimentsHashTestCase test_case = GetParam();
209 uint64 hash_result = test_case.hash;
210 auto job_name = "job";
211 auto hash_func = [hash_result](const string& str) { return hash_result; };
212 auto experiments = GetExperiments(job_name, hash_func);
213
214 absl::flat_hash_set<string> experiment_set(experiments.begin(),
215 experiments.end());
216 for (const auto& experiment : test_case.expected_in) {
217 EXPECT_TRUE(experiment_set.find(experiment) != experiment_set.end())
218 << "experiment=" << experiment << " hash=" << hash_result;
219 }
220 for (const auto& experiment : test_case.expected_out) {
221 EXPECT_TRUE(experiment_set.find(experiment) == experiment_set.end())
222 << "experiment=" << experiment << " hash=" << hash_result;
223 }
224 }
225
226 INSTANTIATE_TEST_SUITE_P(
227 Test, GetExperimentsHashTest,
228 ::testing::Values<GetExperimentsHashTestCase>(
229 GetExperimentsHashTestCase{
230 /*hash=*/0,
231 /*expected_in=*/
232 {"test_only_experiment_1", "test_only_experiment_5",
233 "test_only_experiment_10", "test_only_experiment_50",
234 "test_only_experiment_99", "test_only_experiment_100"},
235 /*expected_out=*/{"test_only_experiment_0"},
236 },
237 GetExperimentsHashTestCase{
238 /*hash=*/5,
239 /*expected_in=*/
240 {"test_only_experiment_10", "test_only_experiment_50",
241 "test_only_experiment_99", "test_only_experiment_100"},
242 /*expected_out=*/
243 {
244 "test_only_experiment_0",
245 "test_only_experiment_1",
246 "test_only_experiment_5",
247 },
248 },
249 GetExperimentsHashTestCase{
250 /*hash=*/95,
251 /*expected_in=*/
252 {"test_only_experiment_99", "test_only_experiment_100"},
253 /*expected_out=*/
254 {"test_only_experiment_0", "test_only_experiment_1",
255 "test_only_experiment_5", "test_only_experiment_10",
256 "test_only_experiment_50"},
257 },
258 GetExperimentsHashTestCase{
259 /*hash=*/99,
260 /*expected_in=*/{"test_only_experiment_100"},
261 /*expected_out=*/
262 {"test_only_experiment_0", "test_only_experiment_1",
263 "test_only_experiment_5", "test_only_experiment_10",
264 "test_only_experiment_50", "test_only_experiment_99"},
265 },
266 GetExperimentsHashTestCase{
267 /*hash=*/100,
268 /*expected_in=*/
269 {"test_only_experiment_1", "test_only_experiment_5",
270 "test_only_experiment_10", "test_only_experiment_50",
271 "test_only_experiment_99", "test_only_experiment_100"},
272 /*expected_out=*/{"test_only_experiment_0"},
273 },
274 GetExperimentsHashTestCase{
275 /*hash=*/105,
276 /*expected_in=*/
277 {"test_only_experiment_10", "test_only_experiment_50",
278 "test_only_experiment_99", "test_only_experiment_100"},
279 /*expected_out=*/
280 {
281 "test_only_experiment_0",
282 "test_only_experiment_1",
283 "test_only_experiment_5",
284 },
285 },
286 GetExperimentsHashTestCase{
287 /*hash=*/195,
288 /*expected_in=*/
289 {"test_only_experiment_99", "test_only_experiment_100"},
290 /*expected_out=*/
291 {"test_only_experiment_0", "test_only_experiment_1",
292 "test_only_experiment_5", "test_only_experiment_10",
293 "test_only_experiment_50"},
294 }));
295
296 struct GetExperimentsOptTestCase {
297 std::vector<string> opt_ins;
298 std::vector<string> opt_outs;
299 std::vector<string> expected_in;
300 std::vector<string> expected_out;
301 };
302
303 class GetExperimentsOptTest
304 : public ::testing::TestWithParam<GetExperimentsOptTestCase> {};
305
TEST_P(GetExperimentsOptTest,DatasetUtils)306 TEST_P(GetExperimentsOptTest, DatasetUtils) {
307 const GetExperimentsOptTestCase test_case = GetParam();
308 auto opt_ins = test_case.opt_ins;
309 auto opt_outs = test_case.opt_outs;
310 if (!opt_ins.empty()) {
311 setenv("TF_DATA_EXPERIMENT_OPT_IN", str_util::Join(opt_ins, ",").c_str(),
312 1);
313 }
314 if (!opt_outs.empty()) {
315 setenv("TF_DATA_EXPERIMENT_OPT_OUT", str_util::Join(opt_outs, ",").c_str(),
316 1);
317 }
318 auto job_name = "job";
319 auto hash_func = [](const string& str) { return 0; };
320 auto experiments = GetExperiments(job_name, hash_func);
321
322 absl::flat_hash_set<string> experiment_set(experiments.begin(),
323 experiments.end());
324 for (const auto& experiment : test_case.expected_in) {
325 EXPECT_TRUE(experiment_set.find(experiment) != experiment_set.end())
326 << "experiment=" << experiment << " opt_ins={"
327 << str_util::Join(opt_ins, ",") << "} opt_outs={"
328 << str_util::Join(opt_outs, ",") << "}";
329 }
330 for (const auto& experiment : test_case.expected_out) {
331 EXPECT_TRUE(experiment_set.find(experiment) == experiment_set.end())
332 << "experiment=" << experiment << " opt_ins={"
333 << str_util::Join(opt_ins, ",") << "} opt_outs={"
334 << str_util::Join(opt_outs, ",") << "}";
335 }
336
337 if (!opt_ins.empty()) {
338 unsetenv("TF_DATA_EXPERIMENT_OPT_IN");
339 }
340 if (!opt_outs.empty()) {
341 unsetenv("TF_DATA_EXPERIMENT_OPT_OUT");
342 }
343 }
344
345 INSTANTIATE_TEST_SUITE_P(
346 Test, GetExperimentsOptTest,
347 ::testing::Values<GetExperimentsOptTestCase>(
348 GetExperimentsOptTestCase{
349 /*opt_ins=*/{"all"},
350 /*opt_outs=*/{"all"},
351 /*expected_in=*/{},
352 /*expected_out=*/
353 {"test_only_experiment_0", "test_only_experiment_1",
354 "test_only_experiment_5", "test_only_experiment_10",
355 "test_only_experiment_50", "test_only_experiment_99",
356 "test_only_experiment_100"}},
357 GetExperimentsOptTestCase{
358 /*opt_ins=*/{"all"},
359 /*opt_outs=*/{},
360 /*expected_in=*/
361 {"test_only_experiment_0", "test_only_experiment_1",
362 "test_only_experiment_5", "test_only_experiment_10",
363 "test_only_experiment_50", "test_only_experiment_99",
364 "test_only_experiment_100"},
365 /*expected_out=*/{}},
366 GetExperimentsOptTestCase{
367 /*opt_ins=*/{"all"},
368 /*opt_outs=*/{"test_only_experiment_1", "test_only_experiment_99"},
369 /*expected_in=*/
370 {"test_only_experiment_0", "test_only_experiment_5",
371 "test_only_experiment_10", "test_only_experiment_50",
372 "test_only_experiment_100"},
373 /*expected_out=*/
374 {"test_only_experiment_1", "test_only_experiment_99"}},
375 GetExperimentsOptTestCase{
376 /*opt_ins=*/{},
377 /*opt_outs=*/{"all"},
378 /*expected_in=*/{},
379 /*expected_out=*/
380 {"test_only_experiment_0", "test_only_experiment_1",
381 "test_only_experiment_5", "test_only_experiment_10",
382 "test_only_experiment_50", "test_only_experiment_99",
383 "test_only_experiment_100"}},
384 GetExperimentsOptTestCase{
385 /*opt_ins=*/{},
386 /*opt_outs=*/{},
387 /*expected_in=*/
388 {"test_only_experiment_1", "test_only_experiment_5",
389 "test_only_experiment_10", "test_only_experiment_50",
390 "test_only_experiment_99", "test_only_experiment_100"},
391 /*expected_out=*/{"test_only_experiment_0"}},
392 GetExperimentsOptTestCase{
393 /*opt_ins=*/{},
394 /*opt_outs=*/{"test_only_experiment_1", "test_only_experiment_99"},
395 /*expected_in=*/
396 {"test_only_experiment_5", "test_only_experiment_10",
397 "test_only_experiment_50", "test_only_experiment_100"},
398 /*expected_out=*/
399 {"test_only_experiment_0", "test_only_experiment_1",
400 "test_only_experiment_99"}},
401 GetExperimentsOptTestCase{
402 /*opt_ins=*/{"test_only_experiment_0", "test_only_experiment_100"},
403 /*opt_outs=*/{"all"},
404 /*expected_in=*/{},
405 /*expected_out=*/
406 {"test_only_experiment_0", "test_only_experiment_1",
407 "test_only_experiment_5", "test_only_experiment_10",
408 "test_only_experiment_50", "test_only_experiment_99",
409 "test_only_experiment_100"}},
410 GetExperimentsOptTestCase{
411 /*opt_ins=*/{"test_only_experiment_0", "test_only_experiment_100"},
412 /*opt_outs=*/{"all_except_opt_in"},
413 /*expected_in=*/
414 {"test_only_experiment_0", "test_only_experiment_100"},
415 /*expected_out=*/
416 {"test_only_experiment_1", "test_only_experiment_5",
417 "test_only_experiment_10", "test_only_experiment_50",
418 "test_only_experiment_99"}},
419 GetExperimentsOptTestCase{
420 /*opt_ins=*/{"test_only_experiment_0", "test_only_experiment_100"},
421 /*opt_outs=*/{},
422 /*expected_in=*/
423 {"test_only_experiment_0", "test_only_experiment_1",
424 "test_only_experiment_5", "test_only_experiment_10",
425 "test_only_experiment_50", "test_only_experiment_99",
426 "test_only_experiment_100"},
427 /*expected_out=*/{}},
428 GetExperimentsOptTestCase{
429 /*opt_ins=*/{"test_only_experiment_0", "test_only_experiment_100"},
430 /*opt_outs=*/{"test_only_experiment_1", "test_only_experiment_99"},
431 /*expected_in=*/
432 {"test_only_experiment_0", "test_only_experiment_5",
433 "test_only_experiment_10", "test_only_experiment_50",
434 "test_only_experiment_100"},
435 /*expected_out=*/
436 {"test_only_experiment_1", "test_only_experiment_99"}}));
437
438 struct GetExperimentsJobNameTestCase {
439 string job_name;
440 std::vector<string> expected_in;
441 std::vector<string> expected_out;
442 };
443
444 class GetExperimentsJobNameTest
445 : public ::testing::TestWithParam<GetExperimentsJobNameTestCase> {};
446
TEST_P(GetExperimentsJobNameTest,DatasetUtils)447 TEST_P(GetExperimentsJobNameTest, DatasetUtils) {
448 const GetExperimentsJobNameTestCase test_case = GetParam();
449 auto job_name = test_case.job_name;
450 auto hash_func = [](const string& str) { return 0; };
451 auto experiments = GetExperiments(job_name, hash_func);
452
453 absl::flat_hash_set<string> experiment_set(experiments.begin(),
454 experiments.end());
455 for (const auto& experiment : test_case.expected_in) {
456 EXPECT_TRUE(experiment_set.find(experiment) != experiment_set.end())
457 << "experiment=" << experiment << " job_name=" << job_name;
458 }
459 for (const auto& experiment : test_case.expected_out) {
460 EXPECT_TRUE(experiment_set.find(experiment) == experiment_set.end())
461 << "experiment=" << experiment << " job_name=" << job_name;
462 }
463 }
464
465 INSTANTIATE_TEST_SUITE_P(
466 Test, GetExperimentsJobNameTest,
467 ::testing::Values(GetExperimentsJobNameTestCase{
468 /*job_name=*/"",
469 /*expected_in=*/{},
470 /*expected_out=*/
471 {"test_only_experiment_0", "test_only_experiment_1",
472 "test_only_experiment_5", "test_only_experiment_10",
473 "test_only_experiment_50", "test_only_experiment_99",
474 "test_only_experiment_100"}},
475 GetExperimentsJobNameTestCase{
476 /*job_name=*/"job_name",
477 /*expected_in=*/
478 {"test_only_experiment_1", "test_only_experiment_5",
479 "test_only_experiment_10", "test_only_experiment_50",
480 "test_only_experiment_99",
481 "test_only_experiment_100"},
482 /*expected_out=*/{"test_only_experiment_0"}}));
483
484 struct GetOptimizationsTestCase {
485 Options options;
486 std::vector<string> expected_enabled;
487 std::vector<string> expected_disabled;
488 std::vector<string> expected_default;
489 };
490
491 // Tests the default.
GetOptimizationTestCase1()492 GetOptimizationsTestCase GetOptimizationTestCase1() {
493 return {
494 /*options=*/Options(),
495 /*expected_enabled=*/{},
496 /*expected_disabled=*/{},
497 /*expected_default=*/
498 {"noop_elimination", "map_and_batch_fusion", "shuffle_and_repeat_fusion",
499 "map_parallelization", "parallel_batch"}};
500 }
501
502 // Tests disabling application of default optimizations.
GetOptimizationTestCase2()503 GetOptimizationsTestCase GetOptimizationTestCase2() {
504 Options options;
505 options.mutable_optimization_options()->set_apply_default_optimizations(
506 false);
507 return {options, /*expected_enabled=*/{}, /*expected_disabled=*/{},
508 /*expected_default=*/{}};
509 }
510
511 // Tests explicitly enabling / disabling some default and non-default
512 // optimizations.
GetOptimizationTestCase3()513 GetOptimizationsTestCase GetOptimizationTestCase3() {
514 Options options;
515 options.set_deterministic(false);
516 options.mutable_optimization_options()->set_map_and_batch_fusion(true);
517 options.mutable_optimization_options()->set_map_parallelization(false);
518 options.mutable_optimization_options()->set_parallel_batch(false);
519 return {options,
520 /*expected_enabled=*/{"make_sloppy", "map_and_batch_fusion"},
521 /*expected_disabled=*/{"parallel_batch", "map_parallelization"},
522 /*expected_default=*/
523 {"noop_elimination", "shuffle_and_repeat_fusion"}};
524 }
525
526 // Test enabling all / most available optimizations.
GetOptimizationTestCase4()527 GetOptimizationsTestCase GetOptimizationTestCase4() {
528 Options options;
529 options.set_deterministic(false);
530 options.mutable_optimization_options()->set_filter_fusion(true);
531 options.mutable_optimization_options()->set_filter_parallelization(true);
532 options.mutable_optimization_options()->set_map_and_batch_fusion(true);
533 options.mutable_optimization_options()->set_map_and_filter_fusion(true);
534 options.mutable_optimization_options()->set_map_fusion(true);
535 options.mutable_optimization_options()->set_map_parallelization(true);
536 options.mutable_optimization_options()->set_noop_elimination(true);
537 options.mutable_optimization_options()->set_parallel_batch(true);
538 options.mutable_optimization_options()->set_shuffle_and_repeat_fusion(true);
539 options.mutable_optimization_options()->set_inject_prefetch(true);
540 options.set_slack(true);
541 return {options,
542 /*expected_enabled=*/
543 {"filter_fusion", "filter_parallelization", "make_sloppy",
544 "map_and_batch_fusion", "map_and_filter_fusion", "map_fusion",
545 "map_parallelization", "noop_elimination", "parallel_batch",
546 "shuffle_and_repeat_fusion", "slack", "inject_prefetch"},
547 /*expected_disabled=*/{},
548 /*expected_default=*/{}};
549 }
550
551 class GetOptimizationsTest
552 : public ::testing::TestWithParam<GetOptimizationsTestCase> {};
553
TEST_P(GetOptimizationsTest,DatasetUtils)554 TEST_P(GetOptimizationsTest, DatasetUtils) {
555 const GetOptimizationsTestCase test_case = GetParam();
556 auto options = test_case.options;
557
558 absl::flat_hash_set<tstring> actual_enabled, actual_disabled, actual_default;
559 GetOptimizations(options, &actual_enabled, &actual_disabled, &actual_default);
560
561 EXPECT_THAT(std::vector<string>(actual_enabled.begin(), actual_enabled.end()),
562 ::testing::UnorderedElementsAreArray(test_case.expected_enabled));
563 EXPECT_THAT(
564 std::vector<string>(actual_disabled.begin(), actual_disabled.end()),
565 ::testing::UnorderedElementsAreArray(test_case.expected_disabled));
566 EXPECT_THAT(std::vector<string>(actual_default.begin(), actual_default.end()),
567 ::testing::UnorderedElementsAreArray(test_case.expected_default));
568 }
569
570 INSTANTIATE_TEST_SUITE_P(Test, GetOptimizationsTest,
571 ::testing::Values(GetOptimizationTestCase1(),
572 GetOptimizationTestCase2(),
573 GetOptimizationTestCase3(),
574 GetOptimizationTestCase4()));
575
TEST(DeterministicOpsTest,GetOptimizations)576 TEST(DeterministicOpsTest, GetOptimizations) {
577 test::DeterministicOpsScope det_scope;
578 Options options;
579 // options.deterministic should be ignored when deterministic ops are enabled.
580 options.set_deterministic(false);
581 absl::flat_hash_set<tstring> actual_enabled, actual_disabled, actual_default;
582 GetOptimizations(options, &actual_enabled, &actual_disabled, &actual_default);
583 EXPECT_THAT(std::vector<string>(actual_enabled.begin(), actual_enabled.end()),
584 ::testing::UnorderedElementsAreArray({"make_deterministic"}));
585 EXPECT_EQ(actual_disabled.size(), 0);
586 }
587
588 REGISTER_DATASET_EXPERIMENT("test_only_experiment", 42);
589
TEST(DatasetUtilsTest,DatasetExperimentRegistry)590 TEST(DatasetUtilsTest, DatasetExperimentRegistry) {
591 auto experiments = DatasetExperimentRegistry::Experiments();
592 EXPECT_TRUE(experiments.find("test_only_experiment") != experiments.end());
593 EXPECT_TRUE(experiments.find("non_existing_experiment") == experiments.end());
594 }
595
596 } // namespace
597 } // namespace data
598 } // namespace tensorflow
599