1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Fault tolerance test for parameter server training in TF2.""" 16 17import gc 18import sys 19import threading 20import time 21 22from tensorflow.python.compat import v2_compat 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.distribute import multi_process_runner 25from tensorflow.python.distribute import multi_worker_test_base 26from tensorflow.python.distribute import parameter_server_strategy_v2 27from tensorflow.python.distribute import test_util 28from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 29from tensorflow.python.distribute.coordinator import cluster_coordinator 30from tensorflow.python.eager import context 31from tensorflow.python.eager import def_function 32from tensorflow.python.eager import test 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import errors 35from tensorflow.python.framework import ops 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import check_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import random_ops 40from tensorflow.python.ops import variables 41from tensorflow.python.platform import tf_logging as logging 42from tensorflow.python.training import coordinator as thread_coordinator 43from tensorflow.python.training import server_lib 44 45_RPC_ERROR_FROM_WORKER = "GRPC error information from remote target /job:worker" 46_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps" 47_WORKER_PREEMPTION_THREAD_NAME = "WorkerPreemptionHandler" 48_WORKER_THREAD_PREFIX = "WorkerClosureProcessingLoop" 49 50 51class Model(object): 52 53 def __init__(self, coordinator): 54 self.cluster_coord = coordinator 55 self.strategy = self.cluster_coord.strategy 56 with self.cluster_coord.strategy.scope(): 57 self.build() 58 59 def build(self): 60 self.w = variables.Variable( 61 initial_value=random_ops.random_uniform((10, 10)), dtype=dtypes.float32) 62 self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32) 63 # Allow external control to make the model run its train_fn in an infinite 64 # loop. This allows us to reliably test worker preemption in the middle of 65 # function execution. 66 self.do_infinite_step = variables.Variable(False) 67 68 self.rebuild_iterators() 69 70 def rebuild_iterators(self, use_dataset_fn=True): 71 72 if use_dataset_fn: 73 74 def dataset_fn(): 75 data = random_ops.random_uniform((10, 10)) 76 dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat() 77 return dataset 78 79 def distribute_dataset_fn(): 80 return self.cluster_coord.strategy.distribute_datasets_from_function( 81 lambda _: dataset_fn()) 82 83 self.iterator = iter( 84 self.cluster_coord.create_per_worker_dataset(distribute_dataset_fn)) 85 self.iterator2 = iter( 86 self.cluster_coord.create_per_worker_dataset(distribute_dataset_fn)) 87 else: 88 data = random_ops.random_uniform((10, 10)) 89 dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat() 90 91 self.iterator = iter( 92 self.cluster_coord.create_per_worker_dataset(dataset)) 93 self.iterator2 = iter( 94 self.cluster_coord.create_per_worker_dataset(dataset)) 95 96 def _train_fn_internal(self, iterator, iterator2): 97 x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w) 98 x = math_ops.matmul(array_ops.squeeze(next(iterator2)), x) 99 x = math_ops.matmul(random_ops.random_uniform((10, 10)), x) 100 self.w.assign_add(x) 101 102 @def_function.function 103 def train_fn(self, iterator, iterator2): 104 self._train_fn_internal(iterator, iterator2) 105 while self.do_infinite_step: 106 self._train_fn_internal(iterator, iterator2) 107 self.iterations.assign_add(1) 108 109 def schedule_training_functions(self, num_steps): 110 with self.strategy.scope(): 111 for _ in range(num_steps): 112 self.cluster_coord.schedule( 113 self.train_fn, args=(self.iterator, self.iterator2)) 114 115 def join_training_functions(self): 116 self.do_infinite_step.assign(False) 117 self.cluster_coord.join() 118 119 120class BaseFaultToleranceTest(object): # pylint: disable=missing-docstring 121 122 def setUp(self, num_workers, num_ps): 123 super(BaseFaultToleranceTest, self).setUp() 124 125 self._cluster = multi_worker_test_base.create_multi_process_cluster( 126 num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") 127 self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict() 128 self._cluster_def["chief"] = [ 129 "localhost:%d" % multi_worker_test_base.pick_unused_port() 130 ] 131 cluster_resolver = SimpleClusterResolver( 132 server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc") 133 134 # The strategy's constructor would connect to the cluster. 135 self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 136 cluster_resolver) 137 self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy) 138 139 self.thread_coord = thread_coordinator.Coordinator( 140 clean_stop_exception_types=[]) 141 self.num_workers = num_workers 142 self.num_ps = num_ps 143 144 def tearDown(self): 145 super(BaseFaultToleranceTest, self).tearDown() 146 self._cluster.stop() 147 self._cluster = None 148 149 def _restart(self, downtime_secs, job): 150 """Kills `job` (index: 0) and restarts it after `downtime_secs`. 151 152 Args: 153 downtime_secs: secs before restarting the job. 154 job: a string specifying the job to restart. 155 """ 156 self._cluster.kill_task(job, 0) 157 time.sleep(downtime_secs) 158 self.assertFalse(context.check_alive("/job:%s/replica:0/task:0" % job)) 159 self._cluster.start_task(job, 0) 160 while not context.check_alive("/job:%s/replica:0/task:0" % job): 161 time.sleep(1) 162 163 def _restart_in_thread(self, downtime_secs, restart_job): 164 165 def _restart_fn(): 166 with self.thread_coord.stop_on_exception(): 167 self._restart(downtime_secs, restart_job) 168 169 restart_thread = threading.Thread(target=_restart_fn) 170 restart_thread.start() 171 return restart_thread 172 173 def _ensure_threads_closed(self): 174 """Ensures worker and preemption threads are closed.""" 175 # Worker and preemption threads should exist before releasing 176 # ClusterCoordinator. 177 running_threads = test_util.get_running_threads() 178 self.assertTrue( 179 test_util.has_thread(_WORKER_THREAD_PREFIX, running_threads)) 180 self.assertIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads) 181 182 # Print object graph if ClusterCoordinator may leak. 183 if sys.getrefcount(self.cluster_coord) > 2: 184 try: 185 test_util.show_backref(self.cluster_coord) 186 except: # pylint: disable=bare-except 187 pass 188 189 # Wait for threads to close. 190 self.cluster_coord = None 191 self.strategy = None 192 gc.collect() 193 time.sleep(1) 194 195 # Verify thread names. 196 running_threads = test_util.get_running_threads() 197 self.assertNotIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads) 198 self.assertFalse( 199 test_util.has_thread(_WORKER_THREAD_PREFIX, running_threads), 200 "Worker thread is not stopped properly.") 201 202 def _create_model_and_run_indefinitely(self): 203 model = Model(self.cluster_coord) 204 model.do_infinite_step.assign(True) 205 model.schedule_training_functions(10) 206 # Model does infinite training step, so at this moment, we expect to have 207 # `self.num_workers` infinite closures inflight, and `10-self.num_workers` 208 # closures in the queue. 209 while (self.cluster_coord._cluster.closure_queue._inflight_closure_count < 210 self.num_workers): 211 time.sleep(0.1) 212 return model 213 214 def testClusterCoordinatorDestroyed(self): 215 self._ensure_threads_closed() 216 217 def testWorkerPreemptionBetweenFunctions(self): 218 model = Model(self.cluster_coord) 219 model.schedule_training_functions(2) 220 model.join_training_functions() 221 self.assertEqual(model.iterations.numpy(), 2) 222 223 self._restart(downtime_secs=2, job="worker") 224 225 model.schedule_training_functions(2) 226 model.join_training_functions() 227 self.assertEqual(model.iterations.numpy(), 4) 228 229 def testWorkerPreemptionMidstFunction(self): 230 model = Model(self.cluster_coord) 231 model.do_infinite_step.assign(True) 232 233 model.schedule_training_functions(4) 234 # Model does infinite training step, so at this moment, we expect to have 235 # `self.num_workers` infinite closures inflight, and `4-self.num_workers` 236 # closures in the queue. 237 while (self.cluster_coord._cluster.closure_queue._inflight_closure_count < 238 self.num_workers): 239 time.sleep(0.1) 240 self.assertFalse(self.cluster_coord.done()) 241 self._restart(downtime_secs=2, job="worker") 242 model.join_training_functions() 243 self.assertGreaterEqual(model.iterations.numpy(), 4) 244 245 def testOneWorkerPreemptionWithCancellation(self): 246 247 @def_function.function 248 def normal_function(): 249 x = random_ops.random_uniform((2, 10)) 250 y = random_ops.random_uniform((10, 2)) 251 return math_ops.reduce_mean(math_ops.matmul(x, y)) 252 253 @def_function.function 254 def error_function(): 255 x = random_ops.random_uniform((2, 10)) 256 y = random_ops.random_uniform((10, 2)) 257 check_ops.assert_non_positive_v2( 258 math_ops.reduce_sum(math_ops.matmul(x, y))) 259 return x 260 261 @def_function.function 262 def long_function(): 263 x = random_ops.random_uniform((1000, 1000)) 264 for _ in math_ops.range(10000): 265 a = random_ops.random_uniform((1000, 1000)) 266 b = random_ops.random_uniform((1000, 1000)) 267 x += math_ops.matmul(a, b) 268 return x 269 270 for _ in range(3): 271 self.cluster_coord.schedule(normal_function) 272 long_function_result = self.cluster_coord.schedule(long_function) 273 self.cluster_coord.schedule(error_function) 274 275 time.sleep(1) # Let it run a couple steps. 276 self._restart(1, "worker") 277 278 with self.assertRaises(errors.InvalidArgumentError): 279 self.cluster_coord.join() 280 281 with self.assertRaises(errors.CancelledError): 282 long_function_result.fetch() 283 284 for _ in range(3): 285 self.cluster_coord.schedule(normal_function) 286 self.cluster_coord.join() 287 288 # The cluster is likely still being recovered since `join` returned early 289 # due to the error_function. 290 failure_handler = self.cluster_coord._cluster.failure_handler 291 failure_handler.stop() 292 failure_handler._preemption_handler_thread.join() 293 294 def testHandleDatasetCreationFailureWithDatasetFn(self): 295 model = Model(self.cluster_coord) 296 297 restart_thread = self._restart_in_thread(5, "worker") 298 299 model.schedule_training_functions(3) 300 model.rebuild_iterators() 301 model.schedule_training_functions(3) 302 model.rebuild_iterators() 303 model.schedule_training_functions(3) 304 305 model.join_training_functions() 306 307 self.thread_coord.join([restart_thread]) 308 self.assertGreaterEqual(model.iterations.numpy(), 3) 309 310 # TODO(yuefengz): consider using combinations when there is more code 311 # duplication. 312 def testHandleDatasetCreationFailureWithDataset(self): 313 model = Model(self.cluster_coord) 314 315 restart_thread = self._restart_in_thread(5, "worker") 316 317 model.schedule_training_functions(3) 318 model.rebuild_iterators(use_dataset_fn=False) 319 model.schedule_training_functions(3) 320 model.rebuild_iterators(use_dataset_fn=False) 321 model.schedule_training_functions(3) 322 323 model.join_training_functions() 324 325 self.thread_coord.join([restart_thread]) 326 self.assertGreaterEqual(model.iterations.numpy(), 3) 327 328 def testWorkerPreemptionErrorType(self): 329 330 @def_function.function 331 def worker_train_fn(): 332 x = random_ops.random_uniform((2, 10)) 333 y = random_ops.random_uniform((10, 2)) 334 return math_ops.reduce_mean(math_ops.matmul(x, y)) 335 336 def run_fn(): 337 with self.thread_coord.stop_on_exception(): 338 with ops.device("/job:worker/replica:0/task:0"): 339 for _ in range(3): 340 for _ in range(3): 341 worker_train_fn() 342 time.sleep(5) 343 344 run_thread = threading.Thread(target=run_fn) 345 run_thread.start() 346 time.sleep(1) # Let it run a couple steps. 347 self._restart(2, "worker") 348 349 try: 350 self.thread_coord.join([run_thread]) 351 except (errors.UnavailableError, errors.AbortedError) as e: 352 logging.info("Got exception %r, error message is %s", e, e) 353 354 self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except 355 self.assertNotIn(_RPC_ERROR_FROM_PS, str(e)) 356 357 self.assertTrue("failed to connect to all addresses" in str(e) or 358 "Unable to find a context_id" in str(e) or 359 "Socket closed" in str(e) or 360 "Connection reset by peer" in str(e) or 361 "Transport closed" in str(e)) 362 363 def testWorkerPreemptionErrorTypeWithPythonFunction(self): 364 365 def worker_train_fn(): 366 x = random_ops.random_uniform((2, 10)) 367 y = random_ops.random_uniform((10, 2)) 368 return math_ops.reduce_mean(math_ops.matmul(x, y)) 369 370 def run_fn(): 371 with self.thread_coord.stop_on_exception(): 372 with ops.device("/job:worker/replica:0/task:0"): 373 for _ in range(3): 374 for _ in range(3): 375 worker_train_fn() 376 time.sleep(5) 377 378 run_thread = threading.Thread(target=run_fn) 379 run_thread.start() 380 time.sleep(1) # Let it run a couple steps. 381 self._restart(2, "worker") 382 383 try: 384 self.thread_coord.join([run_thread]) 385 except (errors.UnavailableError, errors.AbortedError) as e: 386 logging.info("Got exception %r, error message is %s", e, e) 387 388 self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except 389 self.assertNotIn(_RPC_ERROR_FROM_PS, str(e)) 390 391 self.assertTrue("failed to connect to all addresses" in str(e) or 392 "Unable to find a context_id" in str(e) or 393 "Socket closed" in str(e) or 394 "Connection reset by peer" in str(e) or 395 "Transport closed" in str(e)) 396 397 def testPSPreemptionErrorType(self): 398 399 with ops.device("/job:ps/replica:0/task:0"): 400 v = variables.Variable( 401 initial_value=random_ops.random_uniform((2, 10)), 402 dtype=dtypes.float32) 403 404 @def_function.function 405 def worker_train_fn(): 406 y = random_ops.random_uniform((10, 2)) 407 return math_ops.reduce_mean(math_ops.matmul(v, y)) 408 409 def run_fn(): 410 with self.thread_coord.stop_on_exception(): 411 with ops.device("/job:worker/replica:0/task:0"): 412 for _ in range(3): 413 for _ in range(3): 414 worker_train_fn() 415 time.sleep(5) 416 417 run_thread = threading.Thread(target=run_fn) 418 run_thread.start() 419 time.sleep(1) # Let it run a couple steps. 420 421 # Use a short restart delay to cover the case that RPC channel is reused 422 self._restart(1, "ps") 423 424 try: 425 self.thread_coord.join([run_thread]) 426 except (errors.UnavailableError, errors.AbortedError) as e: 427 logging.info("Got exception %r, error message is %s", e, e) 428 self.assertIn(_RPC_ERROR_FROM_PS, str(e)) # pylint: disable=g-assert-in-except 429 430 if isinstance(e, errors.UnavailableError): 431 self.assertTrue("failed to connect to all addresses" in str(e) or 432 "Socket closed" in str(e) or 433 "Connection reset by peer" in str(e) or 434 "Transport closed" in str(e)) 435 436 if isinstance(e, errors.AbortedError): 437 self.assertTrue( 438 "RecvTensor expects a different device incarnation" in str(e) or 439 "Unable to find a context_id" in str(e)) 440 self._ensure_threads_closed() 441 442 def testTwoWorkersPreempted(self): 443 if self.num_workers < 2: 444 self.skipTest("Worker number is less than 2.") 445 model = self._create_model_and_run_indefinitely() 446 447 self.assertFalse(self.cluster_coord.done()) 448 self._cluster.kill_task("worker", 0) 449 self._cluster.kill_task("worker", 1) 450 time.sleep(2) 451 self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) 452 self.assertFalse(context.check_alive("/job:worker/replica:0/task:1")) 453 self._cluster.start_task("worker", 0) 454 self._cluster.start_task("worker", 1) 455 time.sleep(2) 456 self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) 457 self.assertTrue(context.check_alive("/job:worker/replica:0/task:1")) 458 459 model.join_training_functions() 460 self.assertGreaterEqual(model.iterations.numpy(), 10) 461 462 def testWorkerContinuousFailure(self): 463 model = self._create_model_and_run_indefinitely() 464 465 self.assertFalse(self.cluster_coord.done()) 466 self._cluster.kill_task("worker", 0) 467 time.sleep(2) 468 self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) 469 self._cluster.start_task("worker", 0) 470 time.sleep(2) 471 self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) 472 self._cluster.kill_task("worker", 0) 473 time.sleep(2) 474 self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) 475 self._cluster.start_task("worker", 0) 476 time.sleep(2) 477 self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) 478 479 model.join_training_functions() 480 self.assertGreaterEqual(model.iterations.numpy(), 10) 481 482 def testPSFailureWhileRecoveryFromWokerFailure(self): 483 model = self._create_model_and_run_indefinitely() 484 485 time.sleep(1) 486 self.assertFalse(self.cluster_coord.done()) 487 488 def kill(task): 489 self._cluster.kill_task(task, 0) 490 self.sleep(1) 491 self._cluster.start_task(task, 0) 492 493 kill_thread_1 = threading.Thread(target=kill, args=("worker",)) 494 kill_thread_2 = threading.Thread(target=kill, args=("ps",)) 495 kill_thread_1.start() 496 kill_thread_2.start() 497 kill_thread_1.join() 498 kill_thread_2.join() 499 500 with self.assertRaises( 501 (errors.UnavailableError, errors.InvalidArgumentError)): 502 model.join_training_functions() 503 504 def testNumpyFetchedAfterWorkerFailure(self): 505 506 with self.strategy.scope(): 507 v = variables.Variable(initial_value=0, dtype=dtypes.int32) 508 509 @def_function.function 510 def worker_fn(): 511 return v + 1, v - 1 512 513 remote_value = self.cluster_coord.schedule(worker_fn) 514 # Attempt to fetch before killing worker task should succeed. 515 self.assertEqual((1, -1), remote_value.fetch()) 516 self._cluster.kill_task("worker", 0) 517 # So should attempt to fetch after killing worker task. 518 self.assertEqual((1, -1), remote_value.fetch()) 519 520 def testTensorGotAfterWorkerFailure(self): 521 522 with self.strategy.scope(): 523 v = variables.Variable(initial_value=0, dtype=dtypes.int32) 524 525 @def_function.function 526 def worker_fn(): 527 return v + 1, v - 1 528 529 remote_value = self.cluster_coord.schedule(worker_fn) 530 531 # Attempt to fetch before killing worker task should succeed. 532 fetched = remote_value.get()[0] 533 self.assertIsInstance(fetched, ops.Tensor) 534 self.assertEqual(fetched.device, "/job:chief/replica:0/task:0/device:CPU:0") 535 self.assertEqual((1, -1), remote_value.get()) 536 remote_value.get()[0].numpy() 537 538 # As well as the remote tensors that point to worker0 or worker1. 539 values = remote_value._values[0] 540 self.assertIsInstance(values, ops.Tensor) 541 self.assertRegex(values.device, 542 "/job:worker/replica:0/task:[0-1]/device:CPU:0") 543 self.assertEqual((1, -1), remote_value._values) 544 remote_value._values[0].numpy() 545 546 # Terminate the workers and wait a little so that they are indeed killed. 547 for i in range(self.num_workers): 548 self._cluster.kill_task("worker", i) 549 time.sleep(5) 550 551 # Attempt to fetch after killing worker tasks should succeed as well. 552 remote_value.get()[0].numpy() 553 self.assertEqual((1, -1), remote_value.get()) 554 555 # Attempting to copy the tensor from worker now should fail. 556 with self.assertRaises(errors.UnavailableError) as cm: 557 remote_value._values[0].numpy() 558 self.assertIn("failed to connect to all addresses", cm.exception.message) 559 self.assertIn("/job:worker/replica:0/task:", cm.exception.message) 560 561 def testFetchFromPSAfterWorkerFailure(self): 562 # Test for flaky failures when reading from a parameter server while a 563 # worker is recovering. 564 # Place some variables on PSes using distribute_datasets_from_function, 565 # kill a worker, and continuously poll one of those variables. 566 567 model = Model(self.cluster_coord) 568 569 # kill the worker after a delay to make sure variable reading runs while 570 # worker is up, while it's down, and while it restarts 571 def kill_after_delay(): 572 time.sleep(3) 573 logging.info("Killing worker 0") 574 self._cluster.kill_task("worker", 0) 575 time.sleep(1) 576 logging.info("Restarting worker 0") 577 self._cluster.start_task("worker", 0) 578 579 kill_thread = threading.Thread(target=kill_after_delay) 580 kill_thread.start() 581 582 model.do_infinite_step.assign(True) 583 model.schedule_training_functions(1) 584 585 num_reads = 0 586 num_reads_after_restart = 0 587 read_interval_secs = 0.1 588 worker_has_stopped = False 589 # limit runtime of the test: stop after doing a few reads after worker 590 # is back up, or after a fixed maximum number of reads 591 while num_reads_after_restart <= 5 and num_reads < 200: 592 worker_up = context.check_alive("/job:worker/replica:0/task:0") 593 if not worker_up: 594 worker_has_stopped = True 595 if worker_up and worker_has_stopped: 596 num_reads_after_restart += 1 597 598 model.join_training_functions() 599 start = time.time() 600 while time.time() < start + read_interval_secs: 601 model.iterations.read_value() 602 603 num_reads += 1 604 # run another epoch 605 model.do_infinite_step.assign(True) 606 model.schedule_training_functions(1) 607 608 def testClusterStateNotDisrupted(self): 609 # This test has side effects and can disrupt other tests, even if the 610 # resource created by it will not be used in following tests. 611 # TODO(b/155209534): enable this test. 612 # self.testPSPreemptionErrorType() 613 614 self.thread_coord = thread_coordinator.Coordinator( 615 clean_stop_exception_types=[]) 616 self.testWorkerPreemptionMidstFunction() 617 618 self.thread_coord = thread_coordinator.Coordinator( 619 clean_stop_exception_types=[]) 620 self.testWorkerPreemptionErrorType() 621 622 # In previous tests, workers may fail after training is done. But the 623 # following tests start with creating resources where failure is not 624 # handled. 625 # TODO(b/153888707): enable the following two tests. 626 # self.testTwoWorkersPreempted() 627 # self.testWorkerContinuousFailure() 628 629 def testJoinRaisesUnavailableErrorAtPsFailure(self): 630 self._create_model_and_run_indefinitely() 631 self._cluster.kill_task("ps", 0) 632 while self.cluster_coord._cluster.closure_queue._error is None: 633 time.sleep(1) 634 with self.assertRaises((errors.UnavailableError, errors.NotFoundError, 635 errors.FailedPreconditionError)): 636 self.cluster_coord.join() 637 638 def testScheduleRaisesUnavailableErrorAtPsFailure(self): 639 self._create_model_and_run_indefinitely() 640 self._cluster.kill_task("ps", 0) 641 while self.cluster_coord._cluster.closure_queue._error is None: 642 time.sleep(1) 643 with self.assertRaises((errors.UnavailableError, errors.NotFoundError, 644 errors.FailedPreconditionError)): 645 self.cluster_coord.schedule(def_function.function(lambda: None)) 646 647 def testWorkerExecutionAfterPsFailureRaisesExpectedError(self): 648 model = self._create_model_and_run_indefinitely() 649 for i in range(self.num_ps): 650 self._cluster.kill_task("ps", i) 651 while self.cluster_coord._cluster.closure_queue._error is None: 652 time.sleep(1) 653 654 @def_function.function 655 def trivial_function(): 656 return model.iterations + 1 657 658 for i in range(self.num_workers): 659 try: 660 with ops.device("/job:worker/replica:0/task:{}".format(i)): 661 trivial_function() 662 except Exception as e: # pylint: disable=broad-except 663 if cluster_coordinator._is_ps_failure(e): 664 if i < self.num_workers - 1: 665 continue 666 return 667 raise AssertionError("Executing a function after PS fails, should " 668 "result in a PS failure.") 669 670 def testAsyncWaitIsNoOp(self): 671 if self.num_workers < 2: 672 self.skipTest("Worker number is less than 2.") 673 model = self._create_model_and_run_indefinitely() 674 675 self.assertFalse(self.cluster_coord.done()) 676 self._cluster.kill_task("worker", 0) 677 time.sleep(2) 678 self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) 679 # Should pass without exception even with failed remote workers 680 context.async_wait() 681 682 model.join_training_functions() 683 self.assertGreaterEqual(model.iterations.numpy(), 10) 684 685 self._cluster.start_task("worker", 0) 686 687 688class MultiWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase): 689 """Multi worker fault tolerance tests. 690 691 This covers the ordinary cases where multiple workers and PS are used. 692 """ 693 694 def setUp(self): 695 super(MultiWorkerFaultToleranceTest, self).setUp(2, 2) 696 697 698class SingleWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase): 699 """Single worker fault tolerance tests. 700 701 This covers the cases that ensure training can continue in a single-worker 702 cluster, even if the only worker can become unavailable at some point and 703 recovered (if there are multiple workers, it is possible that the training 704 succeeds with the workers that did not fail). Realistically single worker 705 is very rarely used, but the tests are important to ensure the correct 706 behaviors. 707 """ 708 709 def setUp(self): 710 super(SingleWorkerFaultToleranceTest, self).setUp(1, 1) 711 712 713if __name__ == "__main__": 714 v2_compat.enable_v2_behavior() 715 multi_process_runner.test_main() 716