xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/checkpoint_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools to work with name-based checkpoints.
16
17While some of these symbols also work with the TF2 object-based checkpoints,
18they are not recommended for TF2. Please check  `tensorflow/python/checkpoint`
19for newer utilities built to work with TF2 checkpoints.
20"""
21
22from collections import abc
23import os
24import time
25
26from tensorflow.python.checkpoint import checkpoint_management
27from tensorflow.python.distribute import distribution_strategy_context
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import io_ops
30from tensorflow.python.ops import resource_variable_ops
31from tensorflow.python.ops import variable_scope as vs
32from tensorflow.python.ops import variables
33from tensorflow.python.platform import gfile
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.training import py_checkpoint_reader
36from tensorflow.python.training.saving import saveable_object_util
37from tensorflow.python.util.tf_export import tf_export
38
39
40__all__ = [
41    "load_checkpoint", "load_variable", "list_variables",
42    "checkpoints_iterator", "init_from_checkpoint"
43]
44
45
46@tf_export("train.load_checkpoint")
47def load_checkpoint(ckpt_dir_or_file):
48  """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`.
49
50  If `ckpt_dir_or_file` resolves to a directory with multiple checkpoints,
51  reader for the latest checkpoint is returned.
52
53  Args:
54    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint
55      file.
56
57  Returns:
58    `CheckpointReader` object.
59
60  Raises:
61    ValueError: If `ckpt_dir_or_file` resolves to a directory with no
62      checkpoints.
63  """
64  filename = _get_checkpoint_filename(ckpt_dir_or_file)
65  if filename is None:
66    raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
67                     "given directory %s" % ckpt_dir_or_file)
68  return py_checkpoint_reader.NewCheckpointReader(filename)
69
70
71@tf_export("train.load_variable")
72def load_variable(ckpt_dir_or_file, name):
73  """Returns the tensor value of the given variable in the checkpoint.
74
75  Args:
76    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
77    name: Name of the variable to return.
78
79  Returns:
80    A numpy `ndarray` with a copy of the value of this variable.
81  """
82  # TODO(b/29227106): Fix this in the right place and remove this.
83  if name.endswith(":0"):
84    name = name[:-2]
85  reader = load_checkpoint(ckpt_dir_or_file)
86  return reader.get_tensor(name)
87
88
89@tf_export("train.list_variables")
90def list_variables(ckpt_dir_or_file):
91  """Lists the checkpoint keys and shapes of variables in a checkpoint.
92
93  Checkpoint keys are paths in a checkpoint graph.
94
95  Example usage:
96
97    ```python
98  import tensorflow as tf
99  import os
100  ckpt_directory = "/tmp/training_checkpoints/ckpt"
101  ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model)
102  manager = tf.train.CheckpointManager(ckpt, ckpt_directory, max_to_keep=3)
103  train_and_checkpoint(model, manager)
104  tf.train.list_variables(manager.latest_checkpoint)
105  ```
106
107  Args:
108    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
109
110  Returns:
111    List of tuples `(key, shape)`.
112  """
113  reader = load_checkpoint(ckpt_dir_or_file)
114  variable_map = reader.get_variable_to_shape_map()
115  names = sorted(variable_map.keys())
116  result = []
117  for name in names:
118    result.append((name, variable_map[name]))
119  return result
120
121
122def wait_for_new_checkpoint(checkpoint_dir,
123                            last_checkpoint=None,
124                            seconds_to_sleep=1,
125                            timeout=None):
126  """Waits until a new checkpoint file is found.
127
128  Args:
129    checkpoint_dir: The directory in which checkpoints are saved.
130    last_checkpoint: The last checkpoint path used or `None` if we're expecting
131      a checkpoint for the first time.
132    seconds_to_sleep: The number of seconds to sleep for before looking for a
133      new checkpoint.
134    timeout: The maximum number of seconds to wait. If left as `None`, then the
135      process will wait indefinitely.
136
137  Returns:
138    a new checkpoint path, or None if the timeout was reached.
139  """
140  logging.info("Waiting for new checkpoint at %s", checkpoint_dir)
141  stop_time = time.time() + timeout if timeout is not None else None
142  while True:
143    checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir)
144    if checkpoint_path is None or checkpoint_path == last_checkpoint:
145      if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
146        return None
147      time.sleep(seconds_to_sleep)
148    else:
149      logging.info("Found new checkpoint at %s", checkpoint_path)
150      return checkpoint_path
151
152
153@tf_export("train.checkpoints_iterator")
154def checkpoints_iterator(checkpoint_dir,
155                         min_interval_secs=0,
156                         timeout=None,
157                         timeout_fn=None):
158  """Continuously yield new checkpoint files as they appear.
159
160  The iterator only checks for new checkpoints when control flow has been
161  reverted to it. This means it can miss checkpoints if your code takes longer
162  to run between iterations than `min_interval_secs` or the interval at which
163  new checkpoints are written.
164
165  The `timeout` argument is the maximum number of seconds to block waiting for
166  a new checkpoint.  It is used in combination with the `timeout_fn` as
167  follows:
168
169  * If the timeout expires and no `timeout_fn` was specified, the iterator
170    stops yielding.
171  * If a `timeout_fn` was specified, that function is called and if it returns
172    a true boolean value the iterator stops yielding.
173  * If the function returns a false boolean value then the iterator resumes the
174    wait for new checkpoints.  At this point the timeout logic applies again.
175
176  This behavior gives control to callers on what to do if checkpoints do not
177  come fast enough or stop being generated.  For example, if callers have a way
178  to detect that the training has stopped and know that no new checkpoints
179  will be generated, they can provide a `timeout_fn` that returns `True` when
180  the training has stopped.  If they know that the training is still going on
181  they return `False` instead.
182
183  Args:
184    checkpoint_dir: The directory in which checkpoints are saved.
185    min_interval_secs: The minimum number of seconds between yielding
186      checkpoints.
187    timeout: The maximum number of seconds to wait between checkpoints. If left
188      as `None`, then the process will wait indefinitely.
189    timeout_fn: Optional function to call after a timeout.  If the function
190      returns True, then it means that no new checkpoints will be generated and
191      the iterator will exit.  The function is called with no arguments.
192
193  Yields:
194    String paths to latest checkpoint files as they arrive.
195  """
196  checkpoint_path = None
197  while True:
198    new_checkpoint_path = wait_for_new_checkpoint(
199        checkpoint_dir, checkpoint_path, timeout=timeout)
200    if new_checkpoint_path is None:
201      if not timeout_fn:
202        # timed out
203        logging.info("Timed-out waiting for a checkpoint.")
204        return
205      if timeout_fn():
206        # The timeout_fn indicated that we are truly done.
207        return
208      else:
209        # The timeout_fn indicated that more checkpoints may come.
210        continue
211    start = time.time()
212    checkpoint_path = new_checkpoint_path
213    yield checkpoint_path
214    time_to_next_eval = start + min_interval_secs - time.time()
215    if time_to_next_eval > 0:
216      time.sleep(time_to_next_eval)
217
218
219@tf_export(v1=["train.init_from_checkpoint"])
220def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
221  """Replaces `tf.Variable` initializers so they load from a checkpoint file.
222
223  @compatibility(TF2)
224  `tf.compat.v1.train.init_from_checkpoint` is not recommended for restoring
225  variable values in TF2.
226
227  To restore checkpoints in TF2, please use
228  `tf.keras.Model.load_weights` or `tf.train.Checkpoint.restore`. These APIs use
229  use an [object-based method of checkpointing]
230  (https://www.tensorflow.org/guide/checkpoint#loading_mechanics), while
231  `tf.compat.v1.init_from_checkpoint` relies on a more-fragile variable-name
232  based method of checkpointing. There is no object-based equivalent of
233  `init_from_checkpoint` in TF2.
234
235  Please re-write your checkpoints immediately using the object-based APIs,
236  see [migration guide]
237  (https://www.tensorflow.org/guide/migrate#checkpoint_compatibility) for more
238  details.
239
240  You can load a name-based checkpoint written by `tf.compat.v1.train.Saver`
241  using `tf.train.Checkpoint.restore` or `tf.keras.Model.load_weights`. However,
242  you may have to change the names of the variables in your model to match the
243  variable names in the name-based checkpoint, which can be viewed with
244  `tf.train.list_variables(path)`.
245
246  Another option is to create an `assignment_map` that maps the name of the
247  variables in the name-based checkpoint to the variables in your model, eg:
248  ```
249  {
250      'sequential/dense/bias': model.variables[0],
251      'sequential/dense/kernel': model.variables[1]
252  }
253  ```
254  and use `tf.compat.v1.train.init_from_checkpoint(path, assignment_map)` to
255  restore the name-based checkpoint.
256
257  After restoring, re-encode your checkpoint using `tf.train.Checkpoint.save`
258  or `tf.keras.Model.save_weights`.
259
260  @end_compatibility
261
262  Values are not loaded immediately, but when the initializer is run
263  (typically by running a `tf.compat.v1.global_variables_initializer` op).
264
265  Note: This overrides default initialization ops of specified variables and
266  redefines dtype.
267
268  Assignment map supports following syntax:
269
270  * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
271    current `scope_name` from `checkpoint_scope_name` with matching tensor
272    names.
273  * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
274    will initialize `scope_name/variable_name` variable
275    from `checkpoint_scope_name/some_other_variable`.
276  * `'scope_variable_name': variable` - will initialize given `tf.Variable`
277    object with tensor 'scope_variable_name' from the checkpoint.
278  * `'scope_variable_name': list(variable)` - will initialize list of
279    partitioned variables with tensor 'scope_variable_name' from the checkpoint.
280  * `'/': 'scope_name/'` - will load all variables in current `scope_name` from
281    checkpoint's root (e.g. no scope).
282
283  Supports loading into partitioned variables, which are represented as
284  `'<variable>/part_<part #>'`.
285
286  Assignment map can be a dict, or a list of pairs.  The latter is
287  necessary to initialize multiple variables in the current graph from
288  the same variable in the checkpoint.
289
290  Example:
291
292  ```python
293
294  # Say, '/tmp/model.ckpt' has the following tensors:
295  #  -- name='old_scope_1/var1', shape=[20, 2]
296  #  -- name='old_scope_1/var2', shape=[50, 4]
297  #  -- name='old_scope_2/var3', shape=[100, 100]
298
299  # Create new model's variables
300  with tf.compat.v1.variable_scope('new_scope_1'):
301    var1 = tf.compat.v1.get_variable('var1', shape=[20, 2],
302                           initializer=tf.compat.v1.zeros_initializer())
303  with tf.compat.v1.variable_scope('new_scope_2'):
304    var2 = tf.compat.v1.get_variable('var2', shape=[50, 4],
305                           initializer=tf.compat.v1.zeros_initializer())
306    # Partition into 5 variables along the first axis.
307    var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100],
308                           initializer=tf.compat.v1.zeros_initializer(),
309                           partitioner=lambda shape, dtype: [5, 1])
310
311  # Initialize all variables in `new_scope_1` from `old_scope_1`.
312  init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1/'})
313
314  # Use names to specify which variables to initialize from checkpoint.
315  init_from_checkpoint('/tmp/model.ckpt',
316                       {'old_scope_1/var1': 'new_scope_1/var1',
317                        'old_scope_1/var2': 'new_scope_2/var2'})
318
319  # Or use tf.Variable objects to identify what to initialize.
320  init_from_checkpoint('/tmp/model.ckpt',
321                       {'old_scope_1/var1': var1,
322                        'old_scope_1/var2': var2})
323
324  # Initialize partitioned variables using variable's name
325  init_from_checkpoint('/tmp/model.ckpt',
326                       {'old_scope_2/var3': 'new_scope_2/var3'})
327
328  # Or specify the list of tf.Variable objects.
329  init_from_checkpoint('/tmp/model.ckpt',
330                       {'old_scope_2/var3': var3._get_variable_list()})
331
332  ```
333
334  Args:
335    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
336    assignment_map: Dict, or a list of key-value pairs, where keys are names
337      of the variables in the checkpoint and values are current variables or
338      names of current variables (in default graph).
339
340  Raises:
341    ValueError: If missing variables in current graph, or if missing
342      checkpoints or tensors in checkpoints.
343
344  """
345  init_from_checkpoint_fn = lambda _: _init_from_checkpoint(
346      ckpt_dir_or_file, assignment_map)
347  if distribution_strategy_context.get_cross_replica_context():
348    init_from_checkpoint_fn(None)
349  else:
350    distribution_strategy_context.get_replica_context().merge_call(
351        init_from_checkpoint_fn)
352
353
354def _init_from_checkpoint(ckpt_dir_or_file, assignment_map):
355  """See `init_from_checkpoint` for documentation."""
356  ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file)
357  reader = load_checkpoint(ckpt_dir_or_file)
358  variable_map = reader.get_variable_to_shape_map()
359  if isinstance(assignment_map, abc.Mapping):
360    assignment_map = assignment_map.items()
361
362  # We only want to sort by tensor names.
363  sort_key = lambda pair: pair[0]
364
365  for tensor_name_in_ckpt, current_var_or_name in sorted(
366      assignment_map, key=sort_key):
367    var = None
368    # Check if this is Variable object or list of Variable objects (in case of
369    # partitioned variables).
370    if _is_variable(current_var_or_name) or (
371        isinstance(current_var_or_name, list)
372        and all(_is_variable(v) for v in current_var_or_name)):
373      var = current_var_or_name
374    else:
375      store_vars = vs._get_default_variable_store()._vars  # pylint:disable=protected-access
376      # Check if this variable is in var_store.
377      var = store_vars.get(current_var_or_name, None)
378      # Also check if variable is partitioned as list.
379      if var is None:
380        var = _collect_partitioned_variable(current_var_or_name, store_vars)
381    if var is not None:
382      # If 1 to 1 mapping was provided, find variable in the checkpoint.
383      if tensor_name_in_ckpt not in variable_map:
384        raise ValueError("Tensor %s is not found in %s checkpoint %s" % (
385            tensor_name_in_ckpt, ckpt_dir_or_file, variable_map
386        ))
387      if _is_variable(var):
388        # Additional at-call-time checks.
389        if not var.get_shape().is_compatible_with(
390            variable_map[tensor_name_in_ckpt]):
391          raise ValueError(
392              "Shape of variable %s (%s) doesn't match with shape of "
393              "tensor %s (%s) from checkpoint reader." % (
394                  var.name, str(var.get_shape()),
395                  tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
396              ))
397        var_name = var.name
398      else:
399        var_name = ",".join(v.name for v in var)
400      _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
401      logging.debug("Initialize variable %s from checkpoint %s with %s",
402                    var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
403    else:
404      scopes = ""
405      # TODO(vihanjain): Support list of 'current_var_or_name' here.
406      if "/" in current_var_or_name:
407        scopes = current_var_or_name[:current_var_or_name.rindex("/")]
408      if not tensor_name_in_ckpt.endswith("/"):
409        raise ValueError(
410            "Assignment map with scope only name {} should map to scope only "
411            "{}. Should be 'scope/': 'other_scope/'.".format(
412                scopes, tensor_name_in_ckpt))
413      # If scope to scope mapping was provided, find all variables in the scope
414      # and create variable to variable mapping.
415      scope_variables = set()
416      for var_name in store_vars:
417        if not scopes or var_name.startswith(scopes + "/"):
418          # Consume /part_ if partitioned variable.
419          if "/part_" in var_name:
420            var_name = var_name[:var_name.index("/part_")]
421          scope_variables.add(var_name)
422      for var_name in sorted(scope_variables):
423        # Lookup name with specified prefix and suffix from current variable.
424        # If tensor_name given is '/' (root), don't use it for full name.
425        full_tensor_name = var_name[len(scopes):]
426        if current_var_or_name != "/":
427          full_tensor_name = full_tensor_name[1:]
428        if tensor_name_in_ckpt != "/":
429          full_tensor_name = tensor_name_in_ckpt + full_tensor_name
430        # Remove trailing '/', if any, in the full_tensor_name
431        if full_tensor_name.endswith("/"):
432          full_tensor_name = full_tensor_name[:-1]
433        if full_tensor_name not in variable_map:
434          raise ValueError(
435              "Tensor %s (%s in %s) is not found in %s checkpoint" % (
436                  full_tensor_name, var_name[len(scopes) + 1:],
437                  tensor_name_in_ckpt, ckpt_dir_or_file
438              ))
439        var = store_vars.get(var_name, None)
440        if var is None:
441          var = _collect_partitioned_variable(var_name, store_vars)
442        _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
443        logging.debug("Initialize variable %s from checkpoint %s with %s",
444                      var_name, ckpt_dir_or_file, full_tensor_name)
445
446
447def _get_checkpoint_filename(ckpt_dir_or_file):
448  """Returns checkpoint filename given directory or specific checkpoint file."""
449  if isinstance(ckpt_dir_or_file, os.PathLike):
450    ckpt_dir_or_file = os.fspath(ckpt_dir_or_file)
451  if gfile.IsDirectory(ckpt_dir_or_file):
452    return checkpoint_management.latest_checkpoint(ckpt_dir_or_file)
453  return ckpt_dir_or_file
454
455
456def _set_checkpoint_initializer(variable,
457                                ckpt_file,
458                                tensor_name,
459                                slice_spec,
460                                name="checkpoint_initializer"):
461  """Overrides given variable's initialization op.
462
463  Sets variable initializer to assign op that initializes variable from tensor's
464  value in the checkpoint.
465
466  Args:
467    variable: `tf.Variable` object.
468    ckpt_file: string, full path of the checkpoint.
469    tensor_name: Name of the tensor to load from the checkpoint.
470    slice_spec: Slice specification for loading partitioned tensors.
471    name: Name of the operation.
472  """
473  base_type = variable.dtype.base_dtype
474  # Do not colocate with variable since RestoreV2 op only runs on CPU and
475  # colocation will force variable (and other ops that colocate with variable)
476  # to be on CPU as well. It is okay to place the variable's initializer op on
477  # CPU since it will only be run once at the start.
478  with ops.device(variable.device), ops.device("/cpu:0"):
479    restore_op = io_ops.restore_v2(
480        ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
481
482    names_to_saveables = saveable_object_util.op_list_to_dict([variable])
483    saveable_objects = []
484    for name, op in names_to_saveables.items():
485      for s in saveable_object_util.saveable_objects_for_op(op, name):
486        saveable_objects.append(s)
487
488    assert len(saveable_objects) == 1  # Should be only one variable.
489  init_op = saveable_objects[0].restore([restore_op], restored_shapes=None)
490
491  # pylint:disable=protected-access
492  variable._initializer_op = init_op
493  restore_op.set_shape(variable.shape)
494  variable._initial_value = restore_op
495  # pylint:enable=protected-access
496
497
498def _set_variable_or_list_initializer(variable_or_list, ckpt_file,
499                                      tensor_name):
500  """Overrides initialization op of given variable or list of variables.
501
502  Calls `_set_checkpoint_initializer` for each variable in the given list of
503  variables.
504
505  Args:
506    variable_or_list: `tf.Variable` object or a list of `tf.Variable` objects.
507    ckpt_file: string, full path of the checkpoint.
508    tensor_name: Name of the tensor to load from the checkpoint.
509
510  Raises:
511    ValueError: if all objects in `variable_or_list` are not partitions of the
512      same large variable.
513  """
514  if isinstance(variable_or_list, (list, tuple)):
515    # A set of slices.
516    slice_name = None
517    for v in variable_or_list:
518      slice_info = v._save_slice_info  # pylint:disable=protected-access
519      if slice_name is None:
520        slice_name = slice_info.full_name
521      elif slice_name != slice_info.full_name:
522        raise ValueError("Slices must all be from the same tensor: %s != %s" %
523                         (slice_name, slice_info.full_name))
524      _set_checkpoint_initializer(v, ckpt_file, tensor_name, slice_info.spec)
525  else:
526    _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "")
527
528
529def _is_variable(x):
530  return (isinstance(x, variables.Variable) or
531          resource_variable_ops.is_resource_variable(x))
532
533
534def _collect_partitioned_variable(name, all_vars):
535  """Returns list of `tf.Variable` that comprise the partitioned variable."""
536  if name + "/part_0" in all_vars:
537    var = []
538    i = 0
539    while name + "/part_%d" % i in all_vars:
540      var.append(all_vars[name + "/part_%d" % i])
541      i += 1
542    return var
543  return None
544