xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/coordinator.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Coordinator to help multiple threads stop when requested."""
16import contextlib
17import sys
18import threading
19import time
20
21from tensorflow.python.framework import errors
22from tensorflow.python.platform import tf_logging as logging
23from tensorflow.python.util import compat
24from tensorflow.python.util.tf_export import tf_export
25
26
27@tf_export("train.Coordinator")
28class Coordinator:
29  """A coordinator for threads.
30
31  This class implements a simple mechanism to coordinate the termination of a
32  set of threads.
33
34  #### Usage:
35
36  ```python
37  # Create a coordinator.
38  coord = Coordinator()
39  # Start a number of threads, passing the coordinator to each of them.
40  ...start thread 1...(coord, ...)
41  ...start thread N...(coord, ...)
42  # Wait for all the threads to terminate.
43  coord.join(threads)
44  ```
45
46  Any of the threads can call `coord.request_stop()` to ask for all the threads
47  to stop.  To cooperate with the requests, each thread must check for
48  `coord.should_stop()` on a regular basis.  `coord.should_stop()` returns
49  `True` as soon as `coord.request_stop()` has been called.
50
51  A typical thread running with a coordinator will do something like:
52
53  ```python
54  while not coord.should_stop():
55    ...do some work...
56  ```
57
58  #### Exception handling:
59
60  A thread can report an exception to the coordinator as part of the
61  `request_stop()` call.  The exception will be re-raised from the
62  `coord.join()` call.
63
64  Thread code:
65
66  ```python
67  try:
68    while not coord.should_stop():
69      ...do some work...
70  except Exception as e:
71    coord.request_stop(e)
72  ```
73
74  Main code:
75
76  ```python
77  try:
78    ...
79    coord = Coordinator()
80    # Start a number of threads, passing the coordinator to each of them.
81    ...start thread 1...(coord, ...)
82    ...start thread N...(coord, ...)
83    # Wait for all the threads to terminate.
84    coord.join(threads)
85  except Exception as e:
86    ...exception that was passed to coord.request_stop()
87  ```
88
89  To simplify the thread implementation, the Coordinator provides a
90  context handler `stop_on_exception()` that automatically requests a stop if
91  an exception is raised.  Using the context handler the thread code above
92  can be written as:
93
94  ```python
95  with coord.stop_on_exception():
96    while not coord.should_stop():
97      ...do some work...
98  ```
99
100  #### Grace period for stopping:
101
102  After a thread has called `coord.request_stop()` the other threads have a
103  fixed time to stop, this is called the 'stop grace period' and defaults to 2
104  minutes.  If any of the threads is still alive after the grace period expires
105  `coord.join()` raises a RuntimeError reporting the laggards.
106
107  ```python
108  try:
109    ...
110    coord = Coordinator()
111    # Start a number of threads, passing the coordinator to each of them.
112    ...start thread 1...(coord, ...)
113    ...start thread N...(coord, ...)
114    # Wait for all the threads to terminate, give them 10s grace period
115    coord.join(threads, stop_grace_period_secs=10)
116  except RuntimeError:
117    ...one of the threads took more than 10s to stop after request_stop()
118    ...was called.
119  except Exception:
120    ...exception that was passed to coord.request_stop()
121  ```
122  """
123
124  def __init__(self, clean_stop_exception_types=None):
125    """Create a new Coordinator.
126
127    Args:
128      clean_stop_exception_types: Optional tuple of Exception types that should
129        cause a clean stop of the coordinator. If an exception of one of these
130        types is reported to `request_stop(ex)` the coordinator will behave as
131        if `request_stop(None)` was called.  Defaults to
132        `(tf.errors.OutOfRangeError,)` which is used by input queues to signal
133        the end of input. When feeding training data from a Python iterator it
134        is common to add `StopIteration` to this list.
135    """
136    if clean_stop_exception_types is None:
137      clean_stop_exception_types = (errors.OutOfRangeError,)
138    self._clean_stop_exception_types = tuple(clean_stop_exception_types)
139    # Protects all attributes.
140    self._lock = threading.Lock()
141    # Event set when threads must stop.
142    self._stop_event = threading.Event()
143    # Python exc_info to report.
144    # If not None, it should hold the returned value of sys.exc_info(), which is
145    # a tuple containing exception (type, value, traceback).
146    self._exc_info_to_raise = None
147    # True if we have called join() already.
148    self._joined = False
149    # Set of threads registered for joining when join() is called.  These
150    # threads will be joined in addition to the threads passed to the join()
151    # call.  It's ok if threads are both registered and passed to the join()
152    # call.
153    self._registered_threads = set()
154
155  def _filter_exception(self, ex):
156    """Check if the exception indicated in 'ex' should be ignored.
157
158    This method examines `ex` to check if it is an exception that should be
159    reported to the users.  If yes, it returns `ex` as is, otherwise it returns
160    None.
161
162    The code returns None for exception types listed in
163    `_clean_stop_exception_types`.
164
165    Args:
166      ex: None, an `Exception`, or a Python `exc_info` tuple as returned by
167        `sys.exc_info()`.
168
169    Returns:
170      ex or None.
171    """
172    if isinstance(ex, tuple):
173      ex2 = ex[1]
174    else:
175      ex2 = ex
176    if isinstance(ex2, self._clean_stop_exception_types):
177      # Ignore the exception.
178      ex = None
179    return ex
180
181  def request_stop(self, ex=None):
182    """Request that the threads stop.
183
184    After this is called, calls to `should_stop()` will return `True`.
185
186    Note: If an exception is being passed in, in must be in the context of
187    handling the exception (i.e. `try: ... except Exception as ex: ...`) and not
188    a newly created one.
189
190    Args:
191      ex: Optional `Exception`, or Python `exc_info` tuple as returned by
192        `sys.exc_info()`.  If this is the first call to `request_stop()` the
193        corresponding exception is recorded and re-raised from `join()`.
194    """
195    with self._lock:
196      ex = self._filter_exception(ex)
197      # If we have already joined the coordinator the exception will not have a
198      # chance to be reported, so just raise it normally.  This can happen if
199      # you continue to use a session have having stopped and joined the
200      # coordinator threads.
201      if self._joined:
202        if isinstance(ex, tuple):
203          _, ex_instance, _ = ex
204          raise ex_instance
205        elif ex is not None:
206          # NOTE(touts): This is bogus if request_stop() is not called
207          # from the exception handler that raised ex.
208          _, ex_instance, _ = sys.exc_info()
209          raise ex_instance
210      if not self._stop_event.is_set():
211        if ex and self._exc_info_to_raise is None:
212          if isinstance(ex, tuple):
213            logging.info("Error reported to Coordinator: %s",
214                         compat.as_str_any(ex[1]),
215                         exc_info=ex)
216            self._exc_info_to_raise = ex
217          else:
218            logging.info("Error reported to Coordinator: %s, %s",
219                         type(ex),
220                         compat.as_str_any(ex))
221            self._exc_info_to_raise = sys.exc_info()
222          # self._exc_info_to_raise should contain a tuple containing exception
223          # (type, value, traceback)
224          if (len(self._exc_info_to_raise) != 3 or
225              not self._exc_info_to_raise[0] or
226              not self._exc_info_to_raise[1]):
227            # Raise, catch and record the exception here so that error happens
228            # where expected.
229            try:
230              raise ValueError(
231                  "ex must be a tuple or sys.exc_info must return the current "
232                  "exception: %s"
233                  % self._exc_info_to_raise)
234            except ValueError:
235              # Record this error so it kills the coordinator properly.
236              # NOTE(touts): As above, this is bogus if request_stop() is not
237              # called from the exception handler that raised ex.
238              self._exc_info_to_raise = sys.exc_info()
239
240        self._stop_event.set()
241
242  def clear_stop(self):
243    """Clears the stop flag.
244
245    After this is called, calls to `should_stop()` will return `False`.
246    """
247    with self._lock:
248      self._joined = False
249      self._exc_info_to_raise = None
250      if self._stop_event.is_set():
251        self._stop_event.clear()
252
253  def should_stop(self):
254    """Check if stop was requested.
255
256    Returns:
257      True if a stop was requested.
258    """
259    return self._stop_event.is_set()
260
261  @contextlib.contextmanager
262  def stop_on_exception(self):
263    """Context manager to request stop when an Exception is raised.
264
265    Code that uses a coordinator must catch exceptions and pass
266    them to the `request_stop()` method to stop the other threads
267    managed by the coordinator.
268
269    This context handler simplifies the exception handling.
270    Use it as follows:
271
272    ```python
273    with coord.stop_on_exception():
274      # Any exception raised in the body of the with
275      # clause is reported to the coordinator before terminating
276      # the execution of the body.
277      ...body...
278    ```
279
280    This is completely equivalent to the slightly longer code:
281
282    ```python
283    try:
284      ...body...
285    except:
286      coord.request_stop(sys.exc_info())
287    ```
288
289    Yields:
290      nothing.
291    """
292    try:
293      yield
294    except:  # pylint: disable=bare-except
295      self.request_stop(ex=sys.exc_info())
296
297  def wait_for_stop(self, timeout=None):
298    """Wait till the Coordinator is told to stop.
299
300    Args:
301      timeout: Float.  Sleep for up to that many seconds waiting for
302        should_stop() to become True.
303
304    Returns:
305      True if the Coordinator is told stop, False if the timeout expired.
306    """
307    return self._stop_event.wait(timeout)
308
309  def register_thread(self, thread):
310    """Register a thread to join.
311
312    Args:
313      thread: A Python thread to join.
314    """
315    with self._lock:
316      self._registered_threads.add(thread)
317
318  def join(self, threads=None, stop_grace_period_secs=120,
319           ignore_live_threads=False):
320    """Wait for threads to terminate.
321
322    This call blocks until a set of threads have terminated.  The set of thread
323    is the union of the threads passed in the `threads` argument and the list
324    of threads that registered with the coordinator by calling
325    `Coordinator.register_thread()`.
326
327    After the threads stop, if an `exc_info` was passed to `request_stop`, that
328    exception is re-raised.
329
330    Grace period handling: When `request_stop()` is called, threads are given
331    'stop_grace_period_secs' seconds to terminate.  If any of them is still
332    alive after that period expires, a `RuntimeError` is raised.  Note that if
333    an `exc_info` was passed to `request_stop()` then it is raised instead of
334    that `RuntimeError`.
335
336    Args:
337      threads: List of `threading.Threads`. The started threads to join in
338        addition to the registered threads.
339      stop_grace_period_secs: Number of seconds given to threads to stop after
340        `request_stop()` has been called.
341      ignore_live_threads: If `False`, raises an error if any of the threads are
342        still alive after `stop_grace_period_secs`.
343
344    Raises:
345      RuntimeError: If any thread is still alive after `request_stop()`
346        is called and the grace period expires.
347    """
348    # Threads registered after this call will not be joined.
349    with self._lock:
350      if threads is None:
351        threads = self._registered_threads
352      else:
353        threads = self._registered_threads.union(set(threads))
354      # Copy the set into a list to avoid race conditions where a new thread
355      # is added while we are waiting.
356      threads = list(threads)
357
358    # Wait for all threads to stop or for request_stop() to be called.
359    while any(t.is_alive() for t in threads) and not self.wait_for_stop(1.0):
360      pass
361
362    # If any thread is still alive, wait for the grace period to expire.
363    # By the time this check is executed, threads may still be shutting down,
364    # so we add a sleep of increasing duration to give them a chance to shut
365    # down without losing too many cycles.
366    # The sleep duration is limited to the remaining grace duration.
367    stop_wait_secs = 0.001
368    while any(t.is_alive() for t in threads) and stop_grace_period_secs >= 0.0:
369      time.sleep(stop_wait_secs)
370      stop_grace_period_secs -= stop_wait_secs
371      stop_wait_secs = 2 * stop_wait_secs
372      # Keep the waiting period within sane bounds.
373      # The minimum value is to avoid decreasing stop_wait_secs to a value
374      # that could cause stop_grace_period_secs to remain unchanged.
375      stop_wait_secs = max(min(stop_wait_secs, stop_grace_period_secs), 0.001)
376
377    # List the threads still alive after the grace period.
378    stragglers = [t.name for t in threads if t.is_alive()]
379
380    # Terminate with an exception if appropriate.
381    with self._lock:
382      self._joined = True
383      self._registered_threads = set()
384      if self._exc_info_to_raise:
385        _, ex_instance, _ = self._exc_info_to_raise
386        raise ex_instance
387      elif stragglers:
388        if ignore_live_threads:
389          logging.info("Coordinator stopped with threads still running: %s",
390                       " ".join(stragglers))
391        else:
392          raise RuntimeError(
393              "Coordinator stopped with threads still running: %s" %
394              " ".join(stragglers))
395
396  @property
397  def joined(self):
398    return self._joined
399
400  def raise_requested_exception(self):
401    """If an exception has been passed to `request_stop`, this raises it."""
402    with self._lock:
403      if self._exc_info_to_raise:
404        _, ex_instance, _ = self._exc_info_to_raise
405        raise ex_instance
406
407
408# Threads for the standard services.
409@tf_export(v1=["train.LooperThread"])
410class LooperThread(threading.Thread):
411  """A thread that runs code repeatedly, optionally on a timer.
412
413  This thread class is intended to be used with a `Coordinator`.  It repeatedly
414  runs code specified either as `target` and `args` or by the `run_loop()`
415  method.
416
417  Before each run the thread checks if the coordinator has requested stop.  In
418  that case the looper thread terminates immediately.
419
420  If the code being run raises an exception, that exception is reported to the
421  coordinator and the thread terminates.  The coordinator will then request all
422  the other threads it coordinates to stop.
423
424  You typically pass looper threads to the supervisor `Join()` method.
425  """
426
427  def __init__(self, coord, timer_interval_secs, target=None, args=None,
428               kwargs=None):
429    """Create a LooperThread.
430
431    Args:
432      coord: A Coordinator.
433      timer_interval_secs: Time boundaries at which to call Run(), or None
434        if it should be called back to back.
435      target: Optional callable object that will be executed in the thread.
436      args: Optional arguments to pass to `target` when calling it.
437      kwargs: Optional keyword arguments to pass to `target` when calling it.
438
439    Raises:
440      ValueError: If one of the arguments is invalid.
441    """
442    if not isinstance(coord, Coordinator):
443      raise ValueError("'coord' argument must be a Coordinator: %s" % coord)
444    super(LooperThread, self).__init__()
445    self.daemon = True
446    self._coord = coord
447    self._timer_interval_secs = timer_interval_secs
448    self._target = target
449    if self._target:
450      self._args = args or ()
451      self._kwargs = kwargs or {}
452    elif args or kwargs:
453      raise ValueError("'args' and 'kwargs' argument require that you also "
454                       "pass 'target'")
455    self._coord.register_thread(self)
456
457  @staticmethod
458  def loop(coord, timer_interval_secs, target, args=None, kwargs=None):
459    """Start a LooperThread that calls a function periodically.
460
461    If `timer_interval_secs` is None the thread calls `target(args)`
462    repeatedly.  Otherwise `target(args)` is called every `timer_interval_secs`
463    seconds.  The thread terminates when a stop of the coordinator is
464    requested.
465
466    Args:
467      coord: A Coordinator.
468      timer_interval_secs: Number. Time boundaries at which to call `target`.
469      target: A callable object.
470      args: Optional arguments to pass to `target` when calling it.
471      kwargs: Optional keyword arguments to pass to `target` when calling it.
472
473    Returns:
474      The started thread.
475    """
476    looper = LooperThread(coord, timer_interval_secs, target=target, args=args,
477                          kwargs=kwargs)
478    looper.start()
479    return looper
480
481  def run(self):
482    with self._coord.stop_on_exception():
483      self.start_loop()
484      if self._timer_interval_secs is None:
485        # Call back-to-back.
486        while not self._coord.should_stop():
487          self.run_loop()
488      else:
489        # Next time at which to call run_loop(), starts as 'now'.
490        next_timer_time = time.time()
491        while not self._coord.wait_for_stop(next_timer_time - time.time()):
492          next_timer_time += self._timer_interval_secs
493          self.run_loop()
494      self.stop_loop()
495
496  def start_loop(self):
497    """Called when the thread starts."""
498    pass
499
500  def stop_loop(self):
501    """Called when the thread stops."""
502    pass
503
504  def run_loop(self):
505    """Called at 'timer_interval_secs' boundaries."""
506    if self._target:
507      self._target(*self._args, **self._kwargs)
508