1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for custom training loops.""" 16 17from absl.testing import parameterized 18 19from tensorflow.python import tf2 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.distribute import combinations 22from tensorflow.python.distribute import device_util 23from tensorflow.python.distribute import distribute_lib 24from tensorflow.python.distribute import reduce_util 25from tensorflow.python.distribute import strategy_combinations 26from tensorflow.python.distribute import test_util 27from tensorflow.python.eager import def_function 28from tensorflow.python.eager import test 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import map_fn 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import variables 38from tensorflow.python.ops.losses import losses 39from tensorflow.python.tpu import tpu 40from tensorflow.python.util import nest 41 42 43def get_dataset_from_tensor_slices(inp_array): 44 dataset = dataset_ops.DatasetV2.from_tensor_slices(inp_array) 45 # TODO(b/138326910): Remove Dataset V1 version once bug resolved. 46 if not tf2.enabled(): 47 dataset = dataset_ops.Dataset.from_tensor_slices(inp_array) 48 return dataset 49 50 51class AssertFlattenedMixin(object): 52 """Mixin for specialized asserts.""" 53 54 def assert_equal_flattened(self, expected_results, actual_results): 55 """Asserts that flattened results are equal. 56 57 Due to the number of replicas in the strategy, the output may have a 58 different structure and needs to be flattened for comparison. 59 60 Args: 61 expected_results: The results expected as a result of a computation. 62 actual_results: The actual results of a computation. 63 """ 64 self.assertEqual(len(expected_results), len(actual_results)) 65 66 for i, expected_result in enumerate(expected_results): 67 final_result = [] 68 actual_result = actual_results[i] 69 for val in actual_result: 70 final_result.extend(val.numpy()) 71 self.assertAllEqual(expected_result, final_result) 72 73 74class InputIterationTest(test.TestCase, parameterized.TestCase, 75 AssertFlattenedMixin): 76 77 @combinations.generate( 78 combinations.combine( 79 distribution=strategy_combinations.all_strategies, 80 mode=["eager"] 81 )) 82 def testConstantNumpyInput(self, distribution): 83 84 @def_function.function 85 def run(x): 86 87 def computation(x): 88 return math_ops.square(x) 89 90 outputs = distribution.experimental_local_results( 91 distribution.run(computation, args=(x,))) 92 return outputs 93 94 self.assertAllEqual( 95 constant_op.constant(4., shape=(distribution.num_replicas_in_sync)), 96 run(2.)) 97 98 @combinations.generate( 99 combinations.combine( 100 distribution=strategy_combinations.all_strategies, 101 mode=["eager"] 102 )) 103 def testStatefulExperimentalRunAlwaysExecute(self, distribution): 104 with distribution.scope(): 105 v = variables.Variable( 106 0.0, aggregation=variables.VariableAggregation.MEAN) 107 108 @def_function.function 109 def train_step(): 110 111 def assign_add(): 112 v.assign_add(1.0) 113 114 distribution.run(assign_add) 115 return array_ops.zeros([]) 116 117 train_step() 118 self.assertAllEqual(1.0, v.numpy()) 119 120 @combinations.generate( 121 combinations.combine( 122 distribution=strategy_combinations.strategies_minus_tpu, 123 mode=["eager"])) 124 def testFullEager(self, distribution): 125 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 126 127 def train_step(data): 128 return math_ops.square(data) 129 130 dist_dataset = distribution.experimental_distribute_dataset(dataset) 131 results = [] 132 for x in dist_dataset: 133 output = distribution.experimental_local_results( 134 distribution.run(train_step, args=(x,))) 135 results.append(output) 136 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 137 138 @combinations.generate( 139 combinations.combine( 140 distribution=strategy_combinations.all_strategies, mode=["eager"])) 141 def testGetNextAsOptional(self, distribution): 142 data = [5., 6., 7., 8.] 143 dataset = get_dataset_from_tensor_slices(data).batch(2) 144 dist_dataset = distribution.experimental_distribute_dataset(dataset) 145 iterator = iter(dist_dataset) 146 147 def train_step(data): 148 return math_ops.square(data) 149 150 @def_function.function 151 def run(iterator): 152 return distribution.experimental_local_results( 153 distribution.run( 154 train_step, args=(iterator.get_next_as_optional().get_value(),))) 155 156 self.assert_equal_flattened([[25., 36.]], [run(iterator)]) 157 158 @combinations.generate( 159 combinations.combine( 160 distribution=strategy_combinations.all_strategies, mode=["eager"])) 161 def testGetNextAsOptionalExampleUsage(self, distribution): 162 global_batch_size = 2 163 steps_per_loop = 6 164 dataset = dataset_ops.Dataset.range( 165 8, output_type=dtypes.int32).batch(global_batch_size) 166 distributed_iterator = iter( 167 distribution.experimental_distribute_dataset(dataset)) 168 169 @def_function.function 170 def train_fn(distributed_iterator): 171 172 def step_fn(x): 173 return x 174 175 for _ in math_ops.range(steps_per_loop): 176 optional_data = distributed_iterator.get_next_as_optional() 177 if not optional_data.has_value(): 178 break 179 distribution.run(step_fn, args=(optional_data.get_value(),)) 180 181 train_fn(distributed_iterator) 182 183 @combinations.generate( 184 combinations.combine( 185 distribution=strategy_combinations.tpu_strategies, mode=["eager"])) 186 def testFullEagerTPU(self, distribution): 187 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 188 189 def train_step(data): 190 return math_ops.square(data) 191 192 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 193 194 with self.assertRaisesRegex(NotImplementedError, 195 "does not support pure eager execution"): 196 distribution.run(train_step, args=(next(input_iterator),)) 197 198 @combinations.generate( 199 combinations.combine( 200 distribution=strategy_combinations.all_strategies, 201 mode=["eager"] 202 )) 203 def testStepInFunction(self, distribution): 204 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 205 206 @def_function.function 207 def train_step(data): 208 return math_ops.square(data) 209 210 dist_dataset = distribution.experimental_distribute_dataset(dataset) 211 results = [] 212 for x in dist_dataset: 213 output = distribution.experimental_local_results( 214 distribution.run(train_step, args=(x,))) 215 results.append(output) 216 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 217 218 @combinations.generate( 219 combinations.combine( 220 distribution=strategy_combinations.all_strategies, 221 mode=["eager"] 222 )) 223 def testRunInFunction(self, distribution): 224 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 225 226 def train_step(data): 227 return math_ops.square(data) 228 229 @def_function.function 230 def f_train_step(input_data): 231 return distribution.experimental_local_results( 232 distribution.run(train_step, args=(input_data,))) 233 234 dist_dataset = distribution.experimental_distribute_dataset(dataset) 235 results = [] 236 for x in dist_dataset: 237 output = f_train_step(x) 238 results.append(output) 239 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 240 241 @combinations.generate( 242 combinations.combine( 243 distribution=[ 244 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 245 strategy_combinations.tpu_strategy, 246 strategy_combinations.tpu_strategy_packed_var, 247 ], 248 mode=["eager"])) 249 def testNestedOutput(self, distribution): 250 dataset = get_dataset_from_tensor_slices([0, 1, 2, 3]).batch(2) 251 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 252 253 @def_function.function 254 def run(iterator): 255 256 def computation(x): 257 return [{ 258 "a": x - 1, 259 "b": x + 1 260 }] 261 262 inputs = next(iterator) 263 outputs = distribution.run(computation, args=(inputs,)) 264 return nest.map_structure(distribution.experimental_local_results, 265 outputs) 266 267 results = run(input_iterator) 268 for replica in range(distribution.num_replicas_in_sync): 269 # The input dataset is range(4), so the replica id is same as input. 270 self.assertAllEqual(results[0]["a"][replica], [replica - 1]) 271 self.assertAllEqual(results[0]["b"][replica], [replica + 1]) 272 273 @combinations.generate( 274 combinations.combine( 275 distribution=strategy_combinations.all_strategies, 276 mode=["eager"] 277 )) 278 def testRunInFunctionAutoGraphApplication(self, distribution): 279 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 280 281 def train_step(data): 282 return math_ops.square(data) 283 284 @def_function.function 285 def f_train_step(input_data): 286 return distribution.experimental_local_results( 287 distribution.run(train_step, args=(input_data,))) 288 289 dist_dataset = distribution.experimental_distribute_dataset(dataset) 290 results = [] 291 for x in dist_dataset: 292 output = f_train_step(x) 293 results.append(output) 294 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 295 296 @combinations.generate( 297 combinations.combine( 298 distribution=strategy_combinations.all_strategies, 299 mode=["eager"] 300 )) 301 def testDatasetIterationInFunction(self, distribution): 302 with distribution.scope(): 303 a = variables.Variable( 304 1.0, aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) 305 306 def train_step(_): 307 a.assign_add(1.0) 308 309 @def_function.function 310 def f_train_step(dist_dataset): 311 number_of_steps = constant_op.constant(0.0) 312 product_of_means = constant_op.constant(2.0) 313 for x in dist_dataset: # loop with values modified each iteration 314 number_of_steps += 1 315 product_of_means *= math_ops.cast( 316 distribution.reduce("MEAN", x, axis=0), product_of_means.dtype) 317 318 for y in dist_dataset: # loop with no intermediate state 319 distribution.run(train_step, args=(y,)) 320 321 return number_of_steps, product_of_means 322 323 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 324 dist_dataset = distribution.experimental_distribute_dataset(dataset) 325 326 number_of_steps, product_of_means = f_train_step(dist_dataset) 327 self.assertEqual(2, number_of_steps.numpy()) 328 self.assertNear((2 * (5+6)/2 * (7+8)/2), product_of_means.numpy(), 1e-3) 329 330 # We set the initial value of `a` to 1 and iterate through the dataset 2 331 # times(4/2 where 4 is the number of dataset elements and 2 is the batch 332 # size). Hence the final result is 3. 333 self.assertEqual(3.0, (a.numpy())) 334 335 @combinations.generate( 336 combinations.combine( 337 distribution=strategy_combinations.all_strategies, 338 mode=["eager"] 339 )) 340 def testDatasetAssertWithDynamicBatch(self, distribution): 341 # Regression test for github issue 33517. 342 def step_fn(data): 343 assert_op = control_flow_ops.Assert(math_ops.less_equal( 344 math_ops.reduce_max(data), 100.), [data]) 345 with ops.control_dependencies([assert_op]): 346 return math_ops.square(data) 347 348 @def_function.function 349 def train(dataset): 350 results = [] 351 iterator = iter(dataset) 352 # we iterate through the loop 5 times since we have 3 elements and a 353 # global batch of 2. 354 for _ in range(2): 355 elem = next(iterator) 356 output = distribution.experimental_local_results( 357 distribution.run(step_fn, args=(elem,))) 358 results.append(output) 359 return results 360 361 dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7.,]).batch(2) 362 # TODO(b/138326910): Remove Dataset V1 version once bug resolved. 363 if not tf2.enabled(): 364 dataset = dataset_ops.Dataset.from_tensor_slices([5., 6., 7.,]).batch(2) 365 dist_dataset = distribution.experimental_distribute_dataset(dataset) 366 results = train(dist_dataset) 367 368 expected_results = [[25., 36.], [49.]] 369 self.assertEqual(len(expected_results), len(results)) 370 371 # Need to expand results since output will be grouped differently depending 372 # on the number of replicas. 373 for i, expected_result in enumerate(expected_results): 374 final_result = [] 375 actual_result = results[i] 376 for val in actual_result: 377 final_result.extend(val.numpy()) 378 self.assertAllEqual(expected_result, final_result) 379 380 @combinations.generate( 381 combinations.combine( 382 distribution=strategy_combinations.all_strategies, 383 mode=["eager"] 384 )) 385 def testDistributeDatasetIteratorWithoutFunction(self, distribution): 386 data = [5., 6., 7., 8.] 387 input_iterator = iter( 388 distribution.distribute_datasets_from_function( 389 lambda _: get_dataset_from_tensor_slices(data))) 390 391 self.assertAllEqual( 392 distribution.experimental_local_results(input_iterator.get_next()), 393 data[0:distribution.num_replicas_in_sync]) 394 395 @combinations.generate( 396 combinations.combine( 397 distribution=strategy_combinations.multidevice_strategies, 398 mode=["eager"] 399 )) 400 def testDistributeDatasetIteratorWithFunction(self, distribution): 401 data = [5., 6., 7., 8.] 402 input_iterator = iter( 403 distribution.distribute_datasets_from_function( 404 lambda _: get_dataset_from_tensor_slices(data))) 405 406 @def_function.function 407 def run(iterator): 408 return distribution.experimental_local_results(iterator.get_next()) 409 410 local_results = run(input_iterator) 411 self.assertAllEqual(local_results, 412 data[0:distribution.num_replicas_in_sync]) 413 backing_devices = [result.backing_device for result in local_results] 414 self.assertAllEqual(backing_devices, distribution.extended.worker_devices) 415 416 @combinations.generate( 417 combinations.combine( 418 distribution=strategy_combinations.multidevice_strategies, 419 mode=["eager"] 420 )) 421 def testDistributeDatasetPrefetch(self, distribution): 422 data = [5., 6., 7., 8.] 423 input_iterator = iter( 424 distribution.experimental_distribute_dataset( 425 get_dataset_from_tensor_slices(data).batch(2))) 426 427 local_results = distribution.experimental_local_results( 428 input_iterator.get_next()) 429 430 backing_devices = [result.backing_device for result in local_results] 431 self.assertAllEqual(backing_devices, distribution.extended.worker_devices) 432 433 @combinations.generate( 434 combinations.combine( 435 distribution=strategy_combinations.multidevice_strategies, 436 mode=["eager"] 437 )) 438 def testDistributeDatasetFunctionPrefetch(self, distribution): 439 data = [5., 6., 7., 8.] 440 input_iterator = iter( 441 distribution.distribute_datasets_from_function( 442 lambda _: get_dataset_from_tensor_slices(data))) 443 444 local_results = distribution.experimental_local_results( 445 input_iterator.get_next()) 446 447 backing_devices = [result.backing_device for result in local_results] 448 self.assertAllEqual(backing_devices, distribution.extended.worker_devices) 449 450 @combinations.generate( 451 combinations.combine( 452 distribution=strategy_combinations.tpu_strategies, 453 mode=["eager"] 454 )) 455 def testDistributeDatasetHostPrefetch(self, distribution): 456 data = [5., 6., 7., 8.] 457 input_iterator = iter( 458 distribution.experimental_distribute_dataset( 459 get_dataset_from_tensor_slices(data).batch(2), 460 distribute_lib.InputOptions(experimental_fetch_to_device=False))) 461 462 local_results = distribution.experimental_local_results( 463 input_iterator.get_next()) 464 465 for result in local_results: 466 self.assertEqual(result.backing_device, 467 device_util.resolve("/device:CPU:0")) 468 469 @combinations.generate( 470 combinations.combine( 471 distribution=strategy_combinations.tpu_strategies, 472 mode=["eager"] 473 )) 474 def testDistributeDatasetFunctionHostPrefetch(self, distribution): 475 data = [5., 6., 7., 8.] 476 input_iterator = iter( 477 distribution.distribute_datasets_from_function( 478 lambda _: get_dataset_from_tensor_slices(data), 479 distribute_lib.InputOptions(experimental_fetch_to_device=False))) 480 481 local_results = distribution.experimental_local_results( 482 input_iterator.get_next()) 483 484 for result in local_results: 485 self.assertEqual(result.backing_device, 486 device_util.resolve("/device:CPU:0")) 487 488 @combinations.generate( 489 combinations.combine( 490 distribution=strategy_combinations.multidevice_strategies, 491 mode=["eager"] 492 )) 493 def testDynamicShapes(self, distribution): 494 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 495 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 496 497 @def_function.function 498 def run(iterator): 499 def computation(x): 500 return math_ops.reduce_mean(x) 501 inputs = next(iterator) 502 outputs = distribution.experimental_local_results( 503 distribution.run(computation, args=(inputs,))) 504 return outputs 505 506 # This assumes that there are exactly 2 replicas 507 self.assertAllEqual([5.5, 7.], run(input_iterator)) 508 509 @combinations.generate( 510 combinations.combine( 511 distribution=strategy_combinations.tpu_strategy, mode=["eager"])) 512 def testDynamicShapesWithRunOptionsBucketizing(self, distribution): 513 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 514 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 515 options = distribute_lib.RunOptions( 516 experimental_bucketizing_dynamic_shape=True) 517 518 @def_function.function 519 def run(iterator): 520 521 def computation(x): 522 return math_ops.reduce_mean(x) 523 524 inputs = next(iterator) 525 outputs = distribution.experimental_local_results( 526 distribution.run( 527 computation, args=(inputs,), options=options)) 528 return outputs 529 530 # This assumes that there are exactly 2 replicas 531 self.assertAllEqual([5.5, 7.], run(input_iterator)) 532 533 @combinations.generate( 534 combinations.combine( 535 distribution=strategy_combinations.tpu_strategy, mode=["eager"])) 536 def testDynamicShapesWithRunOptionsDisableDynamicPadder(self, distribution): 537 dataset = get_dataset_from_tensor_slices([5, 6, 7]).batch(4) 538 mask_dataset = get_dataset_from_tensor_slices([1, 0, 1]).batch(4) 539 dataset = dataset_ops.DatasetV2.zip((dataset, mask_dataset)) 540 541 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 542 options = distribute_lib.RunOptions( 543 experimental_xla_options=tpu.XLAOptions( 544 enable_xla_dynamic_padder=False)) 545 546 @def_function.function 547 def run(iterator): 548 549 def computation(inputs): 550 x, mask = inputs 551 y = x * mask 552 return math_ops.reduce_sum(y) 553 554 inputs = next(iterator) 555 outputs = distribution.experimental_local_results( 556 distribution.run(computation, args=(inputs,), options=options)) 557 return outputs 558 559 # This assumes that there are exactly 2 replicas 560 self.assertAllEqual([5, 7], run(input_iterator)) 561 562 @combinations.generate( 563 combinations.combine( 564 distribution=strategy_combinations.multidevice_strategies, 565 mode=["eager"])) 566 def testDynamicOutputsWithX64(self, distribution): 567 dataset = get_dataset_from_tensor_slices( 568 [5]).map(lambda x: math_ops.cast(x, dtypes.int64)).batch(2) 569 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 570 571 @def_function.function 572 def run(iterator): 573 574 def computation(x): 575 return math_ops.add(x, x) 576 577 inputs = next(iterator) 578 outputs = distribution.experimental_local_results( 579 distribution.run(computation, args=(inputs,))) 580 return outputs 581 582 # This assumes that there are exactly 2 replicas 583 result = run(input_iterator) 584 self.assertAllEqual([10], result[0]) 585 self.assertAllEqual([], result[1]) 586 587 @combinations.generate( 588 combinations.combine( 589 distribution=strategy_combinations.multidevice_strategies, 590 mode=["eager"] 591 )) 592 def testDynamicShapesWithGetNextOutsideFunction(self, distribution): 593 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 594 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 595 596 @def_function.function 597 def run(inputs): 598 def computation(x): 599 return math_ops.reduce_mean(x) 600 outputs = distribution.experimental_local_results( 601 distribution.run(computation, args=(inputs,))) 602 return outputs 603 604 # This assumes that there are exactly 2 replicas 605 self.assertAllEqual([5.5, 7.], run(next(input_iterator))) 606 607 @combinations.generate( 608 combinations.combine( 609 distribution=strategy_combinations.multidevice_strategies, 610 mode=["eager"] 611 )) 612 def testStrategyReduceWithDynamicShapes(self, distribution): 613 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 614 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 615 616 @def_function.function 617 def run(iterator): 618 inputs = next(iterator) 619 return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0) 620 621 self.assertAllEqual(6., run(input_iterator)) 622 623 @combinations.generate( 624 combinations.combine( 625 distribution=strategy_combinations.multidevice_strategies, 626 mode=["eager"] 627 )) 628 def testStrategyReduceWithDynamicShapesRank2(self, distribution): 629 dataset = get_dataset_from_tensor_slices( 630 [[1., 1.], [1., 1.], [1., 1.]]).batch(4) 631 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 632 633 @def_function.function 634 def run(iterator): 635 inputs = next(iterator) 636 return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0) 637 638 self.assertAllEqual([1., 1.], run(input_iterator)) 639 640 @combinations.generate( 641 combinations.combine( 642 distribution=strategy_combinations.multidevice_strategies, 643 mode=["eager"] 644 )) 645 def testDynamicShapesWithSizeOp(self, distribution): 646 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 647 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 648 649 @def_function.function 650 def run(inputs): 651 def computation(x): 652 return array_ops.size_v2(x) 653 outputs = distribution.experimental_local_results( 654 distribution.run(computation, args=(inputs,))) 655 return outputs 656 657 # This assumes that there are exactly 2 replicas 658 self.assertAllEqual([2, 1], run(next(input_iterator))) 659 660 @combinations.generate( 661 combinations.combine( 662 distribution=strategy_combinations.multidevice_strategies, 663 mode=["eager"])) 664 def testSegmentSumWithDynamicNumberOfSegments(self, distribution): 665 666 def dataset_fn(_): 667 data = array_ops.zeros(5, dtype=dtypes.int32) 668 dataset = get_dataset_from_tensor_slices(data) 669 dataset = dataset.batch(3) 670 return dataset 671 672 input_iterator = iter( 673 distribution.distribute_datasets_from_function(dataset_fn)) 674 675 @def_function.function 676 def step_fn(example): 677 segment_ids = array_ops.zeros_like_v2(example) 678 num_segment = array_ops.shape(example)[0] 679 # If number of segments is dynamic, output should be a dynamic shape. 680 return math_ops.unsorted_segment_sum(example, segment_ids, num_segment) 681 682 # This assumes that there are exactly 2 replicas 683 outputs = distribution.experimental_local_results( 684 distribution.run(step_fn, args=(next(input_iterator),))) 685 self.assertAllEqual((3,), outputs[0].shape) 686 self.assertAllEqual((2,), outputs[1].shape) 687 688 @combinations.generate( 689 combinations.combine( 690 distribution=strategy_combinations.multidevice_strategies, 691 mode=["eager"])) 692 def testReshapeWithDynamicInputs(self, distribution): 693 694 def dataset_fn(_): 695 data = array_ops.zeros((5, 1, 2), dtype=dtypes.int32) 696 dataset = get_dataset_from_tensor_slices(data) 697 dataset = dataset.batch(3) 698 return dataset 699 700 input_iterator = iter( 701 distribution.distribute_datasets_from_function(dataset_fn)) 702 703 @def_function.function 704 def step_fn(example): 705 # example: [<=3, 1, 2] 706 # tile: [<=3, <=3, 2] 707 tile = array_ops.tile(example, [1, array_ops.shape(example)[0], 1]) 708 # reshape1: [<=(3*3 = 9), 2] 709 reshape1 = array_ops.reshape(tile, [-1, 2]) 710 711 # reshape2: [<=3, <=3, 2] 712 reshape2 = array_ops.reshape( 713 reshape1, 714 [array_ops.shape(example)[0], 715 array_ops.shape(example)[0], 2]) 716 717 # reshape3: [<=3, -1, 2] 718 reshape3 = array_ops.reshape(reshape1, 719 [array_ops.shape(example)[0], -1, 2]) 720 # reshape4: [-1, <=3, 2] 721 reshape4 = array_ops.reshape(reshape1, 722 [-1, array_ops.shape(example)[0], 2]) 723 # Reshape1 is duplicated in order to test dynamic dimension on copies. 724 return [reshape1, reshape2, reshape3, reshape4, reshape1] 725 726 # This assumes that there are exactly 2 replicas 727 outputs = distribution.experimental_local_results( 728 distribution.run(step_fn, args=(next(input_iterator),))) 729 self.assertAllEqual((9, 2), outputs[0][0].shape) 730 self.assertAllEqual((3, 3, 2), outputs[0][1].shape) 731 self.assertAllEqual((3, 3, 2), outputs[0][2].shape) 732 self.assertAllEqual((3, 3, 2), outputs[0][3].shape) 733 self.assertAllEqual((9, 2), outputs[0][4].shape) 734 735 self.assertAllEqual((4, 2), outputs[1][0].shape) 736 self.assertAllEqual((2, 2, 2), outputs[1][1].shape) 737 self.assertAllEqual((2, 2, 2), outputs[1][2].shape) 738 self.assertAllEqual((2, 2, 2), outputs[1][3].shape) 739 self.assertAllEqual((4, 2), outputs[1][4].shape) 740 741 @combinations.generate( 742 combinations.combine( 743 distribution=strategy_combinations.multidevice_strategies, 744 mode=["eager"])) 745 def testDynamicShapesWithFirstReplicaNotMaximumShape(self, distribution): 746 def dataset_fn(_): 747 dataset1 = get_dataset_from_tensor_slices([[1., 2.], [1., 2.]]) 748 dataset2 = get_dataset_from_tensor_slices([[1., 2., 3.], 749 [1., 2., 3.]]) 750 dataset = dataset1.concatenate(dataset2) 751 dataset = dataset.batch(2, drop_remainder=True) 752 return dataset 753 754 input_iterator = iter( 755 distribution.distribute_datasets_from_function(dataset_fn)) 756 757 @def_function.function 758 def run(inputs): 759 def computation(x): 760 return math_ops.reduce_mean(x) 761 outputs = distribution.experimental_local_results( 762 distribution.run(computation, args=(inputs,))) 763 return outputs 764 765 # This assumes that there are exactly 2 replicas 766 self.assertAllEqual([1.5, 2.], run(next(input_iterator))) 767 768 @combinations.generate( 769 combinations.combine( 770 distribution=strategy_combinations.multidevice_strategies, 771 mode=["eager"])) 772 def testMapFnWithDynamicInputs(self, distribution): 773 774 def dataset_fn(_): 775 data = array_ops.zeros((20, 300, 32), dtype=dtypes.int32) 776 dataset = get_dataset_from_tensor_slices(data) 777 dataset = dataset.batch(16) 778 return dataset 779 780 input_iterator = iter( 781 distribution.distribute_datasets_from_function(dataset_fn)) 782 783 def embedding_lookup(inputs): 784 embedding_weights = array_ops.zeros((1, 128)) 785 flat_inputs = array_ops.reshape(inputs, [-1]) 786 embeddings = array_ops.gather(embedding_weights, flat_inputs) 787 embeddings = array_ops.reshape(embeddings, inputs.shape.as_list() + [128]) 788 return embeddings 789 790 @def_function.function 791 def step_fn(example): 792 return map_fn.map_fn( 793 embedding_lookup, example, fn_output_signature=dtypes.float32) 794 795 # This assumes that there are exactly 2 replicas 796 outputs = distribution.experimental_local_results( 797 distribution.run(step_fn, args=(next(input_iterator),))) 798 self.assertAllEqual((16, 300, 32, 128), outputs[0].shape) 799 self.assertAllEqual((4, 300, 32, 128), outputs[1].shape) 800 801 @combinations.generate( 802 combinations.combine( 803 distribution=strategy_combinations.all_strategies, 804 mode=["eager"] 805 )) 806 def testDatasetDistributeEvenlyDivisibleDrop(self, distribution): 807 # If the batch size is evenly divisible by the number of workers and we set 808 # drop_remainder=True on the dataset, then DistributedIterator will use a 809 # different (and more efficient) code path which avoids some control flow 810 # ops. 811 dataset = get_dataset_from_tensor_slices([5., 6.]).batch( 812 2, drop_remainder=True) 813 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 814 815 data = next(input_iterator) 816 817 expected_result = [5., 6.] 818 final_result = [] 819 actual_result = distribution.experimental_local_results(data) 820 for val in actual_result: 821 final_result.extend(val) 822 self.assertAllEqual(expected_result, final_result) 823 824 @combinations.generate( 825 combinations.combine( 826 distribution=strategy_combinations.all_strategies, 827 mode=["eager"] 828 )) 829 def testDatasetDistributeNotDivisibleDrop(self, distribution): 830 # If each batch is not evenly divisible by the number of workers, 831 # the remainder will be dropped. 832 dataset = get_dataset_from_tensor_slices([5., 6.]).batch( 833 1, drop_remainder=True) 834 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 835 836 data = next(input_iterator) 837 838 expected_result = [5.] 839 final_result = [] 840 actual_result = distribution.experimental_local_results(data) 841 for val in actual_result: 842 final_result.extend(val) 843 self.assertAllEqual(expected_result, final_result) 844 845 @combinations.generate( 846 combinations.combine( 847 distribution=strategy_combinations.all_strategies, 848 mode=["eager"] 849 )) 850 def testDatasetDistributeEvenlyDivisibleNoDrop(self, distribution): 851 # Setting drop_remainder=False on the dataset causes DistributedIterator 852 # to use get_next_as_optional(), even if the batched dataset is evenly 853 # divisible by the number of workers. 854 dataset = get_dataset_from_tensor_slices([5., 6.]).batch( 855 2, drop_remainder=False) 856 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 857 858 data = next(input_iterator) 859 860 expected_result = [5., 6.] 861 final_result = [] 862 actual_result = distribution.experimental_local_results(data) 863 for val in actual_result: 864 final_result.extend(val) 865 self.assertAllEqual(expected_result, final_result) 866 867 @combinations.generate( 868 combinations.combine( 869 distribution=strategy_combinations.all_strategies, 870 mode=["eager"] 871 )) 872 def testDatasetPartialBatchWithMixedOutputs(self, distribution): 873 # Dynamic output size with a mix of static and dynamic outputs 874 dataset = get_dataset_from_tensor_slices([5.]).batch(2) 875 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 876 877 @def_function.function 878 def run(iterator): 879 880 def computation(x): 881 # Fixed size output with a dynamic sized output. 882 return array_ops.zeros([3]), math_ops.square(x) 883 884 return distribution.run( 885 computation, args=(next(iterator),)) 886 887 results = run(input_iterator) 888 889 # First result is fixed for all replicas. 890 for replica_id in range(distribution.num_replicas_in_sync): 891 self.assertAllEqual([0., 0., 0.], 892 distribution.experimental_local_results( 893 results[0])[replica_id]) 894 # Only first replica has distributed dataset computation. 895 self.assertAllEqual([25.], 896 distribution.experimental_local_results(results[1])[0]) 897 # Other replicas have no distributed dataset computation. 898 for replica_id in range(1, distribution.num_replicas_in_sync): 899 self.assertAllEqual([], 900 distribution.experimental_local_results( 901 results[1])[replica_id]) 902 903 @combinations.generate( 904 combinations.combine( 905 distribution=strategy_combinations.all_strategies, 906 mode=["eager"] 907 )) 908 def testIterationInsideFunction(self, distribution): 909 910 def step_fn(data): 911 return math_ops.square(data) 912 913 @def_function.function 914 def train(dataset): 915 results = [] 916 iterator = iter(dataset) 917 # we iterate through the loop 2 times since we have 4 elements and a 918 # global batch of 2. 919 for _ in range(2): 920 elem = next(iterator) 921 output = distribution.experimental_local_results( 922 distribution.run(step_fn, args=(elem,))) 923 results.append(output) 924 return results 925 926 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 927 dist_dataset = distribution.experimental_distribute_dataset(dataset) 928 results = train(dist_dataset) 929 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 930 931 @combinations.generate( 932 combinations.combine( 933 distribution=strategy_combinations.all_strategies, 934 mode=["eager"] 935 )) 936 def testIterationOutsideFunction(self, distribution): 937 938 def train_step(data): 939 return math_ops.square(data) 940 941 @def_function.function 942 def f_train_step(input_data): 943 return distribution.experimental_local_results( 944 distribution.run(train_step, args=(input_data,))) 945 946 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 947 dist_dataset = distribution.experimental_distribute_dataset(dataset) 948 iterator = iter(dist_dataset) 949 results = [] 950 # we iterate through the loop 2 times since we have 4 elements and a 951 # global batch of 2. 952 for _ in range(2): 953 output = f_train_step(next(iterator)) 954 results.append(output) 955 self.assert_equal_flattened([[25., 36.], [49., 64.]], results) 956 957 @combinations.generate( 958 combinations.combine( 959 distribution=strategy_combinations.all_strategies, 960 mode=["eager"] 961 )) 962 def testMultiDeviceDataCapturedFunction(self, distribution): 963 inputs = constant_op.constant([2., 3.]) 964 dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5) 965 input_iterator = iter( 966 distribution.distribute_datasets_from_function(dataset)) 967 with distribution.scope(): 968 var = variables.Variable(1.0) 969 970 @def_function.function 971 def train_step(input_iterator): 972 973 def func(inputs): 974 return math_ops.square(inputs) + var 975 976 per_replica_outputs = distribution.run( 977 func, (next(input_iterator),)) 978 mean = distribution.reduce( 979 reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None) 980 for _ in dataset_ops.Dataset.range(1): 981 per_replica_outputs = distribution.run( 982 func, (next(input_iterator),)) 983 mean = distribution.reduce( 984 reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None) 985 return mean 986 987 with distribution.scope(): 988 if distribution.num_replicas_in_sync == 1: 989 self.assertAlmostEqual(10.0, self.evaluate(train_step(input_iterator))) 990 else: 991 self.assertAlmostEqual(7.5, self.evaluate(train_step(input_iterator))) 992 993 @combinations.generate( 994 combinations.combine( 995 distribution=strategy_combinations.all_strategies, 996 mode=["eager"] 997 )) 998 def testDatasetOutOfRange(self, distribution): 999 with distribution.scope(): 1000 a = variables.Variable( 1001 0.0, aggregation=variables.VariableAggregation.SUM) 1002 1003 def train_step(val): 1004 a.assign_add(math_ops.reduce_sum(val)) 1005 1006 @def_function.function 1007 def f_train_step(iterator): 1008 distribution.run(train_step, args=(next(iterator),)) 1009 return a 1010 1011 dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) 1012 dist_dataset = distribution.experimental_distribute_dataset(dataset) 1013 1014 iterator = iter(dist_dataset) 1015 with self.assertRaises(errors.OutOfRangeError): 1016 for _ in range(100): 1017 f_train_step(iterator) 1018 1019 self.assertAlmostEqual(26.0, a.numpy()) 1020 1021 @combinations.generate( 1022 combinations.combine( 1023 distribution=strategy_combinations.multidevice_strategies, 1024 mode=["eager"])) 1025 def testComputeLossWithDynamicShapes(self, distribution): 1026 dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4) 1027 input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) 1028 1029 @def_function.function 1030 def run(iterator): 1031 1032 def computation(x): 1033 return losses.compute_weighted_loss(x, weights=array_ops.ones_like(x)) 1034 1035 inputs = next(iterator) 1036 outputs = distribution.experimental_local_results( 1037 distribution.run(computation, args=(inputs,))) 1038 return outputs 1039 1040 # This assumes that there are exactly 2 replicas 1041 self.assertAllEqual([5.5, 7.], run(input_iterator)) 1042 1043 1044if __name__ == "__main__": 1045 test_util.main() 1046