1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for coordinator.py."""
16
17import collections
18import contextlib
19import functools
20import gc
21import os
22import platform
23import sys
24import threading
25import time
26import traceback
27from absl.testing import parameterized
28
29from tensorflow.python.compat import v2_compat
30from tensorflow.python.data.ops import dataset_ops
31from tensorflow.python.distribute import distribute_utils
32from tensorflow.python.distribute import distribution_strategy_context
33from tensorflow.python.distribute import input_lib
34from tensorflow.python.distribute import multi_worker_test_base
35from tensorflow.python.distribute import parameter_server_strategy_v2
36from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
37from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
38from tensorflow.python.distribute.coordinator import values as values_lib
39from tensorflow.python.eager import cancellation
40from tensorflow.python.eager import def_function
41from tensorflow.python.eager import test
42from tensorflow.python.framework import constant_op
43from tensorflow.python.framework import dtypes
44from tensorflow.python.framework import errors
45from tensorflow.python.framework import random_seed
46from tensorflow.python.framework import tensor_spec
47from tensorflow.python.framework import test_util
48from tensorflow.python.ops import array_ops
49from tensorflow.python.ops import check_ops
50from tensorflow.python.ops import math_ops
51from tensorflow.python.ops import random_ops
52from tensorflow.python.ops import variable_scope
53from tensorflow.python.ops import variables
54from tensorflow.python.platform import tf_logging as logging
55from tensorflow.python.training import coordinator
56from tensorflow.python.training.server_lib import ClusterSpec
57
58
59class ClosureWithOutput(coordinator_lib.Closure):
60
61  def __init__(self, function, cancellation_mgr=None, args=None, kwargs=None):
62    super(ClosureWithOutput, self).__init__(
63        function, cancellation_mgr=cancellation_mgr, args=args, kwargs=kwargs)
64    self.output_remote_value = self.build_output_remote_value()
65
66
67class CoordinatedClosureQueueTest(test.TestCase):
68
69  def testBasic(self):
70    queue = coordinator_lib._CoordinatedClosureQueue()
71    closure1 = self._create_closure(queue._cancellation_mgr)
72    queue.put(closure1)
73    self.assertIs(closure1, queue.get())
74    self.assertFalse(queue.done())
75    queue.put_back(closure1)
76    self.assertEqual(closure1, queue.get())
77    queue.mark_finished()
78    self.assertTrue(queue.done())
79    queue.wait()
80
81  def testProcessAtLeaseOnce(self):
82    closure_queue = coordinator_lib._CoordinatedClosureQueue()
83    labels = ['A', 'B', 'C', 'D', 'E']
84    processed_count = collections.defaultdict(int)
85
86    coord = coordinator.Coordinator(clean_stop_exception_types=[])
87
88    def process_queue():
89      with coord.stop_on_exception():
90        has_been_put_back = False
91        while True:
92          closure = closure_queue.get(timeout=30)
93          if closure is None:
94            break
95          if not has_been_put_back:
96            has_been_put_back = True
97            closure_queue.put_back(closure)
98            continue
99          closure._function()
100          closure_queue.mark_finished()
101
102    def get_func(label):
103
104      def func():
105        time.sleep(3)
106        processed_count[label] += 1
107
108      return func
109
110    cm = cancellation.CancellationManager()
111    for label in labels:
112      closure_queue.put(ClosureWithOutput(get_func(label), cm))
113    t1 = threading.Thread(target=process_queue, daemon=True)
114    t1.start()
115    t2 = threading.Thread(target=process_queue, daemon=True)
116    t2.start()
117
118    # Make sure multiple wait() calls are fine.
119    closure_queue.wait()
120    closure_queue.wait()
121    closure_queue.wait()
122    closure_queue.wait()
123
124    self.assertEqual(processed_count, collections.Counter(labels))
125
126    coord.join([t1, t2])
127
128  def testNotifyBeforeWait(self):
129    closure_queue = coordinator_lib._CoordinatedClosureQueue()
130
131    def func():
132      logging.info('func running')
133
134    coord = coordinator.Coordinator(clean_stop_exception_types=[])
135
136    def process_queue():
137      with coord.stop_on_exception():
138        closure_queue.get()
139        closure_queue.mark_finished()
140
141    closure_queue.put(ClosureWithOutput(func, closure_queue._cancellation_mgr))
142    t = threading.Thread(target=process_queue)
143    t.start()
144    coord.join([t])
145
146    # This test asserts that waiting at the time the function has been processed
147    # doesn't time out.
148    closure_queue.wait()
149
150  def _assert_one_unblock_the_other(self, first_fn, second_fn):
151    """Asserts `second_fn` wouldn't return before `first_fn` is finished."""
152    first_fn_done = threading.Event()
153    second_fn_done = threading.Event()
154    coord = coordinator.Coordinator(clean_stop_exception_types=[])
155
156    def wrapped_first_fn():
157      with coord.stop_on_exception():
158        self.assertFalse(second_fn_done.is_set())
159        first_fn()
160        first_fn_done.set()
161
162    self.assertFalse(first_fn_done.is_set())
163    t = threading.Thread(target=wrapped_first_fn)
164    t.start()
165
166    second_fn()
167    self.assertTrue(first_fn_done.is_set())
168    second_fn_done.set()
169
170    coord.join([t])
171
172  def _run_two_fns_in_parallel(self, first_fn, second_fn):
173    coord = coordinator.Coordinator(clean_stop_exception_types=[])
174
175    def wrapped_first_fn():
176      with coord.stop_on_exception():
177        first_fn()
178
179    t = threading.Thread(target=wrapped_first_fn)
180    t.start()
181
182    second_fn()
183    coord.join([t])
184
185  def testWaitRaiseErrorAfterMarkFailure(self):
186    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
187      # TODO(b/165013260): Fix this
188      self.skipTest('Test is currently broken on Windows with Python 3.8')
189
190    closure_queue = coordinator_lib._CoordinatedClosureQueue()
191    closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
192    closure = closure_queue.get()
193
194    wait_finish_event = threading.Event()
195    coord = coordinator.Coordinator(clean_stop_exception_types=[])
196
197    # Using a thread to verify that closure_queue.wait() will not return until
198    # all inflight closures are finished.
199
200    def mark_finished_fn():
201      try:
202        raise ValueError('Some error.')
203      except ValueError as e:
204        closure_queue.mark_failed(e)
205
206    def wait_fn():
207      with self.assertRaises(ValueError):
208        closure_queue.wait()
209
210    self._assert_one_unblock_the_other(mark_finished_fn, wait_fn)
211
212    self.assertTrue(closure_queue.done())
213
214  def _create_closure(self, cancellation_mgr):
215
216    @def_function.function()
217    def some_function():
218      return 1.0
219
220    return ClosureWithOutput(some_function, cancellation_mgr)
221
222  def _put_two_closures_and_get_one(self):
223    closure_queue = coordinator_lib._CoordinatedClosureQueue()
224    closure1 = self._create_closure(closure_queue._cancellation_mgr)
225    closure_queue.put(closure1)
226
227    closure2 = self._create_closure(closure_queue._cancellation_mgr)
228    closure_queue.put(closure2)
229
230    closure_got = closure_queue.get()  # returns closure1
231    self.assertIs(closure_got, closure1)
232    self.assertIsNot(closure_got, closure2)
233    return closure_queue, closure1, closure2
234
235  def testPutRaiseError(self):
236    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
237      # TODO(b/165013260): Fix this
238      self.skipTest('Test is currently broken on Windows with Python 3.8')
239
240    closure_queue, _, closure2 = self._put_two_closures_and_get_one()
241
242    closure_queue.mark_failed(ValueError())
243
244    with self.assertRaises(ValueError):
245      closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
246
247    self.assertTrue(closure_queue.done())
248
249    with self.assertRaisesRegex(
250        errors.CancelledError,
251        'The corresponding function is cancelled. Please reschedule the '
252        'function.'):
253      closure2.output_remote_value.fetch()
254
255    # The error is cleared.
256    closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
257
258  def testWaitRaiseError(self):
259    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
260      # TODO(b/165013260): Fix this
261      self.skipTest('Test is currently broken on Windows with Python 3.8')
262
263    closure_queue, _, closure2 = self._put_two_closures_and_get_one()
264
265    closure_queue.mark_failed(ValueError())
266
267    with self.assertRaises(ValueError):
268      closure_queue.wait()
269    self.assertTrue(closure_queue.done())
270
271    with self.assertRaisesRegex(
272        errors.CancelledError,
273        'The corresponding function is cancelled. Please reschedule the '
274        'function.'):
275      closure2.output_remote_value.fetch()
276
277    # The error is cleared.
278    closure_queue.wait()
279
280  def testDoneRaiseError(self):
281    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
282      # TODO(b/165013260): Fix this
283      self.skipTest('Test is currently broken on Windows with Python 3.8')
284
285    closure_queue, _, _ = self._put_two_closures_and_get_one()
286
287    self.assertFalse(closure_queue.done())
288    closure_queue.mark_failed(ValueError())
289    with self.assertRaises(ValueError):
290      closure_queue.done()
291
292  def _set_error(self, closure_queue, closure, error):
293    try:
294      raise error
295    except Exception as e:  # pylint: disable=broad-except
296      closure.output_remote_value._set_error(e)
297      closure_queue.mark_failed(e)
298
299  def _test_cancel_closure_when_error(self, call_wait):
300    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
301      # TODO(b/165013260): Fix this
302      self.skipTest('Test is currently broken on Windows with Python 3.8')
303
304    closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
305    closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
306    closure_queue.get()
307    # At this moment, there are two inflight, one in queue.
308    self.assertEqual(closure_queue._inflight_closure_count, 2)
309
310    # Hold a copy of the queue's cancellation manager at this point
311    initial_cm = closure_queue._cancellation_mgr
312
313    # Simulating closure1 fails.
314    self._set_error(closure_queue, closure1, ValueError('Some error.'))
315
316    # At this moment, there are one inflight, one in queue.
317    self.assertEqual(closure_queue._queue.qsize(), 1)
318    self.assertEqual(closure_queue._inflight_closure_count, 1)
319
320    closure3 = self._create_closure(closure_queue._cancellation_mgr)
321
322    def fake_cancellation():
323      self._set_error(closure_queue, closure2,
324                      ValueError('Fake cancellation error.'))
325
326    def report_error():
327      # It should not report the fake cancellation error.
328      with self.assertRaisesRegex(ValueError, 'Some error.'):
329        # Verifying `wait()` or `put()` raises even if one closure is in
330        # flight.
331        if call_wait:
332          closure_queue.wait()
333        else:
334          closure_queue.put(closure3)
335
336    self._assert_one_unblock_the_other(fake_cancellation, report_error)
337
338    # The original cancellation manager of the queue has been cancelled.
339    self.assertTrue(initial_cm.is_cancelled)
340
341    # At this moment, there is zero inflight, nothing in queue.
342    self.assertTrue(closure_queue._queue.empty())
343    self.assertEqual(closure_queue._inflight_closure_count, 0)
344    self.assertIsNone(closure_queue._error)
345
346    # This asserts that closure1 has errored.
347    with self.assertRaisesRegex(ValueError, 'Some error.'):
348      closure1.output_remote_value.fetch()
349
350    # The following asserts that closure3 should have been cancelled.
351    if not call_wait:
352      with self.assertRaisesRegex(
353          errors.CancelledError,
354          'The corresponding function is cancelled. Please reschedule the '
355          'function.'):
356        closure3.output_remote_value.fetch()
357
358    # Closure2 was an inflight closure when it got cancelled.
359    self.assertEqual(closure2.output_remote_value._status,
360                     values_lib.RemoteValueStatus.READY)
361    with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'):
362      closure2.output_remote_value.fetch()
363
364    # This asserts that the queue has a clear state.
365    self.testBasic()
366
367  def testWaitRaiseErrorAfterCancelClosure(self):
368    self._test_cancel_closure_when_error(call_wait=True)
369
370  def testPutRaiseErrorAfterCancelClosure(self):
371    self._test_cancel_closure_when_error(call_wait=False)
372
373  def testStateIsRestoredAfterJoinIsCalled(self):
374    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
375      # TODO(b/165013260): Fix this
376      self.skipTest('Test is currently broken on Windows with Python 3.8')
377
378    closure_queue, _, _ = self._put_two_closures_and_get_one()
379    self.assertEqual(closure_queue._inflight_closure_count, 1)
380    closure_queue.mark_failed(ValueError('test error'))
381    with self.assertRaises(ValueError):
382      closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
383
384    # Its error should have been cleared.
385    self.assertIsNone(closure_queue._error)
386    closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
387    self.assertIsNone(closure_queue._error)
388
389  def testThreadSafey(self):
390    thread_count = 10
391    queue = coordinator_lib._CoordinatedClosureQueue()
392
393    # Each thread performs 20 queue actions: 10 are `put_back` and 10 are
394    # `mark_finished`.
395    action_count = 20
396
397    def func():
398      for i in range(action_count):
399        closure = queue.get()
400        if i % 2 == 0:
401          queue.put_back(closure)
402        else:
403          queue.mark_finished()
404
405    threads = [threading.Thread(target=func) for i in range(thread_count)]
406    for t in threads:
407      t.start()
408
409    for _ in range(thread_count * action_count // 2):
410      queue.put(self._create_closure(queue._cancellation_mgr))
411    queue.wait()
412    self.assertTrue(queue.done())
413
414  def testPutGetWithTag(self):
415    queue = coordinator_lib._CoordinatedClosureQueue()
416
417    closure1 = self._create_closure(queue._cancellation_mgr)
418    closure2 = self._create_closure(queue._cancellation_mgr)
419    closure3 = self._create_closure(queue._cancellation_mgr)
420
421    def put_fn():
422      queue.put(closure3, tag=1)
423      queue.put(closure2, tag=2)
424      queue.put(closure1)
425
426    def get_fn():
427      # The get should only return the closure with tag 2.
428      self.assertIs(closure2, queue.get(tag=2))
429
430    self._run_two_fns_in_parallel(put_fn, get_fn)
431
432    self.assertFalse(queue.done())
433    self.assertEqual(closure1, queue.get())
434    queue.mark_finished()
435
436    # It will not wait for closure3
437    self.assertTrue(queue.done())
438    queue.wait()
439
440
441class ErrorReportingThread(threading.Thread):
442
443  error = None
444
445  def __init__(self, *args, **kwargs):
446    assert 'target' in kwargs
447    target = kwargs['target']
448
449    @functools.wraps(target)
450    def wrapped_target(*args, **kwargs):
451      try:
452        return target(*args, **kwargs)
453      except Exception as e:  # pylint: disable=broad-except
454        traceback.print_exception(*sys.exc_info())
455        ErrorReportingThread.error = e
456
457    kwargs['target'] = wrapped_target
458    super(ErrorReportingThread, self).__init__(*args, **kwargs)
459
460
461class TestCaseWithErrorReportingThread(test.TestCase):
462
463  @classmethod
464  def setUpClass(cls):
465    cls._threading_thread = threading.Thread
466    threading.Thread = ErrorReportingThread
467    super(TestCaseWithErrorReportingThread, cls).setUpClass()
468
469  @classmethod
470  def tearDownClass(cls):
471    super(TestCaseWithErrorReportingThread, cls).tearDownClass()
472    threading.Thread = cls._threading_thread
473
474  def setUp(self):
475    ErrorReportingThread.error = None
476    super(TestCaseWithErrorReportingThread, self).setUp()
477
478  def tearDown(self):
479    super(TestCaseWithErrorReportingThread, self).tearDown()
480    if ErrorReportingThread.error:
481      raise ErrorReportingThread.error  # pylint: disable=raising-bad-type
482
483
484def make_coordinator(num_workers, num_ps):
485  # TODO(rchao): Test the internal rpc_layer version.
486  cluster_def = multi_worker_test_base.create_in_process_cluster(
487      num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc')
488  cluster_def['chief'] = [
489      'localhost:%d' % test_util.pick_unused_port()
490  ]
491  cluster_resolver = SimpleClusterResolver(
492      ClusterSpec(cluster_def), rpc_layer='grpc')
493  strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
494      cluster_resolver)
495  return coordinator_lib.ClusterCoordinator(strategy)
496
497
498class ClusterCoordinatorTest(TestCaseWithErrorReportingThread,
499                             parameterized.TestCase):
500
501  @classmethod
502  def setUpClass(cls):
503    super(ClusterCoordinatorTest, cls).setUpClass()
504    cls.coordinator = make_coordinator(num_workers=5, num_ps=2)
505    cls.strategy = cls.coordinator.strategy
506
507  def testClusterCoordinatorOnlyInitOnce(self):
508    cluster = self.coordinator._cluster
509    same_coordinator = coordinator_lib.ClusterCoordinator(self.strategy)
510    self.assertIs(self.coordinator, same_coordinator)
511    self.assertIs(cluster, same_coordinator._cluster)
512
513  def testFnReturnNestedValues(self):
514    x = constant_op.constant(1)
515
516    @def_function.function
517    def f():
518      return x + 1, (x + 2, x + 3), [x + 4], {'v': x}
519
520    got = self.coordinator.schedule(f)
521    want = 2, (3, 4), [5], {'v': 1}
522    self.assertEqual(got.fetch(), want)
523    self.assertEqual(self.coordinator.fetch(got), want)
524
525  def testFetchingRemoteValueStructure(self):
526    self.skipTest('b/171040359: flaky test')
527    x = constant_op.constant(1)
528
529    @def_function.function
530    def f():
531      return x + 1, (x + 2, x + 3), [x + 4], {'v': x}
532
533    want = 2, (3, 4), [5], {'v': 1}
534    remote_value_list = [self.coordinator.schedule(f) for _ in range(5)]
535    self.assertAllEqual(
536        self.coordinator.fetch(remote_value_list), [want for _ in range(5)])
537
538  def testInputFunction(self):
539
540    def input_fn():
541      return dataset_ops.DatasetV2.range(1, 2)
542
543    with self.strategy.scope():
544      v = variables.Variable(initial_value=0, dtype=dtypes.int64)
545
546    @def_function.function
547    def worker_fn(iterator):
548      x = next(iterator)
549      v.assign_add(x)
550      return x
551
552    distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
553    result = self.coordinator.schedule(
554        worker_fn, args=(iter(distributed_dataset),))
555    result = self.coordinator.fetch(result)
556    self.assertEqual(result, (1,))
557    result = self.coordinator.schedule(
558        worker_fn, args=(iter(distributed_dataset),))
559    result = self.coordinator.fetch(result)
560
561    self.assertEqual(result, (1,))
562    self.assertAlmostEqual(v.read_value(), 2, delta=1e-6)
563
564  def testAsyncScheduleAndJoin(self):
565    if test_util.is_xla_enabled():
566      self.skipTest('Assign_add is not deterministic across threads in XLA')
567
568    def input_fn():
569      return dataset_ops.DatasetV2.from_tensor_slices([2] * 10)
570
571    with self.strategy.scope():
572      v = variables.Variable(initial_value=0, dtype=dtypes.int32)
573
574    # TODO(yuefengz): the following tf.function has a return value which is None
575    # in its structured_outputs.
576    @def_function.function
577    def worker_fn(iterator):
578      x = next(iterator)
579      v.assign_add(x)
580
581    distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
582
583    iterator = iter(distributed_dataset)
584
585    # Verifying joining without any scheduling doesn't hang.
586    self.coordinator.join()
587    self.assertEqual(v.read_value().numpy(), 0)
588
589    for _ in range(5):
590      self.coordinator.schedule(worker_fn, args=(iterator,))
591    self.coordinator.join()
592
593    # With 5 addition it should be 2*5 = 10.
594    self.assertEqual(v.read_value().numpy(), 10)
595
596    for _ in range(5):
597      self.coordinator.schedule(worker_fn, args=(iterator,))
598
599    # Verifying multiple join is fine.
600    self.coordinator.join()
601    self.coordinator.join()
602    self.coordinator.join()
603
604    self.assertTrue(self.coordinator.done())
605
606    # Likewise, it's now 20.
607    self.assertEqual(v.read_value().numpy(), 20.)
608
609  @parameterized.parameters(True, False)
610  def testInputFunctionWithMap(self, use_input_fn):
611    self._map_fn_tracing_count = 0
612
613    def input_fn():
614
615      def map_fn(x):
616        self._map_fn_tracing_count += 1
617        return x + 10
618
619      return dataset_ops.DatasetV2.range(0, 10).map(map_fn)
620
621    @def_function.function
622    def worker_fn(iterator):
623      return next(iterator)
624
625    if use_input_fn:
626      distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
627    else:
628      distributed_dataset = self.coordinator.create_per_worker_dataset(
629          input_fn())
630
631    result = self.coordinator.schedule(
632        worker_fn, args=(iter(distributed_dataset),))
633    self.assertEqual(result.fetch(), (10,))
634    self.assertEqual(self._map_fn_tracing_count, 1)
635
636  def testInputFunctionCreateVariables(self):
637
638    def input_fn():
639      v = variables.Variable(initial_value=0.0)
640      return v.read_value()
641
642    with self.assertRaises(ValueError):
643      self.coordinator.create_per_worker_dataset(input_fn)
644
645  @parameterized.parameters(True, False)
646  def testDatasetsShuffledDifferently(self, use_input_fn):
647    # This test requires at least two workers in the cluster.
648    self.assertGreaterEqual(len(self.coordinator._cluster.workers), 2)
649
650    random_seed.set_random_seed(None)
651
652    def input_fn():
653      dataset = dataset_ops.DatasetV2.range(0, 100).shuffle(100).batch(1)
654      return self.strategy.experimental_distribute_dataset(dataset)
655
656    if use_input_fn:
657      distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
658    else:
659      distributed_dataset = self.coordinator.create_per_worker_dataset(
660          input_fn())
661    distributed_iterator = iter(distributed_dataset)
662    # Get elements from the first two iterators.
663    iterator_1 = distributed_iterator._values[0]
664    iterator_1 = iterator_1.fetch()
665    elements_in_iterator_1 = [
666        self.strategy.experimental_local_results(e)
667        for e in iterator_1
668    ]
669    iterator_2 = distributed_iterator._values[1]
670    iterator_2 = iterator_2.fetch()
671    elements_in_iterator_2 = [
672        self.strategy.experimental_local_results(e)
673        for e in iterator_2
674    ]
675
676    self.assertNotAllEqual(elements_in_iterator_1, elements_in_iterator_2)
677
678  def testPerWorkerValue(self):
679    self.skipTest('b/168569314')
680    var_shape = tuple()
681    var_dtype = dtypes.float32
682    var_name = 'var'
683
684    def create_var():
685      var = variables.Variable(
686          initial_value=0.0, dtype=var_dtype, name=var_name)
687      self.assertIn('worker', var.device)
688      return var
689
690    worker_local_var = self.coordinator._create_per_worker_resources(create_var)
691
692    # The following is a workaround to allow `worker_local_var` to be passed in
693    # as args to the `coordinator.schedule` method which requires tensor specs
694    # to trace tf.function but _create_worker_resources' return values don't
695    # have tensor specs. We can get rid of this workaround once
696    # _create_worker_resources is able to infer the tensor spec of the return
697    # value of the function passed in. See b/154675763.
698    for var in worker_local_var._values:
699      var._type_spec = tensor_spec.TensorSpec(var_shape, var_dtype, var_name)
700
701    def worker_fn(var):
702      var.assign_add(1.0)
703
704    for _ in range(10):
705      # Which slice of `worker_local_var` will be used will depend on which
706      # worker the `worker_fn` gets scheduled on.
707      self.coordinator.schedule(worker_fn, args=(worker_local_var,))
708    self.coordinator.join()
709
710    var_sum = sum(self.coordinator.fetch(worker_local_var._values))
711    self.assertEqual(var_sum, 10.0)
712
713  def testDisallowRemoteValueAsInput(self):
714
715    @def_function.function
716    def func_0():
717      return 1.0
718
719    @def_function.function
720    def func_1(x):
721      return x + 1.0
722
723    remote_v = self.coordinator.schedule(func_0)
724    with self.assertRaises(ValueError):
725      self.coordinator.schedule(func_1, args=(remote_v,))
726
727  def testPythonFunctionNotAllowedToSchedule(self):
728
729    def func(a):
730      return array_ops.identity(a)
731
732    with self.assertRaisesRegexp(
733        TypeError,
734        '`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` '
735        'only accepts a `tf.function` or a concrete function.'):
736      self.coordinator.schedule(func, args=(1,))
737
738  def testDatasetPartiallyCreatedOnCoordinator(self):
739    dataset = dataset_ops.DatasetV2.range(1, 10)
740
741    @def_function.function
742    def input_fn():
743      return dataset.shuffle(9)
744
745    @def_function.function
746    def worker_fn(iterator):
747      x = next(iterator)
748      return x
749
750    per_worker_dataset = self.coordinator.create_per_worker_dataset(input_fn)
751    self.coordinator.schedule(worker_fn, args=(iter(per_worker_dataset),))
752
753    with self.assertRaisesRegexp(
754        coordinator_lib.ClosureInputError,
755        'error message is Failed copying input tensor from'):
756      self.coordinator.join()
757
758  def testPassDatasetToCreatePerWorkerDataset(self):
759    dataset = dataset_ops.DatasetV2.range(1, 11).batch(4)
760
761    @def_function.function
762    def worker_fn(iterator):
763      return next(iterator)
764
765    per_worker_dataset = self.coordinator.create_per_worker_dataset(dataset)
766    result = self.coordinator.schedule(
767        worker_fn, args=(iter(per_worker_dataset),))
768    result = result.fetch()
769    expected_result = math_ops.range(1., 5.)
770
771    self.assertAllEqual(result, (expected_result))
772
773  def testMultipleDatasets(self):
774
775    def input_fn1():
776      return dataset_ops.DatasetV2.range(0, 5)
777
778    def input_fn2():
779      return dataset_ops.DatasetV2.range(5, 10)
780
781    per_worker_dataset1 = self.coordinator.create_per_worker_dataset(input_fn1)
782    per_worker_iterator1 = iter(per_worker_dataset1)
783    per_worker_dataset2 = self.coordinator.create_per_worker_dataset(input_fn2)
784    per_worker_iterator2 = iter(per_worker_dataset2)
785
786    @def_function.function
787    def worker_fn(iterator1, iterator2):
788      return next(iterator1) + next(iterator2)
789
790    result = self.coordinator.schedule(
791        worker_fn, args=(per_worker_iterator1, per_worker_iterator2))
792    self.assertEqual(result.fetch(), 5.0)
793
794    per_worker_dataset3 = self.coordinator.create_per_worker_dataset(input_fn1)
795    per_worker_iterator3 = iter(per_worker_dataset3)
796
797    result = self.coordinator.schedule(
798        worker_fn, args=(per_worker_iterator3, per_worker_iterator2))
799    self.assertGreaterEqual(result.fetch(), 5.0)
800
801  def testRepeatedIteratorCreation(self):
802
803    def input_fn():
804      return dataset_ops.DatasetV2.range(1, 100)
805
806    per_worker_dataset1 = self.coordinator.create_per_worker_dataset(input_fn)
807    per_worker_dataset2 = self.coordinator.create_per_worker_dataset(input_fn)
808
809    @def_function.function
810    def worker_fn(iterator1, iterator2):
811      return next(iterator1) + next(iterator2)
812
813    for _ in range(10):
814      per_worker_iterator1 = iter(per_worker_dataset1)
815      per_worker_iterator2 = iter(per_worker_dataset2)
816      result = self.coordinator.schedule(
817          worker_fn, args=(per_worker_iterator1, per_worker_iterator2))
818      for _ in range(10):
819        self.coordinator.schedule(
820            worker_fn, args=(per_worker_iterator1, per_worker_iterator2))
821      self.coordinator.join()
822      self.assertGreaterEqual(result.fetch(), 2.0)
823    del per_worker_iterator1, per_worker_iterator2
824    gc.collect()
825
826    # There shouldn't be any live iterator objects.
827    for w in self.coordinator._cluster.workers:
828      for r in w._resource_remote_value_refs:
829        self.assertIsNone(r())
830
831
832class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest):
833  """Test basic functionality works with explicit maximum closure queue size.
834
835  Execute the same set of test cases as in `ClusterCoordinatorTest`, with an
836  explicit size limit for the closure queue. Note that even when the queue size
837  is set to infinite, there is still a maximum practical size (depends on host
838  memory limit) that might cause the queue.put operations to be blocking when
839  scheduling a large number of closures on a big cluster. These tests make sure
840  that the coordinator does not run into deadlocks in such scenario.
841  """
842
843  @classmethod
844  def setUpClass(cls):
845    super(LimitedClosureQueueSizeBasicTest, cls).setUpClass()
846    coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2
847    cls.coordinator = make_coordinator(num_workers=5, num_ps=2)
848    cls.strategy = cls.coordinator.strategy
849
850
851class ScheduleStartDelayTest(ClusterCoordinatorTest):
852  """Test basic functionality works with worker scheduling delay.
853
854  This is basically to make sure that setting environment variables
855  `TF_COORDINATOR_SCHEDULE_START_DELAY` and
856  `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX` will cause any failure.
857  """
858
859  @classmethod
860  def setUpClass(cls):
861    super(ScheduleStartDelayTest, cls).setUpClass()
862    os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY'] = '2'
863    os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY_MAX'] = '4'
864    cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
865    cls.strategy = cls.coordinator.strategy
866
867  @classmethod
868  def tearDownClass(cls):
869    del os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY']
870    del os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY_MAX']
871    super(ScheduleStartDelayTest, cls).tearDownClass()
872
873
874class ErrorReportingTest(TestCaseWithErrorReportingThread):
875
876  @classmethod
877  def setUpClass(cls):
878    super(ErrorReportingTest, cls).setUpClass()
879    cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
880    cls.strategy = cls.coordinator.strategy
881
882    with cls.strategy.scope():
883      cls.iteration = variables.Variable(initial_value=0.0)
884
885  @def_function.function
886  def _normal_function(self):
887    x = random_ops.random_uniform((2, 10))
888    y = random_ops.random_uniform((10, 2))
889    self.iteration.assign_add(1.0)
890    return math_ops.reduce_mean(math_ops.matmul(x, y))
891
892  @def_function.function
893  def _error_function(self):
894    x = random_ops.random_uniform((2, 10))
895    y = random_ops.random_uniform((10, 2))
896    check_ops.assert_non_positive_v2(math_ops.reduce_sum(math_ops.matmul(x, y)))
897    self.iteration.assign_add(1.0)
898    return self.iteration
899
900  @def_function.function
901  def _long_function(self):
902    x = random_ops.random_uniform((1000, 1000))
903    for _ in math_ops.range(10000):
904      a = random_ops.random_uniform((1000, 1000))
905      b = random_ops.random_uniform((1000, 1000))
906      x += math_ops.matmul(a, b)
907    return x
908
909  def testJoinRaiseError(self):
910    for _ in range(3):
911      self.coordinator.schedule(self._normal_function)
912    self.coordinator.schedule(self._error_function)
913    with self.assertRaises(errors.InvalidArgumentError):
914      self.coordinator.join()
915
916  def testScheduleRaiseError(self):
917    for _ in range(3):
918      self.coordinator.schedule(self._normal_function)
919    self.coordinator.schedule(self._error_function)
920    with self.assertRaises(errors.InvalidArgumentError):
921      while True:
922        self.coordinator.schedule(self._normal_function)
923
924  def testScheduleRaiseErrorWithMultipleFailure(self):
925    for _ in range(3):
926      self.coordinator.schedule(self._normal_function)
927    self.coordinator.schedule(self._error_function)
928    with self.assertRaises(errors.InvalidArgumentError):
929      while True:
930        self.coordinator.schedule(self._error_function)
931    self.coordinator.join()
932
933  def testErrorWillbeCleared(self):
934    self.coordinator.schedule(self._error_function)
935    with self.assertRaises(errors.InvalidArgumentError):
936      self.coordinator.join()
937
938    for _ in range(3):
939      self.coordinator.schedule(self._normal_function)
940    self.coordinator.schedule(self._error_function)
941    with self.assertRaises(errors.InvalidArgumentError):
942      self.coordinator.join()
943
944  def testRemoteValueReturnError(self):
945    self.skipTest('TODO(b/211502459): Fix this in OSS test.')
946
947    result = self.coordinator.schedule(self._error_function)
948
949    with self.assertRaises(errors.InvalidArgumentError):
950      result.fetch()
951
952    # Clear the error.
953    with self.assertRaises(errors.InvalidArgumentError):
954      self.coordinator.join()
955
956  def testInputError(self):
957
958    worker_local_val = self.coordinator._create_per_worker_resources(
959        self._error_function)
960
961    @def_function.function
962    def func(x):
963      return x + 1
964
965    result = self.coordinator.schedule(func, args=(worker_local_val,))
966    with self.assertRaises(coordinator_lib.ClosureInputError):
967      self.coordinator.join()
968
969    with self.assertRaises(coordinator_lib.ClosureInputError):
970      result.fetch()
971
972  def testErroredInputNotUsed(self):
973    input_0 = self.coordinator._create_per_worker_resources(
974        self._normal_function)
975
976    self.coordinator._create_per_worker_resources(
977        self._error_function)
978
979    @def_function.function
980    def func(x):
981      return x + 1
982
983    result = self.coordinator.schedule(func, args=(input_0,))
984
985    # It should not raise.
986    self.coordinator.join()
987    result.fetch()
988
989  def testCancellation(self):
990    for _ in range(3):
991      self.coordinator.schedule(self._normal_function)
992    long_function = self.coordinator.schedule(self._long_function)
993    self.coordinator.schedule(self._error_function)
994
995    with self.assertRaises(errors.InvalidArgumentError):
996      self.coordinator.join()
997
998    with self.assertRaises(errors.CancelledError):
999      long_function.fetch()
1000
1001    for _ in range(3):
1002      self.coordinator.schedule(self._normal_function)
1003    self.coordinator.join()
1004
1005  def testResourceCanStillbeUsedAfterCancellation(self):
1006
1007    def input_fn():
1008      return dataset_ops.DatasetV2.range(0, 5)
1009
1010    per_worker_dataset = self.coordinator.create_per_worker_dataset(input_fn)
1011    per_worker_iterator = iter(per_worker_dataset)
1012
1013    @def_function.function
1014    def worker_fn(iterator):
1015      return next(iterator)
1016
1017    self.coordinator.schedule(worker_fn, args=(per_worker_iterator,))
1018    self.coordinator.schedule(self._error_function)
1019
1020    with self.assertRaises(errors.InvalidArgumentError):
1021      self.coordinator.join()
1022
1023    self.coordinator.schedule(worker_fn, args=(per_worker_iterator,))
1024    self.coordinator.join()
1025
1026
1027class LimitedClosureQueueErrorTest(ErrorReportingTest):
1028  """Test error reporting works with explicit maximum closure queue size.
1029
1030  Execute the same set of test cases as in ErrorReportingTest, with an explicit
1031  size limit for the closure queue.
1032  """
1033
1034  @classmethod
1035  def setUpClass(cls):
1036    super(LimitedClosureQueueErrorTest, cls).setUpClass()
1037    coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2
1038    cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
1039    cls.strategy = cls.coordinator.strategy
1040
1041    with cls.coordinator.strategy.scope():
1042      cls.iteration = variables.Variable(initial_value=0.0)
1043
1044
1045class StrategyIntegrationTest(test.TestCase, parameterized.TestCase):
1046
1047  @classmethod
1048  def setUpClass(cls):
1049    super(StrategyIntegrationTest, cls).setUpClass()
1050    cls.coordinator = make_coordinator(num_workers=1, num_ps=1)
1051    cls.strategy = cls.coordinator.strategy
1052
1053  def testRunNotUsedWithClusterCoordinatorSchedule(self):
1054
1055    @def_function.function
1056    def input_fn():
1057      return dataset_ops.DatasetV2.range(1, 3)
1058
1059    with self.strategy.scope():
1060      v = variables.Variable(initial_value=1, dtype=dtypes.int64)
1061
1062      def replica_fn(input_tensor):
1063        return input_tensor + v, input_tensor - v
1064
1065      @def_function.function
1066      def worker_fn(iterator):
1067        return self.strategy.run(replica_fn, args=(next(iterator),))
1068
1069    per_worker_dataset = self.coordinator.create_per_worker_dataset(input_fn)
1070
1071    @contextlib.contextmanager
1072    def _assert_logs_usage_warning():
1073      with self.assertLogs(level='WARNING') as logs:
1074        yield
1075
1076      self.assertIn(
1077          'A `tf.distribute.experimental.ParameterServerStrategy` method is '
1078          'invoked without using `ClusterCoordinator.schedule`. If you are not '
1079          'tracing a tf.function, this method is possibly executed on the '
1080          'coordinator, which can be slow. To properly dispatch functions to '
1081          'run on workers, methods like `run` or `reduce` should be used '
1082          'within a function passed to `tf.distribute.experimental.coordinator.'
1083          'ClusterCoordinator.schedule`.', logs.output[0])
1084
1085    with _assert_logs_usage_warning():
1086      # Invoking `run` without `coordinator.schedule` should result in a
1087      # warning.
1088      self.strategy.run(
1089          replica_fn, args=(constant_op.constant(1, dtype=dtypes.int64),))
1090
1091    # A proper `schedule` should succeed.
1092    rv = self.coordinator.schedule(worker_fn, args=(iter(per_worker_dataset),))
1093
1094    with _assert_logs_usage_warning():
1095      # Invoking `run` without `coordinator.schedule` again should result in a
1096      # warning.
1097      self.strategy.run(
1098          replica_fn, args=(constant_op.constant(1, dtype=dtypes.int64),))
1099
1100    all_results = [(2, 0)] * self.strategy.num_replicas_in_sync
1101    expected_result = []
1102    for i in range(self.strategy.num_replicas_in_sync):
1103      expected_result.append(all_results[i])
1104
1105    self.assertAllEqual(
1106        tuple(expected_result),
1107        self.strategy.experimental_local_results(rv.fetch()))
1108
1109  def testBasicVariableAssignment(self):
1110    self.strategy.extended._variable_count = 0
1111    with self.strategy.scope():
1112      v1 = variables.Variable(initial_value=0.0)
1113      v2 = variables.Variable(initial_value=1.0)
1114    self.assertEqual(self.strategy.extended._variable_count, 2)
1115
1116    @def_function.function
1117    def worker_fn():
1118      v1.assign_add(0.1)
1119      v2.assign_sub(0.2)
1120      return v1.read_value() / v2.read_value()
1121
1122    results = self.coordinator.schedule(worker_fn)
1123    logging.info('Results of experimental_run_v2: %f',
1124                 self.coordinator.fetch(results))
1125
1126    self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6)
1127    self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6)
1128
1129  def testRunAndReduce(self):
1130    self.assertFalse(distribution_strategy_context.in_cross_replica_context())
1131    with self.strategy.scope():
1132      self.assertTrue(distribution_strategy_context.in_cross_replica_context())
1133      v = variables.Variable(initial_value=1.)
1134
1135      expected_result = (4. * self.strategy.num_replicas_in_sync,
1136                         2. * self.strategy.num_replicas_in_sync)
1137
1138      @def_function.function
1139      def worker_fn(input_tensor):
1140
1141        def replica_fn(input_tensor):
1142          # Within `replica_fn`, it has to be in a replica context.
1143          self.assertFalse(
1144              distribution_strategy_context.in_cross_replica_context())
1145          return input_tensor + v, input_tensor - v
1146
1147        run_result = self.strategy.run(replica_fn, args=(input_tensor,))
1148        reduced_result = self.strategy.reduce('SUM', run_result, axis=None)
1149        check_ops.assert_equal_v2(reduced_result, expected_result)
1150        return reduced_result
1151
1152      # Asserting scheduling in scope has the expected behavior.
1153      result = self.coordinator.schedule(
1154          worker_fn, args=(constant_op.constant(3.),))
1155      self.assertIsInstance(result, coordinator_lib.RemoteValue)
1156      self.assertEqual(result.fetch(), expected_result)
1157
1158    # Asserting scheduling out of scope has the expected behavior.
1159    result = self.coordinator.schedule(
1160        worker_fn, args=(constant_op.constant(3.),))
1161    self.assertEqual(result.fetch(), expected_result)
1162
1163  def testRunAndReduceWithAssignAdd(self):
1164    self.assertFalse(distribution_strategy_context.in_cross_replica_context())
1165    with self.strategy.scope():
1166      self.assertTrue(distribution_strategy_context.in_cross_replica_context())
1167      v = variables.Variable(initial_value=1.)
1168      v1 = variables.Variable(
1169          initial_value=0.,
1170          aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA)
1171
1172      expected_result = (4. * self.strategy.num_replicas_in_sync,
1173                         2. * self.strategy.num_replicas_in_sync)
1174
1175      @def_function.function
1176      def worker_fn(input_tensor):
1177
1178        def replica_fn(input_tensor):
1179          # Within `replica_fn`, it has to be in a replica context.
1180          self.assertFalse(
1181              distribution_strategy_context.in_cross_replica_context())
1182
1183          v1.assign_add(input_tensor)
1184          return input_tensor + v, input_tensor - v
1185
1186        run_result = self.strategy.run(replica_fn, args=(input_tensor,))
1187        reduced_result = self.strategy.reduce('SUM', run_result, axis=None)
1188        check_ops.assert_equal_v2(reduced_result, expected_result)
1189        return reduced_result
1190
1191      # Asserting scheduling in scope has the expected behavior.
1192      result = self.coordinator.schedule(
1193          worker_fn, args=(constant_op.constant(3.),))
1194      self.assertIsInstance(result, coordinator_lib.RemoteValue)
1195      self.assertEqual(result.fetch(), expected_result)
1196
1197    # Asserting scheduling out of scope has the expected behavior.
1198    result = self.coordinator.schedule(
1199        worker_fn, args=(constant_op.constant(3.),))
1200    self.assertEqual(result.fetch(), expected_result)
1201    self.assertEqual(v1, 6.)
1202
1203  def testVariableAggregation(self):
1204    self.assertFalse(distribution_strategy_context.in_cross_replica_context())
1205    with self.strategy.scope():
1206      self.assertTrue(distribution_strategy_context.in_cross_replica_context())
1207      v = variables.Variable(
1208          initial_value=1.,
1209          aggregation=variable_scope.VariableAggregation.SUM)
1210
1211      @def_function.function
1212      def worker_fn():
1213
1214        def replica_fn():
1215          value = math_ops.cast(
1216              distribution_strategy_context.get_replica_context()
1217              .replica_id_in_sync_group + 1, v.dtype)
1218          v.assign(value)
1219
1220        self.strategy.run(replica_fn)
1221
1222      self.coordinator.schedule(worker_fn)
1223      self.coordinator.join()
1224      expected_result = 0.
1225      for i in range(self.strategy.num_replicas_in_sync):
1226        expected_result = expected_result + i + 1
1227      self.assertEqual(v, expected_result)
1228
1229  def testVariableCaching(self):
1230    self.assertFalse(distribution_strategy_context.in_cross_replica_context())
1231    with self.strategy.scope():
1232      self.assertTrue(distribution_strategy_context.in_cross_replica_context())
1233      v = variables.Variable(
1234          initial_value=1.,
1235          aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA)
1236
1237      # Test read value inside caching scope
1238      with distribute_utils.cache_variable_reads():
1239        v.read_value()  # Reads value 1.0
1240        v.assign(constant_op.constant(5.0))  # v changes to 5.0
1241        self.assertEqual(v.read_value(), 1.0)  # should be cached 1.0 value.
1242
1243      # Reset v to 2.0
1244      v.assign(2.0)
1245
1246      # Test convert to tensor value inside caching scope
1247      with distribute_utils.cache_variable_reads():
1248        t = v * 3.0
1249        self.assertEqual(t, 6.0)
1250        v.assign(3.0)
1251        t1 = v * 3.0
1252        self.assertEqual(t1, 6.0)  # should be cached 2.0 * 3.0 value.
1253
1254      # Reset v to 1.0
1255      v.assign(1.0)
1256
1257      # Verify caching scope inside tf.function
1258      @def_function.function
1259      def worker_fn():
1260        with distribute_utils.cache_variable_reads():
1261          def replica_fn():
1262            t = v.read_value()  # Reads value 1.0
1263            v.assign(constant_op.constant(5.0))  # v changes to 5.0
1264            t = v.read_value()  # should return 1.0
1265            return t  # Should be 1.0 instead of 5.0
1266
1267          return self.strategy.run(replica_fn)
1268
1269      result = self.coordinator.schedule(worker_fn)
1270      result = result.fetch()
1271      expected_result = 1.
1272      self.assertEqual(result, expected_result)
1273
1274      # Verify that v.read_value works as expected outside of scope.
1275      v.assign(4.0)
1276      self.assertEqual(v.read_value(), 4.0)
1277
1278      v.assign(constant_op.constant(2.0))  # v changes to 2.0
1279      # Check with scope outside of tf function and check that cache is reset
1280      @def_function.function
1281      def worker_fn1():
1282        def replica_fn():
1283          t = v.read_value()  # Reads value 2.0 ==> Should be cached
1284          v.assign(constant_op.constant(5.0))  # v changes to 5.0
1285          t = v.read_value()  # should return cached value 2.0
1286          return t  # Should be 2.0 instead of 5.0
1287
1288        return self.strategy.run(replica_fn)
1289
1290      with distribute_utils.cache_variable_reads():
1291        result = self.coordinator.schedule(worker_fn1)
1292      result = result.fetch()
1293      expected_result = 2.
1294      self.assertEqual(result, expected_result)
1295
1296    # Verify scope nesting is not permitted.
1297    with self.assertRaises(ValueError):
1298      with distribute_utils.cache_variable_reads():
1299        with distribute_utils.cache_variable_reads():
1300          v.read_value()
1301
1302  @parameterized.parameters(True, False)
1303  def testDistributedDatasetInsidePerWorkerDatasetFn(self, from_function):
1304    if from_function:
1305
1306      def per_worker_dataset_fn():
1307        dataset_fn = lambda _: dataset_ops.DatasetV2.range(1, 11).batch(4)
1308        return self.strategy.distribute_datasets_from_function(dataset_fn)
1309    else:
1310
1311      def per_worker_dataset_fn():
1312        dataset = dataset_ops.DatasetV2.range(1, 11).batch(4)
1313        return self.strategy.experimental_distribute_dataset(dataset)
1314
1315    @def_function.function
1316    def worker_fn(iterator):
1317      return self.strategy.experimental_local_results(next(iterator))
1318
1319    per_worker_dataset = self.coordinator.create_per_worker_dataset(
1320        per_worker_dataset_fn)
1321    result = self.coordinator.schedule(
1322        worker_fn, args=(iter(per_worker_dataset),))
1323    result = result.fetch()
1324    expected_result = array_ops.split(
1325        math_ops.range(1., 5.),
1326        num_or_size_splits=self.strategy.num_replicas_in_sync,
1327        axis=0)
1328
1329    self.assertAllEqual(result, (expected_result))
1330
1331  @parameterized.parameters(True, False)
1332  def testPassDistributedDatasetToCreatePerWorkerDataset(self, from_function):
1333    if from_function:
1334      dataset_fn = lambda _: dataset_ops.DatasetV2.range(1, 11).batch(4)
1335      distributed_dataset = self.strategy.distribute_datasets_from_function(
1336          dataset_fn)
1337    else:
1338      dataset = dataset_ops.DatasetV2.range(1, 11).batch(4)
1339      distributed_dataset = self.strategy.experimental_distribute_dataset(
1340          dataset)
1341
1342    @def_function.function
1343    def worker_fn(iterator):
1344      return self.strategy.experimental_local_results(next(iterator))
1345
1346    per_worker_dataset = self.coordinator.create_per_worker_dataset(
1347        distributed_dataset)
1348    result = self.coordinator.schedule(
1349        worker_fn, args=(iter(per_worker_dataset),))
1350    result = result.fetch()
1351    expected_result = array_ops.split(
1352        math_ops.range(1., 5.),
1353        num_or_size_splits=self.strategy.num_replicas_in_sync,
1354        axis=0)
1355
1356    self.assertAllEqual(result, (expected_result))
1357
1358  def testDistributeDatasetsFromFunction(self):
1359
1360    def per_worker_dataset_fn():
1361
1362      def input_worker_device_fn(input_context):
1363        self.assertIsNotNone(input_context)
1364        return dataset_ops.DatasetV2.range(1, 11).batch(1)
1365
1366      return self.strategy.distribute_datasets_from_function(
1367          input_worker_device_fn)
1368
1369    @def_function.function
1370    def worker_fn(iterator):
1371      result = self.strategy.experimental_local_results(next(iterator))
1372      return result
1373
1374    distributed_dataset = self.coordinator.create_per_worker_dataset(
1375        per_worker_dataset_fn)
1376    result = self.coordinator.schedule(
1377        worker_fn, args=(iter(distributed_dataset),))
1378    result = result.fetch()
1379    expected_result = []
1380    for i in range(self.strategy.num_replicas_in_sync):
1381      expected_result.append([1 + i])
1382    self.assertAllEqual(result, expected_result)
1383
1384  def testAsyncScheduleWithDistributedDataset(self):
1385
1386    def input_fn():
1387      dataset = dataset_ops.DatasetV2.from_tensor_slices([2.]).repeat().batch(
1388          self.strategy.num_replicas_in_sync)
1389      return self.strategy.experimental_distribute_dataset(dataset)
1390
1391    with self.strategy.scope():
1392      v = variables.Variable(initial_value=[0], dtype=dtypes.float32)
1393
1394    # TODO(yuefengz): the following tf.function has a return value which is None
1395    # in its structured_outputs.
1396    @def_function.function
1397    def worker_fn(iterator):
1398      x = next(iterator)
1399      # Reduce to convert PerReplica values to single value
1400      reduced_value = self.strategy.reduce('MEAN', x, axis=None)
1401      v.assign_add(reduced_value)
1402
1403    distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
1404
1405    iterator = iter(distributed_dataset)
1406
1407    # Verifying joining without any scheduling doesn't hang.
1408    self.coordinator.join()
1409    self.assertAllEqual(v.read_value(), (0,))
1410
1411    for _ in range(5):
1412      self.coordinator.schedule(worker_fn, args=(iterator,))
1413    self.coordinator.join()
1414
1415    # With 5 addition it should be 2*5 = 10.
1416    self.assertAllEqual(
1417        self.strategy.experimental_local_results(v.read_value()), ([[10]]))
1418
1419    for _ in range(5):
1420      self.coordinator.schedule(worker_fn, args=(iterator,))
1421
1422    # Verifying multiple join is fine.
1423    self.coordinator.join()
1424    self.coordinator.join()
1425    self.coordinator.join()
1426
1427    self.assertTrue(self.coordinator.done())
1428
1429    # Likewise, it's now 20.
1430    self.assertAllEqual(
1431        self.strategy.experimental_local_results(v.read_value()), ([[20]]))
1432
1433  def testInputFunctionWithMapWithDistributedDataset(self):
1434    self._map_fn_tracing_count = 0
1435
1436    def input_fn():
1437
1438      def map_fn(x):
1439        self._map_fn_tracing_count += 1
1440        return x + 10
1441
1442      dataset = dataset_ops.DatasetV2.range(0, 10).batch(
1443          self.strategy.num_replicas_in_sync).map(map_fn)
1444      return self.strategy.experimental_distribute_dataset(dataset)
1445
1446    @def_function.function
1447    def worker_fn(iterator):
1448      return next(iterator)
1449
1450    distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
1451    result = self.coordinator.schedule(
1452        worker_fn, args=(iter(distributed_dataset),))
1453
1454    expected_result = array_ops.split(
1455        math_ops.range(10., 10. + self.strategy.num_replicas_in_sync),
1456        num_or_size_splits=self.strategy.num_replicas_in_sync,
1457        axis=0)
1458
1459    self.assertAllEqual(
1460        self.strategy.experimental_local_results(result.fetch()),
1461        tuple(expected_result))
1462    self.assertEqual(self._map_fn_tracing_count, 1)
1463
1464  def testPerWorkerDistributeDatasetsElementSpec(self):
1465
1466    def per_worker_dataset_fn():
1467      return self.strategy.distribute_datasets_from_function(
1468          lambda _: dataset_ops.DatasetV2.from_tensor_slices([1, 2]))
1469
1470    dataset = dataset_ops.DatasetV2.from_tensor_slices([1, 2])
1471    per_worker_distribute_dataset = self.coordinator.create_per_worker_dataset(
1472        per_worker_dataset_fn)
1473
1474    self.assertAllEqual(
1475        # Converts to PerReplicaSpec when num_replicas_in_sync are > 1
1476        input_lib._create_distributed_tensor_spec(self.strategy,
1477                                                  dataset.element_spec),
1478        per_worker_distribute_dataset.element_spec)
1479
1480  def testPerWorkerDistributedIteratorTypeSpec(self):
1481    self._tracing_count = 0
1482
1483    def per_worker_dataset_fn():
1484      self._tracing_count += 1
1485      return self.strategy.distribute_datasets_from_function(
1486          lambda _: dataset_ops.DatasetV2.range(1, 2))
1487
1488    @def_function.function
1489    def worker_fn(iterator):
1490      return next(iterator)
1491
1492    distributed_iterator = iter(
1493        self.coordinator.create_per_worker_dataset(per_worker_dataset_fn))
1494    worker_fn.get_concrete_function(distributed_iterator)
1495
1496    self.coordinator.schedule(worker_fn, args=(distributed_iterator,))
1497    self.assertEqual(self._tracing_count, 1)
1498
1499
1500if __name__ == '__main__':
1501  v2_compat.enable_v2_behavior()
1502  test.main()
1503