xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/session_manager.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training helper that checkpoints models and creates session."""
16import time
17
18import numpy as np
19from tensorflow.python.checkpoint import checkpoint_management
20from tensorflow.python.client import session
21from tensorflow.python.distribute import distribution_strategy_context
22from tensorflow.python.framework import errors
23from tensorflow.python.framework import ops
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.util.tf_export import tf_export
26
27
28def _maybe_name(obj):
29  """Returns object name if it has one, or a message otherwise.
30
31  This is useful for names that apper in error messages.
32  Args:
33    obj: Object to get the name of.
34  Returns:
35    name, "None", or a "no name" message.
36  """
37  if obj is None:
38    return "None"
39  elif hasattr(obj, "name"):
40    return obj.name
41  else:
42    return "<no name for %s>" % type(obj)
43
44
45def _restore_checkpoint_and_maybe_run_saved_model_initializers(
46    sess, saver, path):
47  """Restores checkpoint values and SavedModel initializers if found."""
48  # NOTE: All references to SavedModel refer to SavedModels loaded from the
49  # load_v2 API (which does not require the `sess` argument).
50
51  # If the graph contains resources loaded from a SavedModel, they are not
52  # restored when calling `saver.restore`. Thus, the SavedModel initializer must
53  # be called with `saver.restore` to properly initialize the model.
54
55  # The SavedModel init is stored in the "saved_model_initializers" collection.
56  # This collection is part of the MetaGraph's default_init_op, so it is already
57  # called by MonitoredSession as long as the saver doesn't restore any
58  # checkpoints from the working dir.
59  saved_model_init_ops = ops.get_collection("saved_model_initializers")
60  if saved_model_init_ops:
61    sess.run(saved_model_init_ops)
62
63  # The saver must be called *after* the SavedModel init, because the SavedModel
64  # init will restore the variables from the SavedModel variables directory.
65  # Initializing/restoring twice is not ideal but there's no other way to do it.
66  saver.restore(sess, path)
67
68
69@tf_export(v1=["train.SessionManager"])
70class SessionManager:
71  """Training helper that restores from checkpoint and creates session.
72
73  This class is a small wrapper that takes care of session creation and
74  checkpoint recovery. It also provides functions that to facilitate
75  coordination among multiple training threads or processes.
76
77  * Checkpointing trained variables as the training progresses.
78  * Initializing variables on startup, restoring them from the most recent
79    checkpoint after a crash, or wait for checkpoints to become available.
80
81  ### Usage:
82
83  ```python
84  with tf.Graph().as_default():
85     ...add operations to the graph...
86    # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.
87    sm = SessionManager()
88    sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
89    # Use the session to train the graph.
90    while True:
91      sess.run(<my_train_op>)
92  ```
93
94  `prepare_session()` initializes or restores a model. It requires `init_op`
95  and `saver` as an argument.
96
97  A second process could wait for the model to be ready by doing the following:
98
99  ```python
100  with tf.Graph().as_default():
101     ...add operations to the graph...
102    # Create a SessionManager that will wait for the model to become ready.
103    sm = SessionManager()
104    sess = sm.wait_for_session(master)
105    # Use the session to train the graph.
106    while True:
107      sess.run(<my_train_op>)
108  ```
109
110  `wait_for_session()` waits for a model to be initialized by other processes.
111
112  """
113
114  def __init__(self,
115               local_init_op=None,
116               ready_op=None,
117               ready_for_local_init_op=None,
118               graph=None,
119               recovery_wait_secs=30,
120               local_init_run_options=None,
121               local_init_feed_dict=None):
122    """Creates a SessionManager.
123
124    The `local_init_op` is an `Operation` that is run always after a new session
125    was created. If `None`, this step is skipped.
126
127    The `ready_op` is an `Operation` used to check if the model is ready.  The
128    model is considered ready if that operation returns an empty 1D string
129    tensor. If the operation returns a non empty 1D string tensor, the elements
130    are concatenated and used to indicate to the user why the model is not
131    ready.
132
133    The `ready_for_local_init_op` is an `Operation` used to check if the model
134    is ready to run local_init_op.  The model is considered ready if that
135    operation returns an empty 1D string tensor. If the operation returns a non
136    empty 1D string tensor, the elements are concatenated and used to indicate
137    to the user why the model is not ready.
138
139    If `ready_op` is `None`, the model is not checked for readiness.
140
141    `recovery_wait_secs` is the number of seconds between checks that
142    the model is ready.  It is used by processes to wait for a model to
143    be initialized or restored.  Defaults to 30 seconds.
144
145    Args:
146      local_init_op: An `Operation` run immediately after session creation.
147         Usually used to initialize tables and local variables.
148      ready_op: An `Operation` to check if the model is initialized.
149      ready_for_local_init_op: An `Operation` to check if the model is ready
150         to run local_init_op.
151      graph: The `Graph` that the model will use.
152      recovery_wait_secs: Seconds between checks for the model to be ready.
153      local_init_run_options: RunOptions to be passed to session.run when
154        executing the local_init_op.
155      local_init_feed_dict: Optional session feed dictionary to use when running
156        the local_init_op.
157
158    Raises:
159      ValueError: If ready_for_local_init_op is not None but local_init_op is
160        None
161    """
162    # Sets default values of arguments.
163    if graph is None:
164      graph = ops.get_default_graph()
165    self._local_init_op = local_init_op
166    self._ready_op = ready_op
167    self._ready_for_local_init_op = ready_for_local_init_op
168    self._graph = graph
169    self._recovery_wait_secs = recovery_wait_secs
170    self._target = None
171    self._local_init_run_options = local_init_run_options
172    self._local_init_feed_dict = local_init_feed_dict
173    if ready_for_local_init_op is not None and local_init_op is None:
174      raise ValueError("If you pass a ready_for_local_init_op "
175                       "you must also pass a local_init_op "
176                       ", ready_for_local_init_op [%s]" %
177                       ready_for_local_init_op)
178
179  def _restore_checkpoint(self,
180                          master,
181                          saver=None,
182                          checkpoint_dir=None,
183                          checkpoint_filename_with_path=None,
184                          wait_for_checkpoint=False,
185                          max_wait_secs=7200,
186                          config=None):
187    """Creates a `Session`, and tries to restore a checkpoint.
188
189
190    Args:
191      master: `String` representation of the TensorFlow master to use.
192      saver: A `Saver` object used to restore a model.
193      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
194        dir will be used to restore.
195      checkpoint_filename_with_path: Full file name path to the checkpoint file.
196      wait_for_checkpoint: Whether to wait for checkpoint to become available.
197      max_wait_secs: Maximum time to wait for checkpoints to become available.
198      config: Optional `ConfigProto` proto used to configure the session.
199
200    Returns:
201      A pair (sess, is_restored) where 'is_restored' is `True` if
202      the session could be restored, `False` otherwise.
203
204    Raises:
205      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
206        set.
207    """
208    self._target = master
209
210    # This is required to so that we initialize the TPU device before
211    # restoring from checkpoint since we'll be placing variables on the device
212    # and TPUInitialize wipes out the memory of the device.
213    strategy = distribution_strategy_context.get_strategy()
214    if strategy and hasattr(strategy.extended,
215                            "_experimental_initialize_system"):
216      strategy.extended._experimental_initialize_system()  # pylint: disable=protected-access
217
218    sess = session.Session(self._target, graph=self._graph, config=config)
219    if checkpoint_dir and checkpoint_filename_with_path:
220      raise ValueError("Can not provide both checkpoint_dir and "
221                       "checkpoint_filename_with_path.")
222    # If either saver or checkpoint_* is not specified, cannot restore. Just
223    # return.
224    if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
225      return sess, False
226
227    if checkpoint_filename_with_path:
228      _restore_checkpoint_and_maybe_run_saved_model_initializers(
229          sess, saver, checkpoint_filename_with_path)
230      return sess, True
231
232    # Waits up until max_wait_secs for checkpoint to become available.
233    wait_time = 0
234    ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
235    while not ckpt or not ckpt.model_checkpoint_path:
236      if wait_for_checkpoint and wait_time < max_wait_secs:
237        logging.info("Waiting for checkpoint to be available.")
238        time.sleep(self._recovery_wait_secs)
239        wait_time += self._recovery_wait_secs
240        ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
241      else:
242        return sess, False
243
244    # Loads the checkpoint.
245    _restore_checkpoint_and_maybe_run_saved_model_initializers(
246        sess, saver, ckpt.model_checkpoint_path)
247    saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
248    return sess, True
249
250  def prepare_session(self,
251                      master,
252                      init_op=None,
253                      saver=None,
254                      checkpoint_dir=None,
255                      checkpoint_filename_with_path=None,
256                      wait_for_checkpoint=False,
257                      max_wait_secs=7200,
258                      config=None,
259                      init_feed_dict=None,
260                      init_fn=None):
261    """Creates a `Session`. Makes sure the model is ready to be used.
262
263    Creates a `Session` on 'master'. If a `saver` object is passed in, and
264    `checkpoint_dir` points to a directory containing valid checkpoint
265    files, then it will try to recover the model from checkpoint. If
266    no checkpoint files are available, and `wait_for_checkpoint` is
267    `True`, then the process would check every `recovery_wait_secs`,
268    up to `max_wait_secs`, for recovery to succeed.
269
270    If the model cannot be recovered successfully then it is initialized by
271    running the `init_op` and calling `init_fn` if they are provided.
272    The `local_init_op` is also run after init_op and init_fn, regardless of
273    whether the model was recovered successfully, but only if
274    `ready_for_local_init_op` passes.
275
276    If the model is recovered from a checkpoint it is assumed that all
277    global variables have been initialized, in particular neither `init_op`
278    nor `init_fn` will be executed.
279
280    It is an error if the model cannot be recovered and no `init_op`
281    or `init_fn` or `local_init_op` are passed.
282
283    Args:
284      master: `String` representation of the TensorFlow master to use.
285      init_op: Optional `Operation` used to initialize the model.
286      saver: A `Saver` object used to restore a model.
287      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
288        dir will be used to restore.
289      checkpoint_filename_with_path: Full file name path to the checkpoint file.
290      wait_for_checkpoint: Whether to wait for checkpoint to become available.
291      max_wait_secs: Maximum time to wait for checkpoints to become available.
292      config: Optional `ConfigProto` proto used to configure the session.
293      init_feed_dict: Optional dictionary that maps `Tensor` objects to feed
294        values.  This feed dictionary is passed to the session `run()` call when
295        running the init op.
296      init_fn: Optional callable used to initialize the model. Called after the
297        optional `init_op` is called.  The callable must accept one argument,
298        the session being initialized.
299
300    Returns:
301      A `Session` object that can be used to drive the model.
302
303    Raises:
304      RuntimeError: If the model cannot be initialized or recovered.
305      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
306        set.
307    """
308
309    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
310        master,
311        saver,
312        checkpoint_dir=checkpoint_dir,
313        checkpoint_filename_with_path=checkpoint_filename_with_path,
314        wait_for_checkpoint=wait_for_checkpoint,
315        max_wait_secs=max_wait_secs,
316        config=config)
317    if not is_loaded_from_checkpoint:
318      if init_op is None and not init_fn and self._local_init_op is None:
319        raise RuntimeError("Model is not initialized and no init_op or "
320                           "init_fn or local_init_op was given")
321      if init_op is not None:
322        sess.run(init_op, feed_dict=init_feed_dict)
323      if init_fn:
324        init_fn(sess)
325
326    local_init_success, msg = self._try_run_local_init_op(sess)
327    if not local_init_success:
328      raise RuntimeError(
329          "Init operations did not make model ready for local_init.  "
330          "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op),
331                                                   init_fn,
332                                                   msg))
333
334    is_ready, msg = self._model_ready(sess)
335    if not is_ready:
336      raise RuntimeError(
337          "Init operations did not make model ready.  "
338          "Init op: %s, init fn: %s, local_init_op: %s, error: %s" %
339          (_maybe_name(init_op), init_fn, self._local_init_op, msg))
340    return sess
341
342  def recover_session(self,
343                      master,
344                      saver=None,
345                      checkpoint_dir=None,
346                      checkpoint_filename_with_path=None,
347                      wait_for_checkpoint=False,
348                      max_wait_secs=7200,
349                      config=None):
350    """Creates a `Session`, recovering if possible.
351
352    Creates a new session on 'master'.  If the session is not initialized
353    and can be recovered from a checkpoint, recover it.
354
355    Args:
356      master: `String` representation of the TensorFlow master to use.
357      saver: A `Saver` object used to restore a model.
358      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
359        dir will be used to restore.
360      checkpoint_filename_with_path: Full file name path to the checkpoint file.
361      wait_for_checkpoint: Whether to wait for checkpoint to become available.
362      max_wait_secs: Maximum time to wait for checkpoints to become available.
363      config: Optional `ConfigProto` proto used to configure the session.
364
365    Returns:
366      A pair (sess, initialized) where 'initialized' is `True` if
367      the session could be recovered and initialized, `False` otherwise.
368
369    Raises:
370      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
371        set.
372    """
373
374    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
375        master,
376        saver,
377        checkpoint_dir=checkpoint_dir,
378        checkpoint_filename_with_path=checkpoint_filename_with_path,
379        wait_for_checkpoint=wait_for_checkpoint,
380        max_wait_secs=max_wait_secs,
381        config=config)
382
383    # Always try to run local_init_op
384    local_init_success, msg = self._try_run_local_init_op(sess)
385
386    if not is_loaded_from_checkpoint:
387      # Do not need to run checks for readiness
388      return sess, False
389
390    restoring_file = checkpoint_dir or checkpoint_filename_with_path
391    if not local_init_success:
392      logging.info(
393          "Restoring model from %s did not make model ready for local init:"
394          " %s", restoring_file, msg)
395      return sess, False
396
397    is_ready, msg = self._model_ready(sess)
398    if not is_ready:
399      logging.info("Restoring model from %s did not make model ready: %s",
400                   restoring_file, msg)
401      return sess, False
402
403    logging.info("Restored model from %s", restoring_file)
404    return sess, is_loaded_from_checkpoint
405
406  def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")):
407    """Creates a new `Session` and waits for model to be ready.
408
409    Creates a new `Session` on 'master'.  Waits for the model to be
410    initialized or recovered from a checkpoint.  It's expected that
411    another thread or process will make the model ready, and that this
412    is intended to be used by threads/processes that participate in a
413    distributed training configuration where a different thread/process
414    is responsible for initializing or recovering the model being trained.
415
416    NB: The amount of time this method waits for the session is bounded
417    by max_wait_secs. By default, this function will wait indefinitely.
418
419    Args:
420      master: `String` representation of the TensorFlow master to use.
421      config: Optional ConfigProto proto used to configure the session.
422      max_wait_secs: Maximum time to wait for the session to become available.
423
424    Returns:
425      A `Session`. May be None if the operation exceeds the timeout
426      specified by config.operation_timeout_in_ms.
427
428    Raises:
429      tf.DeadlineExceededError: if the session is not available after
430        max_wait_secs.
431    """
432    self._target = master
433
434    if max_wait_secs is None:
435      max_wait_secs = float("Inf")
436    timer = _CountDownTimer(max_wait_secs)
437
438    while True:
439      sess = session.Session(self._target, graph=self._graph, config=config)
440      not_ready_msg = None
441      not_ready_local_msg = None
442      local_init_success, not_ready_local_msg = self._try_run_local_init_op(
443          sess)
444      if local_init_success:
445        # Successful if local_init_op is None, or ready_for_local_init_op passes
446        is_ready, not_ready_msg = self._model_ready(sess)
447        if is_ready:
448          return sess
449
450      self._safe_close(sess)
451
452      # Do we have enough time left to try again?
453      remaining_ms_after_wait = (
454          timer.secs_remaining() - self._recovery_wait_secs)
455      if remaining_ms_after_wait < 0:
456        raise errors.DeadlineExceededError(
457            None, None,
458            "Session was not ready after waiting %d secs." % (max_wait_secs,))
459
460      logging.info("Waiting for model to be ready.  "
461                   "Ready_for_local_init_op:  %s, ready: %s",
462                   not_ready_local_msg, not_ready_msg)
463      time.sleep(self._recovery_wait_secs)
464
465  def _safe_close(self, sess):
466    """Closes a session without raising an exception.
467
468    Just like sess.close() but ignores exceptions.
469
470    Args:
471      sess: A `Session`.
472    """
473    # pylint: disable=broad-except
474    try:
475      sess.close()
476    except Exception:
477      # Intentionally not logging to avoid user complaints that
478      # they get cryptic errors.  We really do not care that Close
479      # fails.
480      pass
481    # pylint: enable=broad-except
482
483  def _model_ready(self, sess):
484    """Checks if the model is ready or not.
485
486    Args:
487      sess: A `Session`.
488
489    Returns:
490      A tuple (is_ready, msg), where is_ready is True if ready and False
491      otherwise, and msg is `None` if the model is ready, a `String` with the
492      reason why it is not ready otherwise.
493    """
494    return _ready(self._ready_op, sess, "Model not ready")
495
496  def _model_ready_for_local_init(self, sess):
497    """Checks if the model is ready to run local_init_op.
498
499    Args:
500      sess: A `Session`.
501
502    Returns:
503      A tuple (is_ready, msg), where is_ready is True if ready to run
504      local_init_op and False otherwise, and msg is `None` if the model is
505      ready to run local_init_op, a `String` with the reason why it is not ready
506      otherwise.
507    """
508    return _ready(self._ready_for_local_init_op, sess,
509                  "Model not ready for local init")
510
511  def _try_run_local_init_op(self, sess):
512    """Tries to run _local_init_op, if not None, and is ready for local init.
513
514    Args:
515      sess: A `Session`.
516
517    Returns:
518      A tuple (is_successful, msg), where is_successful is True if
519      _local_init_op is None, or we ran _local_init_op, and False otherwise;
520      and msg is a `String` with the reason why the model was not ready to run
521      local init.
522    """
523    if self._local_init_op is not None:
524      is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
525      if is_ready_for_local_init:
526        logging.info("Running local_init_op.")
527        sess.run(self._local_init_op, feed_dict=self._local_init_feed_dict,
528                 options=self._local_init_run_options)
529        logging.info("Done running local_init_op.")
530        return True, None
531      else:
532        return False, msg
533    return True, None
534
535
536def _ready(op, sess, msg):
537  """Checks if the model is ready or not, as determined by op.
538
539  Args:
540    op: An op, either _ready_op or _ready_for_local_init_op, which defines the
541      readiness of the model.
542    sess: A `Session`.
543    msg: A message to log to warning if not ready
544
545  Returns:
546    A tuple (is_ready, msg), where is_ready is True if ready and False
547    otherwise, and msg is `None` if the model is ready, a `String` with the
548    reason why it is not ready otherwise.
549  """
550  if op is None:
551    return True, None
552  else:
553    try:
554      ready_value = sess.run(op)
555      # The model is considered ready if ready_op returns an empty 1-D tensor.
556      # Also compare to `None` and dtype being int32 for backward
557      # compatibility.
558      if (ready_value is None or ready_value.dtype == np.int32 or
559          ready_value.size == 0):
560        return True, None
561      else:
562        # TODO(sherrym): If a custom ready_op returns other types of tensor,
563        # or strings other than variable names, this message could be
564        # confusing.
565        non_initialized_varnames = ", ".join(
566            [i.decode("utf-8") for i in ready_value])
567        return False, "Variables not initialized: " + non_initialized_varnames
568    except errors.FailedPreconditionError as e:
569      if "uninitialized" not in str(e):
570        logging.warning("%s : error [%s]", msg, str(e))
571        raise e
572      return False, str(e)
573
574
575class _CountDownTimer:
576  """A timer that tracks a duration since creation."""
577
578  __slots__ = ["_start_time_secs", "_duration_secs"]
579
580  def __init__(self, duration_secs):
581    self._start_time_secs = time.time()
582    self._duration_secs = duration_secs
583
584  def secs_remaining(self):
585    diff = self._duration_secs - (time.time() - self._start_time_secs)
586    return max(0, diff)
587