1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training helper that checkpoints models and creates session.""" 16import time 17 18import numpy as np 19from tensorflow.python.checkpoint import checkpoint_management 20from tensorflow.python.client import session 21from tensorflow.python.distribute import distribution_strategy_context 22from tensorflow.python.framework import errors 23from tensorflow.python.framework import ops 24from tensorflow.python.platform import tf_logging as logging 25from tensorflow.python.util.tf_export import tf_export 26 27 28def _maybe_name(obj): 29 """Returns object name if it has one, or a message otherwise. 30 31 This is useful for names that apper in error messages. 32 Args: 33 obj: Object to get the name of. 34 Returns: 35 name, "None", or a "no name" message. 36 """ 37 if obj is None: 38 return "None" 39 elif hasattr(obj, "name"): 40 return obj.name 41 else: 42 return "<no name for %s>" % type(obj) 43 44 45def _restore_checkpoint_and_maybe_run_saved_model_initializers( 46 sess, saver, path): 47 """Restores checkpoint values and SavedModel initializers if found.""" 48 # NOTE: All references to SavedModel refer to SavedModels loaded from the 49 # load_v2 API (which does not require the `sess` argument). 50 51 # If the graph contains resources loaded from a SavedModel, they are not 52 # restored when calling `saver.restore`. Thus, the SavedModel initializer must 53 # be called with `saver.restore` to properly initialize the model. 54 55 # The SavedModel init is stored in the "saved_model_initializers" collection. 56 # This collection is part of the MetaGraph's default_init_op, so it is already 57 # called by MonitoredSession as long as the saver doesn't restore any 58 # checkpoints from the working dir. 59 saved_model_init_ops = ops.get_collection("saved_model_initializers") 60 if saved_model_init_ops: 61 sess.run(saved_model_init_ops) 62 63 # The saver must be called *after* the SavedModel init, because the SavedModel 64 # init will restore the variables from the SavedModel variables directory. 65 # Initializing/restoring twice is not ideal but there's no other way to do it. 66 saver.restore(sess, path) 67 68 69@tf_export(v1=["train.SessionManager"]) 70class SessionManager: 71 """Training helper that restores from checkpoint and creates session. 72 73 This class is a small wrapper that takes care of session creation and 74 checkpoint recovery. It also provides functions that to facilitate 75 coordination among multiple training threads or processes. 76 77 * Checkpointing trained variables as the training progresses. 78 * Initializing variables on startup, restoring them from the most recent 79 checkpoint after a crash, or wait for checkpoints to become available. 80 81 ### Usage: 82 83 ```python 84 with tf.Graph().as_default(): 85 ...add operations to the graph... 86 # Create a SessionManager that will checkpoint the model in '/tmp/mydir'. 87 sm = SessionManager() 88 sess = sm.prepare_session(master, init_op, saver, checkpoint_dir) 89 # Use the session to train the graph. 90 while True: 91 sess.run(<my_train_op>) 92 ``` 93 94 `prepare_session()` initializes or restores a model. It requires `init_op` 95 and `saver` as an argument. 96 97 A second process could wait for the model to be ready by doing the following: 98 99 ```python 100 with tf.Graph().as_default(): 101 ...add operations to the graph... 102 # Create a SessionManager that will wait for the model to become ready. 103 sm = SessionManager() 104 sess = sm.wait_for_session(master) 105 # Use the session to train the graph. 106 while True: 107 sess.run(<my_train_op>) 108 ``` 109 110 `wait_for_session()` waits for a model to be initialized by other processes. 111 112 """ 113 114 def __init__(self, 115 local_init_op=None, 116 ready_op=None, 117 ready_for_local_init_op=None, 118 graph=None, 119 recovery_wait_secs=30, 120 local_init_run_options=None, 121 local_init_feed_dict=None): 122 """Creates a SessionManager. 123 124 The `local_init_op` is an `Operation` that is run always after a new session 125 was created. If `None`, this step is skipped. 126 127 The `ready_op` is an `Operation` used to check if the model is ready. The 128 model is considered ready if that operation returns an empty 1D string 129 tensor. If the operation returns a non empty 1D string tensor, the elements 130 are concatenated and used to indicate to the user why the model is not 131 ready. 132 133 The `ready_for_local_init_op` is an `Operation` used to check if the model 134 is ready to run local_init_op. The model is considered ready if that 135 operation returns an empty 1D string tensor. If the operation returns a non 136 empty 1D string tensor, the elements are concatenated and used to indicate 137 to the user why the model is not ready. 138 139 If `ready_op` is `None`, the model is not checked for readiness. 140 141 `recovery_wait_secs` is the number of seconds between checks that 142 the model is ready. It is used by processes to wait for a model to 143 be initialized or restored. Defaults to 30 seconds. 144 145 Args: 146 local_init_op: An `Operation` run immediately after session creation. 147 Usually used to initialize tables and local variables. 148 ready_op: An `Operation` to check if the model is initialized. 149 ready_for_local_init_op: An `Operation` to check if the model is ready 150 to run local_init_op. 151 graph: The `Graph` that the model will use. 152 recovery_wait_secs: Seconds between checks for the model to be ready. 153 local_init_run_options: RunOptions to be passed to session.run when 154 executing the local_init_op. 155 local_init_feed_dict: Optional session feed dictionary to use when running 156 the local_init_op. 157 158 Raises: 159 ValueError: If ready_for_local_init_op is not None but local_init_op is 160 None 161 """ 162 # Sets default values of arguments. 163 if graph is None: 164 graph = ops.get_default_graph() 165 self._local_init_op = local_init_op 166 self._ready_op = ready_op 167 self._ready_for_local_init_op = ready_for_local_init_op 168 self._graph = graph 169 self._recovery_wait_secs = recovery_wait_secs 170 self._target = None 171 self._local_init_run_options = local_init_run_options 172 self._local_init_feed_dict = local_init_feed_dict 173 if ready_for_local_init_op is not None and local_init_op is None: 174 raise ValueError("If you pass a ready_for_local_init_op " 175 "you must also pass a local_init_op " 176 ", ready_for_local_init_op [%s]" % 177 ready_for_local_init_op) 178 179 def _restore_checkpoint(self, 180 master, 181 saver=None, 182 checkpoint_dir=None, 183 checkpoint_filename_with_path=None, 184 wait_for_checkpoint=False, 185 max_wait_secs=7200, 186 config=None): 187 """Creates a `Session`, and tries to restore a checkpoint. 188 189 190 Args: 191 master: `String` representation of the TensorFlow master to use. 192 saver: A `Saver` object used to restore a model. 193 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 194 dir will be used to restore. 195 checkpoint_filename_with_path: Full file name path to the checkpoint file. 196 wait_for_checkpoint: Whether to wait for checkpoint to become available. 197 max_wait_secs: Maximum time to wait for checkpoints to become available. 198 config: Optional `ConfigProto` proto used to configure the session. 199 200 Returns: 201 A pair (sess, is_restored) where 'is_restored' is `True` if 202 the session could be restored, `False` otherwise. 203 204 Raises: 205 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 206 set. 207 """ 208 self._target = master 209 210 # This is required to so that we initialize the TPU device before 211 # restoring from checkpoint since we'll be placing variables on the device 212 # and TPUInitialize wipes out the memory of the device. 213 strategy = distribution_strategy_context.get_strategy() 214 if strategy and hasattr(strategy.extended, 215 "_experimental_initialize_system"): 216 strategy.extended._experimental_initialize_system() # pylint: disable=protected-access 217 218 sess = session.Session(self._target, graph=self._graph, config=config) 219 if checkpoint_dir and checkpoint_filename_with_path: 220 raise ValueError("Can not provide both checkpoint_dir and " 221 "checkpoint_filename_with_path.") 222 # If either saver or checkpoint_* is not specified, cannot restore. Just 223 # return. 224 if not saver or not (checkpoint_dir or checkpoint_filename_with_path): 225 return sess, False 226 227 if checkpoint_filename_with_path: 228 _restore_checkpoint_and_maybe_run_saved_model_initializers( 229 sess, saver, checkpoint_filename_with_path) 230 return sess, True 231 232 # Waits up until max_wait_secs for checkpoint to become available. 233 wait_time = 0 234 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) 235 while not ckpt or not ckpt.model_checkpoint_path: 236 if wait_for_checkpoint and wait_time < max_wait_secs: 237 logging.info("Waiting for checkpoint to be available.") 238 time.sleep(self._recovery_wait_secs) 239 wait_time += self._recovery_wait_secs 240 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) 241 else: 242 return sess, False 243 244 # Loads the checkpoint. 245 _restore_checkpoint_and_maybe_run_saved_model_initializers( 246 sess, saver, ckpt.model_checkpoint_path) 247 saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) 248 return sess, True 249 250 def prepare_session(self, 251 master, 252 init_op=None, 253 saver=None, 254 checkpoint_dir=None, 255 checkpoint_filename_with_path=None, 256 wait_for_checkpoint=False, 257 max_wait_secs=7200, 258 config=None, 259 init_feed_dict=None, 260 init_fn=None): 261 """Creates a `Session`. Makes sure the model is ready to be used. 262 263 Creates a `Session` on 'master'. If a `saver` object is passed in, and 264 `checkpoint_dir` points to a directory containing valid checkpoint 265 files, then it will try to recover the model from checkpoint. If 266 no checkpoint files are available, and `wait_for_checkpoint` is 267 `True`, then the process would check every `recovery_wait_secs`, 268 up to `max_wait_secs`, for recovery to succeed. 269 270 If the model cannot be recovered successfully then it is initialized by 271 running the `init_op` and calling `init_fn` if they are provided. 272 The `local_init_op` is also run after init_op and init_fn, regardless of 273 whether the model was recovered successfully, but only if 274 `ready_for_local_init_op` passes. 275 276 If the model is recovered from a checkpoint it is assumed that all 277 global variables have been initialized, in particular neither `init_op` 278 nor `init_fn` will be executed. 279 280 It is an error if the model cannot be recovered and no `init_op` 281 or `init_fn` or `local_init_op` are passed. 282 283 Args: 284 master: `String` representation of the TensorFlow master to use. 285 init_op: Optional `Operation` used to initialize the model. 286 saver: A `Saver` object used to restore a model. 287 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 288 dir will be used to restore. 289 checkpoint_filename_with_path: Full file name path to the checkpoint file. 290 wait_for_checkpoint: Whether to wait for checkpoint to become available. 291 max_wait_secs: Maximum time to wait for checkpoints to become available. 292 config: Optional `ConfigProto` proto used to configure the session. 293 init_feed_dict: Optional dictionary that maps `Tensor` objects to feed 294 values. This feed dictionary is passed to the session `run()` call when 295 running the init op. 296 init_fn: Optional callable used to initialize the model. Called after the 297 optional `init_op` is called. The callable must accept one argument, 298 the session being initialized. 299 300 Returns: 301 A `Session` object that can be used to drive the model. 302 303 Raises: 304 RuntimeError: If the model cannot be initialized or recovered. 305 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 306 set. 307 """ 308 309 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 310 master, 311 saver, 312 checkpoint_dir=checkpoint_dir, 313 checkpoint_filename_with_path=checkpoint_filename_with_path, 314 wait_for_checkpoint=wait_for_checkpoint, 315 max_wait_secs=max_wait_secs, 316 config=config) 317 if not is_loaded_from_checkpoint: 318 if init_op is None and not init_fn and self._local_init_op is None: 319 raise RuntimeError("Model is not initialized and no init_op or " 320 "init_fn or local_init_op was given") 321 if init_op is not None: 322 sess.run(init_op, feed_dict=init_feed_dict) 323 if init_fn: 324 init_fn(sess) 325 326 local_init_success, msg = self._try_run_local_init_op(sess) 327 if not local_init_success: 328 raise RuntimeError( 329 "Init operations did not make model ready for local_init. " 330 "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op), 331 init_fn, 332 msg)) 333 334 is_ready, msg = self._model_ready(sess) 335 if not is_ready: 336 raise RuntimeError( 337 "Init operations did not make model ready. " 338 "Init op: %s, init fn: %s, local_init_op: %s, error: %s" % 339 (_maybe_name(init_op), init_fn, self._local_init_op, msg)) 340 return sess 341 342 def recover_session(self, 343 master, 344 saver=None, 345 checkpoint_dir=None, 346 checkpoint_filename_with_path=None, 347 wait_for_checkpoint=False, 348 max_wait_secs=7200, 349 config=None): 350 """Creates a `Session`, recovering if possible. 351 352 Creates a new session on 'master'. If the session is not initialized 353 and can be recovered from a checkpoint, recover it. 354 355 Args: 356 master: `String` representation of the TensorFlow master to use. 357 saver: A `Saver` object used to restore a model. 358 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 359 dir will be used to restore. 360 checkpoint_filename_with_path: Full file name path to the checkpoint file. 361 wait_for_checkpoint: Whether to wait for checkpoint to become available. 362 max_wait_secs: Maximum time to wait for checkpoints to become available. 363 config: Optional `ConfigProto` proto used to configure the session. 364 365 Returns: 366 A pair (sess, initialized) where 'initialized' is `True` if 367 the session could be recovered and initialized, `False` otherwise. 368 369 Raises: 370 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 371 set. 372 """ 373 374 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 375 master, 376 saver, 377 checkpoint_dir=checkpoint_dir, 378 checkpoint_filename_with_path=checkpoint_filename_with_path, 379 wait_for_checkpoint=wait_for_checkpoint, 380 max_wait_secs=max_wait_secs, 381 config=config) 382 383 # Always try to run local_init_op 384 local_init_success, msg = self._try_run_local_init_op(sess) 385 386 if not is_loaded_from_checkpoint: 387 # Do not need to run checks for readiness 388 return sess, False 389 390 restoring_file = checkpoint_dir or checkpoint_filename_with_path 391 if not local_init_success: 392 logging.info( 393 "Restoring model from %s did not make model ready for local init:" 394 " %s", restoring_file, msg) 395 return sess, False 396 397 is_ready, msg = self._model_ready(sess) 398 if not is_ready: 399 logging.info("Restoring model from %s did not make model ready: %s", 400 restoring_file, msg) 401 return sess, False 402 403 logging.info("Restored model from %s", restoring_file) 404 return sess, is_loaded_from_checkpoint 405 406 def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")): 407 """Creates a new `Session` and waits for model to be ready. 408 409 Creates a new `Session` on 'master'. Waits for the model to be 410 initialized or recovered from a checkpoint. It's expected that 411 another thread or process will make the model ready, and that this 412 is intended to be used by threads/processes that participate in a 413 distributed training configuration where a different thread/process 414 is responsible for initializing or recovering the model being trained. 415 416 NB: The amount of time this method waits for the session is bounded 417 by max_wait_secs. By default, this function will wait indefinitely. 418 419 Args: 420 master: `String` representation of the TensorFlow master to use. 421 config: Optional ConfigProto proto used to configure the session. 422 max_wait_secs: Maximum time to wait for the session to become available. 423 424 Returns: 425 A `Session`. May be None if the operation exceeds the timeout 426 specified by config.operation_timeout_in_ms. 427 428 Raises: 429 tf.DeadlineExceededError: if the session is not available after 430 max_wait_secs. 431 """ 432 self._target = master 433 434 if max_wait_secs is None: 435 max_wait_secs = float("Inf") 436 timer = _CountDownTimer(max_wait_secs) 437 438 while True: 439 sess = session.Session(self._target, graph=self._graph, config=config) 440 not_ready_msg = None 441 not_ready_local_msg = None 442 local_init_success, not_ready_local_msg = self._try_run_local_init_op( 443 sess) 444 if local_init_success: 445 # Successful if local_init_op is None, or ready_for_local_init_op passes 446 is_ready, not_ready_msg = self._model_ready(sess) 447 if is_ready: 448 return sess 449 450 self._safe_close(sess) 451 452 # Do we have enough time left to try again? 453 remaining_ms_after_wait = ( 454 timer.secs_remaining() - self._recovery_wait_secs) 455 if remaining_ms_after_wait < 0: 456 raise errors.DeadlineExceededError( 457 None, None, 458 "Session was not ready after waiting %d secs." % (max_wait_secs,)) 459 460 logging.info("Waiting for model to be ready. " 461 "Ready_for_local_init_op: %s, ready: %s", 462 not_ready_local_msg, not_ready_msg) 463 time.sleep(self._recovery_wait_secs) 464 465 def _safe_close(self, sess): 466 """Closes a session without raising an exception. 467 468 Just like sess.close() but ignores exceptions. 469 470 Args: 471 sess: A `Session`. 472 """ 473 # pylint: disable=broad-except 474 try: 475 sess.close() 476 except Exception: 477 # Intentionally not logging to avoid user complaints that 478 # they get cryptic errors. We really do not care that Close 479 # fails. 480 pass 481 # pylint: enable=broad-except 482 483 def _model_ready(self, sess): 484 """Checks if the model is ready or not. 485 486 Args: 487 sess: A `Session`. 488 489 Returns: 490 A tuple (is_ready, msg), where is_ready is True if ready and False 491 otherwise, and msg is `None` if the model is ready, a `String` with the 492 reason why it is not ready otherwise. 493 """ 494 return _ready(self._ready_op, sess, "Model not ready") 495 496 def _model_ready_for_local_init(self, sess): 497 """Checks if the model is ready to run local_init_op. 498 499 Args: 500 sess: A `Session`. 501 502 Returns: 503 A tuple (is_ready, msg), where is_ready is True if ready to run 504 local_init_op and False otherwise, and msg is `None` if the model is 505 ready to run local_init_op, a `String` with the reason why it is not ready 506 otherwise. 507 """ 508 return _ready(self._ready_for_local_init_op, sess, 509 "Model not ready for local init") 510 511 def _try_run_local_init_op(self, sess): 512 """Tries to run _local_init_op, if not None, and is ready for local init. 513 514 Args: 515 sess: A `Session`. 516 517 Returns: 518 A tuple (is_successful, msg), where is_successful is True if 519 _local_init_op is None, or we ran _local_init_op, and False otherwise; 520 and msg is a `String` with the reason why the model was not ready to run 521 local init. 522 """ 523 if self._local_init_op is not None: 524 is_ready_for_local_init, msg = self._model_ready_for_local_init(sess) 525 if is_ready_for_local_init: 526 logging.info("Running local_init_op.") 527 sess.run(self._local_init_op, feed_dict=self._local_init_feed_dict, 528 options=self._local_init_run_options) 529 logging.info("Done running local_init_op.") 530 return True, None 531 else: 532 return False, msg 533 return True, None 534 535 536def _ready(op, sess, msg): 537 """Checks if the model is ready or not, as determined by op. 538 539 Args: 540 op: An op, either _ready_op or _ready_for_local_init_op, which defines the 541 readiness of the model. 542 sess: A `Session`. 543 msg: A message to log to warning if not ready 544 545 Returns: 546 A tuple (is_ready, msg), where is_ready is True if ready and False 547 otherwise, and msg is `None` if the model is ready, a `String` with the 548 reason why it is not ready otherwise. 549 """ 550 if op is None: 551 return True, None 552 else: 553 try: 554 ready_value = sess.run(op) 555 # The model is considered ready if ready_op returns an empty 1-D tensor. 556 # Also compare to `None` and dtype being int32 for backward 557 # compatibility. 558 if (ready_value is None or ready_value.dtype == np.int32 or 559 ready_value.size == 0): 560 return True, None 561 else: 562 # TODO(sherrym): If a custom ready_op returns other types of tensor, 563 # or strings other than variable names, this message could be 564 # confusing. 565 non_initialized_varnames = ", ".join( 566 [i.decode("utf-8") for i in ready_value]) 567 return False, "Variables not initialized: " + non_initialized_varnames 568 except errors.FailedPreconditionError as e: 569 if "uninitialized" not in str(e): 570 logging.warning("%s : error [%s]", msg, str(e)) 571 raise e 572 return False, str(e) 573 574 575class _CountDownTimer: 576 """A timer that tracks a duration since creation.""" 577 578 __slots__ = ["_start_time_secs", "_duration_secs"] 579 580 def __init__(self, duration_secs): 581 self._start_time_secs = time.time() 582 self._duration_secs = duration_secs 583 584 def secs_remaining(self): 585 diff = self._duration_secs - (time.time() - self._start_time_secs) 586 return max(0, diff) 587