xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/custom_training_loop_input_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for custom training loops."""
16
17from absl.testing import parameterized
18
19from tensorflow.python import tf2
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.distribute import combinations
22from tensorflow.python.distribute import device_util
23from tensorflow.python.distribute import distribute_lib
24from tensorflow.python.distribute import reduce_util
25from tensorflow.python.distribute import strategy_combinations
26from tensorflow.python.distribute import test_util
27from tensorflow.python.eager import def_function
28from tensorflow.python.eager import test
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import map_fn
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import variables
38from tensorflow.python.ops.losses import losses
39from tensorflow.python.tpu import tpu
40from tensorflow.python.util import nest
41
42
43def get_dataset_from_tensor_slices(inp_array):
44  dataset = dataset_ops.DatasetV2.from_tensor_slices(inp_array)
45  # TODO(b/138326910): Remove Dataset V1 version once bug resolved.
46  if not tf2.enabled():
47    dataset = dataset_ops.Dataset.from_tensor_slices(inp_array)
48  return dataset
49
50
51class AssertFlattenedMixin(object):
52  """Mixin for specialized asserts."""
53
54  def assert_equal_flattened(self, expected_results, actual_results):
55    """Asserts that flattened results are equal.
56
57    Due to the number of replicas in the strategy, the output may have a
58    different structure and needs to be flattened for comparison.
59
60    Args:
61      expected_results: The results expected as a result of a computation.
62      actual_results: The actual results of a computation.
63    """
64    self.assertEqual(len(expected_results), len(actual_results))
65
66    for i, expected_result in enumerate(expected_results):
67      final_result = []
68      actual_result = actual_results[i]
69      for val in actual_result:
70        final_result.extend(val.numpy())
71      self.assertAllEqual(expected_result, final_result)
72
73
74class InputIterationTest(test.TestCase, parameterized.TestCase,
75                         AssertFlattenedMixin):
76
77  @combinations.generate(
78      combinations.combine(
79          distribution=strategy_combinations.all_strategies,
80          mode=["eager"]
81      ))
82  def testConstantNumpyInput(self, distribution):
83
84    @def_function.function
85    def run(x):
86
87      def computation(x):
88        return math_ops.square(x)
89
90      outputs = distribution.experimental_local_results(
91          distribution.run(computation, args=(x,)))
92      return outputs
93
94    self.assertAllEqual(
95        constant_op.constant(4., shape=(distribution.num_replicas_in_sync)),
96        run(2.))
97
98  @combinations.generate(
99      combinations.combine(
100          distribution=strategy_combinations.all_strategies,
101          mode=["eager"]
102      ))
103  def testStatefulExperimentalRunAlwaysExecute(self, distribution):
104    with distribution.scope():
105      v = variables.Variable(
106          0.0, aggregation=variables.VariableAggregation.MEAN)
107
108    @def_function.function
109    def train_step():
110
111      def assign_add():
112        v.assign_add(1.0)
113
114      distribution.run(assign_add)
115      return array_ops.zeros([])
116
117    train_step()
118    self.assertAllEqual(1.0, v.numpy())
119
120  @combinations.generate(
121      combinations.combine(
122          distribution=strategy_combinations.strategies_minus_tpu,
123          mode=["eager"]))
124  def testFullEager(self, distribution):
125    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
126
127    def train_step(data):
128      return math_ops.square(data)
129
130    dist_dataset = distribution.experimental_distribute_dataset(dataset)
131    results = []
132    for x in dist_dataset:
133      output = distribution.experimental_local_results(
134          distribution.run(train_step, args=(x,)))
135      results.append(output)
136    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
137
138  @combinations.generate(
139      combinations.combine(
140          distribution=strategy_combinations.all_strategies, mode=["eager"]))
141  def testGetNextAsOptional(self, distribution):
142    data = [5., 6., 7., 8.]
143    dataset = get_dataset_from_tensor_slices(data).batch(2)
144    dist_dataset = distribution.experimental_distribute_dataset(dataset)
145    iterator = iter(dist_dataset)
146
147    def train_step(data):
148      return math_ops.square(data)
149
150    @def_function.function
151    def run(iterator):
152      return distribution.experimental_local_results(
153          distribution.run(
154              train_step, args=(iterator.get_next_as_optional().get_value(),)))
155
156    self.assert_equal_flattened([[25., 36.]], [run(iterator)])
157
158  @combinations.generate(
159      combinations.combine(
160          distribution=strategy_combinations.all_strategies, mode=["eager"]))
161  def testGetNextAsOptionalExampleUsage(self, distribution):
162    global_batch_size = 2
163    steps_per_loop = 6
164    dataset = dataset_ops.Dataset.range(
165        8, output_type=dtypes.int32).batch(global_batch_size)
166    distributed_iterator = iter(
167        distribution.experimental_distribute_dataset(dataset))
168
169    @def_function.function
170    def train_fn(distributed_iterator):
171
172      def step_fn(x):
173        return x
174
175      for _ in math_ops.range(steps_per_loop):
176        optional_data = distributed_iterator.get_next_as_optional()
177        if not optional_data.has_value():
178          break
179        distribution.run(step_fn, args=(optional_data.get_value(),))
180
181    train_fn(distributed_iterator)
182
183  @combinations.generate(
184      combinations.combine(
185          distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
186  def testFullEagerTPU(self, distribution):
187    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
188
189    def train_step(data):
190      return math_ops.square(data)
191
192    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
193
194    with self.assertRaisesRegex(NotImplementedError,
195                                "does not support pure eager execution"):
196      distribution.run(train_step, args=(next(input_iterator),))
197
198  @combinations.generate(
199      combinations.combine(
200          distribution=strategy_combinations.all_strategies,
201          mode=["eager"]
202      ))
203  def testStepInFunction(self, distribution):
204    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
205
206    @def_function.function
207    def train_step(data):
208      return math_ops.square(data)
209
210    dist_dataset = distribution.experimental_distribute_dataset(dataset)
211    results = []
212    for x in dist_dataset:
213      output = distribution.experimental_local_results(
214          distribution.run(train_step, args=(x,)))
215      results.append(output)
216    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
217
218  @combinations.generate(
219      combinations.combine(
220          distribution=strategy_combinations.all_strategies,
221          mode=["eager"]
222      ))
223  def testRunInFunction(self, distribution):
224    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
225
226    def train_step(data):
227      return math_ops.square(data)
228
229    @def_function.function
230    def f_train_step(input_data):
231      return distribution.experimental_local_results(
232          distribution.run(train_step, args=(input_data,)))
233
234    dist_dataset = distribution.experimental_distribute_dataset(dataset)
235    results = []
236    for x in dist_dataset:
237      output = f_train_step(x)
238      results.append(output)
239    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
240
241  @combinations.generate(
242      combinations.combine(
243          distribution=[
244              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
245              strategy_combinations.tpu_strategy,
246              strategy_combinations.tpu_strategy_packed_var,
247          ],
248          mode=["eager"]))
249  def testNestedOutput(self, distribution):
250    dataset = get_dataset_from_tensor_slices([0, 1, 2, 3]).batch(2)
251    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
252
253    @def_function.function
254    def run(iterator):
255
256      def computation(x):
257        return [{
258            "a": x - 1,
259            "b": x + 1
260        }]
261
262      inputs = next(iterator)
263      outputs = distribution.run(computation, args=(inputs,))
264      return nest.map_structure(distribution.experimental_local_results,
265                                outputs)
266
267    results = run(input_iterator)
268    for replica in range(distribution.num_replicas_in_sync):
269      # The input dataset is range(4), so the replica id is same as input.
270      self.assertAllEqual(results[0]["a"][replica], [replica - 1])
271      self.assertAllEqual(results[0]["b"][replica], [replica + 1])
272
273  @combinations.generate(
274      combinations.combine(
275          distribution=strategy_combinations.all_strategies,
276          mode=["eager"]
277      ))
278  def testRunInFunctionAutoGraphApplication(self, distribution):
279    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
280
281    def train_step(data):
282      return math_ops.square(data)
283
284    @def_function.function
285    def f_train_step(input_data):
286      return distribution.experimental_local_results(
287          distribution.run(train_step, args=(input_data,)))
288
289    dist_dataset = distribution.experimental_distribute_dataset(dataset)
290    results = []
291    for x in dist_dataset:
292      output = f_train_step(x)
293      results.append(output)
294    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
295
296  @combinations.generate(
297      combinations.combine(
298          distribution=strategy_combinations.all_strategies,
299          mode=["eager"]
300      ))
301  def testDatasetIterationInFunction(self, distribution):
302    with distribution.scope():
303      a = variables.Variable(
304          1.0, aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
305
306    def train_step(_):
307      a.assign_add(1.0)
308
309    @def_function.function
310    def f_train_step(dist_dataset):
311      number_of_steps = constant_op.constant(0.0)
312      product_of_means = constant_op.constant(2.0)
313      for x in dist_dataset:  # loop with values modified each iteration
314        number_of_steps += 1
315        product_of_means *= math_ops.cast(
316            distribution.reduce("MEAN", x, axis=0), product_of_means.dtype)
317
318      for y in dist_dataset:  # loop with no intermediate state
319        distribution.run(train_step, args=(y,))
320
321      return number_of_steps, product_of_means
322
323    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
324    dist_dataset = distribution.experimental_distribute_dataset(dataset)
325
326    number_of_steps, product_of_means = f_train_step(dist_dataset)
327    self.assertEqual(2, number_of_steps.numpy())
328    self.assertNear((2 * (5+6)/2 * (7+8)/2), product_of_means.numpy(), 1e-3)
329
330    # We set the initial value of `a` to 1 and iterate through the dataset 2
331    # times(4/2 where 4 is the number of dataset elements and 2 is the batch
332    # size). Hence the final result is 3.
333    self.assertEqual(3.0, (a.numpy()))
334
335  @combinations.generate(
336      combinations.combine(
337          distribution=strategy_combinations.all_strategies,
338          mode=["eager"]
339      ))
340  def testDatasetAssertWithDynamicBatch(self, distribution):
341    # Regression test for github issue 33517.
342    def step_fn(data):
343      assert_op = control_flow_ops.Assert(math_ops.less_equal(
344          math_ops.reduce_max(data), 100.), [data])
345      with ops.control_dependencies([assert_op]):
346        return math_ops.square(data)
347
348    @def_function.function
349    def train(dataset):
350      results = []
351      iterator = iter(dataset)
352      # we iterate through the loop 5 times since we have 3 elements and a
353      # global batch of 2.
354      for _ in range(2):
355        elem = next(iterator)
356        output = distribution.experimental_local_results(
357            distribution.run(step_fn, args=(elem,)))
358        results.append(output)
359      return results
360
361    dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7.,]).batch(2)
362    # TODO(b/138326910): Remove Dataset V1 version once bug resolved.
363    if not tf2.enabled():
364      dataset = dataset_ops.Dataset.from_tensor_slices([5., 6., 7.,]).batch(2)
365    dist_dataset = distribution.experimental_distribute_dataset(dataset)
366    results = train(dist_dataset)
367
368    expected_results = [[25., 36.], [49.]]
369    self.assertEqual(len(expected_results), len(results))
370
371    # Need to expand results since output will be grouped differently depending
372    # on the number of replicas.
373    for i, expected_result in enumerate(expected_results):
374      final_result = []
375      actual_result = results[i]
376      for val in actual_result:
377        final_result.extend(val.numpy())
378      self.assertAllEqual(expected_result, final_result)
379
380  @combinations.generate(
381      combinations.combine(
382          distribution=strategy_combinations.all_strategies,
383          mode=["eager"]
384      ))
385  def testDistributeDatasetIteratorWithoutFunction(self, distribution):
386    data = [5., 6., 7., 8.]
387    input_iterator = iter(
388        distribution.distribute_datasets_from_function(
389            lambda _: get_dataset_from_tensor_slices(data)))
390
391    self.assertAllEqual(
392        distribution.experimental_local_results(input_iterator.get_next()),
393        data[0:distribution.num_replicas_in_sync])
394
395  @combinations.generate(
396      combinations.combine(
397          distribution=strategy_combinations.multidevice_strategies,
398          mode=["eager"]
399      ))
400  def testDistributeDatasetIteratorWithFunction(self, distribution):
401    data = [5., 6., 7., 8.]
402    input_iterator = iter(
403        distribution.distribute_datasets_from_function(
404            lambda _: get_dataset_from_tensor_slices(data)))
405
406    @def_function.function
407    def run(iterator):
408      return distribution.experimental_local_results(iterator.get_next())
409
410    local_results = run(input_iterator)
411    self.assertAllEqual(local_results,
412                        data[0:distribution.num_replicas_in_sync])
413    backing_devices = [result.backing_device for result in local_results]
414    self.assertAllEqual(backing_devices, distribution.extended.worker_devices)
415
416  @combinations.generate(
417      combinations.combine(
418          distribution=strategy_combinations.multidevice_strategies,
419          mode=["eager"]
420      ))
421  def testDistributeDatasetPrefetch(self, distribution):
422    data = [5., 6., 7., 8.]
423    input_iterator = iter(
424        distribution.experimental_distribute_dataset(
425            get_dataset_from_tensor_slices(data).batch(2)))
426
427    local_results = distribution.experimental_local_results(
428        input_iterator.get_next())
429
430    backing_devices = [result.backing_device for result in local_results]
431    self.assertAllEqual(backing_devices, distribution.extended.worker_devices)
432
433  @combinations.generate(
434      combinations.combine(
435          distribution=strategy_combinations.multidevice_strategies,
436          mode=["eager"]
437      ))
438  def testDistributeDatasetFunctionPrefetch(self, distribution):
439    data = [5., 6., 7., 8.]
440    input_iterator = iter(
441        distribution.distribute_datasets_from_function(
442            lambda _: get_dataset_from_tensor_slices(data)))
443
444    local_results = distribution.experimental_local_results(
445        input_iterator.get_next())
446
447    backing_devices = [result.backing_device for result in local_results]
448    self.assertAllEqual(backing_devices, distribution.extended.worker_devices)
449
450  @combinations.generate(
451      combinations.combine(
452          distribution=strategy_combinations.tpu_strategies,
453          mode=["eager"]
454      ))
455  def testDistributeDatasetHostPrefetch(self, distribution):
456    data = [5., 6., 7., 8.]
457    input_iterator = iter(
458        distribution.experimental_distribute_dataset(
459            get_dataset_from_tensor_slices(data).batch(2),
460            distribute_lib.InputOptions(experimental_fetch_to_device=False)))
461
462    local_results = distribution.experimental_local_results(
463        input_iterator.get_next())
464
465    for result in local_results:
466      self.assertEqual(result.backing_device,
467                       device_util.resolve("/device:CPU:0"))
468
469  @combinations.generate(
470      combinations.combine(
471          distribution=strategy_combinations.tpu_strategies,
472          mode=["eager"]
473      ))
474  def testDistributeDatasetFunctionHostPrefetch(self, distribution):
475    data = [5., 6., 7., 8.]
476    input_iterator = iter(
477        distribution.distribute_datasets_from_function(
478            lambda _: get_dataset_from_tensor_slices(data),
479            distribute_lib.InputOptions(experimental_fetch_to_device=False)))
480
481    local_results = distribution.experimental_local_results(
482        input_iterator.get_next())
483
484    for result in local_results:
485      self.assertEqual(result.backing_device,
486                       device_util.resolve("/device:CPU:0"))
487
488  @combinations.generate(
489      combinations.combine(
490          distribution=strategy_combinations.multidevice_strategies,
491          mode=["eager"]
492      ))
493  def testDynamicShapes(self, distribution):
494    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
495    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
496
497    @def_function.function
498    def run(iterator):
499      def computation(x):
500        return math_ops.reduce_mean(x)
501      inputs = next(iterator)
502      outputs = distribution.experimental_local_results(
503          distribution.run(computation, args=(inputs,)))
504      return outputs
505
506    # This assumes that there are exactly 2 replicas
507    self.assertAllEqual([5.5, 7.], run(input_iterator))
508
509  @combinations.generate(
510      combinations.combine(
511          distribution=strategy_combinations.tpu_strategy, mode=["eager"]))
512  def testDynamicShapesWithRunOptionsBucketizing(self, distribution):
513    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
514    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
515    options = distribute_lib.RunOptions(
516        experimental_bucketizing_dynamic_shape=True)
517
518    @def_function.function
519    def run(iterator):
520
521      def computation(x):
522        return math_ops.reduce_mean(x)
523
524      inputs = next(iterator)
525      outputs = distribution.experimental_local_results(
526          distribution.run(
527              computation, args=(inputs,), options=options))
528      return outputs
529
530    # This assumes that there are exactly 2 replicas
531    self.assertAllEqual([5.5, 7.], run(input_iterator))
532
533  @combinations.generate(
534      combinations.combine(
535          distribution=strategy_combinations.tpu_strategy, mode=["eager"]))
536  def testDynamicShapesWithRunOptionsDisableDynamicPadder(self, distribution):
537    dataset = get_dataset_from_tensor_slices([5, 6, 7]).batch(4)
538    mask_dataset = get_dataset_from_tensor_slices([1, 0, 1]).batch(4)
539    dataset = dataset_ops.DatasetV2.zip((dataset, mask_dataset))
540
541    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
542    options = distribute_lib.RunOptions(
543        experimental_xla_options=tpu.XLAOptions(
544            enable_xla_dynamic_padder=False))
545
546    @def_function.function
547    def run(iterator):
548
549      def computation(inputs):
550        x, mask = inputs
551        y = x * mask
552        return math_ops.reduce_sum(y)
553
554      inputs = next(iterator)
555      outputs = distribution.experimental_local_results(
556          distribution.run(computation, args=(inputs,), options=options))
557      return outputs
558
559    # This assumes that there are exactly 2 replicas
560    self.assertAllEqual([5, 7], run(input_iterator))
561
562  @combinations.generate(
563      combinations.combine(
564          distribution=strategy_combinations.multidevice_strategies,
565          mode=["eager"]))
566  def testDynamicOutputsWithX64(self, distribution):
567    dataset = get_dataset_from_tensor_slices(
568        [5]).map(lambda x: math_ops.cast(x, dtypes.int64)).batch(2)
569    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
570
571    @def_function.function
572    def run(iterator):
573
574      def computation(x):
575        return math_ops.add(x, x)
576
577      inputs = next(iterator)
578      outputs = distribution.experimental_local_results(
579          distribution.run(computation, args=(inputs,)))
580      return outputs
581
582    # This assumes that there are exactly 2 replicas
583    result = run(input_iterator)
584    self.assertAllEqual([10], result[0])
585    self.assertAllEqual([], result[1])
586
587  @combinations.generate(
588      combinations.combine(
589          distribution=strategy_combinations.multidevice_strategies,
590          mode=["eager"]
591      ))
592  def testDynamicShapesWithGetNextOutsideFunction(self, distribution):
593    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
594    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
595
596    @def_function.function
597    def run(inputs):
598      def computation(x):
599        return math_ops.reduce_mean(x)
600      outputs = distribution.experimental_local_results(
601          distribution.run(computation, args=(inputs,)))
602      return outputs
603
604    # This assumes that there are exactly 2 replicas
605    self.assertAllEqual([5.5, 7.], run(next(input_iterator)))
606
607  @combinations.generate(
608      combinations.combine(
609          distribution=strategy_combinations.multidevice_strategies,
610          mode=["eager"]
611      ))
612  def testStrategyReduceWithDynamicShapes(self, distribution):
613    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
614    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
615
616    @def_function.function
617    def run(iterator):
618      inputs = next(iterator)
619      return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0)
620
621    self.assertAllEqual(6., run(input_iterator))
622
623  @combinations.generate(
624      combinations.combine(
625          distribution=strategy_combinations.multidevice_strategies,
626          mode=["eager"]
627      ))
628  def testStrategyReduceWithDynamicShapesRank2(self, distribution):
629    dataset = get_dataset_from_tensor_slices(
630        [[1., 1.], [1., 1.], [1., 1.]]).batch(4)
631    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
632
633    @def_function.function
634    def run(iterator):
635      inputs = next(iterator)
636      return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0)
637
638    self.assertAllEqual([1., 1.], run(input_iterator))
639
640  @combinations.generate(
641      combinations.combine(
642          distribution=strategy_combinations.multidevice_strategies,
643          mode=["eager"]
644      ))
645  def testDynamicShapesWithSizeOp(self, distribution):
646    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
647    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
648
649    @def_function.function
650    def run(inputs):
651      def computation(x):
652        return array_ops.size_v2(x)
653      outputs = distribution.experimental_local_results(
654          distribution.run(computation, args=(inputs,)))
655      return outputs
656
657    # This assumes that there are exactly 2 replicas
658    self.assertAllEqual([2, 1], run(next(input_iterator)))
659
660  @combinations.generate(
661      combinations.combine(
662          distribution=strategy_combinations.multidevice_strategies,
663          mode=["eager"]))
664  def testSegmentSumWithDynamicNumberOfSegments(self, distribution):
665
666    def dataset_fn(_):
667      data = array_ops.zeros(5, dtype=dtypes.int32)
668      dataset = get_dataset_from_tensor_slices(data)
669      dataset = dataset.batch(3)
670      return dataset
671
672    input_iterator = iter(
673        distribution.distribute_datasets_from_function(dataset_fn))
674
675    @def_function.function
676    def step_fn(example):
677      segment_ids = array_ops.zeros_like_v2(example)
678      num_segment = array_ops.shape(example)[0]
679      # If number of segments is dynamic, output should be a dynamic shape.
680      return math_ops.unsorted_segment_sum(example, segment_ids, num_segment)
681
682    # This assumes that there are exactly 2 replicas
683    outputs = distribution.experimental_local_results(
684        distribution.run(step_fn, args=(next(input_iterator),)))
685    self.assertAllEqual((3,), outputs[0].shape)
686    self.assertAllEqual((2,), outputs[1].shape)
687
688  @combinations.generate(
689      combinations.combine(
690          distribution=strategy_combinations.multidevice_strategies,
691          mode=["eager"]))
692  def testReshapeWithDynamicInputs(self, distribution):
693
694    def dataset_fn(_):
695      data = array_ops.zeros((5, 1, 2), dtype=dtypes.int32)
696      dataset = get_dataset_from_tensor_slices(data)
697      dataset = dataset.batch(3)
698      return dataset
699
700    input_iterator = iter(
701        distribution.distribute_datasets_from_function(dataset_fn))
702
703    @def_function.function
704    def step_fn(example):
705      # example: [<=3, 1, 2]
706      # tile: [<=3, <=3, 2]
707      tile = array_ops.tile(example, [1, array_ops.shape(example)[0], 1])
708      # reshape1: [<=(3*3 = 9), 2]
709      reshape1 = array_ops.reshape(tile, [-1, 2])
710
711      # reshape2: [<=3, <=3, 2]
712      reshape2 = array_ops.reshape(
713          reshape1,
714          [array_ops.shape(example)[0],
715           array_ops.shape(example)[0], 2])
716
717      # reshape3: [<=3, -1, 2]
718      reshape3 = array_ops.reshape(reshape1,
719                                   [array_ops.shape(example)[0], -1, 2])
720      # reshape4: [-1, <=3, 2]
721      reshape4 = array_ops.reshape(reshape1,
722                                   [-1, array_ops.shape(example)[0], 2])
723      # Reshape1 is duplicated in order to test dynamic dimension on copies.
724      return [reshape1, reshape2, reshape3, reshape4, reshape1]
725
726    # This assumes that there are exactly 2 replicas
727    outputs = distribution.experimental_local_results(
728        distribution.run(step_fn, args=(next(input_iterator),)))
729    self.assertAllEqual((9, 2), outputs[0][0].shape)
730    self.assertAllEqual((3, 3, 2), outputs[0][1].shape)
731    self.assertAllEqual((3, 3, 2), outputs[0][2].shape)
732    self.assertAllEqual((3, 3, 2), outputs[0][3].shape)
733    self.assertAllEqual((9, 2), outputs[0][4].shape)
734
735    self.assertAllEqual((4, 2), outputs[1][0].shape)
736    self.assertAllEqual((2, 2, 2), outputs[1][1].shape)
737    self.assertAllEqual((2, 2, 2), outputs[1][2].shape)
738    self.assertAllEqual((2, 2, 2), outputs[1][3].shape)
739    self.assertAllEqual((4, 2), outputs[1][4].shape)
740
741  @combinations.generate(
742      combinations.combine(
743          distribution=strategy_combinations.multidevice_strategies,
744          mode=["eager"]))
745  def testDynamicShapesWithFirstReplicaNotMaximumShape(self, distribution):
746    def dataset_fn(_):
747      dataset1 = get_dataset_from_tensor_slices([[1., 2.], [1., 2.]])
748      dataset2 = get_dataset_from_tensor_slices([[1., 2., 3.],
749                                                 [1., 2., 3.]])
750      dataset = dataset1.concatenate(dataset2)
751      dataset = dataset.batch(2, drop_remainder=True)
752      return dataset
753
754    input_iterator = iter(
755        distribution.distribute_datasets_from_function(dataset_fn))
756
757    @def_function.function
758    def run(inputs):
759      def computation(x):
760        return math_ops.reduce_mean(x)
761      outputs = distribution.experimental_local_results(
762          distribution.run(computation, args=(inputs,)))
763      return outputs
764
765    # This assumes that there are exactly 2 replicas
766    self.assertAllEqual([1.5, 2.], run(next(input_iterator)))
767
768  @combinations.generate(
769      combinations.combine(
770          distribution=strategy_combinations.multidevice_strategies,
771          mode=["eager"]))
772  def testMapFnWithDynamicInputs(self, distribution):
773
774    def dataset_fn(_):
775      data = array_ops.zeros((20, 300, 32), dtype=dtypes.int32)
776      dataset = get_dataset_from_tensor_slices(data)
777      dataset = dataset.batch(16)
778      return dataset
779
780    input_iterator = iter(
781        distribution.distribute_datasets_from_function(dataset_fn))
782
783    def embedding_lookup(inputs):
784      embedding_weights = array_ops.zeros((1, 128))
785      flat_inputs = array_ops.reshape(inputs, [-1])
786      embeddings = array_ops.gather(embedding_weights, flat_inputs)
787      embeddings = array_ops.reshape(embeddings, inputs.shape.as_list() + [128])
788      return embeddings
789
790    @def_function.function
791    def step_fn(example):
792      return map_fn.map_fn(
793          embedding_lookup, example, fn_output_signature=dtypes.float32)
794
795    # This assumes that there are exactly 2 replicas
796    outputs = distribution.experimental_local_results(
797        distribution.run(step_fn, args=(next(input_iterator),)))
798    self.assertAllEqual((16, 300, 32, 128), outputs[0].shape)
799    self.assertAllEqual((4, 300, 32, 128), outputs[1].shape)
800
801  @combinations.generate(
802      combinations.combine(
803          distribution=strategy_combinations.all_strategies,
804          mode=["eager"]
805      ))
806  def testDatasetDistributeEvenlyDivisibleDrop(self, distribution):
807    # If the batch size is evenly divisible by the number of workers and we set
808    # drop_remainder=True on the dataset, then DistributedIterator will use a
809    # different (and more efficient) code path which avoids some control flow
810    # ops.
811    dataset = get_dataset_from_tensor_slices([5., 6.]).batch(
812        2, drop_remainder=True)
813    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
814
815    data = next(input_iterator)
816
817    expected_result = [5., 6.]
818    final_result = []
819    actual_result = distribution.experimental_local_results(data)
820    for val in actual_result:
821      final_result.extend(val)
822    self.assertAllEqual(expected_result, final_result)
823
824  @combinations.generate(
825      combinations.combine(
826          distribution=strategy_combinations.all_strategies,
827          mode=["eager"]
828      ))
829  def testDatasetDistributeNotDivisibleDrop(self, distribution):
830    # If each batch is not evenly divisible by the number of workers,
831    # the remainder will be dropped.
832    dataset = get_dataset_from_tensor_slices([5., 6.]).batch(
833        1, drop_remainder=True)
834    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
835
836    data = next(input_iterator)
837
838    expected_result = [5.]
839    final_result = []
840    actual_result = distribution.experimental_local_results(data)
841    for val in actual_result:
842      final_result.extend(val)
843    self.assertAllEqual(expected_result, final_result)
844
845  @combinations.generate(
846      combinations.combine(
847          distribution=strategy_combinations.all_strategies,
848          mode=["eager"]
849      ))
850  def testDatasetDistributeEvenlyDivisibleNoDrop(self, distribution):
851    # Setting drop_remainder=False on the dataset causes DistributedIterator
852    # to use get_next_as_optional(), even if the batched dataset is evenly
853    # divisible by the number of workers.
854    dataset = get_dataset_from_tensor_slices([5., 6.]).batch(
855        2, drop_remainder=False)
856    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
857
858    data = next(input_iterator)
859
860    expected_result = [5., 6.]
861    final_result = []
862    actual_result = distribution.experimental_local_results(data)
863    for val in actual_result:
864      final_result.extend(val)
865    self.assertAllEqual(expected_result, final_result)
866
867  @combinations.generate(
868      combinations.combine(
869          distribution=strategy_combinations.all_strategies,
870          mode=["eager"]
871      ))
872  def testDatasetPartialBatchWithMixedOutputs(self, distribution):
873    # Dynamic output size with a mix of static and dynamic outputs
874    dataset = get_dataset_from_tensor_slices([5.]).batch(2)
875    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
876
877    @def_function.function
878    def run(iterator):
879
880      def computation(x):
881        # Fixed size output with a dynamic sized output.
882        return array_ops.zeros([3]), math_ops.square(x)
883
884      return distribution.run(
885          computation, args=(next(iterator),))
886
887    results = run(input_iterator)
888
889    # First result is fixed for all replicas.
890    for replica_id in range(distribution.num_replicas_in_sync):
891      self.assertAllEqual([0., 0., 0.],
892                          distribution.experimental_local_results(
893                              results[0])[replica_id])
894    # Only first replica has distributed dataset computation.
895    self.assertAllEqual([25.],
896                        distribution.experimental_local_results(results[1])[0])
897    # Other replicas have no distributed dataset computation.
898    for replica_id in range(1, distribution.num_replicas_in_sync):
899      self.assertAllEqual([],
900                          distribution.experimental_local_results(
901                              results[1])[replica_id])
902
903  @combinations.generate(
904      combinations.combine(
905          distribution=strategy_combinations.all_strategies,
906          mode=["eager"]
907      ))
908  def testIterationInsideFunction(self, distribution):
909
910    def step_fn(data):
911      return math_ops.square(data)
912
913    @def_function.function
914    def train(dataset):
915      results = []
916      iterator = iter(dataset)
917      # we iterate through the loop 2 times since we have 4 elements and a
918      # global batch of 2.
919      for _ in range(2):
920        elem = next(iterator)
921        output = distribution.experimental_local_results(
922            distribution.run(step_fn, args=(elem,)))
923        results.append(output)
924      return results
925
926    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
927    dist_dataset = distribution.experimental_distribute_dataset(dataset)
928    results = train(dist_dataset)
929    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
930
931  @combinations.generate(
932      combinations.combine(
933          distribution=strategy_combinations.all_strategies,
934          mode=["eager"]
935      ))
936  def testIterationOutsideFunction(self, distribution):
937
938    def train_step(data):
939      return math_ops.square(data)
940
941    @def_function.function
942    def f_train_step(input_data):
943      return distribution.experimental_local_results(
944          distribution.run(train_step, args=(input_data,)))
945
946    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
947    dist_dataset = distribution.experimental_distribute_dataset(dataset)
948    iterator = iter(dist_dataset)
949    results = []
950    # we iterate through the loop 2 times since we have 4 elements and a
951    # global batch of 2.
952    for _ in range(2):
953      output = f_train_step(next(iterator))
954      results.append(output)
955    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
956
957  @combinations.generate(
958      combinations.combine(
959          distribution=strategy_combinations.all_strategies,
960          mode=["eager"]
961      ))
962  def testMultiDeviceDataCapturedFunction(self, distribution):
963    inputs = constant_op.constant([2., 3.])
964    dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
965    input_iterator = iter(
966        distribution.distribute_datasets_from_function(dataset))
967    with distribution.scope():
968      var = variables.Variable(1.0)
969
970    @def_function.function
971    def train_step(input_iterator):
972
973      def func(inputs):
974        return math_ops.square(inputs) + var
975
976      per_replica_outputs = distribution.run(
977          func, (next(input_iterator),))
978      mean = distribution.reduce(
979          reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None)
980      for _ in dataset_ops.Dataset.range(1):
981        per_replica_outputs = distribution.run(
982            func, (next(input_iterator),))
983        mean = distribution.reduce(
984            reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None)
985      return mean
986
987    with distribution.scope():
988      if distribution.num_replicas_in_sync == 1:
989        self.assertAlmostEqual(10.0, self.evaluate(train_step(input_iterator)))
990      else:
991        self.assertAlmostEqual(7.5, self.evaluate(train_step(input_iterator)))
992
993  @combinations.generate(
994      combinations.combine(
995          distribution=strategy_combinations.all_strategies,
996          mode=["eager"]
997      ))
998  def testDatasetOutOfRange(self, distribution):
999    with distribution.scope():
1000      a = variables.Variable(
1001          0.0, aggregation=variables.VariableAggregation.SUM)
1002
1003    def train_step(val):
1004      a.assign_add(math_ops.reduce_sum(val))
1005
1006    @def_function.function
1007    def f_train_step(iterator):
1008      distribution.run(train_step, args=(next(iterator),))
1009      return a
1010
1011    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
1012    dist_dataset = distribution.experimental_distribute_dataset(dataset)
1013
1014    iterator = iter(dist_dataset)
1015    with self.assertRaises(errors.OutOfRangeError):
1016      for _ in range(100):
1017        f_train_step(iterator)
1018
1019    self.assertAlmostEqual(26.0, a.numpy())
1020
1021  @combinations.generate(
1022      combinations.combine(
1023          distribution=strategy_combinations.multidevice_strategies,
1024          mode=["eager"]))
1025  def testComputeLossWithDynamicShapes(self, distribution):
1026    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
1027    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
1028
1029    @def_function.function
1030    def run(iterator):
1031
1032      def computation(x):
1033        return losses.compute_weighted_loss(x, weights=array_ops.ones_like(x))
1034
1035      inputs = next(iterator)
1036      outputs = distribution.experimental_local_results(
1037          distribution.run(computation, args=(inputs,)))
1038      return outputs
1039
1040    # This assumes that there are exactly 2 replicas
1041    self.assertAllEqual([5.5, 7.], run(input_iterator))
1042
1043
1044if __name__ == "__main__":
1045  test_util.main()
1046