1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Coordinator to help multiple threads stop when requested.""" 16import contextlib 17import sys 18import threading 19import time 20 21from tensorflow.python.framework import errors 22from tensorflow.python.platform import tf_logging as logging 23from tensorflow.python.util import compat 24from tensorflow.python.util.tf_export import tf_export 25 26 27@tf_export("train.Coordinator") 28class Coordinator: 29 """A coordinator for threads. 30 31 This class implements a simple mechanism to coordinate the termination of a 32 set of threads. 33 34 #### Usage: 35 36 ```python 37 # Create a coordinator. 38 coord = Coordinator() 39 # Start a number of threads, passing the coordinator to each of them. 40 ...start thread 1...(coord, ...) 41 ...start thread N...(coord, ...) 42 # Wait for all the threads to terminate. 43 coord.join(threads) 44 ``` 45 46 Any of the threads can call `coord.request_stop()` to ask for all the threads 47 to stop. To cooperate with the requests, each thread must check for 48 `coord.should_stop()` on a regular basis. `coord.should_stop()` returns 49 `True` as soon as `coord.request_stop()` has been called. 50 51 A typical thread running with a coordinator will do something like: 52 53 ```python 54 while not coord.should_stop(): 55 ...do some work... 56 ``` 57 58 #### Exception handling: 59 60 A thread can report an exception to the coordinator as part of the 61 `request_stop()` call. The exception will be re-raised from the 62 `coord.join()` call. 63 64 Thread code: 65 66 ```python 67 try: 68 while not coord.should_stop(): 69 ...do some work... 70 except Exception as e: 71 coord.request_stop(e) 72 ``` 73 74 Main code: 75 76 ```python 77 try: 78 ... 79 coord = Coordinator() 80 # Start a number of threads, passing the coordinator to each of them. 81 ...start thread 1...(coord, ...) 82 ...start thread N...(coord, ...) 83 # Wait for all the threads to terminate. 84 coord.join(threads) 85 except Exception as e: 86 ...exception that was passed to coord.request_stop() 87 ``` 88 89 To simplify the thread implementation, the Coordinator provides a 90 context handler `stop_on_exception()` that automatically requests a stop if 91 an exception is raised. Using the context handler the thread code above 92 can be written as: 93 94 ```python 95 with coord.stop_on_exception(): 96 while not coord.should_stop(): 97 ...do some work... 98 ``` 99 100 #### Grace period for stopping: 101 102 After a thread has called `coord.request_stop()` the other threads have a 103 fixed time to stop, this is called the 'stop grace period' and defaults to 2 104 minutes. If any of the threads is still alive after the grace period expires 105 `coord.join()` raises a RuntimeError reporting the laggards. 106 107 ```python 108 try: 109 ... 110 coord = Coordinator() 111 # Start a number of threads, passing the coordinator to each of them. 112 ...start thread 1...(coord, ...) 113 ...start thread N...(coord, ...) 114 # Wait for all the threads to terminate, give them 10s grace period 115 coord.join(threads, stop_grace_period_secs=10) 116 except RuntimeError: 117 ...one of the threads took more than 10s to stop after request_stop() 118 ...was called. 119 except Exception: 120 ...exception that was passed to coord.request_stop() 121 ``` 122 """ 123 124 def __init__(self, clean_stop_exception_types=None): 125 """Create a new Coordinator. 126 127 Args: 128 clean_stop_exception_types: Optional tuple of Exception types that should 129 cause a clean stop of the coordinator. If an exception of one of these 130 types is reported to `request_stop(ex)` the coordinator will behave as 131 if `request_stop(None)` was called. Defaults to 132 `(tf.errors.OutOfRangeError,)` which is used by input queues to signal 133 the end of input. When feeding training data from a Python iterator it 134 is common to add `StopIteration` to this list. 135 """ 136 if clean_stop_exception_types is None: 137 clean_stop_exception_types = (errors.OutOfRangeError,) 138 self._clean_stop_exception_types = tuple(clean_stop_exception_types) 139 # Protects all attributes. 140 self._lock = threading.Lock() 141 # Event set when threads must stop. 142 self._stop_event = threading.Event() 143 # Python exc_info to report. 144 # If not None, it should hold the returned value of sys.exc_info(), which is 145 # a tuple containing exception (type, value, traceback). 146 self._exc_info_to_raise = None 147 # True if we have called join() already. 148 self._joined = False 149 # Set of threads registered for joining when join() is called. These 150 # threads will be joined in addition to the threads passed to the join() 151 # call. It's ok if threads are both registered and passed to the join() 152 # call. 153 self._registered_threads = set() 154 155 def _filter_exception(self, ex): 156 """Check if the exception indicated in 'ex' should be ignored. 157 158 This method examines `ex` to check if it is an exception that should be 159 reported to the users. If yes, it returns `ex` as is, otherwise it returns 160 None. 161 162 The code returns None for exception types listed in 163 `_clean_stop_exception_types`. 164 165 Args: 166 ex: None, an `Exception`, or a Python `exc_info` tuple as returned by 167 `sys.exc_info()`. 168 169 Returns: 170 ex or None. 171 """ 172 if isinstance(ex, tuple): 173 ex2 = ex[1] 174 else: 175 ex2 = ex 176 if isinstance(ex2, self._clean_stop_exception_types): 177 # Ignore the exception. 178 ex = None 179 return ex 180 181 def request_stop(self, ex=None): 182 """Request that the threads stop. 183 184 After this is called, calls to `should_stop()` will return `True`. 185 186 Note: If an exception is being passed in, in must be in the context of 187 handling the exception (i.e. `try: ... except Exception as ex: ...`) and not 188 a newly created one. 189 190 Args: 191 ex: Optional `Exception`, or Python `exc_info` tuple as returned by 192 `sys.exc_info()`. If this is the first call to `request_stop()` the 193 corresponding exception is recorded and re-raised from `join()`. 194 """ 195 with self._lock: 196 ex = self._filter_exception(ex) 197 # If we have already joined the coordinator the exception will not have a 198 # chance to be reported, so just raise it normally. This can happen if 199 # you continue to use a session have having stopped and joined the 200 # coordinator threads. 201 if self._joined: 202 if isinstance(ex, tuple): 203 _, ex_instance, _ = ex 204 raise ex_instance 205 elif ex is not None: 206 # NOTE(touts): This is bogus if request_stop() is not called 207 # from the exception handler that raised ex. 208 _, ex_instance, _ = sys.exc_info() 209 raise ex_instance 210 if not self._stop_event.is_set(): 211 if ex and self._exc_info_to_raise is None: 212 if isinstance(ex, tuple): 213 logging.info("Error reported to Coordinator: %s", 214 compat.as_str_any(ex[1]), 215 exc_info=ex) 216 self._exc_info_to_raise = ex 217 else: 218 logging.info("Error reported to Coordinator: %s, %s", 219 type(ex), 220 compat.as_str_any(ex)) 221 self._exc_info_to_raise = sys.exc_info() 222 # self._exc_info_to_raise should contain a tuple containing exception 223 # (type, value, traceback) 224 if (len(self._exc_info_to_raise) != 3 or 225 not self._exc_info_to_raise[0] or 226 not self._exc_info_to_raise[1]): 227 # Raise, catch and record the exception here so that error happens 228 # where expected. 229 try: 230 raise ValueError( 231 "ex must be a tuple or sys.exc_info must return the current " 232 "exception: %s" 233 % self._exc_info_to_raise) 234 except ValueError: 235 # Record this error so it kills the coordinator properly. 236 # NOTE(touts): As above, this is bogus if request_stop() is not 237 # called from the exception handler that raised ex. 238 self._exc_info_to_raise = sys.exc_info() 239 240 self._stop_event.set() 241 242 def clear_stop(self): 243 """Clears the stop flag. 244 245 After this is called, calls to `should_stop()` will return `False`. 246 """ 247 with self._lock: 248 self._joined = False 249 self._exc_info_to_raise = None 250 if self._stop_event.is_set(): 251 self._stop_event.clear() 252 253 def should_stop(self): 254 """Check if stop was requested. 255 256 Returns: 257 True if a stop was requested. 258 """ 259 return self._stop_event.is_set() 260 261 @contextlib.contextmanager 262 def stop_on_exception(self): 263 """Context manager to request stop when an Exception is raised. 264 265 Code that uses a coordinator must catch exceptions and pass 266 them to the `request_stop()` method to stop the other threads 267 managed by the coordinator. 268 269 This context handler simplifies the exception handling. 270 Use it as follows: 271 272 ```python 273 with coord.stop_on_exception(): 274 # Any exception raised in the body of the with 275 # clause is reported to the coordinator before terminating 276 # the execution of the body. 277 ...body... 278 ``` 279 280 This is completely equivalent to the slightly longer code: 281 282 ```python 283 try: 284 ...body... 285 except: 286 coord.request_stop(sys.exc_info()) 287 ``` 288 289 Yields: 290 nothing. 291 """ 292 try: 293 yield 294 except: # pylint: disable=bare-except 295 self.request_stop(ex=sys.exc_info()) 296 297 def wait_for_stop(self, timeout=None): 298 """Wait till the Coordinator is told to stop. 299 300 Args: 301 timeout: Float. Sleep for up to that many seconds waiting for 302 should_stop() to become True. 303 304 Returns: 305 True if the Coordinator is told stop, False if the timeout expired. 306 """ 307 return self._stop_event.wait(timeout) 308 309 def register_thread(self, thread): 310 """Register a thread to join. 311 312 Args: 313 thread: A Python thread to join. 314 """ 315 with self._lock: 316 self._registered_threads.add(thread) 317 318 def join(self, threads=None, stop_grace_period_secs=120, 319 ignore_live_threads=False): 320 """Wait for threads to terminate. 321 322 This call blocks until a set of threads have terminated. The set of thread 323 is the union of the threads passed in the `threads` argument and the list 324 of threads that registered with the coordinator by calling 325 `Coordinator.register_thread()`. 326 327 After the threads stop, if an `exc_info` was passed to `request_stop`, that 328 exception is re-raised. 329 330 Grace period handling: When `request_stop()` is called, threads are given 331 'stop_grace_period_secs' seconds to terminate. If any of them is still 332 alive after that period expires, a `RuntimeError` is raised. Note that if 333 an `exc_info` was passed to `request_stop()` then it is raised instead of 334 that `RuntimeError`. 335 336 Args: 337 threads: List of `threading.Threads`. The started threads to join in 338 addition to the registered threads. 339 stop_grace_period_secs: Number of seconds given to threads to stop after 340 `request_stop()` has been called. 341 ignore_live_threads: If `False`, raises an error if any of the threads are 342 still alive after `stop_grace_period_secs`. 343 344 Raises: 345 RuntimeError: If any thread is still alive after `request_stop()` 346 is called and the grace period expires. 347 """ 348 # Threads registered after this call will not be joined. 349 with self._lock: 350 if threads is None: 351 threads = self._registered_threads 352 else: 353 threads = self._registered_threads.union(set(threads)) 354 # Copy the set into a list to avoid race conditions where a new thread 355 # is added while we are waiting. 356 threads = list(threads) 357 358 # Wait for all threads to stop or for request_stop() to be called. 359 while any(t.is_alive() for t in threads) and not self.wait_for_stop(1.0): 360 pass 361 362 # If any thread is still alive, wait for the grace period to expire. 363 # By the time this check is executed, threads may still be shutting down, 364 # so we add a sleep of increasing duration to give them a chance to shut 365 # down without losing too many cycles. 366 # The sleep duration is limited to the remaining grace duration. 367 stop_wait_secs = 0.001 368 while any(t.is_alive() for t in threads) and stop_grace_period_secs >= 0.0: 369 time.sleep(stop_wait_secs) 370 stop_grace_period_secs -= stop_wait_secs 371 stop_wait_secs = 2 * stop_wait_secs 372 # Keep the waiting period within sane bounds. 373 # The minimum value is to avoid decreasing stop_wait_secs to a value 374 # that could cause stop_grace_period_secs to remain unchanged. 375 stop_wait_secs = max(min(stop_wait_secs, stop_grace_period_secs), 0.001) 376 377 # List the threads still alive after the grace period. 378 stragglers = [t.name for t in threads if t.is_alive()] 379 380 # Terminate with an exception if appropriate. 381 with self._lock: 382 self._joined = True 383 self._registered_threads = set() 384 if self._exc_info_to_raise: 385 _, ex_instance, _ = self._exc_info_to_raise 386 raise ex_instance 387 elif stragglers: 388 if ignore_live_threads: 389 logging.info("Coordinator stopped with threads still running: %s", 390 " ".join(stragglers)) 391 else: 392 raise RuntimeError( 393 "Coordinator stopped with threads still running: %s" % 394 " ".join(stragglers)) 395 396 @property 397 def joined(self): 398 return self._joined 399 400 def raise_requested_exception(self): 401 """If an exception has been passed to `request_stop`, this raises it.""" 402 with self._lock: 403 if self._exc_info_to_raise: 404 _, ex_instance, _ = self._exc_info_to_raise 405 raise ex_instance 406 407 408# Threads for the standard services. 409@tf_export(v1=["train.LooperThread"]) 410class LooperThread(threading.Thread): 411 """A thread that runs code repeatedly, optionally on a timer. 412 413 This thread class is intended to be used with a `Coordinator`. It repeatedly 414 runs code specified either as `target` and `args` or by the `run_loop()` 415 method. 416 417 Before each run the thread checks if the coordinator has requested stop. In 418 that case the looper thread terminates immediately. 419 420 If the code being run raises an exception, that exception is reported to the 421 coordinator and the thread terminates. The coordinator will then request all 422 the other threads it coordinates to stop. 423 424 You typically pass looper threads to the supervisor `Join()` method. 425 """ 426 427 def __init__(self, coord, timer_interval_secs, target=None, args=None, 428 kwargs=None): 429 """Create a LooperThread. 430 431 Args: 432 coord: A Coordinator. 433 timer_interval_secs: Time boundaries at which to call Run(), or None 434 if it should be called back to back. 435 target: Optional callable object that will be executed in the thread. 436 args: Optional arguments to pass to `target` when calling it. 437 kwargs: Optional keyword arguments to pass to `target` when calling it. 438 439 Raises: 440 ValueError: If one of the arguments is invalid. 441 """ 442 if not isinstance(coord, Coordinator): 443 raise ValueError("'coord' argument must be a Coordinator: %s" % coord) 444 super(LooperThread, self).__init__() 445 self.daemon = True 446 self._coord = coord 447 self._timer_interval_secs = timer_interval_secs 448 self._target = target 449 if self._target: 450 self._args = args or () 451 self._kwargs = kwargs or {} 452 elif args or kwargs: 453 raise ValueError("'args' and 'kwargs' argument require that you also " 454 "pass 'target'") 455 self._coord.register_thread(self) 456 457 @staticmethod 458 def loop(coord, timer_interval_secs, target, args=None, kwargs=None): 459 """Start a LooperThread that calls a function periodically. 460 461 If `timer_interval_secs` is None the thread calls `target(args)` 462 repeatedly. Otherwise `target(args)` is called every `timer_interval_secs` 463 seconds. The thread terminates when a stop of the coordinator is 464 requested. 465 466 Args: 467 coord: A Coordinator. 468 timer_interval_secs: Number. Time boundaries at which to call `target`. 469 target: A callable object. 470 args: Optional arguments to pass to `target` when calling it. 471 kwargs: Optional keyword arguments to pass to `target` when calling it. 472 473 Returns: 474 The started thread. 475 """ 476 looper = LooperThread(coord, timer_interval_secs, target=target, args=args, 477 kwargs=kwargs) 478 looper.start() 479 return looper 480 481 def run(self): 482 with self._coord.stop_on_exception(): 483 self.start_loop() 484 if self._timer_interval_secs is None: 485 # Call back-to-back. 486 while not self._coord.should_stop(): 487 self.run_loop() 488 else: 489 # Next time at which to call run_loop(), starts as 'now'. 490 next_timer_time = time.time() 491 while not self._coord.wait_for_stop(next_timer_time - time.time()): 492 next_timer_time += self._timer_interval_secs 493 self.run_loop() 494 self.stop_loop() 495 496 def start_loop(self): 497 """Called when the thread starts.""" 498 pass 499 500 def stop_loop(self): 501 """Called when the thread stops.""" 502 pass 503 504 def run_loop(self): 505 """Called at 'timer_interval_secs' boundaries.""" 506 if self._target: 507 self._target(*self._args, **self._kwargs) 508