xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/coordinator/fault_tolerance_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Fault tolerance test for parameter server training in TF2."""
16
17import gc
18import sys
19import threading
20import time
21
22from tensorflow.python.compat import v2_compat
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.distribute import multi_process_runner
25from tensorflow.python.distribute import multi_worker_test_base
26from tensorflow.python.distribute import parameter_server_strategy_v2
27from tensorflow.python.distribute import test_util
28from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
29from tensorflow.python.distribute.coordinator import cluster_coordinator
30from tensorflow.python.eager import context
31from tensorflow.python.eager import def_function
32from tensorflow.python.eager import test
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors
35from tensorflow.python.framework import ops
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import check_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import random_ops
40from tensorflow.python.ops import variables
41from tensorflow.python.platform import tf_logging as logging
42from tensorflow.python.training import coordinator as thread_coordinator
43from tensorflow.python.training import server_lib
44
45_RPC_ERROR_FROM_WORKER = "GRPC error information from remote target /job:worker"
46_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps"
47_WORKER_PREEMPTION_THREAD_NAME = "WorkerPreemptionHandler"
48_WORKER_THREAD_PREFIX = "WorkerClosureProcessingLoop"
49
50
51class Model(object):
52
53  def __init__(self, coordinator):
54    self.cluster_coord = coordinator
55    self.strategy = self.cluster_coord.strategy
56    with self.cluster_coord.strategy.scope():
57      self.build()
58
59  def build(self):
60    self.w = variables.Variable(
61        initial_value=random_ops.random_uniform((10, 10)), dtype=dtypes.float32)
62    self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32)
63    # Allow external control to make the model run its train_fn in an infinite
64    # loop. This allows us to reliably test worker preemption in the middle of
65    # function execution.
66    self.do_infinite_step = variables.Variable(False)
67
68    self.rebuild_iterators()
69
70  def rebuild_iterators(self, use_dataset_fn=True):
71
72    if use_dataset_fn:
73
74      def dataset_fn():
75        data = random_ops.random_uniform((10, 10))
76        dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat()
77        return dataset
78
79      def distribute_dataset_fn():
80        return self.cluster_coord.strategy.distribute_datasets_from_function(
81            lambda _: dataset_fn())
82
83      self.iterator = iter(
84          self.cluster_coord.create_per_worker_dataset(distribute_dataset_fn))
85      self.iterator2 = iter(
86          self.cluster_coord.create_per_worker_dataset(distribute_dataset_fn))
87    else:
88      data = random_ops.random_uniform((10, 10))
89      dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat()
90
91      self.iterator = iter(
92          self.cluster_coord.create_per_worker_dataset(dataset))
93      self.iterator2 = iter(
94          self.cluster_coord.create_per_worker_dataset(dataset))
95
96  def _train_fn_internal(self, iterator, iterator2):
97    x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w)
98    x = math_ops.matmul(array_ops.squeeze(next(iterator2)), x)
99    x = math_ops.matmul(random_ops.random_uniform((10, 10)), x)
100    self.w.assign_add(x)
101
102  @def_function.function
103  def train_fn(self, iterator, iterator2):
104    self._train_fn_internal(iterator, iterator2)
105    while self.do_infinite_step:
106      self._train_fn_internal(iterator, iterator2)
107    self.iterations.assign_add(1)
108
109  def schedule_training_functions(self, num_steps):
110    with self.strategy.scope():
111      for _ in range(num_steps):
112        self.cluster_coord.schedule(
113            self.train_fn, args=(self.iterator, self.iterator2))
114
115  def join_training_functions(self):
116    self.do_infinite_step.assign(False)
117    self.cluster_coord.join()
118
119
120class BaseFaultToleranceTest(object):  # pylint: disable=missing-docstring
121
122  def setUp(self, num_workers, num_ps):
123    super(BaseFaultToleranceTest, self).setUp()
124
125    self._cluster = multi_worker_test_base.create_multi_process_cluster(
126        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
127    self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
128    self._cluster_def["chief"] = [
129        "localhost:%d" % multi_worker_test_base.pick_unused_port()
130    ]
131    cluster_resolver = SimpleClusterResolver(
132        server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc")
133
134    # The strategy's constructor would connect to the cluster.
135    self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
136        cluster_resolver)
137    self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy)
138
139    self.thread_coord = thread_coordinator.Coordinator(
140        clean_stop_exception_types=[])
141    self.num_workers = num_workers
142    self.num_ps = num_ps
143
144  def tearDown(self):
145    super(BaseFaultToleranceTest, self).tearDown()
146    self._cluster.stop()
147    self._cluster = None
148
149  def _restart(self, downtime_secs, job):
150    """Kills `job` (index: 0) and restarts it after `downtime_secs`.
151
152    Args:
153      downtime_secs: secs before restarting the job.
154      job: a string specifying the job to restart.
155    """
156    self._cluster.kill_task(job, 0)
157    time.sleep(downtime_secs)
158    self.assertFalse(context.check_alive("/job:%s/replica:0/task:0" % job))
159    self._cluster.start_task(job, 0)
160    while not context.check_alive("/job:%s/replica:0/task:0" % job):
161      time.sleep(1)
162
163  def _restart_in_thread(self, downtime_secs, restart_job):
164
165    def _restart_fn():
166      with self.thread_coord.stop_on_exception():
167        self._restart(downtime_secs, restart_job)
168
169    restart_thread = threading.Thread(target=_restart_fn)
170    restart_thread.start()
171    return restart_thread
172
173  def _ensure_threads_closed(self):
174    """Ensures worker and preemption threads are closed."""
175    # Worker and preemption threads should exist before releasing
176    # ClusterCoordinator.
177    running_threads = test_util.get_running_threads()
178    self.assertTrue(
179        test_util.has_thread(_WORKER_THREAD_PREFIX, running_threads))
180    self.assertIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
181
182    # Print object graph if ClusterCoordinator may leak.
183    if sys.getrefcount(self.cluster_coord) > 2:
184      try:
185        test_util.show_backref(self.cluster_coord)
186      except:  # pylint: disable=bare-except
187        pass
188
189    # Wait for threads to close.
190    self.cluster_coord = None
191    self.strategy = None
192    gc.collect()
193    time.sleep(1)
194
195    # Verify thread names.
196    running_threads = test_util.get_running_threads()
197    self.assertNotIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
198    self.assertFalse(
199        test_util.has_thread(_WORKER_THREAD_PREFIX, running_threads),
200        "Worker thread is not stopped properly.")
201
202  def _create_model_and_run_indefinitely(self):
203    model = Model(self.cluster_coord)
204    model.do_infinite_step.assign(True)
205    model.schedule_training_functions(10)
206    # Model does infinite training step, so at this moment, we expect to have
207    # `self.num_workers` infinite closures inflight, and `10-self.num_workers`
208    # closures in the queue.
209    while (self.cluster_coord._cluster.closure_queue._inflight_closure_count <
210           self.num_workers):
211      time.sleep(0.1)
212    return model
213
214  def testClusterCoordinatorDestroyed(self):
215    self._ensure_threads_closed()
216
217  def testWorkerPreemptionBetweenFunctions(self):
218    model = Model(self.cluster_coord)
219    model.schedule_training_functions(2)
220    model.join_training_functions()
221    self.assertEqual(model.iterations.numpy(), 2)
222
223    self._restart(downtime_secs=2, job="worker")
224
225    model.schedule_training_functions(2)
226    model.join_training_functions()
227    self.assertEqual(model.iterations.numpy(), 4)
228
229  def testWorkerPreemptionMidstFunction(self):
230    model = Model(self.cluster_coord)
231    model.do_infinite_step.assign(True)
232
233    model.schedule_training_functions(4)
234    # Model does infinite training step, so at this moment, we expect to have
235    # `self.num_workers` infinite closures inflight, and `4-self.num_workers`
236    # closures in the queue.
237    while (self.cluster_coord._cluster.closure_queue._inflight_closure_count <
238           self.num_workers):
239      time.sleep(0.1)
240    self.assertFalse(self.cluster_coord.done())
241    self._restart(downtime_secs=2, job="worker")
242    model.join_training_functions()
243    self.assertGreaterEqual(model.iterations.numpy(), 4)
244
245  def testOneWorkerPreemptionWithCancellation(self):
246
247    @def_function.function
248    def normal_function():
249      x = random_ops.random_uniform((2, 10))
250      y = random_ops.random_uniform((10, 2))
251      return math_ops.reduce_mean(math_ops.matmul(x, y))
252
253    @def_function.function
254    def error_function():
255      x = random_ops.random_uniform((2, 10))
256      y = random_ops.random_uniform((10, 2))
257      check_ops.assert_non_positive_v2(
258          math_ops.reduce_sum(math_ops.matmul(x, y)))
259      return x
260
261    @def_function.function
262    def long_function():
263      x = random_ops.random_uniform((1000, 1000))
264      for _ in math_ops.range(10000):
265        a = random_ops.random_uniform((1000, 1000))
266        b = random_ops.random_uniform((1000, 1000))
267        x += math_ops.matmul(a, b)
268      return x
269
270    for _ in range(3):
271      self.cluster_coord.schedule(normal_function)
272    long_function_result = self.cluster_coord.schedule(long_function)
273    self.cluster_coord.schedule(error_function)
274
275    time.sleep(1)  # Let it run a couple steps.
276    self._restart(1, "worker")
277
278    with self.assertRaises(errors.InvalidArgumentError):
279      self.cluster_coord.join()
280
281    with self.assertRaises(errors.CancelledError):
282      long_function_result.fetch()
283
284    for _ in range(3):
285      self.cluster_coord.schedule(normal_function)
286    self.cluster_coord.join()
287
288    # The cluster is likely still being recovered since `join` returned early
289    # due to the error_function.
290    failure_handler = self.cluster_coord._cluster.failure_handler
291    failure_handler.stop()
292    failure_handler._preemption_handler_thread.join()
293
294  def testHandleDatasetCreationFailureWithDatasetFn(self):
295    model = Model(self.cluster_coord)
296
297    restart_thread = self._restart_in_thread(5, "worker")
298
299    model.schedule_training_functions(3)
300    model.rebuild_iterators()
301    model.schedule_training_functions(3)
302    model.rebuild_iterators()
303    model.schedule_training_functions(3)
304
305    model.join_training_functions()
306
307    self.thread_coord.join([restart_thread])
308    self.assertGreaterEqual(model.iterations.numpy(), 3)
309
310  # TODO(yuefengz): consider using combinations when there is more code
311  # duplication.
312  def testHandleDatasetCreationFailureWithDataset(self):
313    model = Model(self.cluster_coord)
314
315    restart_thread = self._restart_in_thread(5, "worker")
316
317    model.schedule_training_functions(3)
318    model.rebuild_iterators(use_dataset_fn=False)
319    model.schedule_training_functions(3)
320    model.rebuild_iterators(use_dataset_fn=False)
321    model.schedule_training_functions(3)
322
323    model.join_training_functions()
324
325    self.thread_coord.join([restart_thread])
326    self.assertGreaterEqual(model.iterations.numpy(), 3)
327
328  def testWorkerPreemptionErrorType(self):
329
330    @def_function.function
331    def worker_train_fn():
332      x = random_ops.random_uniform((2, 10))
333      y = random_ops.random_uniform((10, 2))
334      return math_ops.reduce_mean(math_ops.matmul(x, y))
335
336    def run_fn():
337      with self.thread_coord.stop_on_exception():
338        with ops.device("/job:worker/replica:0/task:0"):
339          for _ in range(3):
340            for _ in range(3):
341              worker_train_fn()
342            time.sleep(5)
343
344    run_thread = threading.Thread(target=run_fn)
345    run_thread.start()
346    time.sleep(1)  # Let it run a couple steps.
347    self._restart(2, "worker")
348
349    try:
350      self.thread_coord.join([run_thread])
351    except (errors.UnavailableError, errors.AbortedError) as e:
352      logging.info("Got exception %r, error message is %s", e, e)
353
354      self.assertIn(_RPC_ERROR_FROM_WORKER, str(e))  # pylint: disable=g-assert-in-except
355      self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
356
357      self.assertTrue("failed to connect to all addresses" in str(e) or
358                      "Unable to find a context_id" in str(e) or
359                      "Socket closed" in str(e) or
360                      "Connection reset by peer" in str(e) or
361                      "Transport closed" in str(e))
362
363  def testWorkerPreemptionErrorTypeWithPythonFunction(self):
364
365    def worker_train_fn():
366      x = random_ops.random_uniform((2, 10))
367      y = random_ops.random_uniform((10, 2))
368      return math_ops.reduce_mean(math_ops.matmul(x, y))
369
370    def run_fn():
371      with self.thread_coord.stop_on_exception():
372        with ops.device("/job:worker/replica:0/task:0"):
373          for _ in range(3):
374            for _ in range(3):
375              worker_train_fn()
376            time.sleep(5)
377
378    run_thread = threading.Thread(target=run_fn)
379    run_thread.start()
380    time.sleep(1)  # Let it run a couple steps.
381    self._restart(2, "worker")
382
383    try:
384      self.thread_coord.join([run_thread])
385    except (errors.UnavailableError, errors.AbortedError) as e:
386      logging.info("Got exception %r, error message is %s", e, e)
387
388      self.assertIn(_RPC_ERROR_FROM_WORKER, str(e))  # pylint: disable=g-assert-in-except
389      self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
390
391      self.assertTrue("failed to connect to all addresses" in str(e) or
392                      "Unable to find a context_id" in str(e) or
393                      "Socket closed" in str(e) or
394                      "Connection reset by peer" in str(e) or
395                      "Transport closed" in str(e))
396
397  def testPSPreemptionErrorType(self):
398
399    with ops.device("/job:ps/replica:0/task:0"):
400      v = variables.Variable(
401          initial_value=random_ops.random_uniform((2, 10)),
402          dtype=dtypes.float32)
403
404    @def_function.function
405    def worker_train_fn():
406      y = random_ops.random_uniform((10, 2))
407      return math_ops.reduce_mean(math_ops.matmul(v, y))
408
409    def run_fn():
410      with self.thread_coord.stop_on_exception():
411        with ops.device("/job:worker/replica:0/task:0"):
412          for _ in range(3):
413            for _ in range(3):
414              worker_train_fn()
415            time.sleep(5)
416
417    run_thread = threading.Thread(target=run_fn)
418    run_thread.start()
419    time.sleep(1)  # Let it run a couple steps.
420
421    # Use a short restart delay to cover the case that RPC channel is reused
422    self._restart(1, "ps")
423
424    try:
425      self.thread_coord.join([run_thread])
426    except (errors.UnavailableError, errors.AbortedError) as e:
427      logging.info("Got exception %r, error message is %s", e, e)
428      self.assertIn(_RPC_ERROR_FROM_PS, str(e))  # pylint: disable=g-assert-in-except
429
430      if isinstance(e, errors.UnavailableError):
431        self.assertTrue("failed to connect to all addresses" in str(e) or
432                        "Socket closed" in str(e) or
433                        "Connection reset by peer" in str(e) or
434                        "Transport closed" in str(e))
435
436      if isinstance(e, errors.AbortedError):
437        self.assertTrue(
438            "RecvTensor expects a different device incarnation" in str(e) or
439            "Unable to find a context_id" in str(e))
440      self._ensure_threads_closed()
441
442  def testTwoWorkersPreempted(self):
443    if self.num_workers < 2:
444      self.skipTest("Worker number is less than 2.")
445    model = self._create_model_and_run_indefinitely()
446
447    self.assertFalse(self.cluster_coord.done())
448    self._cluster.kill_task("worker", 0)
449    self._cluster.kill_task("worker", 1)
450    time.sleep(2)
451    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
452    self.assertFalse(context.check_alive("/job:worker/replica:0/task:1"))
453    self._cluster.start_task("worker", 0)
454    self._cluster.start_task("worker", 1)
455    time.sleep(2)
456    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
457    self.assertTrue(context.check_alive("/job:worker/replica:0/task:1"))
458
459    model.join_training_functions()
460    self.assertGreaterEqual(model.iterations.numpy(), 10)
461
462  def testWorkerContinuousFailure(self):
463    model = self._create_model_and_run_indefinitely()
464
465    self.assertFalse(self.cluster_coord.done())
466    self._cluster.kill_task("worker", 0)
467    time.sleep(2)
468    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
469    self._cluster.start_task("worker", 0)
470    time.sleep(2)
471    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
472    self._cluster.kill_task("worker", 0)
473    time.sleep(2)
474    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
475    self._cluster.start_task("worker", 0)
476    time.sleep(2)
477    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
478
479    model.join_training_functions()
480    self.assertGreaterEqual(model.iterations.numpy(), 10)
481
482  def testPSFailureWhileRecoveryFromWokerFailure(self):
483    model = self._create_model_and_run_indefinitely()
484
485    time.sleep(1)
486    self.assertFalse(self.cluster_coord.done())
487
488    def kill(task):
489      self._cluster.kill_task(task, 0)
490      self.sleep(1)
491      self._cluster.start_task(task, 0)
492
493    kill_thread_1 = threading.Thread(target=kill, args=("worker",))
494    kill_thread_2 = threading.Thread(target=kill, args=("ps",))
495    kill_thread_1.start()
496    kill_thread_2.start()
497    kill_thread_1.join()
498    kill_thread_2.join()
499
500    with self.assertRaises(
501        (errors.UnavailableError, errors.InvalidArgumentError)):
502      model.join_training_functions()
503
504  def testNumpyFetchedAfterWorkerFailure(self):
505
506    with self.strategy.scope():
507      v = variables.Variable(initial_value=0, dtype=dtypes.int32)
508
509    @def_function.function
510    def worker_fn():
511      return v + 1, v - 1
512
513    remote_value = self.cluster_coord.schedule(worker_fn)
514    # Attempt to fetch before killing worker task should succeed.
515    self.assertEqual((1, -1), remote_value.fetch())
516    self._cluster.kill_task("worker", 0)
517    # So should attempt to fetch after killing worker task.
518    self.assertEqual((1, -1), remote_value.fetch())
519
520  def testTensorGotAfterWorkerFailure(self):
521
522    with self.strategy.scope():
523      v = variables.Variable(initial_value=0, dtype=dtypes.int32)
524
525    @def_function.function
526    def worker_fn():
527      return v + 1, v - 1
528
529    remote_value = self.cluster_coord.schedule(worker_fn)
530
531    # Attempt to fetch before killing worker task should succeed.
532    fetched = remote_value.get()[0]
533    self.assertIsInstance(fetched, ops.Tensor)
534    self.assertEqual(fetched.device, "/job:chief/replica:0/task:0/device:CPU:0")
535    self.assertEqual((1, -1), remote_value.get())
536    remote_value.get()[0].numpy()
537
538    # As well as the remote tensors that point to worker0 or worker1.
539    values = remote_value._values[0]
540    self.assertIsInstance(values, ops.Tensor)
541    self.assertRegex(values.device,
542                     "/job:worker/replica:0/task:[0-1]/device:CPU:0")
543    self.assertEqual((1, -1), remote_value._values)
544    remote_value._values[0].numpy()
545
546    # Terminate the workers and wait a little so that they are indeed killed.
547    for i in range(self.num_workers):
548      self._cluster.kill_task("worker", i)
549    time.sleep(5)
550
551    # Attempt to fetch after killing worker tasks should succeed as well.
552    remote_value.get()[0].numpy()
553    self.assertEqual((1, -1), remote_value.get())
554
555    # Attempting to copy the tensor from worker now should fail.
556    with self.assertRaises(errors.UnavailableError) as cm:
557      remote_value._values[0].numpy()
558    self.assertIn("failed to connect to all addresses", cm.exception.message)
559    self.assertIn("/job:worker/replica:0/task:", cm.exception.message)
560
561  def testFetchFromPSAfterWorkerFailure(self):
562    # Test for flaky failures when reading from a parameter server while a
563    # worker is recovering.
564    # Place some variables on PSes using distribute_datasets_from_function,
565    # kill a worker, and continuously poll one of those variables.
566
567    model = Model(self.cluster_coord)
568
569    # kill the worker after a delay to make sure variable reading runs while
570    # worker is up, while it's down, and while it restarts
571    def kill_after_delay():
572      time.sleep(3)
573      logging.info("Killing worker 0")
574      self._cluster.kill_task("worker", 0)
575      time.sleep(1)
576      logging.info("Restarting worker 0")
577      self._cluster.start_task("worker", 0)
578
579    kill_thread = threading.Thread(target=kill_after_delay)
580    kill_thread.start()
581
582    model.do_infinite_step.assign(True)
583    model.schedule_training_functions(1)
584
585    num_reads = 0
586    num_reads_after_restart = 0
587    read_interval_secs = 0.1
588    worker_has_stopped = False
589    # limit runtime of the test: stop after doing a few reads after worker
590    # is back up, or after a fixed maximum number of reads
591    while num_reads_after_restart <= 5 and num_reads < 200:
592      worker_up = context.check_alive("/job:worker/replica:0/task:0")
593      if not worker_up:
594        worker_has_stopped = True
595      if worker_up and worker_has_stopped:
596        num_reads_after_restart += 1
597
598      model.join_training_functions()
599      start = time.time()
600      while time.time() < start + read_interval_secs:
601        model.iterations.read_value()
602
603      num_reads += 1
604      # run another epoch
605      model.do_infinite_step.assign(True)
606      model.schedule_training_functions(1)
607
608  def testClusterStateNotDisrupted(self):
609    # This test has side effects and can disrupt other tests, even if the
610    # resource created by it will not be used in following tests.
611    # TODO(b/155209534): enable this test.
612    # self.testPSPreemptionErrorType()
613
614    self.thread_coord = thread_coordinator.Coordinator(
615        clean_stop_exception_types=[])
616    self.testWorkerPreemptionMidstFunction()
617
618    self.thread_coord = thread_coordinator.Coordinator(
619        clean_stop_exception_types=[])
620    self.testWorkerPreemptionErrorType()
621
622    # In previous tests, workers may fail after training is done. But the
623    # following tests start with creating resources where failure is not
624    # handled.
625    # TODO(b/153888707): enable the following two tests.
626    # self.testTwoWorkersPreempted()
627    # self.testWorkerContinuousFailure()
628
629  def testJoinRaisesUnavailableErrorAtPsFailure(self):
630    self._create_model_and_run_indefinitely()
631    self._cluster.kill_task("ps", 0)
632    while self.cluster_coord._cluster.closure_queue._error is None:
633      time.sleep(1)
634    with self.assertRaises((errors.UnavailableError, errors.NotFoundError,
635                            errors.FailedPreconditionError)):
636      self.cluster_coord.join()
637
638  def testScheduleRaisesUnavailableErrorAtPsFailure(self):
639    self._create_model_and_run_indefinitely()
640    self._cluster.kill_task("ps", 0)
641    while self.cluster_coord._cluster.closure_queue._error is None:
642      time.sleep(1)
643    with self.assertRaises((errors.UnavailableError, errors.NotFoundError,
644                            errors.FailedPreconditionError)):
645      self.cluster_coord.schedule(def_function.function(lambda: None))
646
647  def testWorkerExecutionAfterPsFailureRaisesExpectedError(self):
648    model = self._create_model_and_run_indefinitely()
649    for i in range(self.num_ps):
650      self._cluster.kill_task("ps", i)
651    while self.cluster_coord._cluster.closure_queue._error is None:
652      time.sleep(1)
653
654    @def_function.function
655    def trivial_function():
656      return model.iterations + 1
657
658    for i in range(self.num_workers):
659      try:
660        with ops.device("/job:worker/replica:0/task:{}".format(i)):
661          trivial_function()
662      except Exception as e:  # pylint: disable=broad-except
663        if cluster_coordinator._is_ps_failure(e):
664          if i < self.num_workers - 1:
665            continue
666          return
667      raise AssertionError("Executing a function after PS fails, should "
668                           "result in a PS failure.")
669
670  def testAsyncWaitIsNoOp(self):
671    if self.num_workers < 2:
672      self.skipTest("Worker number is less than 2.")
673    model = self._create_model_and_run_indefinitely()
674
675    self.assertFalse(self.cluster_coord.done())
676    self._cluster.kill_task("worker", 0)
677    time.sleep(2)
678    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
679    # Should pass without exception even with failed remote workers
680    context.async_wait()
681
682    model.join_training_functions()
683    self.assertGreaterEqual(model.iterations.numpy(), 10)
684
685    self._cluster.start_task("worker", 0)
686
687
688class MultiWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
689  """Multi worker fault tolerance tests.
690
691  This covers the ordinary cases where multiple workers and PS are used.
692  """
693
694  def setUp(self):
695    super(MultiWorkerFaultToleranceTest, self).setUp(2, 2)
696
697
698class SingleWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
699  """Single worker fault tolerance tests.
700
701  This covers the cases that ensure training can continue in a single-worker
702  cluster, even if the only worker can become unavailable at some point and
703  recovered (if there are multiple workers, it is possible that the training
704  succeeds with the workers that did not fail). Realistically single worker
705  is very rarely used, but the tests are important to ensure the correct
706  behaviors.
707  """
708
709  def setUp(self):
710    super(SingleWorkerFaultToleranceTest, self).setUp(1, 1)
711
712
713if __name__ == "__main__":
714  v2_compat.enable_v2_behavior()
715  multi_process_runner.test_main()
716