1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for coordinator.py.""" 16 17import collections 18import contextlib 19import functools 20import gc 21import os 22import platform 23import sys 24import threading 25import time 26import traceback 27from absl.testing import parameterized 28 29from tensorflow.python.compat import v2_compat 30from tensorflow.python.data.ops import dataset_ops 31from tensorflow.python.distribute import distribute_utils 32from tensorflow.python.distribute import distribution_strategy_context 33from tensorflow.python.distribute import input_lib 34from tensorflow.python.distribute import multi_worker_test_base 35from tensorflow.python.distribute import parameter_server_strategy_v2 36from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 37from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib 38from tensorflow.python.distribute.coordinator import values as values_lib 39from tensorflow.python.eager import cancellation 40from tensorflow.python.eager import def_function 41from tensorflow.python.eager import test 42from tensorflow.python.framework import constant_op 43from tensorflow.python.framework import dtypes 44from tensorflow.python.framework import errors 45from tensorflow.python.framework import random_seed 46from tensorflow.python.framework import tensor_spec 47from tensorflow.python.framework import test_util 48from tensorflow.python.ops import array_ops 49from tensorflow.python.ops import check_ops 50from tensorflow.python.ops import math_ops 51from tensorflow.python.ops import random_ops 52from tensorflow.python.ops import variable_scope 53from tensorflow.python.ops import variables 54from tensorflow.python.platform import tf_logging as logging 55from tensorflow.python.training import coordinator 56from tensorflow.python.training.server_lib import ClusterSpec 57 58 59class ClosureWithOutput(coordinator_lib.Closure): 60 61 def __init__(self, function, cancellation_mgr=None, args=None, kwargs=None): 62 super(ClosureWithOutput, self).__init__( 63 function, cancellation_mgr=cancellation_mgr, args=args, kwargs=kwargs) 64 self.output_remote_value = self.build_output_remote_value() 65 66 67class CoordinatedClosureQueueTest(test.TestCase): 68 69 def testBasic(self): 70 queue = coordinator_lib._CoordinatedClosureQueue() 71 closure1 = self._create_closure(queue._cancellation_mgr) 72 queue.put(closure1) 73 self.assertIs(closure1, queue.get()) 74 self.assertFalse(queue.done()) 75 queue.put_back(closure1) 76 self.assertEqual(closure1, queue.get()) 77 queue.mark_finished() 78 self.assertTrue(queue.done()) 79 queue.wait() 80 81 def testProcessAtLeaseOnce(self): 82 closure_queue = coordinator_lib._CoordinatedClosureQueue() 83 labels = ['A', 'B', 'C', 'D', 'E'] 84 processed_count = collections.defaultdict(int) 85 86 coord = coordinator.Coordinator(clean_stop_exception_types=[]) 87 88 def process_queue(): 89 with coord.stop_on_exception(): 90 has_been_put_back = False 91 while True: 92 closure = closure_queue.get(timeout=30) 93 if closure is None: 94 break 95 if not has_been_put_back: 96 has_been_put_back = True 97 closure_queue.put_back(closure) 98 continue 99 closure._function() 100 closure_queue.mark_finished() 101 102 def get_func(label): 103 104 def func(): 105 time.sleep(3) 106 processed_count[label] += 1 107 108 return func 109 110 cm = cancellation.CancellationManager() 111 for label in labels: 112 closure_queue.put(ClosureWithOutput(get_func(label), cm)) 113 t1 = threading.Thread(target=process_queue, daemon=True) 114 t1.start() 115 t2 = threading.Thread(target=process_queue, daemon=True) 116 t2.start() 117 118 # Make sure multiple wait() calls are fine. 119 closure_queue.wait() 120 closure_queue.wait() 121 closure_queue.wait() 122 closure_queue.wait() 123 124 self.assertEqual(processed_count, collections.Counter(labels)) 125 126 coord.join([t1, t2]) 127 128 def testNotifyBeforeWait(self): 129 closure_queue = coordinator_lib._CoordinatedClosureQueue() 130 131 def func(): 132 logging.info('func running') 133 134 coord = coordinator.Coordinator(clean_stop_exception_types=[]) 135 136 def process_queue(): 137 with coord.stop_on_exception(): 138 closure_queue.get() 139 closure_queue.mark_finished() 140 141 closure_queue.put(ClosureWithOutput(func, closure_queue._cancellation_mgr)) 142 t = threading.Thread(target=process_queue) 143 t.start() 144 coord.join([t]) 145 146 # This test asserts that waiting at the time the function has been processed 147 # doesn't time out. 148 closure_queue.wait() 149 150 def _assert_one_unblock_the_other(self, first_fn, second_fn): 151 """Asserts `second_fn` wouldn't return before `first_fn` is finished.""" 152 first_fn_done = threading.Event() 153 second_fn_done = threading.Event() 154 coord = coordinator.Coordinator(clean_stop_exception_types=[]) 155 156 def wrapped_first_fn(): 157 with coord.stop_on_exception(): 158 self.assertFalse(second_fn_done.is_set()) 159 first_fn() 160 first_fn_done.set() 161 162 self.assertFalse(first_fn_done.is_set()) 163 t = threading.Thread(target=wrapped_first_fn) 164 t.start() 165 166 second_fn() 167 self.assertTrue(first_fn_done.is_set()) 168 second_fn_done.set() 169 170 coord.join([t]) 171 172 def _run_two_fns_in_parallel(self, first_fn, second_fn): 173 coord = coordinator.Coordinator(clean_stop_exception_types=[]) 174 175 def wrapped_first_fn(): 176 with coord.stop_on_exception(): 177 first_fn() 178 179 t = threading.Thread(target=wrapped_first_fn) 180 t.start() 181 182 second_fn() 183 coord.join([t]) 184 185 def testWaitRaiseErrorAfterMarkFailure(self): 186 if sys.version_info >= (3, 8) and platform.system() == 'Windows': 187 # TODO(b/165013260): Fix this 188 self.skipTest('Test is currently broken on Windows with Python 3.8') 189 190 closure_queue = coordinator_lib._CoordinatedClosureQueue() 191 closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) 192 closure = closure_queue.get() 193 194 wait_finish_event = threading.Event() 195 coord = coordinator.Coordinator(clean_stop_exception_types=[]) 196 197 # Using a thread to verify that closure_queue.wait() will not return until 198 # all inflight closures are finished. 199 200 def mark_finished_fn(): 201 try: 202 raise ValueError('Some error.') 203 except ValueError as e: 204 closure_queue.mark_failed(e) 205 206 def wait_fn(): 207 with self.assertRaises(ValueError): 208 closure_queue.wait() 209 210 self._assert_one_unblock_the_other(mark_finished_fn, wait_fn) 211 212 self.assertTrue(closure_queue.done()) 213 214 def _create_closure(self, cancellation_mgr): 215 216 @def_function.function() 217 def some_function(): 218 return 1.0 219 220 return ClosureWithOutput(some_function, cancellation_mgr) 221 222 def _put_two_closures_and_get_one(self): 223 closure_queue = coordinator_lib._CoordinatedClosureQueue() 224 closure1 = self._create_closure(closure_queue._cancellation_mgr) 225 closure_queue.put(closure1) 226 227 closure2 = self._create_closure(closure_queue._cancellation_mgr) 228 closure_queue.put(closure2) 229 230 closure_got = closure_queue.get() # returns closure1 231 self.assertIs(closure_got, closure1) 232 self.assertIsNot(closure_got, closure2) 233 return closure_queue, closure1, closure2 234 235 def testPutRaiseError(self): 236 if sys.version_info >= (3, 8) and platform.system() == 'Windows': 237 # TODO(b/165013260): Fix this 238 self.skipTest('Test is currently broken on Windows with Python 3.8') 239 240 closure_queue, _, closure2 = self._put_two_closures_and_get_one() 241 242 closure_queue.mark_failed(ValueError()) 243 244 with self.assertRaises(ValueError): 245 closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) 246 247 self.assertTrue(closure_queue.done()) 248 249 with self.assertRaisesRegex( 250 errors.CancelledError, 251 'The corresponding function is cancelled. Please reschedule the ' 252 'function.'): 253 closure2.output_remote_value.fetch() 254 255 # The error is cleared. 256 closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) 257 258 def testWaitRaiseError(self): 259 if sys.version_info >= (3, 8) and platform.system() == 'Windows': 260 # TODO(b/165013260): Fix this 261 self.skipTest('Test is currently broken on Windows with Python 3.8') 262 263 closure_queue, _, closure2 = self._put_two_closures_and_get_one() 264 265 closure_queue.mark_failed(ValueError()) 266 267 with self.assertRaises(ValueError): 268 closure_queue.wait() 269 self.assertTrue(closure_queue.done()) 270 271 with self.assertRaisesRegex( 272 errors.CancelledError, 273 'The corresponding function is cancelled. Please reschedule the ' 274 'function.'): 275 closure2.output_remote_value.fetch() 276 277 # The error is cleared. 278 closure_queue.wait() 279 280 def testDoneRaiseError(self): 281 if sys.version_info >= (3, 8) and platform.system() == 'Windows': 282 # TODO(b/165013260): Fix this 283 self.skipTest('Test is currently broken on Windows with Python 3.8') 284 285 closure_queue, _, _ = self._put_two_closures_and_get_one() 286 287 self.assertFalse(closure_queue.done()) 288 closure_queue.mark_failed(ValueError()) 289 with self.assertRaises(ValueError): 290 closure_queue.done() 291 292 def _set_error(self, closure_queue, closure, error): 293 try: 294 raise error 295 except Exception as e: # pylint: disable=broad-except 296 closure.output_remote_value._set_error(e) 297 closure_queue.mark_failed(e) 298 299 def _test_cancel_closure_when_error(self, call_wait): 300 if sys.version_info >= (3, 8) and platform.system() == 'Windows': 301 # TODO(b/165013260): Fix this 302 self.skipTest('Test is currently broken on Windows with Python 3.8') 303 304 closure_queue, closure1, closure2 = self._put_two_closures_and_get_one() 305 closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) 306 closure_queue.get() 307 # At this moment, there are two inflight, one in queue. 308 self.assertEqual(closure_queue._inflight_closure_count, 2) 309 310 # Hold a copy of the queue's cancellation manager at this point 311 initial_cm = closure_queue._cancellation_mgr 312 313 # Simulating closure1 fails. 314 self._set_error(closure_queue, closure1, ValueError('Some error.')) 315 316 # At this moment, there are one inflight, one in queue. 317 self.assertEqual(closure_queue._queue.qsize(), 1) 318 self.assertEqual(closure_queue._inflight_closure_count, 1) 319 320 closure3 = self._create_closure(closure_queue._cancellation_mgr) 321 322 def fake_cancellation(): 323 self._set_error(closure_queue, closure2, 324 ValueError('Fake cancellation error.')) 325 326 def report_error(): 327 # It should not report the fake cancellation error. 328 with self.assertRaisesRegex(ValueError, 'Some error.'): 329 # Verifying `wait()` or `put()` raises even if one closure is in 330 # flight. 331 if call_wait: 332 closure_queue.wait() 333 else: 334 closure_queue.put(closure3) 335 336 self._assert_one_unblock_the_other(fake_cancellation, report_error) 337 338 # The original cancellation manager of the queue has been cancelled. 339 self.assertTrue(initial_cm.is_cancelled) 340 341 # At this moment, there is zero inflight, nothing in queue. 342 self.assertTrue(closure_queue._queue.empty()) 343 self.assertEqual(closure_queue._inflight_closure_count, 0) 344 self.assertIsNone(closure_queue._error) 345 346 # This asserts that closure1 has errored. 347 with self.assertRaisesRegex(ValueError, 'Some error.'): 348 closure1.output_remote_value.fetch() 349 350 # The following asserts that closure3 should have been cancelled. 351 if not call_wait: 352 with self.assertRaisesRegex( 353 errors.CancelledError, 354 'The corresponding function is cancelled. Please reschedule the ' 355 'function.'): 356 closure3.output_remote_value.fetch() 357 358 # Closure2 was an inflight closure when it got cancelled. 359 self.assertEqual(closure2.output_remote_value._status, 360 values_lib.RemoteValueStatus.READY) 361 with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'): 362 closure2.output_remote_value.fetch() 363 364 # This asserts that the queue has a clear state. 365 self.testBasic() 366 367 def testWaitRaiseErrorAfterCancelClosure(self): 368 self._test_cancel_closure_when_error(call_wait=True) 369 370 def testPutRaiseErrorAfterCancelClosure(self): 371 self._test_cancel_closure_when_error(call_wait=False) 372 373 def testStateIsRestoredAfterJoinIsCalled(self): 374 if sys.version_info >= (3, 8) and platform.system() == 'Windows': 375 # TODO(b/165013260): Fix this 376 self.skipTest('Test is currently broken on Windows with Python 3.8') 377 378 closure_queue, _, _ = self._put_two_closures_and_get_one() 379 self.assertEqual(closure_queue._inflight_closure_count, 1) 380 closure_queue.mark_failed(ValueError('test error')) 381 with self.assertRaises(ValueError): 382 closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) 383 384 # Its error should have been cleared. 385 self.assertIsNone(closure_queue._error) 386 closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) 387 self.assertIsNone(closure_queue._error) 388 389 def testThreadSafey(self): 390 thread_count = 10 391 queue = coordinator_lib._CoordinatedClosureQueue() 392 393 # Each thread performs 20 queue actions: 10 are `put_back` and 10 are 394 # `mark_finished`. 395 action_count = 20 396 397 def func(): 398 for i in range(action_count): 399 closure = queue.get() 400 if i % 2 == 0: 401 queue.put_back(closure) 402 else: 403 queue.mark_finished() 404 405 threads = [threading.Thread(target=func) for i in range(thread_count)] 406 for t in threads: 407 t.start() 408 409 for _ in range(thread_count * action_count // 2): 410 queue.put(self._create_closure(queue._cancellation_mgr)) 411 queue.wait() 412 self.assertTrue(queue.done()) 413 414 def testPutGetWithTag(self): 415 queue = coordinator_lib._CoordinatedClosureQueue() 416 417 closure1 = self._create_closure(queue._cancellation_mgr) 418 closure2 = self._create_closure(queue._cancellation_mgr) 419 closure3 = self._create_closure(queue._cancellation_mgr) 420 421 def put_fn(): 422 queue.put(closure3, tag=1) 423 queue.put(closure2, tag=2) 424 queue.put(closure1) 425 426 def get_fn(): 427 # The get should only return the closure with tag 2. 428 self.assertIs(closure2, queue.get(tag=2)) 429 430 self._run_two_fns_in_parallel(put_fn, get_fn) 431 432 self.assertFalse(queue.done()) 433 self.assertEqual(closure1, queue.get()) 434 queue.mark_finished() 435 436 # It will not wait for closure3 437 self.assertTrue(queue.done()) 438 queue.wait() 439 440 441class ErrorReportingThread(threading.Thread): 442 443 error = None 444 445 def __init__(self, *args, **kwargs): 446 assert 'target' in kwargs 447 target = kwargs['target'] 448 449 @functools.wraps(target) 450 def wrapped_target(*args, **kwargs): 451 try: 452 return target(*args, **kwargs) 453 except Exception as e: # pylint: disable=broad-except 454 traceback.print_exception(*sys.exc_info()) 455 ErrorReportingThread.error = e 456 457 kwargs['target'] = wrapped_target 458 super(ErrorReportingThread, self).__init__(*args, **kwargs) 459 460 461class TestCaseWithErrorReportingThread(test.TestCase): 462 463 @classmethod 464 def setUpClass(cls): 465 cls._threading_thread = threading.Thread 466 threading.Thread = ErrorReportingThread 467 super(TestCaseWithErrorReportingThread, cls).setUpClass() 468 469 @classmethod 470 def tearDownClass(cls): 471 super(TestCaseWithErrorReportingThread, cls).tearDownClass() 472 threading.Thread = cls._threading_thread 473 474 def setUp(self): 475 ErrorReportingThread.error = None 476 super(TestCaseWithErrorReportingThread, self).setUp() 477 478 def tearDown(self): 479 super(TestCaseWithErrorReportingThread, self).tearDown() 480 if ErrorReportingThread.error: 481 raise ErrorReportingThread.error # pylint: disable=raising-bad-type 482 483 484def make_coordinator(num_workers, num_ps): 485 # TODO(rchao): Test the internal rpc_layer version. 486 cluster_def = multi_worker_test_base.create_in_process_cluster( 487 num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc') 488 cluster_def['chief'] = [ 489 'localhost:%d' % test_util.pick_unused_port() 490 ] 491 cluster_resolver = SimpleClusterResolver( 492 ClusterSpec(cluster_def), rpc_layer='grpc') 493 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 494 cluster_resolver) 495 return coordinator_lib.ClusterCoordinator(strategy) 496 497 498class ClusterCoordinatorTest(TestCaseWithErrorReportingThread, 499 parameterized.TestCase): 500 501 @classmethod 502 def setUpClass(cls): 503 super(ClusterCoordinatorTest, cls).setUpClass() 504 cls.coordinator = make_coordinator(num_workers=5, num_ps=2) 505 cls.strategy = cls.coordinator.strategy 506 507 def testClusterCoordinatorOnlyInitOnce(self): 508 cluster = self.coordinator._cluster 509 same_coordinator = coordinator_lib.ClusterCoordinator(self.strategy) 510 self.assertIs(self.coordinator, same_coordinator) 511 self.assertIs(cluster, same_coordinator._cluster) 512 513 def testFnReturnNestedValues(self): 514 x = constant_op.constant(1) 515 516 @def_function.function 517 def f(): 518 return x + 1, (x + 2, x + 3), [x + 4], {'v': x} 519 520 got = self.coordinator.schedule(f) 521 want = 2, (3, 4), [5], {'v': 1} 522 self.assertEqual(got.fetch(), want) 523 self.assertEqual(self.coordinator.fetch(got), want) 524 525 def testFetchingRemoteValueStructure(self): 526 self.skipTest('b/171040359: flaky test') 527 x = constant_op.constant(1) 528 529 @def_function.function 530 def f(): 531 return x + 1, (x + 2, x + 3), [x + 4], {'v': x} 532 533 want = 2, (3, 4), [5], {'v': 1} 534 remote_value_list = [self.coordinator.schedule(f) for _ in range(5)] 535 self.assertAllEqual( 536 self.coordinator.fetch(remote_value_list), [want for _ in range(5)]) 537 538 def testInputFunction(self): 539 540 def input_fn(): 541 return dataset_ops.DatasetV2.range(1, 2) 542 543 with self.strategy.scope(): 544 v = variables.Variable(initial_value=0, dtype=dtypes.int64) 545 546 @def_function.function 547 def worker_fn(iterator): 548 x = next(iterator) 549 v.assign_add(x) 550 return x 551 552 distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) 553 result = self.coordinator.schedule( 554 worker_fn, args=(iter(distributed_dataset),)) 555 result = self.coordinator.fetch(result) 556 self.assertEqual(result, (1,)) 557 result = self.coordinator.schedule( 558 worker_fn, args=(iter(distributed_dataset),)) 559 result = self.coordinator.fetch(result) 560 561 self.assertEqual(result, (1,)) 562 self.assertAlmostEqual(v.read_value(), 2, delta=1e-6) 563 564 def testAsyncScheduleAndJoin(self): 565 if test_util.is_xla_enabled(): 566 self.skipTest('Assign_add is not deterministic across threads in XLA') 567 568 def input_fn(): 569 return dataset_ops.DatasetV2.from_tensor_slices([2] * 10) 570 571 with self.strategy.scope(): 572 v = variables.Variable(initial_value=0, dtype=dtypes.int32) 573 574 # TODO(yuefengz): the following tf.function has a return value which is None 575 # in its structured_outputs. 576 @def_function.function 577 def worker_fn(iterator): 578 x = next(iterator) 579 v.assign_add(x) 580 581 distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) 582 583 iterator = iter(distributed_dataset) 584 585 # Verifying joining without any scheduling doesn't hang. 586 self.coordinator.join() 587 self.assertEqual(v.read_value().numpy(), 0) 588 589 for _ in range(5): 590 self.coordinator.schedule(worker_fn, args=(iterator,)) 591 self.coordinator.join() 592 593 # With 5 addition it should be 2*5 = 10. 594 self.assertEqual(v.read_value().numpy(), 10) 595 596 for _ in range(5): 597 self.coordinator.schedule(worker_fn, args=(iterator,)) 598 599 # Verifying multiple join is fine. 600 self.coordinator.join() 601 self.coordinator.join() 602 self.coordinator.join() 603 604 self.assertTrue(self.coordinator.done()) 605 606 # Likewise, it's now 20. 607 self.assertEqual(v.read_value().numpy(), 20.) 608 609 @parameterized.parameters(True, False) 610 def testInputFunctionWithMap(self, use_input_fn): 611 self._map_fn_tracing_count = 0 612 613 def input_fn(): 614 615 def map_fn(x): 616 self._map_fn_tracing_count += 1 617 return x + 10 618 619 return dataset_ops.DatasetV2.range(0, 10).map(map_fn) 620 621 @def_function.function 622 def worker_fn(iterator): 623 return next(iterator) 624 625 if use_input_fn: 626 distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) 627 else: 628 distributed_dataset = self.coordinator.create_per_worker_dataset( 629 input_fn()) 630 631 result = self.coordinator.schedule( 632 worker_fn, args=(iter(distributed_dataset),)) 633 self.assertEqual(result.fetch(), (10,)) 634 self.assertEqual(self._map_fn_tracing_count, 1) 635 636 def testInputFunctionCreateVariables(self): 637 638 def input_fn(): 639 v = variables.Variable(initial_value=0.0) 640 return v.read_value() 641 642 with self.assertRaises(ValueError): 643 self.coordinator.create_per_worker_dataset(input_fn) 644 645 @parameterized.parameters(True, False) 646 def testDatasetsShuffledDifferently(self, use_input_fn): 647 # This test requires at least two workers in the cluster. 648 self.assertGreaterEqual(len(self.coordinator._cluster.workers), 2) 649 650 random_seed.set_random_seed(None) 651 652 def input_fn(): 653 dataset = dataset_ops.DatasetV2.range(0, 100).shuffle(100).batch(1) 654 return self.strategy.experimental_distribute_dataset(dataset) 655 656 if use_input_fn: 657 distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) 658 else: 659 distributed_dataset = self.coordinator.create_per_worker_dataset( 660 input_fn()) 661 distributed_iterator = iter(distributed_dataset) 662 # Get elements from the first two iterators. 663 iterator_1 = distributed_iterator._values[0] 664 iterator_1 = iterator_1.fetch() 665 elements_in_iterator_1 = [ 666 self.strategy.experimental_local_results(e) 667 for e in iterator_1 668 ] 669 iterator_2 = distributed_iterator._values[1] 670 iterator_2 = iterator_2.fetch() 671 elements_in_iterator_2 = [ 672 self.strategy.experimental_local_results(e) 673 for e in iterator_2 674 ] 675 676 self.assertNotAllEqual(elements_in_iterator_1, elements_in_iterator_2) 677 678 def testPerWorkerValue(self): 679 self.skipTest('b/168569314') 680 var_shape = tuple() 681 var_dtype = dtypes.float32 682 var_name = 'var' 683 684 def create_var(): 685 var = variables.Variable( 686 initial_value=0.0, dtype=var_dtype, name=var_name) 687 self.assertIn('worker', var.device) 688 return var 689 690 worker_local_var = self.coordinator._create_per_worker_resources(create_var) 691 692 # The following is a workaround to allow `worker_local_var` to be passed in 693 # as args to the `coordinator.schedule` method which requires tensor specs 694 # to trace tf.function but _create_worker_resources' return values don't 695 # have tensor specs. We can get rid of this workaround once 696 # _create_worker_resources is able to infer the tensor spec of the return 697 # value of the function passed in. See b/154675763. 698 for var in worker_local_var._values: 699 var._type_spec = tensor_spec.TensorSpec(var_shape, var_dtype, var_name) 700 701 def worker_fn(var): 702 var.assign_add(1.0) 703 704 for _ in range(10): 705 # Which slice of `worker_local_var` will be used will depend on which 706 # worker the `worker_fn` gets scheduled on. 707 self.coordinator.schedule(worker_fn, args=(worker_local_var,)) 708 self.coordinator.join() 709 710 var_sum = sum(self.coordinator.fetch(worker_local_var._values)) 711 self.assertEqual(var_sum, 10.0) 712 713 def testDisallowRemoteValueAsInput(self): 714 715 @def_function.function 716 def func_0(): 717 return 1.0 718 719 @def_function.function 720 def func_1(x): 721 return x + 1.0 722 723 remote_v = self.coordinator.schedule(func_0) 724 with self.assertRaises(ValueError): 725 self.coordinator.schedule(func_1, args=(remote_v,)) 726 727 def testPythonFunctionNotAllowedToSchedule(self): 728 729 def func(a): 730 return array_ops.identity(a) 731 732 with self.assertRaisesRegexp( 733 TypeError, 734 '`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` ' 735 'only accepts a `tf.function` or a concrete function.'): 736 self.coordinator.schedule(func, args=(1,)) 737 738 def testDatasetPartiallyCreatedOnCoordinator(self): 739 dataset = dataset_ops.DatasetV2.range(1, 10) 740 741 @def_function.function 742 def input_fn(): 743 return dataset.shuffle(9) 744 745 @def_function.function 746 def worker_fn(iterator): 747 x = next(iterator) 748 return x 749 750 per_worker_dataset = self.coordinator.create_per_worker_dataset(input_fn) 751 self.coordinator.schedule(worker_fn, args=(iter(per_worker_dataset),)) 752 753 with self.assertRaisesRegexp( 754 coordinator_lib.ClosureInputError, 755 'error message is Failed copying input tensor from'): 756 self.coordinator.join() 757 758 def testPassDatasetToCreatePerWorkerDataset(self): 759 dataset = dataset_ops.DatasetV2.range(1, 11).batch(4) 760 761 @def_function.function 762 def worker_fn(iterator): 763 return next(iterator) 764 765 per_worker_dataset = self.coordinator.create_per_worker_dataset(dataset) 766 result = self.coordinator.schedule( 767 worker_fn, args=(iter(per_worker_dataset),)) 768 result = result.fetch() 769 expected_result = math_ops.range(1., 5.) 770 771 self.assertAllEqual(result, (expected_result)) 772 773 def testMultipleDatasets(self): 774 775 def input_fn1(): 776 return dataset_ops.DatasetV2.range(0, 5) 777 778 def input_fn2(): 779 return dataset_ops.DatasetV2.range(5, 10) 780 781 per_worker_dataset1 = self.coordinator.create_per_worker_dataset(input_fn1) 782 per_worker_iterator1 = iter(per_worker_dataset1) 783 per_worker_dataset2 = self.coordinator.create_per_worker_dataset(input_fn2) 784 per_worker_iterator2 = iter(per_worker_dataset2) 785 786 @def_function.function 787 def worker_fn(iterator1, iterator2): 788 return next(iterator1) + next(iterator2) 789 790 result = self.coordinator.schedule( 791 worker_fn, args=(per_worker_iterator1, per_worker_iterator2)) 792 self.assertEqual(result.fetch(), 5.0) 793 794 per_worker_dataset3 = self.coordinator.create_per_worker_dataset(input_fn1) 795 per_worker_iterator3 = iter(per_worker_dataset3) 796 797 result = self.coordinator.schedule( 798 worker_fn, args=(per_worker_iterator3, per_worker_iterator2)) 799 self.assertGreaterEqual(result.fetch(), 5.0) 800 801 def testRepeatedIteratorCreation(self): 802 803 def input_fn(): 804 return dataset_ops.DatasetV2.range(1, 100) 805 806 per_worker_dataset1 = self.coordinator.create_per_worker_dataset(input_fn) 807 per_worker_dataset2 = self.coordinator.create_per_worker_dataset(input_fn) 808 809 @def_function.function 810 def worker_fn(iterator1, iterator2): 811 return next(iterator1) + next(iterator2) 812 813 for _ in range(10): 814 per_worker_iterator1 = iter(per_worker_dataset1) 815 per_worker_iterator2 = iter(per_worker_dataset2) 816 result = self.coordinator.schedule( 817 worker_fn, args=(per_worker_iterator1, per_worker_iterator2)) 818 for _ in range(10): 819 self.coordinator.schedule( 820 worker_fn, args=(per_worker_iterator1, per_worker_iterator2)) 821 self.coordinator.join() 822 self.assertGreaterEqual(result.fetch(), 2.0) 823 del per_worker_iterator1, per_worker_iterator2 824 gc.collect() 825 826 # There shouldn't be any live iterator objects. 827 for w in self.coordinator._cluster.workers: 828 for r in w._resource_remote_value_refs: 829 self.assertIsNone(r()) 830 831 832class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest): 833 """Test basic functionality works with explicit maximum closure queue size. 834 835 Execute the same set of test cases as in `ClusterCoordinatorTest`, with an 836 explicit size limit for the closure queue. Note that even when the queue size 837 is set to infinite, there is still a maximum practical size (depends on host 838 memory limit) that might cause the queue.put operations to be blocking when 839 scheduling a large number of closures on a big cluster. These tests make sure 840 that the coordinator does not run into deadlocks in such scenario. 841 """ 842 843 @classmethod 844 def setUpClass(cls): 845 super(LimitedClosureQueueSizeBasicTest, cls).setUpClass() 846 coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2 847 cls.coordinator = make_coordinator(num_workers=5, num_ps=2) 848 cls.strategy = cls.coordinator.strategy 849 850 851class ScheduleStartDelayTest(ClusterCoordinatorTest): 852 """Test basic functionality works with worker scheduling delay. 853 854 This is basically to make sure that setting environment variables 855 `TF_COORDINATOR_SCHEDULE_START_DELAY` and 856 `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX` will cause any failure. 857 """ 858 859 @classmethod 860 def setUpClass(cls): 861 super(ScheduleStartDelayTest, cls).setUpClass() 862 os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY'] = '2' 863 os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY_MAX'] = '4' 864 cls.coordinator = make_coordinator(num_workers=3, num_ps=2) 865 cls.strategy = cls.coordinator.strategy 866 867 @classmethod 868 def tearDownClass(cls): 869 del os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY'] 870 del os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY_MAX'] 871 super(ScheduleStartDelayTest, cls).tearDownClass() 872 873 874class ErrorReportingTest(TestCaseWithErrorReportingThread): 875 876 @classmethod 877 def setUpClass(cls): 878 super(ErrorReportingTest, cls).setUpClass() 879 cls.coordinator = make_coordinator(num_workers=3, num_ps=2) 880 cls.strategy = cls.coordinator.strategy 881 882 with cls.strategy.scope(): 883 cls.iteration = variables.Variable(initial_value=0.0) 884 885 @def_function.function 886 def _normal_function(self): 887 x = random_ops.random_uniform((2, 10)) 888 y = random_ops.random_uniform((10, 2)) 889 self.iteration.assign_add(1.0) 890 return math_ops.reduce_mean(math_ops.matmul(x, y)) 891 892 @def_function.function 893 def _error_function(self): 894 x = random_ops.random_uniform((2, 10)) 895 y = random_ops.random_uniform((10, 2)) 896 check_ops.assert_non_positive_v2(math_ops.reduce_sum(math_ops.matmul(x, y))) 897 self.iteration.assign_add(1.0) 898 return self.iteration 899 900 @def_function.function 901 def _long_function(self): 902 x = random_ops.random_uniform((1000, 1000)) 903 for _ in math_ops.range(10000): 904 a = random_ops.random_uniform((1000, 1000)) 905 b = random_ops.random_uniform((1000, 1000)) 906 x += math_ops.matmul(a, b) 907 return x 908 909 def testJoinRaiseError(self): 910 for _ in range(3): 911 self.coordinator.schedule(self._normal_function) 912 self.coordinator.schedule(self._error_function) 913 with self.assertRaises(errors.InvalidArgumentError): 914 self.coordinator.join() 915 916 def testScheduleRaiseError(self): 917 for _ in range(3): 918 self.coordinator.schedule(self._normal_function) 919 self.coordinator.schedule(self._error_function) 920 with self.assertRaises(errors.InvalidArgumentError): 921 while True: 922 self.coordinator.schedule(self._normal_function) 923 924 def testScheduleRaiseErrorWithMultipleFailure(self): 925 for _ in range(3): 926 self.coordinator.schedule(self._normal_function) 927 self.coordinator.schedule(self._error_function) 928 with self.assertRaises(errors.InvalidArgumentError): 929 while True: 930 self.coordinator.schedule(self._error_function) 931 self.coordinator.join() 932 933 def testErrorWillbeCleared(self): 934 self.coordinator.schedule(self._error_function) 935 with self.assertRaises(errors.InvalidArgumentError): 936 self.coordinator.join() 937 938 for _ in range(3): 939 self.coordinator.schedule(self._normal_function) 940 self.coordinator.schedule(self._error_function) 941 with self.assertRaises(errors.InvalidArgumentError): 942 self.coordinator.join() 943 944 def testRemoteValueReturnError(self): 945 self.skipTest('TODO(b/211502459): Fix this in OSS test.') 946 947 result = self.coordinator.schedule(self._error_function) 948 949 with self.assertRaises(errors.InvalidArgumentError): 950 result.fetch() 951 952 # Clear the error. 953 with self.assertRaises(errors.InvalidArgumentError): 954 self.coordinator.join() 955 956 def testInputError(self): 957 958 worker_local_val = self.coordinator._create_per_worker_resources( 959 self._error_function) 960 961 @def_function.function 962 def func(x): 963 return x + 1 964 965 result = self.coordinator.schedule(func, args=(worker_local_val,)) 966 with self.assertRaises(coordinator_lib.ClosureInputError): 967 self.coordinator.join() 968 969 with self.assertRaises(coordinator_lib.ClosureInputError): 970 result.fetch() 971 972 def testErroredInputNotUsed(self): 973 input_0 = self.coordinator._create_per_worker_resources( 974 self._normal_function) 975 976 self.coordinator._create_per_worker_resources( 977 self._error_function) 978 979 @def_function.function 980 def func(x): 981 return x + 1 982 983 result = self.coordinator.schedule(func, args=(input_0,)) 984 985 # It should not raise. 986 self.coordinator.join() 987 result.fetch() 988 989 def testCancellation(self): 990 for _ in range(3): 991 self.coordinator.schedule(self._normal_function) 992 long_function = self.coordinator.schedule(self._long_function) 993 self.coordinator.schedule(self._error_function) 994 995 with self.assertRaises(errors.InvalidArgumentError): 996 self.coordinator.join() 997 998 with self.assertRaises(errors.CancelledError): 999 long_function.fetch() 1000 1001 for _ in range(3): 1002 self.coordinator.schedule(self._normal_function) 1003 self.coordinator.join() 1004 1005 def testResourceCanStillbeUsedAfterCancellation(self): 1006 1007 def input_fn(): 1008 return dataset_ops.DatasetV2.range(0, 5) 1009 1010 per_worker_dataset = self.coordinator.create_per_worker_dataset(input_fn) 1011 per_worker_iterator = iter(per_worker_dataset) 1012 1013 @def_function.function 1014 def worker_fn(iterator): 1015 return next(iterator) 1016 1017 self.coordinator.schedule(worker_fn, args=(per_worker_iterator,)) 1018 self.coordinator.schedule(self._error_function) 1019 1020 with self.assertRaises(errors.InvalidArgumentError): 1021 self.coordinator.join() 1022 1023 self.coordinator.schedule(worker_fn, args=(per_worker_iterator,)) 1024 self.coordinator.join() 1025 1026 1027class LimitedClosureQueueErrorTest(ErrorReportingTest): 1028 """Test error reporting works with explicit maximum closure queue size. 1029 1030 Execute the same set of test cases as in ErrorReportingTest, with an explicit 1031 size limit for the closure queue. 1032 """ 1033 1034 @classmethod 1035 def setUpClass(cls): 1036 super(LimitedClosureQueueErrorTest, cls).setUpClass() 1037 coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2 1038 cls.coordinator = make_coordinator(num_workers=3, num_ps=2) 1039 cls.strategy = cls.coordinator.strategy 1040 1041 with cls.coordinator.strategy.scope(): 1042 cls.iteration = variables.Variable(initial_value=0.0) 1043 1044 1045class StrategyIntegrationTest(test.TestCase, parameterized.TestCase): 1046 1047 @classmethod 1048 def setUpClass(cls): 1049 super(StrategyIntegrationTest, cls).setUpClass() 1050 cls.coordinator = make_coordinator(num_workers=1, num_ps=1) 1051 cls.strategy = cls.coordinator.strategy 1052 1053 def testRunNotUsedWithClusterCoordinatorSchedule(self): 1054 1055 @def_function.function 1056 def input_fn(): 1057 return dataset_ops.DatasetV2.range(1, 3) 1058 1059 with self.strategy.scope(): 1060 v = variables.Variable(initial_value=1, dtype=dtypes.int64) 1061 1062 def replica_fn(input_tensor): 1063 return input_tensor + v, input_tensor - v 1064 1065 @def_function.function 1066 def worker_fn(iterator): 1067 return self.strategy.run(replica_fn, args=(next(iterator),)) 1068 1069 per_worker_dataset = self.coordinator.create_per_worker_dataset(input_fn) 1070 1071 @contextlib.contextmanager 1072 def _assert_logs_usage_warning(): 1073 with self.assertLogs(level='WARNING') as logs: 1074 yield 1075 1076 self.assertIn( 1077 'A `tf.distribute.experimental.ParameterServerStrategy` method is ' 1078 'invoked without using `ClusterCoordinator.schedule`. If you are not ' 1079 'tracing a tf.function, this method is possibly executed on the ' 1080 'coordinator, which can be slow. To properly dispatch functions to ' 1081 'run on workers, methods like `run` or `reduce` should be used ' 1082 'within a function passed to `tf.distribute.experimental.coordinator.' 1083 'ClusterCoordinator.schedule`.', logs.output[0]) 1084 1085 with _assert_logs_usage_warning(): 1086 # Invoking `run` without `coordinator.schedule` should result in a 1087 # warning. 1088 self.strategy.run( 1089 replica_fn, args=(constant_op.constant(1, dtype=dtypes.int64),)) 1090 1091 # A proper `schedule` should succeed. 1092 rv = self.coordinator.schedule(worker_fn, args=(iter(per_worker_dataset),)) 1093 1094 with _assert_logs_usage_warning(): 1095 # Invoking `run` without `coordinator.schedule` again should result in a 1096 # warning. 1097 self.strategy.run( 1098 replica_fn, args=(constant_op.constant(1, dtype=dtypes.int64),)) 1099 1100 all_results = [(2, 0)] * self.strategy.num_replicas_in_sync 1101 expected_result = [] 1102 for i in range(self.strategy.num_replicas_in_sync): 1103 expected_result.append(all_results[i]) 1104 1105 self.assertAllEqual( 1106 tuple(expected_result), 1107 self.strategy.experimental_local_results(rv.fetch())) 1108 1109 def testBasicVariableAssignment(self): 1110 self.strategy.extended._variable_count = 0 1111 with self.strategy.scope(): 1112 v1 = variables.Variable(initial_value=0.0) 1113 v2 = variables.Variable(initial_value=1.0) 1114 self.assertEqual(self.strategy.extended._variable_count, 2) 1115 1116 @def_function.function 1117 def worker_fn(): 1118 v1.assign_add(0.1) 1119 v2.assign_sub(0.2) 1120 return v1.read_value() / v2.read_value() 1121 1122 results = self.coordinator.schedule(worker_fn) 1123 logging.info('Results of experimental_run_v2: %f', 1124 self.coordinator.fetch(results)) 1125 1126 self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6) 1127 self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6) 1128 1129 def testRunAndReduce(self): 1130 self.assertFalse(distribution_strategy_context.in_cross_replica_context()) 1131 with self.strategy.scope(): 1132 self.assertTrue(distribution_strategy_context.in_cross_replica_context()) 1133 v = variables.Variable(initial_value=1.) 1134 1135 expected_result = (4. * self.strategy.num_replicas_in_sync, 1136 2. * self.strategy.num_replicas_in_sync) 1137 1138 @def_function.function 1139 def worker_fn(input_tensor): 1140 1141 def replica_fn(input_tensor): 1142 # Within `replica_fn`, it has to be in a replica context. 1143 self.assertFalse( 1144 distribution_strategy_context.in_cross_replica_context()) 1145 return input_tensor + v, input_tensor - v 1146 1147 run_result = self.strategy.run(replica_fn, args=(input_tensor,)) 1148 reduced_result = self.strategy.reduce('SUM', run_result, axis=None) 1149 check_ops.assert_equal_v2(reduced_result, expected_result) 1150 return reduced_result 1151 1152 # Asserting scheduling in scope has the expected behavior. 1153 result = self.coordinator.schedule( 1154 worker_fn, args=(constant_op.constant(3.),)) 1155 self.assertIsInstance(result, coordinator_lib.RemoteValue) 1156 self.assertEqual(result.fetch(), expected_result) 1157 1158 # Asserting scheduling out of scope has the expected behavior. 1159 result = self.coordinator.schedule( 1160 worker_fn, args=(constant_op.constant(3.),)) 1161 self.assertEqual(result.fetch(), expected_result) 1162 1163 def testRunAndReduceWithAssignAdd(self): 1164 self.assertFalse(distribution_strategy_context.in_cross_replica_context()) 1165 with self.strategy.scope(): 1166 self.assertTrue(distribution_strategy_context.in_cross_replica_context()) 1167 v = variables.Variable(initial_value=1.) 1168 v1 = variables.Variable( 1169 initial_value=0., 1170 aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA) 1171 1172 expected_result = (4. * self.strategy.num_replicas_in_sync, 1173 2. * self.strategy.num_replicas_in_sync) 1174 1175 @def_function.function 1176 def worker_fn(input_tensor): 1177 1178 def replica_fn(input_tensor): 1179 # Within `replica_fn`, it has to be in a replica context. 1180 self.assertFalse( 1181 distribution_strategy_context.in_cross_replica_context()) 1182 1183 v1.assign_add(input_tensor) 1184 return input_tensor + v, input_tensor - v 1185 1186 run_result = self.strategy.run(replica_fn, args=(input_tensor,)) 1187 reduced_result = self.strategy.reduce('SUM', run_result, axis=None) 1188 check_ops.assert_equal_v2(reduced_result, expected_result) 1189 return reduced_result 1190 1191 # Asserting scheduling in scope has the expected behavior. 1192 result = self.coordinator.schedule( 1193 worker_fn, args=(constant_op.constant(3.),)) 1194 self.assertIsInstance(result, coordinator_lib.RemoteValue) 1195 self.assertEqual(result.fetch(), expected_result) 1196 1197 # Asserting scheduling out of scope has the expected behavior. 1198 result = self.coordinator.schedule( 1199 worker_fn, args=(constant_op.constant(3.),)) 1200 self.assertEqual(result.fetch(), expected_result) 1201 self.assertEqual(v1, 6.) 1202 1203 def testVariableAggregation(self): 1204 self.assertFalse(distribution_strategy_context.in_cross_replica_context()) 1205 with self.strategy.scope(): 1206 self.assertTrue(distribution_strategy_context.in_cross_replica_context()) 1207 v = variables.Variable( 1208 initial_value=1., 1209 aggregation=variable_scope.VariableAggregation.SUM) 1210 1211 @def_function.function 1212 def worker_fn(): 1213 1214 def replica_fn(): 1215 value = math_ops.cast( 1216 distribution_strategy_context.get_replica_context() 1217 .replica_id_in_sync_group + 1, v.dtype) 1218 v.assign(value) 1219 1220 self.strategy.run(replica_fn) 1221 1222 self.coordinator.schedule(worker_fn) 1223 self.coordinator.join() 1224 expected_result = 0. 1225 for i in range(self.strategy.num_replicas_in_sync): 1226 expected_result = expected_result + i + 1 1227 self.assertEqual(v, expected_result) 1228 1229 def testVariableCaching(self): 1230 self.assertFalse(distribution_strategy_context.in_cross_replica_context()) 1231 with self.strategy.scope(): 1232 self.assertTrue(distribution_strategy_context.in_cross_replica_context()) 1233 v = variables.Variable( 1234 initial_value=1., 1235 aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA) 1236 1237 # Test read value inside caching scope 1238 with distribute_utils.cache_variable_reads(): 1239 v.read_value() # Reads value 1.0 1240 v.assign(constant_op.constant(5.0)) # v changes to 5.0 1241 self.assertEqual(v.read_value(), 1.0) # should be cached 1.0 value. 1242 1243 # Reset v to 2.0 1244 v.assign(2.0) 1245 1246 # Test convert to tensor value inside caching scope 1247 with distribute_utils.cache_variable_reads(): 1248 t = v * 3.0 1249 self.assertEqual(t, 6.0) 1250 v.assign(3.0) 1251 t1 = v * 3.0 1252 self.assertEqual(t1, 6.0) # should be cached 2.0 * 3.0 value. 1253 1254 # Reset v to 1.0 1255 v.assign(1.0) 1256 1257 # Verify caching scope inside tf.function 1258 @def_function.function 1259 def worker_fn(): 1260 with distribute_utils.cache_variable_reads(): 1261 def replica_fn(): 1262 t = v.read_value() # Reads value 1.0 1263 v.assign(constant_op.constant(5.0)) # v changes to 5.0 1264 t = v.read_value() # should return 1.0 1265 return t # Should be 1.0 instead of 5.0 1266 1267 return self.strategy.run(replica_fn) 1268 1269 result = self.coordinator.schedule(worker_fn) 1270 result = result.fetch() 1271 expected_result = 1. 1272 self.assertEqual(result, expected_result) 1273 1274 # Verify that v.read_value works as expected outside of scope. 1275 v.assign(4.0) 1276 self.assertEqual(v.read_value(), 4.0) 1277 1278 v.assign(constant_op.constant(2.0)) # v changes to 2.0 1279 # Check with scope outside of tf function and check that cache is reset 1280 @def_function.function 1281 def worker_fn1(): 1282 def replica_fn(): 1283 t = v.read_value() # Reads value 2.0 ==> Should be cached 1284 v.assign(constant_op.constant(5.0)) # v changes to 5.0 1285 t = v.read_value() # should return cached value 2.0 1286 return t # Should be 2.0 instead of 5.0 1287 1288 return self.strategy.run(replica_fn) 1289 1290 with distribute_utils.cache_variable_reads(): 1291 result = self.coordinator.schedule(worker_fn1) 1292 result = result.fetch() 1293 expected_result = 2. 1294 self.assertEqual(result, expected_result) 1295 1296 # Verify scope nesting is not permitted. 1297 with self.assertRaises(ValueError): 1298 with distribute_utils.cache_variable_reads(): 1299 with distribute_utils.cache_variable_reads(): 1300 v.read_value() 1301 1302 @parameterized.parameters(True, False) 1303 def testDistributedDatasetInsidePerWorkerDatasetFn(self, from_function): 1304 if from_function: 1305 1306 def per_worker_dataset_fn(): 1307 dataset_fn = lambda _: dataset_ops.DatasetV2.range(1, 11).batch(4) 1308 return self.strategy.distribute_datasets_from_function(dataset_fn) 1309 else: 1310 1311 def per_worker_dataset_fn(): 1312 dataset = dataset_ops.DatasetV2.range(1, 11).batch(4) 1313 return self.strategy.experimental_distribute_dataset(dataset) 1314 1315 @def_function.function 1316 def worker_fn(iterator): 1317 return self.strategy.experimental_local_results(next(iterator)) 1318 1319 per_worker_dataset = self.coordinator.create_per_worker_dataset( 1320 per_worker_dataset_fn) 1321 result = self.coordinator.schedule( 1322 worker_fn, args=(iter(per_worker_dataset),)) 1323 result = result.fetch() 1324 expected_result = array_ops.split( 1325 math_ops.range(1., 5.), 1326 num_or_size_splits=self.strategy.num_replicas_in_sync, 1327 axis=0) 1328 1329 self.assertAllEqual(result, (expected_result)) 1330 1331 @parameterized.parameters(True, False) 1332 def testPassDistributedDatasetToCreatePerWorkerDataset(self, from_function): 1333 if from_function: 1334 dataset_fn = lambda _: dataset_ops.DatasetV2.range(1, 11).batch(4) 1335 distributed_dataset = self.strategy.distribute_datasets_from_function( 1336 dataset_fn) 1337 else: 1338 dataset = dataset_ops.DatasetV2.range(1, 11).batch(4) 1339 distributed_dataset = self.strategy.experimental_distribute_dataset( 1340 dataset) 1341 1342 @def_function.function 1343 def worker_fn(iterator): 1344 return self.strategy.experimental_local_results(next(iterator)) 1345 1346 per_worker_dataset = self.coordinator.create_per_worker_dataset( 1347 distributed_dataset) 1348 result = self.coordinator.schedule( 1349 worker_fn, args=(iter(per_worker_dataset),)) 1350 result = result.fetch() 1351 expected_result = array_ops.split( 1352 math_ops.range(1., 5.), 1353 num_or_size_splits=self.strategy.num_replicas_in_sync, 1354 axis=0) 1355 1356 self.assertAllEqual(result, (expected_result)) 1357 1358 def testDistributeDatasetsFromFunction(self): 1359 1360 def per_worker_dataset_fn(): 1361 1362 def input_worker_device_fn(input_context): 1363 self.assertIsNotNone(input_context) 1364 return dataset_ops.DatasetV2.range(1, 11).batch(1) 1365 1366 return self.strategy.distribute_datasets_from_function( 1367 input_worker_device_fn) 1368 1369 @def_function.function 1370 def worker_fn(iterator): 1371 result = self.strategy.experimental_local_results(next(iterator)) 1372 return result 1373 1374 distributed_dataset = self.coordinator.create_per_worker_dataset( 1375 per_worker_dataset_fn) 1376 result = self.coordinator.schedule( 1377 worker_fn, args=(iter(distributed_dataset),)) 1378 result = result.fetch() 1379 expected_result = [] 1380 for i in range(self.strategy.num_replicas_in_sync): 1381 expected_result.append([1 + i]) 1382 self.assertAllEqual(result, expected_result) 1383 1384 def testAsyncScheduleWithDistributedDataset(self): 1385 1386 def input_fn(): 1387 dataset = dataset_ops.DatasetV2.from_tensor_slices([2.]).repeat().batch( 1388 self.strategy.num_replicas_in_sync) 1389 return self.strategy.experimental_distribute_dataset(dataset) 1390 1391 with self.strategy.scope(): 1392 v = variables.Variable(initial_value=[0], dtype=dtypes.float32) 1393 1394 # TODO(yuefengz): the following tf.function has a return value which is None 1395 # in its structured_outputs. 1396 @def_function.function 1397 def worker_fn(iterator): 1398 x = next(iterator) 1399 # Reduce to convert PerReplica values to single value 1400 reduced_value = self.strategy.reduce('MEAN', x, axis=None) 1401 v.assign_add(reduced_value) 1402 1403 distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) 1404 1405 iterator = iter(distributed_dataset) 1406 1407 # Verifying joining without any scheduling doesn't hang. 1408 self.coordinator.join() 1409 self.assertAllEqual(v.read_value(), (0,)) 1410 1411 for _ in range(5): 1412 self.coordinator.schedule(worker_fn, args=(iterator,)) 1413 self.coordinator.join() 1414 1415 # With 5 addition it should be 2*5 = 10. 1416 self.assertAllEqual( 1417 self.strategy.experimental_local_results(v.read_value()), ([[10]])) 1418 1419 for _ in range(5): 1420 self.coordinator.schedule(worker_fn, args=(iterator,)) 1421 1422 # Verifying multiple join is fine. 1423 self.coordinator.join() 1424 self.coordinator.join() 1425 self.coordinator.join() 1426 1427 self.assertTrue(self.coordinator.done()) 1428 1429 # Likewise, it's now 20. 1430 self.assertAllEqual( 1431 self.strategy.experimental_local_results(v.read_value()), ([[20]])) 1432 1433 def testInputFunctionWithMapWithDistributedDataset(self): 1434 self._map_fn_tracing_count = 0 1435 1436 def input_fn(): 1437 1438 def map_fn(x): 1439 self._map_fn_tracing_count += 1 1440 return x + 10 1441 1442 dataset = dataset_ops.DatasetV2.range(0, 10).batch( 1443 self.strategy.num_replicas_in_sync).map(map_fn) 1444 return self.strategy.experimental_distribute_dataset(dataset) 1445 1446 @def_function.function 1447 def worker_fn(iterator): 1448 return next(iterator) 1449 1450 distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) 1451 result = self.coordinator.schedule( 1452 worker_fn, args=(iter(distributed_dataset),)) 1453 1454 expected_result = array_ops.split( 1455 math_ops.range(10., 10. + self.strategy.num_replicas_in_sync), 1456 num_or_size_splits=self.strategy.num_replicas_in_sync, 1457 axis=0) 1458 1459 self.assertAllEqual( 1460 self.strategy.experimental_local_results(result.fetch()), 1461 tuple(expected_result)) 1462 self.assertEqual(self._map_fn_tracing_count, 1) 1463 1464 def testPerWorkerDistributeDatasetsElementSpec(self): 1465 1466 def per_worker_dataset_fn(): 1467 return self.strategy.distribute_datasets_from_function( 1468 lambda _: dataset_ops.DatasetV2.from_tensor_slices([1, 2])) 1469 1470 dataset = dataset_ops.DatasetV2.from_tensor_slices([1, 2]) 1471 per_worker_distribute_dataset = self.coordinator.create_per_worker_dataset( 1472 per_worker_dataset_fn) 1473 1474 self.assertAllEqual( 1475 # Converts to PerReplicaSpec when num_replicas_in_sync are > 1 1476 input_lib._create_distributed_tensor_spec(self.strategy, 1477 dataset.element_spec), 1478 per_worker_distribute_dataset.element_spec) 1479 1480 def testPerWorkerDistributedIteratorTypeSpec(self): 1481 self._tracing_count = 0 1482 1483 def per_worker_dataset_fn(): 1484 self._tracing_count += 1 1485 return self.strategy.distribute_datasets_from_function( 1486 lambda _: dataset_ops.DatasetV2.range(1, 2)) 1487 1488 @def_function.function 1489 def worker_fn(iterator): 1490 return next(iterator) 1491 1492 distributed_iterator = iter( 1493 self.coordinator.create_per_worker_dataset(per_worker_dataset_fn)) 1494 worker_fn.get_concrete_function(distributed_iterator) 1495 1496 self.coordinator.schedule(worker_fn, args=(distributed_iterator,)) 1497 self.assertEqual(self._tracing_count, 1) 1498 1499 1500if __name__ == '__main__': 1501 v2_compat.enable_v2_behavior() 1502 test.main() 1503