xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/saver.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Save and restore variables.
18
19Symbols in this file are deprecated. See replacements in
20tensorflow/python/training/trackable and tensorflow/python/training/saving.
21"""
22import collections
23import glob
24import os.path
25import threading
26import time
27
28import numpy as np
29from tensorflow.core.protobuf import meta_graph_pb2
30from tensorflow.core.protobuf import saver_pb2
31from tensorflow.core.protobuf import trackable_object_graph_pb2
32from tensorflow.python.checkpoint import checkpoint_management
33from tensorflow.python.client import session
34from tensorflow.python.eager import context
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import device as pydev
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import meta_graph
39from tensorflow.python.framework import ops
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import control_flow_ops
42from tensorflow.python.ops import gen_io_ops
43from tensorflow.python.ops import io_ops
44from tensorflow.python.ops import string_ops
45from tensorflow.python.ops import variables
46from tensorflow.python.platform import gfile
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.saved_model.pywrap_saved_model import metrics
49from tensorflow.python.trackable import base as trackable
50from tensorflow.python.training import py_checkpoint_reader
51from tensorflow.python.training import training_util
52from tensorflow.python.training.saving import saveable_object
53from tensorflow.python.training.saving import saveable_object_util
54from tensorflow.python.util import compat
55from tensorflow.python.util.tf_export import tf_export
56
57# TODO(allenl): Remove these aliases once all users are migrated off.
58get_checkpoint_state = checkpoint_management.get_checkpoint_state
59update_checkpoint_state = checkpoint_management.update_checkpoint_state
60generate_checkpoint_state_proto = (
61    checkpoint_management.generate_checkpoint_state_proto)
62latest_checkpoint = checkpoint_management.latest_checkpoint
63checkpoint_exists = checkpoint_management.checkpoint_exists
64get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes
65remove_checkpoint = checkpoint_management.remove_checkpoint
66
67# Captures the timestamp of the first Saver object instantiation or end of a
68# save operation. Can be accessed by multiple Saver instances.
69_END_TIME_OF_LAST_WRITE = None
70_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock()
71
72# API label for cell name used in checkpoint metrics.
73_SAVER_LABEL = "saver_v1"
74
75
76def _get_duration_microseconds(start_time_seconds, end_time_seconds):
77  if end_time_seconds < start_time_seconds:
78    # Avoid returning negative value in case of clock skew.
79    return 0
80  return round((end_time_seconds - start_time_seconds) * 1000000)
81
82
83def _get_checkpoint_size(prefix):
84  """Calculates filesize of checkpoint based on prefix."""
85  size = 0
86  # Gather all files beginning with prefix (.index plus sharded data files).
87  files = glob.glob("{}*".format(prefix))
88  for file in files:
89    # Use TensorFlow's C++ FileSystem API.
90    size += metrics.CalculateFileSize(file)
91  return size
92
93
94class BaseSaverBuilder:
95  """Base class for Savers.
96
97  Can be extended to create different Ops.
98  """
99
100  SaveSpec = saveable_object.SaveSpec
101  SaveableObject = saveable_object.SaveableObject
102
103  # Aliases for code which was moved but still has lots of users.
104  VariableSaveable = saveable_object_util.ReferenceVariableSaveable
105  ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable
106
107  def __init__(self, write_version=saver_pb2.SaverDef.V2):
108    self._write_version = write_version
109
110  def save_op(self, filename_tensor, saveables):
111    """Create an Op to save 'saveables'.
112
113    This is intended to be overridden by subclasses that want to generate
114    different Ops.
115
116    Args:
117      filename_tensor: String Tensor.
118      saveables: A list of BaseSaverBuilder.SaveableObject objects.
119
120    Returns:
121      An Operation that save the variables.
122
123    Raises:
124      RuntimeError: (implementation detail) if "self._write_version" is an
125        unexpected value.
126    """
127    # pylint: disable=protected-access
128    tensor_names = []
129    tensors = []
130    tensor_slices = []
131    for saveable in saveables:
132      for spec in saveable.specs:
133        tensor_names.append(spec.name)
134        tensors.append(spec.tensor)
135        tensor_slices.append(spec.slice_spec)
136    if self._write_version == saver_pb2.SaverDef.V1:
137      return io_ops._save(
138          filename=filename_tensor,
139          tensor_names=tensor_names,
140          tensors=tensors,
141          tensor_slices=tensor_slices)
142    elif self._write_version == saver_pb2.SaverDef.V2:
143      # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
144      # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
145      return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
146                            tensors)
147    else:
148      raise RuntimeError("Unexpected write_version: " + self._write_version)
149
150  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
151                   restore_sequentially):
152    """Restore all tensors contained in saveables.
153
154    By default, this issues separate calls to `restore_op` for each saveable.
155    Subclasses may override to load multiple saveables in a single call.
156
157    Args:
158      filename_tensor: String Tensor.
159      saveables: List of BaseSaverBuilder.SaveableObject objects.
160      preferred_shard: Int.  Shard to open first when loading a sharded file.
161      restore_sequentially: Unused.  Bool.  If true, each restore is sequential.
162
163    Returns:
164      A list of Tensors resulting from reading 'saveable' from
165        'filename'.
166
167    """
168    del restore_sequentially
169    all_tensors = []
170    for saveable in saveables:
171      if saveable.device:
172        device = saveable_object_util.set_cpu0(saveable.device)
173      else:
174        device = None
175      with ops.device(device):
176        all_tensors.extend(
177            self.restore_op(filename_tensor, saveable, preferred_shard))
178    return all_tensors
179
180  # pylint: disable=unused-argument
181  def restore_op(self, filename_tensor, saveable, preferred_shard):
182    """Create ops to restore 'saveable'.
183
184    This is intended to be overridden by subclasses that want to generate
185    different Ops.
186
187    Args:
188      filename_tensor: String Tensor.
189      saveable: A BaseSaverBuilder.SaveableObject object.
190      preferred_shard: Int.  Shard to open first when loading a sharded file.
191
192    Returns:
193      A list of Tensors resulting from reading 'saveable' from
194        'filename'.
195    """
196    # pylint: disable=protected-access
197    tensors = []
198    for spec in saveable.specs:
199      tensors.append(
200          io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec],
201                            [spec.dtype])[0])
202
203    return tensors
204
205  # pylint: enable=unused-argument
206
207  def sharded_filename(self, filename_tensor, shard, num_shards):
208    """Append sharding information to a filename.
209
210    Args:
211      filename_tensor: A string tensor.
212      shard: Integer.  The shard for the filename.
213      num_shards: An int Tensor for the number of shards.
214
215    Returns:
216      A string tensor.
217    """
218    return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
219
220  def _AddSaveOps(self, filename_tensor, saveables):
221    """Add ops to save variables that are on the same shard.
222
223    Args:
224      filename_tensor: String Tensor.
225      saveables: A list of SaveableObject objects.
226
227    Returns:
228      A tensor with the filename used to save.
229    """
230    save = self.save_op(filename_tensor, saveables)
231    return control_flow_ops.with_dependencies([save], filename_tensor)
232
233  def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device):
234    """Add ops to save the params per shard, for the V2 format.
235
236    Note that the sharded save procedure for the V2 format is different from
237    V1: there is a special "merge" step that merges the small metadata produced
238    from each device.
239
240    Args:
241      checkpoint_prefix: scalar String Tensor.  Interpreted *NOT AS A FILENAME*,
242        but as a prefix of a V2 checkpoint;
243      per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
244        returned by _GroupByDevices().
245
246    Returns:
247      An op to save the variables, which, when evaluated, returns the prefix
248        "<user-fed prefix>" only and does not include the sharded spec suffix.
249    """
250    # IMPLEMENTATION DETAILS: most clients should skip.
251    #
252    # Suffix for any well-formed "checkpoint_prefix", when sharded.
253    # Transformations:
254    # * Users pass in "save_path" in save() and restore().  Say "myckpt".
255    # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>.
256    # * If checkpoint_prefix is a S3 bucket path ".part" is appended to it
257    # * Otherwise _temp/part is appended which is normalized relative to the OS
258    # Example:
259    #   During runtime, a temporary directory is first created, which contains
260    #   files
261    #
262    #     <train dir>/myckpt_temp/
263    #        part-?????-of-?????{.index, .data-00000-of-00001}
264    #
265    #   Before .save() finishes, they will be (hopefully, atomically) renamed to
266    #
267    #     <train dir>/
268    #        myckpt{.index, .data-?????-of-?????}
269    #
270    #   Filesystems with eventual consistency (such as S3), don't need a
271    #   temporary location. Using a temporary directory in those cases might
272    #   cause situations where files are not available during copy.
273    #
274    # Users only need to interact with the user-specified prefix, which is
275    # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
276    # prefix directly, instead of any physical pathname.  (On failure and
277    # subsequent restore, an outdated and orphaned temporary directory can be
278    # safely removed.)
279    with ops.device("CPU"):
280      _SHARDED_SUFFIX = array_ops.where(
281          string_ops.regex_full_match(checkpoint_prefix, "^s3://.*"),
282          constant_op.constant(".part"),
283          constant_op.constant(os.path.normpath("_temp/part")))
284      tmp_checkpoint_prefix = string_ops.string_join(
285          [checkpoint_prefix, _SHARDED_SUFFIX])
286
287    num_shards = len(per_device)
288    sharded_saves = []
289    sharded_prefixes = []
290    num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
291    last_device = None
292    for shard, (device, saveables) in enumerate(per_device):
293      last_device = device
294      with ops.device(saveable_object_util.set_cpu0(device)):
295        sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
296                                                 num_shards_tensor)
297        sharded_prefixes.append(sharded_filename)
298        sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
299
300    with ops.control_dependencies([x.op for x in sharded_saves]):
301      # Co-locates the merge step with the last device.
302      with ops.device(saveable_object_util.set_cpu0(last_device)):
303        # V2 format write path consists of a metadata merge step.  Once merged,
304        # attempts to delete the temporary directory, "<user-fed prefix>_temp".
305        merge_step = gen_io_ops.merge_v2_checkpoints(
306            sharded_prefixes, checkpoint_prefix, delete_old_dirs=True)
307        with ops.control_dependencies([merge_step]):
308          # Returns the prefix "<user-fed prefix>" only.  DOES NOT include the
309          # sharded spec suffix.
310          return array_ops.identity(checkpoint_prefix)
311
312  def _AddShardedSaveOps(self, filename_tensor, per_device):
313    """Add ops to save the params per shard.
314
315    Args:
316      filename_tensor: a scalar String Tensor.
317      per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as
318        returned by _GroupByDevices().
319
320    Returns:
321      An op to save the variables.
322    """
323    if self._write_version == saver_pb2.SaverDef.V2:
324      return self._AddShardedSaveOpsForV2(filename_tensor, per_device)
325
326    num_shards = len(per_device)
327    sharded_saves = []
328    num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
329    for shard, (device, saveables) in enumerate(per_device):
330      with ops.device(device):
331        sharded_filename = self.sharded_filename(filename_tensor, shard,
332                                                 num_shards_tensor)
333        sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
334    # Return the sharded name for the save path.
335    with ops.control_dependencies([x.op for x in sharded_saves]):
336      return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor)
337
338  def _AddRestoreOps(self,
339                     filename_tensor,
340                     saveables,
341                     restore_sequentially,
342                     reshape,
343                     preferred_shard=-1,
344                     name="restore_all"):
345    """Add operations to restore saveables.
346
347    Args:
348      filename_tensor: Tensor for the path of the file to load.
349      saveables: A list of SaveableObject objects.
350      restore_sequentially: True if we want to restore variables sequentially
351        within a shard.
352      reshape: True if we want to reshape loaded tensors to the shape of the
353        corresponding variable.
354      preferred_shard: Shard to open first when loading a sharded file.
355      name: Name for the returned op.
356
357    Returns:
358      An Operation that restores the variables.
359    """
360    all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard,
361                                    restore_sequentially)
362
363    assign_ops = []
364    idx = 0
365    # Load and optionally reshape on the CPU, as string tensors are not
366    # available on the GPU.
367    # TODO(touts): Re-enable restore on GPU when we can support annotating
368    # string tensors as "HostMemory" inputs.
369    for saveable in saveables:
370      shapes = None
371      if reshape:
372        # Compute the shapes, let the restore op decide if and how to do
373        # the reshape.
374        shapes = []
375        for spec in saveable.specs:
376          v = spec.tensor
377          shape = v.get_shape()
378          if not shape.is_fully_defined():
379            shape = array_ops.shape(v)
380          shapes.append(shape)
381      saveable_tensors = all_tensors[idx:idx + len(saveable.specs)]
382      idx += len(saveable.specs)
383      assign_ops.append(saveable.restore(saveable_tensors, shapes))
384
385    # Create a Noop that has control dependencies from all the updates.
386    return control_flow_ops.group(*assign_ops, name=name)
387
388  def _AddShardedRestoreOps(self, filename_tensor, per_device,
389                            restore_sequentially, reshape):
390    """Add Ops to restore variables from multiple devices.
391
392    Args:
393      filename_tensor: Tensor for the path of the file to load.
394      per_device: A list of (device, SaveableObject) pairs, as returned by
395        _GroupByDevices().
396      restore_sequentially: True if we want to restore variables sequentially
397        within a shard.
398      reshape: True if we want to reshape loaded tensors to the shape of the
399        corresponding variable.
400
401    Returns:
402      An Operation that restores the variables.
403    """
404    sharded_restores = []
405    for shard, (device, saveables) in enumerate(per_device):
406      with ops.device(device):
407        sharded_restores.append(
408            self._AddRestoreOps(
409                filename_tensor,
410                saveables,
411                restore_sequentially,
412                reshape,
413                preferred_shard=shard,
414                name="restore_shard"))
415    return control_flow_ops.group(*sharded_restores, name="restore_all")
416
417  def _GroupByDevices(self, saveables):
418    """Group Variable tensor slices per device.
419
420    TODO(touts): Make sure that all the devices found are on different
421    job/replica/task/cpu|gpu.  It would be bad if 2 were on the same device.
422    It can happen if the devices are unspecified.
423
424    Args:
425      saveables: A list of BaseSaverBuilder.SaveableObject objects.
426
427    Returns:
428      A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples.
429      The list is sorted by ascending device_name.
430
431    Raises:
432      ValueError: If the tensors of a saveable are on different devices.
433    """
434    per_device = collections.defaultdict(lambda: [])
435    for saveable in saveables:
436      canonical_device = set(
437          pydev.canonical_name(spec.device) for spec in saveable.specs)
438      if len(canonical_device) != 1:
439        raise ValueError("All tensors of a saveable object must be "
440                         "on the same device: %s" % saveable.name)
441      per_device[canonical_device.pop()].append(saveable)
442    return sorted(per_device.items(), key=lambda t: t[0])
443
444  def build(self,
445            names_to_saveables,
446            reshape=False,
447            sharded=False,
448            max_to_keep=5,
449            keep_checkpoint_every_n_hours=10000.0,
450            name=None,
451            restore_sequentially=False,
452            filename="model"):
453    """Builds save/restore graph nodes or runs save/restore in eager mode.
454
455    Args:
456      names_to_saveables: A dictionary mapping name to a Variable or
457        SaveableObject. Each name will be associated with the corresponding
458        variable in the checkpoint.
459      reshape: If True, allow restoring parameters from a checkpoint that where
460        the parameters have a different shape.  This is only needed when you try
461        to restore from a Dist-Belief checkpoint, and only some times.
462      sharded: If True, shard the checkpoints, one per device that has Variable
463        nodes.
464      max_to_keep: Maximum number of checkpoints to keep.  As new checkpoints
465        are created, old ones are deleted.  If None or 0, no checkpoints are
466        deleted from the filesystem but only the last one is kept in the
467        `checkpoint` file.  Presently the number is only roughly enforced.  For
468        example in case of restarts more than max_to_keep checkpoints may be
469        kept.
470      keep_checkpoint_every_n_hours: How often checkpoints should be kept.
471        Defaults to 10,000 hours.
472      name: String.  Optional name to use as a prefix when adding operations.
473      restore_sequentially: A Bool, which if true, causes restore of different
474        variables to happen sequentially within each device.
475      filename: If known at graph construction time, filename used for variable
476        loading/saving. If None, then the default name "model" will be used.
477
478    Returns:
479      A SaverDef proto.
480
481    Raises:
482      TypeError: If 'names_to_saveables' is not a dictionary mapping string
483        keys to variable Tensors.
484      ValueError: If any of the keys or values in 'names_to_saveables' is not
485        unique.
486    """
487    return self._build_internal(
488        names_to_saveables=names_to_saveables,
489        reshape=reshape,
490        sharded=sharded,
491        max_to_keep=max_to_keep,
492        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
493        name=name,
494        restore_sequentially=restore_sequentially,
495        filename=filename)
496
497  def _build_internal(self,
498                      names_to_saveables,
499                      reshape=False,
500                      sharded=False,
501                      max_to_keep=5,
502                      keep_checkpoint_every_n_hours=10000.0,
503                      name=None,
504                      restore_sequentially=False,
505                      filename="model",
506                      build_save=True,
507                      build_restore=True):
508    """build() with option to only perform save and restore."""
509    if not context.executing_eagerly() and (not build_save or
510                                            not build_restore):
511      raise ValueError("save and restore operations need to be built together "
512                       " when eager execution is not enabled.")
513
514    saveables = saveable_object_util.validate_and_slice_inputs(
515        names_to_saveables)
516    if max_to_keep is None:
517      max_to_keep = 0
518
519    with ops.name_scope(name, "save",
520                        [saveable.op for saveable in saveables]) as name:
521      # Add a placeholder string tensor for the filename.
522      filename_tensor = array_ops.placeholder_with_default(
523          filename or "model", shape=(), name="filename")
524      # Keep the name "Const" for backwards compatibility.
525      filename_tensor = array_ops.placeholder_with_default(
526          filename_tensor, shape=(), name="Const")
527
528      # Add the save ops.
529      if sharded:
530        per_device = self._GroupByDevices(saveables)
531        if build_save:
532          save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
533        if build_restore:
534          restore_op = self._AddShardedRestoreOps(filename_tensor, per_device,
535                                                  restore_sequentially, reshape)
536      else:
537        if build_save:
538          save_tensor = self._AddSaveOps(filename_tensor, saveables)
539        if build_restore:
540          restore_op = self._AddRestoreOps(filename_tensor, saveables,
541                                           restore_sequentially, reshape)
542
543    # In the following use case, it's possible to have restore_ops be called
544    # something else:
545    # - Build inference graph and export a meta_graph.
546    # - Import the inference meta_graph
547    # - Extend the inference graph to a train graph.
548    # - Export a new meta_graph.
549    # Now the second restore_op will be called "restore_all_1".
550    # As such, comment out the assert for now until we know whether supporting
551    # such usage model makes sense.
552    #
553    # assert restore_op.name.endswith("restore_all"), restore_op.name
554    if context.executing_eagerly():
555      # Store the tensor values to the tensor_names.
556      save_tensor_name = save_tensor.numpy() if build_save else ""
557      return saver_pb2.SaverDef(
558          filename_tensor_name=filename_tensor.numpy(),
559          save_tensor_name=save_tensor_name,
560          restore_op_name="",
561          max_to_keep=max_to_keep,
562          sharded=sharded,
563          keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
564          version=self._write_version)
565    else:
566      graph = ops.get_default_graph()
567      # Do some sanity checking on collections containing
568      # PartitionedVariables. If a saved collection has a PartitionedVariable,
569      # the GraphDef needs to include concat ops to get the value (or there'll
570      # be a lookup error on load).
571      check_collection_list = graph.get_all_collection_keys()
572      for collection_type in check_collection_list:
573        for element in graph.get_collection(collection_type):
574          if isinstance(element, variables.PartitionedVariable):
575            try:
576              graph.get_operation_by_name(element.name)
577            except KeyError:
578              # Create a concat op for this PartitionedVariable. The user may
579              # not need it, but we'll try looking it up on MetaGraph restore
580              # since it's in a collection.
581              element.as_tensor()
582      return saver_pb2.SaverDef(
583          filename_tensor_name=filename_tensor.name,
584          save_tensor_name=save_tensor.name,
585          restore_op_name=restore_op.name,
586          max_to_keep=max_to_keep,
587          sharded=sharded,
588          keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
589          version=self._write_version)
590
591
592class BulkSaverBuilder(BaseSaverBuilder):
593  """SaverBuilder with support for bulk restoring multiple saveables."""
594
595  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
596                   restore_sequentially):
597
598    # Ignored: bulk restore is internally sequential.
599    del restore_sequentially
600    restore_specs = []
601    for saveable in saveables:
602      for spec in saveable.specs:
603        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
604
605    names, slices, dtypes = zip(*restore_specs)
606    # Load all tensors onto CPU 0 for compatibility with existing code.
607    with ops.device("cpu:0"):
608      return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
609
610
611def _get_saver_or_default():
612  """Returns the saver from SAVERS collection, or creates a default one.
613
614  This method is used by other members of the training module, such as
615  `Scaffold`, or `CheckpointSaverHook`.
616
617  Returns:
618    `Saver`.
619
620  Raises:
621    RuntimeError: If the SAVERS collection already has more than one items.
622  """
623  collection_key = ops.GraphKeys.SAVERS
624  savers = ops.get_collection(collection_key)
625  if savers:
626    if len(savers) > 1:
627      raise RuntimeError(
628          "More than one item in collection {}. "
629          "Please indicate which one to use by passing it to the constructor."
630          .format(collection_key))
631    return savers[0]
632  saver = Saver(sharded=True, allow_empty=True)
633  if saver is not None:
634    ops.add_to_collection(collection_key, saver)
635  return saver
636
637
638@tf_export(v1=["train.Saver"])
639class Saver:
640  # pylint: disable=line-too-long
641  """Saves and restores variables.
642
643  @compatibility(TF2)
644  `tf.compat.v1.train.Saver` is not supported for saving and restoring
645  checkpoints in TF2. Please switch to `tf.train.Checkpoint` or
646  `tf.keras.Model.save_weights`, which perform a more robust [object-based
647  saving](https://www.tensorflow.org/guide/checkpoint#loading_mechanics).
648
649  ### How to Rewrite Checkpoints
650
651  Please rewrite your checkpoints immediately using the object-based checkpoint
652  APIs.
653
654  You can load a name-based checkpoint written by `tf.compat.v1.train.Saver`
655  using `tf.train.Checkpoint.restore` or `tf.keras.Model.load_weights`. However,
656  you may have to change the names of the variables in your model to match the
657  variable names in the name-based checkpoint, which can be viewed with
658  `tf.train.list_variables(path)`.
659
660  Another option is to create an `assignment_map` that maps the name of the
661  variables in the name-based checkpoint to the variables in your model, eg:
662  ```
663  {
664      'sequential/dense/bias': model.variables[0],
665      'sequential/dense/kernel': model.variables[1]
666  }
667  ```
668  and use `tf.compat.v1.train.init_from_checkpoint(path, assignment_map)` to
669  restore the name-based checkpoint.
670
671  After restoring, re-encode your checkpoint
672  using `tf.train.Checkpoint.save` or `tf.keras.Model.save_weights`.
673
674  See the [Checkpoint compatibility](
675  https://www.tensorflow.org/guide/migrate#checkpoint_compatibility)
676  section of the migration guide for more details.
677
678
679  ### Checkpoint Management in TF2
680
681  Use `tf.train.CheckpointManager` to manage checkpoints in TF2.
682  `tf.train.CheckpointManager` offers equivalent `keep_checkpoint_every_n_hours`
683  and `max_to_keep` parameters.
684
685  To recover the latest checkpoint,
686
687  ```
688  checkpoint = tf.train.Checkpoint(model)
689  manager = tf.train.CheckpointManager(checkpoint)
690  status = checkpoint.restore(manager.latest_checkpoint)
691  ```
692
693  `tf.train.CheckpointManager` also writes a [`CheckpointState` proto]
694  (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/checkpoint_state.proto)
695  which contains the timestamp when each checkpoint was created.
696
697  ### Writing `MetaGraphDef`s in TF2
698
699  To replace, `tf.compat.v1.train.Saver.save(write_meta_graph=True)`, use
700  `tf.saved_model.save` to write the `MetaGraphDef` (which is contained in
701  `saved_model.pb`).
702
703  @end_compatibility
704
705  See [Variables](https://tensorflow.org/guide/variables)
706  for an overview of variables, saving and restoring.
707
708  The `Saver` class adds ops to save and restore variables to and from
709  *checkpoints*.  It also provides convenience methods to run these ops.
710
711  Checkpoints are binary files in a proprietary format which map variable names
712  to tensor values.  The best way to examine the contents of a checkpoint is to
713  load it using a `Saver`.
714
715  Savers can automatically number checkpoint filenames with a provided counter.
716  This lets you keep multiple checkpoints at different steps while training a
717  model.  For example you can number the checkpoint filenames with the training
718  step number.  To avoid filling up disks, savers manage checkpoint files
719  automatically. For example, they can keep only the N most recent files, or
720  one checkpoint for every N hours of training.
721
722  You number checkpoint filenames by passing a value to the optional
723  `global_step` argument to `save()`:
724
725  ```python
726  saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
727  ...
728  saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
729  ```
730
731  Additionally, optional arguments to the `Saver()` constructor let you control
732  the proliferation of checkpoint files on disk:
733
734  * `max_to_keep` indicates the maximum number of recent checkpoint files to
735    keep.  As new files are created, older files are deleted.   If None or 0,
736    no checkpoints are deleted from the filesystem but only the last one is
737    kept in the `checkpoint` file.  Defaults to 5 (that is, the 5 most recent
738    checkpoint files are kept.)
739
740  * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent
741    `max_to_keep` checkpoint files, you might want to keep one checkpoint file
742    for every N hours of training.  This can be useful if you want to later
743    analyze how a model progressed during a long training session.  For
744    example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep
745    one checkpoint file for every 2 hours of training.  The default value of
746    10,000 hours effectively disables the feature.
747
748  Note that you still have to call the `save()` method to save the model.
749  Passing these arguments to the constructor will not save variables
750  automatically for you.
751
752  A training program that saves regularly looks like:
753
754  ```python
755  ...
756  # Create a saver.
757  saver = tf.compat.v1.train.Saver(...variables...)
758  # Launch the graph and train, saving the model every 1,000 steps.
759  sess = tf.compat.v1.Session()
760  for step in range(1000000):
761      sess.run(..training_op..)
762      if step % 1000 == 0:
763          # Append the step number to the checkpoint name:
764          saver.save(sess, 'my-model', global_step=step)
765  ```
766
767  In addition to checkpoint files, savers keep a protocol buffer on disk with
768  the list of recent checkpoints. This is used to manage numbered checkpoint
769  files and by `latest_checkpoint()`, which makes it easy to discover the path
770  to the most recent checkpoint. That protocol buffer is stored in a file named
771  'checkpoint' next to the checkpoint files.
772
773  If you create several savers, you can specify a different filename for the
774  protocol buffer file in the call to `save()`.
775  """
776
777  # pylint: enable=line-too-long
778
779  def __init__(self,
780               var_list=None,
781               reshape=False,
782               sharded=False,
783               max_to_keep=5,
784               keep_checkpoint_every_n_hours=10000.0,
785               name=None,
786               restore_sequentially=False,
787               saver_def=None,
788               builder=None,
789               defer_build=False,
790               allow_empty=False,
791               write_version=saver_pb2.SaverDef.V2,
792               pad_step_number=False,
793               save_relative_paths=False,
794               filename=None):
795    """Creates a `Saver`.
796
797    The constructor adds ops to save and restore variables.
798
799    `var_list` specifies the variables that will be saved and restored. It can
800    be passed as a `dict` or a list:
801
802    * A `dict` of names to variables: The keys are the names that will be
803      used to save or restore the variables in the checkpoint files.
804    * A list of variables: The variables will be keyed with their op name in
805      the checkpoint files.
806
807    For example:
808
809    ```python
810    v1 = tf.Variable(..., name='v1')
811    v2 = tf.Variable(..., name='v2')
812
813    # Pass the variables as a dict:
814    saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2})
815
816    # Or pass them as a list.
817    saver = tf.compat.v1.train.Saver([v1, v2])
818    # Passing a list is equivalent to passing a dict with the variable op names
819    # as keys:
820    saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]})
821    ```
822
823    Note: the newer `AutoTrackable` API is not supported by `Saver`. In this
824    case, the `tf.train.Checkpoint` class should be used.
825
826    The optional `reshape` argument, if `True`, allows restoring a variable from
827    a save file where the variable had a different shape, but the same number
828    of elements and type.  This is useful if you have reshaped a variable and
829    want to reload it from an older checkpoint.
830
831    The optional `sharded` argument, if `True`, instructs the saver to shard
832    checkpoints per device.
833
834    Args:
835      var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
836        names to `SaveableObject`s. If `None`, defaults to the list of all
837        saveable objects.
838      reshape: If `True`, allows restoring parameters from a checkpoint where
839        the variables have a different shape.
840      sharded: If `True`, shard the checkpoints, one per device.
841      max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5.
842      keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to
843        10,000 hours.
844      name: String.  Optional name to use as a prefix when adding operations.
845      restore_sequentially: A `Bool`, which if true, causes restore of different
846        variables to happen sequentially within each device.  This can lower
847        memory usage when restoring very large models.
848      saver_def: Optional `SaverDef` proto to use instead of running the
849        builder. This is only useful for specialty code that wants to recreate a
850        `Saver` object for a previously built `Graph` that had a `Saver`. The
851        `saver_def` proto should be the one returned by the `as_saver_def()`
852        call of the `Saver` that was created for that `Graph`.
853      builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
854        Defaults to `BulkSaverBuilder()`.
855      defer_build: If `True`, defer adding the save and restore ops to the
856        `build()` call. In that case `build()` should be called before
857        finalizing the graph or using the saver.
858      allow_empty: If `False` (default) raise an error if there are no variables
859        in the graph. Otherwise, construct the saver anyway and make it a no-op.
860      write_version: controls what format to use when saving checkpoints.  It
861        also affects certain filepath matching logic.  The V2 format is the
862        recommended choice: it is much more optimized than V1 in terms of memory
863        required and latency incurred during restore.  Regardless of this flag,
864        the Saver is able to restore from both V2 and V1 checkpoints.
865      pad_step_number: if True, pads the global step number in the checkpoint
866        filepaths to some fixed width (8 by default).  This is turned off by
867        default.
868      save_relative_paths: If `True`, will write relative paths to the
869        checkpoint state file. This is needed if the user wants to copy the
870        checkpoint directory and reload from the copied directory.
871      filename: If known at graph construction time, filename used for variable
872        loading/saving.
873
874    Raises:
875      TypeError: If `var_list` is invalid.
876      ValueError: If any of the keys or values in `var_list` are not unique.
877      RuntimeError: If eager execution is enabled and`var_list` does not specify
878        a list of variables to save.
879
880    @compatibility(eager)
881    When eager execution is enabled, `var_list` must specify a `list` or `dict`
882    of variables to save. Otherwise, a `RuntimeError` will be raised.
883
884    Although Saver works in some cases when executing eagerly, it is
885    fragile. Please switch to `tf.train.Checkpoint` or
886    `tf.keras.Model.save_weights`, which perform a more robust object-based
887    saving. These APIs will load checkpoints written by `Saver`.
888    @end_compatibility
889    """
890    global _END_TIME_OF_LAST_WRITE
891    with _END_TIME_OF_LAST_WRITE_LOCK:
892      if _END_TIME_OF_LAST_WRITE is None:
893        _END_TIME_OF_LAST_WRITE = time.time()
894
895    if defer_build and var_list:
896      raise ValueError(
897          "If `var_list` is provided then build cannot be deferred. "
898          "Either set defer_build=False or var_list=None.")
899    if context.executing_eagerly():
900      logging.warning(
901          "Saver is deprecated, please switch to tf.train.Checkpoint or "
902          "tf.keras.Model.save_weights for training checkpoints. When "
903          "executing eagerly variables do not necessarily have unique names, "
904          "and so the variable.name-based lookups Saver performs are "
905          "error-prone.")
906      if var_list is None:
907        raise RuntimeError(
908            "When eager execution is enabled, `var_list` must specify a list "
909            "or dict of variables to save")
910    self._var_list = var_list
911    self._reshape = reshape
912    self._sharded = sharded
913    self._max_to_keep = max_to_keep
914    self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
915    self._name = name
916    self._restore_sequentially = restore_sequentially
917    self.saver_def = saver_def
918    self._builder = builder
919    self._is_built = False
920    self._allow_empty = allow_empty
921    self._is_empty = None
922    self._write_version = write_version
923    self._pad_step_number = pad_step_number
924    self._filename = filename
925    self._last_checkpoints = []
926    self._checkpoints_to_be_deleted = []
927    if context.executing_eagerly():
928      self._next_checkpoint_time = (
929          time.time() + self._keep_checkpoint_every_n_hours * 3600)
930    elif not defer_build:
931      self.build()
932    if self.saver_def:
933      self._check_saver_def()
934      self._write_version = self.saver_def.version
935    self._save_relative_paths = save_relative_paths
936    # For compatibility with object-based checkpoints, we may build a second
937    # Saver to read the renamed keys.
938    self._object_restore_saver = None
939
940  def build(self):
941    if context.executing_eagerly():
942      raise RuntimeError("Use save/restore instead of build in eager mode.")
943    self._build(self._filename, build_save=True, build_restore=True)
944
945  def _build_eager(self, checkpoint_path, build_save, build_restore):
946    self._build(
947        checkpoint_path, build_save=build_save, build_restore=build_restore)
948
949  def _build(self, checkpoint_path, build_save, build_restore):
950    """Builds saver_def."""
951    if not context.executing_eagerly():
952      if self._is_built:
953        return
954      self._is_built = True
955
956    if not self.saver_def or context.executing_eagerly():
957      if self._builder is None:
958        self._builder = BulkSaverBuilder(self._write_version)
959
960      if self._var_list is None:
961        # pylint: disable=protected-access
962        self._var_list = variables._all_saveable_objects()
963      if not self._var_list:
964        if self._allow_empty:
965          self._is_empty = True
966          return
967        else:
968          raise ValueError("No variables to save")
969      self._is_empty = False
970
971      self.saver_def = self._builder._build_internal(  # pylint: disable=protected-access
972          self._var_list,
973          reshape=self._reshape,
974          sharded=self._sharded,
975          max_to_keep=self._max_to_keep,
976          keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
977          name=self._name,
978          restore_sequentially=self._restore_sequentially,
979          filename=checkpoint_path,
980          build_save=build_save,
981          build_restore=build_restore)
982    elif self.saver_def and self._name:
983      # Since self._name is used as a name_scope by builder(), we are
984      # overloading the use of this field to represent the "import_scope" as
985      # well.
986      self.saver_def.filename_tensor_name = ops.prepend_name_scope(
987          self.saver_def.filename_tensor_name, self._name)
988      self.saver_def.save_tensor_name = ops.prepend_name_scope(
989          self.saver_def.save_tensor_name, self._name)
990      self.saver_def.restore_op_name = ops.prepend_name_scope(
991          self.saver_def.restore_op_name, self._name)
992
993    self._check_saver_def()
994    if not context.executing_eagerly():
995      # Updates next checkpoint time.
996      # Set in __init__ when executing eagerly.
997      self._next_checkpoint_time = (
998          time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600)
999
1000  def _check_saver_def(self):
1001    if not isinstance(self.saver_def, saver_pb2.SaverDef):
1002      raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" %
1003                       self.saver_def)
1004    if not context.executing_eagerly():
1005      if not self.saver_def.save_tensor_name:
1006        raise ValueError("saver_def must specify the save_tensor_name: %s" %
1007                         str(self.saver_def))
1008      if not self.saver_def.restore_op_name:
1009        raise ValueError("saver_def must specify the restore_op_name: %s" %
1010                         str(self.saver_def))
1011
1012  def _CheckpointFilename(self, p):
1013    """Returns the checkpoint filename given a `(filename, time)` pair.
1014
1015    Args:
1016      p: (filename, time) pair.
1017
1018    Returns:
1019      Checkpoint file name.
1020    """
1021    name, _ = p
1022    return name
1023
1024  def _RecordLastCheckpoint(self, latest_save_path):
1025    """Manages the list of the latest checkpoints."""
1026    if not self.saver_def.max_to_keep:
1027      return
1028    # Remove first from list if the same name was used before.
1029    for p in self._last_checkpoints:
1030      if latest_save_path == self._CheckpointFilename(p):
1031        self._last_checkpoints.remove(p)
1032    # Append new path to list
1033    self._last_checkpoints.append((latest_save_path, time.time()))
1034
1035    # If more than max_to_keep, remove oldest.
1036    if len(self._last_checkpoints) > self.saver_def.max_to_keep:
1037      self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0))
1038
1039  def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"):
1040    """Deletes old checkpoints if necessary.
1041
1042    `self._checkpoints_to_be_deleted` is going to contain checkpoints that are
1043    over `max_to_keep`.  They are going to be deleted.  If
1044    `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint
1045    every `N` hours. For example, if `N` is 0.5, an additional checkpoint is
1046    kept for every 0.5 hours of training; if `N` is 10, an additional
1047    checkpoint is kept for every 10 hours of training.
1048
1049    Args:
1050      meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
1051    """
1052    if self._checkpoints_to_be_deleted:
1053      p = self._checkpoints_to_be_deleted.pop(0)
1054      # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
1055      # have reached N hours of training.
1056      should_keep = p[1] > self._next_checkpoint_time
1057      if should_keep:
1058        self._next_checkpoint_time += (
1059            self.saver_def.keep_checkpoint_every_n_hours * 3600)
1060        return
1061
1062      # Otherwise delete the files.
1063      try:
1064        checkpoint_management.remove_checkpoint(
1065            self._CheckpointFilename(p), self.saver_def.version,
1066            meta_graph_suffix)
1067      except Exception as e:  # pylint: disable=broad-except
1068        logging.warning("Ignoring: %s", str(e))
1069
1070  def as_saver_def(self):
1071    """Generates a `SaverDef` representation of this saver.
1072
1073    Returns:
1074      A `SaverDef` proto.
1075    """
1076    return self.saver_def
1077
1078  def to_proto(self, export_scope=None):
1079    """Converts this `Saver` to a `SaverDef` protocol buffer.
1080
1081    Args:
1082      export_scope: Optional `string`. Name scope to remove.
1083
1084    Returns:
1085      A `SaverDef` protocol buffer.
1086    """
1087    if export_scope is None:
1088      return self.saver_def
1089
1090    if not (self.saver_def.filename_tensor_name.startswith(export_scope) and
1091            self.saver_def.save_tensor_name.startswith(export_scope) and
1092            self.saver_def.restore_op_name.startswith(export_scope)):
1093      return None
1094
1095    saver_def = saver_pb2.SaverDef()
1096    saver_def.CopyFrom(self.saver_def)
1097    saver_def.filename_tensor_name = ops.strip_name_scope(
1098        saver_def.filename_tensor_name, export_scope)
1099    saver_def.save_tensor_name = ops.strip_name_scope(
1100        saver_def.save_tensor_name, export_scope)
1101    saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name,
1102                                                     export_scope)
1103    return saver_def
1104
1105  @staticmethod
1106  def from_proto(saver_def, import_scope=None):
1107    """Returns a `Saver` object created from `saver_def`.
1108
1109    Args:
1110      saver_def: a `SaverDef` protocol buffer.
1111      import_scope: Optional `string`. Name scope to use.
1112
1113    Returns:
1114      A `Saver` built from saver_def.
1115    """
1116    return Saver(saver_def=saver_def, name=import_scope)
1117
1118  @property
1119  def last_checkpoints(self):
1120    """List of not-yet-deleted checkpoint filenames.
1121
1122    You can pass any of the returned values to `restore()`.
1123
1124    Returns:
1125      A list of checkpoint filenames, sorted from oldest to newest.
1126    """
1127    return list(self._CheckpointFilename(p) for p in self._last_checkpoints)
1128
1129  def set_last_checkpoints(self, last_checkpoints):
1130    """DEPRECATED: Use set_last_checkpoints_with_time.
1131
1132    Sets the list of old checkpoint filenames.
1133
1134    Args:
1135      last_checkpoints: A list of checkpoint filenames.
1136
1137    Raises:
1138      AssertionError: If last_checkpoints is not a list.
1139    """
1140    assert isinstance(last_checkpoints, list)
1141    # We use a timestamp of +inf so that this checkpoint will never be
1142    # deleted.  This is both safe and backwards compatible to a previous
1143    # version of the code which used s[1] as the "timestamp".
1144    self._last_checkpoints = [(s, np.inf) for s in last_checkpoints]
1145
1146  def set_last_checkpoints_with_time(self, last_checkpoints_with_time):
1147    """Sets the list of old checkpoint filenames and timestamps.
1148
1149    Args:
1150      last_checkpoints_with_time: A list of tuples of checkpoint filenames and
1151        timestamps.
1152
1153    Raises:
1154      AssertionError: If last_checkpoints_with_time is not a list.
1155    """
1156    assert isinstance(last_checkpoints_with_time, list)
1157    self._last_checkpoints = last_checkpoints_with_time
1158
1159  def recover_last_checkpoints(self, checkpoint_paths):
1160    """Recovers the internal saver state after a crash.
1161
1162    This method is useful for recovering the "self._last_checkpoints" state.
1163
1164    Globs for the checkpoints pointed to by `checkpoint_paths`.  If the files
1165    exist, use their mtime as the checkpoint timestamp.
1166
1167    Args:
1168      checkpoint_paths: a list of checkpoint paths.
1169    """
1170    checkpoints_with_mtimes = []
1171    for checkpoint_path in checkpoint_paths:
1172      try:
1173        mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path])
1174      except errors.NotFoundError:
1175        # It's fine if some other thread/process is deleting some older
1176        # checkpoint concurrently.
1177        continue
1178      if mtime:
1179        checkpoints_with_mtimes.append((checkpoint_path, mtime[0]))
1180    self.set_last_checkpoints_with_time(checkpoints_with_mtimes)
1181
1182  def save(self,
1183           sess,
1184           save_path,
1185           global_step=None,
1186           latest_filename=None,
1187           meta_graph_suffix="meta",
1188           write_meta_graph=True,
1189           write_state=True,
1190           strip_default_attrs=False,
1191           save_debug_info=False):
1192    # pylint: disable=line-too-long
1193    """Saves variables.
1194
1195    This method runs the ops added by the constructor for saving variables.
1196    It requires a session in which the graph was launched.  The variables to
1197    save must also have been initialized.
1198
1199    The method returns the path prefix of the newly created checkpoint files.
1200    This string can be passed directly to a call to `restore()`.
1201
1202    Args:
1203      sess: A Session to use to save the variables.
1204      save_path: String.  Prefix of filenames created for the checkpoint.
1205      global_step: If provided the global step number is appended to `save_path`
1206        to create the checkpoint filenames. The optional argument can be a
1207        `Tensor`, a `Tensor` name or an integer.
1208      latest_filename: Optional name for the protocol buffer file that will
1209        contains the list of most recent checkpoints.  That file, kept in the
1210        same directory as the checkpoint files, is automatically managed by the
1211        saver to keep track of recent checkpoints.  Defaults to 'checkpoint'.
1212      meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
1213      write_meta_graph: `Boolean` indicating whether or not to write the meta
1214        graph file.
1215      write_state: `Boolean` indicating whether or not to write the
1216        `CheckpointStateProto`.
1217      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1218        removed from the NodeDefs. For a detailed guide, see [Stripping
1219        Default-Valued
1220        Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1221      save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1222        which in the same directory of save_path and with `_debug` added before
1223        the file extension. This is only enabled when `write_meta_graph` is
1224        `True`
1225
1226    Returns:
1227      A string: path prefix used for the checkpoint files.  If the saver is
1228        sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
1229        is the number of shards created.
1230      If the saver is empty, returns None.
1231
1232    Raises:
1233      TypeError: If `sess` is not a `Session`.
1234      ValueError: If `latest_filename` contains path components, or if it
1235        collides with `save_path`.
1236      RuntimeError: If save and restore ops weren't built.
1237    """
1238    # pylint: enable=line-too-long
1239    start_time = time.time()
1240    if not self._is_built and not context.executing_eagerly():
1241      raise RuntimeError(
1242          "`build()` should be called before save if defer_build==True")
1243    if latest_filename is None:
1244      latest_filename = "checkpoint"
1245    if self._write_version != saver_pb2.SaverDef.V2:
1246      logging.warning("*******************************************************")
1247      logging.warning("TensorFlow's V1 checkpoint format has been deprecated.")
1248      logging.warning("Consider switching to the more efficient V2 format:")
1249      logging.warning("   `tf.train.Saver(write_version=tf.train.SaverDef.V2)`")
1250      logging.warning("now on by default.")
1251      logging.warning("*******************************************************")
1252
1253    if os.path.split(latest_filename)[0]:
1254      raise ValueError("'latest_filename' must not contain path components")
1255
1256    save_path = compat.as_str(save_path)
1257    if global_step is not None:
1258      if not isinstance(global_step, compat.integral_types):
1259        global_step = training_util.global_step(sess, global_step)
1260      checkpoint_file = "%s-%d" % (save_path, global_step)
1261      if self._pad_step_number:
1262        # Zero-pads the step numbers, so that they are sorted when listed.
1263        checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
1264    else:
1265      checkpoint_file = save_path
1266      if os.path.basename(save_path) == latest_filename and not self._sharded:
1267        # Guard against collision between data file and checkpoint state file.
1268        raise ValueError(
1269            "'latest_filename' collides with 'save_path': '%s' and '%s'" %
1270            (latest_filename, save_path))
1271
1272    if (not context.executing_eagerly() and
1273        not isinstance(sess, session.SessionInterface)):
1274      raise TypeError("'sess' must be a Session; %s" % sess)
1275
1276    save_path_parent = os.path.dirname(save_path)
1277    if not self._is_empty:
1278      try:
1279        if context.executing_eagerly():
1280          self._build_eager(
1281              checkpoint_file, build_save=True, build_restore=False)
1282          model_checkpoint_path = self.saver_def.save_tensor_name
1283        else:
1284          model_checkpoint_path = sess.run(
1285              self.saver_def.save_tensor_name,
1286              {self.saver_def.filename_tensor_name: checkpoint_file})
1287
1288        model_checkpoint_path = compat.as_str(model_checkpoint_path)
1289        if write_state:
1290          self._RecordLastCheckpoint(model_checkpoint_path)
1291          checkpoint_management.update_checkpoint_state_internal(
1292              save_dir=save_path_parent,
1293              model_checkpoint_path=model_checkpoint_path,
1294              all_model_checkpoint_paths=self.last_checkpoints,
1295              latest_filename=latest_filename,
1296              save_relative_paths=self._save_relative_paths)
1297          self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)
1298      except (errors.FailedPreconditionError, errors.NotFoundError) as exc:
1299        if not gfile.IsDirectory(save_path_parent):
1300          exc = ValueError(
1301              "Parent directory of {} doesn't exist, can't save.".format(
1302                  save_path))
1303        raise exc
1304
1305    end_time = time.time()
1306    metrics.AddCheckpointWriteDuration(
1307        api_label=_SAVER_LABEL,
1308        microseconds=_get_duration_microseconds(start_time, end_time))
1309    global _END_TIME_OF_LAST_WRITE
1310    with _END_TIME_OF_LAST_WRITE_LOCK:
1311      metrics.AddTrainingTimeSaved(
1312          api_label=_SAVER_LABEL,
1313          microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE,
1314                                                  end_time))
1315      _END_TIME_OF_LAST_WRITE = end_time
1316
1317    if write_meta_graph:
1318      meta_graph_filename = checkpoint_management.meta_graph_filename(
1319          checkpoint_file, meta_graph_suffix=meta_graph_suffix)
1320      if not context.executing_eagerly():
1321        with sess.graph.as_default():
1322          self.export_meta_graph(
1323              meta_graph_filename,
1324              strip_default_attrs=strip_default_attrs,
1325              save_debug_info=save_debug_info)
1326
1327    if self._is_empty:
1328      return None
1329    else:
1330      metrics.RecordCheckpointSize(
1331          api_label=_SAVER_LABEL,
1332          filesize=_get_checkpoint_size(model_checkpoint_path))
1333      return model_checkpoint_path
1334
1335  def export_meta_graph(self,
1336                        filename=None,
1337                        collection_list=None,
1338                        as_text=False,
1339                        export_scope=None,
1340                        clear_devices=False,
1341                        clear_extraneous_savers=False,
1342                        strip_default_attrs=False,
1343                        save_debug_info=False):
1344    # pylint: disable=line-too-long
1345    """Writes `MetaGraphDef` to save_path/filename.
1346
1347    Args:
1348      filename: Optional meta_graph filename including the path.
1349      collection_list: List of string keys to collect.
1350      as_text: If `True`, writes the meta_graph as an ASCII proto.
1351      export_scope: Optional `string`. Name scope to remove.
1352      clear_devices: Whether or not to clear the device field for an `Operation`
1353        or `Tensor` during export.
1354      clear_extraneous_savers: Remove any Saver-related information from the
1355        graph (both Save/Restore ops and SaverDefs) that are not associated with
1356        this Saver.
1357      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1358        removed from the NodeDefs. For a detailed guide, see [Stripping
1359        Default-Valued
1360        Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1361      save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1362        which in the same directory of filename and with `_debug` added before
1363        the file extension.
1364
1365    Returns:
1366      A `MetaGraphDef` proto.
1367    """
1368    # pylint: enable=line-too-long
1369    return export_meta_graph(
1370        filename=filename,
1371        graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
1372        saver_def=self.saver_def,
1373        collection_list=collection_list,
1374        as_text=as_text,
1375        export_scope=export_scope,
1376        clear_devices=clear_devices,
1377        clear_extraneous_savers=clear_extraneous_savers,
1378        strip_default_attrs=strip_default_attrs,
1379        save_debug_info=save_debug_info)
1380
1381  def restore(self, sess, save_path):
1382    """Restores previously saved variables.
1383
1384    This method runs the ops added by the constructor for restoring variables.
1385    It requires a session in which the graph was launched.  The variables to
1386    restore do not have to have been initialized, as restoring is itself a way
1387    to initialize variables.
1388
1389    The `save_path` argument is typically a value previously returned from a
1390    `save()` call, or a call to `latest_checkpoint()`.
1391
1392    Args:
1393      sess: A `Session` to use to restore the parameters. None in eager mode.
1394      save_path: Path where parameters were previously saved.
1395
1396    Raises:
1397      ValueError: If save_path is None or not a valid checkpoint.
1398    """
1399    start_time = time.time()
1400    if self._is_empty:
1401      return
1402    if save_path is None:
1403      raise ValueError("Can't load save_path when it is None.")
1404
1405    checkpoint_prefix = compat.as_text(save_path)
1406    if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix):
1407      raise ValueError("The passed save_path is not a valid checkpoint: " +
1408                       checkpoint_prefix)
1409
1410    logging.info("Restoring parameters from %s", checkpoint_prefix)
1411    try:
1412      if context.executing_eagerly():
1413        self._build_eager(save_path, build_save=False, build_restore=True)
1414      else:
1415        sess.run(self.saver_def.restore_op_name,
1416                 {self.saver_def.filename_tensor_name: save_path})
1417    except errors.NotFoundError as err:
1418      # There are three common conditions that might cause this error:
1419      # 0. The file is missing. We ignore here, as this is checked above.
1420      # 1. This is an object-based checkpoint trying name-based loading.
1421      # 2. The graph has been altered and a variable or other name is missing.
1422
1423      # 1. The checkpoint would not be loaded successfully as is. Try to parse
1424      # it as an object-based checkpoint.
1425      try:
1426        names_to_keys = object_graph_key_mapping(save_path)
1427      except errors.NotFoundError:
1428        # 2. This is not an object-based checkpoint, which likely means there
1429        # is a graph mismatch. Re-raise the original error with
1430        # a helpful message (b/110263146)
1431        raise _wrap_restore_error_with_msg(
1432            err, "a Variable name or other graph key that is missing")
1433
1434      # This is an object-based checkpoint. We'll print a warning and then do
1435      # the restore.
1436      logging.warning(
1437          "Restoring an object-based checkpoint using a name-based saver. This "
1438          "may be somewhat fragile, and will re-build the Saver. Instead, "
1439          "consider loading object-based checkpoints using "
1440          "tf.train.Checkpoint().")
1441      self._object_restore_saver = saver_from_object_based_checkpoint(
1442          checkpoint_path=save_path,
1443          var_list=self._var_list,
1444          builder=self._builder,
1445          names_to_keys=names_to_keys,
1446          cached_saver=self._object_restore_saver)
1447      self._object_restore_saver.restore(sess=sess, save_path=save_path)
1448    except errors.InvalidArgumentError as err:
1449      # There is a mismatch between the graph and the checkpoint being loaded.
1450      # We add a more reasonable error message here to help users (b/110263146)
1451      raise _wrap_restore_error_with_msg(
1452          err, "a mismatch between the current graph and the graph")
1453    metrics.AddCheckpointReadDuration(
1454        api_label=_SAVER_LABEL,
1455        microseconds=_get_duration_microseconds(start_time, time.time()))
1456
1457  @staticmethod
1458  def _add_collection_def(meta_graph_def, key, export_scope=None):
1459    """Adds a collection to MetaGraphDef protocol buffer.
1460
1461    Args:
1462      meta_graph_def: MetaGraphDef protocol buffer.
1463      key: One of the GraphKeys or user-defined string.
1464      export_scope: Optional `string`. Name scope to remove.
1465    """
1466    meta_graph.add_collection_def(
1467        meta_graph_def, key, export_scope=export_scope)
1468
1469
1470@tf_export(v1=["train.import_meta_graph"])
1471def import_meta_graph(meta_graph_or_file,
1472                      clear_devices=False,
1473                      import_scope=None,
1474                      **kwargs):
1475  """Recreates a Graph saved in a `MetaGraphDef` proto.
1476
1477  This function takes a `MetaGraphDef` protocol buffer as input. If
1478  the argument is a file containing a `MetaGraphDef` protocol buffer ,
1479  it constructs a protocol buffer from the file content. The function
1480  then adds all the nodes from the `graph_def` field to the
1481  current graph, recreates all the collections, and returns a saver
1482  constructed from the `saver_def` field.
1483
1484  In combination with `export_meta_graph()`, this function can be used to
1485
1486  * Serialize a graph along with other Python objects such as `QueueRunner`,
1487    `Variable` into a `MetaGraphDef`.
1488
1489  * Restart training from a saved graph and checkpoints.
1490
1491  * Run inference from a saved graph and checkpoints.
1492
1493  ```Python
1494  ...
1495  # Create a saver.
1496  saver = tf.compat.v1.train.Saver(...variables...)
1497  # Remember the training_op we want to run by adding it to a collection.
1498  tf.compat.v1.add_to_collection('train_op', train_op)
1499  sess = tf.compat.v1.Session()
1500  for step in range(1000000):
1501      sess.run(train_op)
1502      if step % 1000 == 0:
1503          # Saves checkpoint, which by default also exports a meta_graph
1504          # named 'my-model-global_step.meta'.
1505          saver.save(sess, 'my-model', global_step=step)
1506  ```
1507
1508  Later we can continue training from this saved `meta_graph` without building
1509  the model from scratch.
1510
1511  ```Python
1512  with tf.Session() as sess:
1513    new_saver =
1514    tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
1515    new_saver.restore(sess, 'my-save-dir/my-model-10000')
1516    # tf.get_collection() returns a list. In this example we only want
1517    # the first one.
1518    train_op = tf.get_collection('train_op')[0]
1519    for step in range(1000000):
1520      sess.run(train_op)
1521  ```
1522
1523  NOTE: Restarting training from saved `meta_graph` only works if the
1524  device assignments have not changed.
1525
1526  Example:
1527  Variables, placeholders, and independent operations can also be stored, as
1528  shown in the following example.
1529
1530  ```Python
1531  # Saving contents and operations.
1532  v1 = tf.placeholder(tf.float32, name="v1")
1533  v2 = tf.placeholder(tf.float32, name="v2")
1534  v3 = tf.math.multiply(v1, v2)
1535  vx = tf.Variable(10.0, name="vx")
1536  v4 = tf.add(v3, vx, name="v4")
1537  saver = tf.train.Saver([vx])
1538  sess = tf.Session()
1539  sess.run(tf.global_variables_initializer())
1540  sess.run(vx.assign(tf.add(vx, vx)))
1541  result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
1542  print(result)
1543  saver.save(sess, "./model_ex1")
1544  ```
1545
1546  Later this model can be restored and contents loaded.
1547
1548  ```Python
1549  # Restoring variables and running operations.
1550  saver = tf.train.import_meta_graph("./model_ex1.meta")
1551  sess = tf.Session()
1552  saver.restore(sess, "./model_ex1")
1553  result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
1554  print(result)
1555  ```
1556
1557  Args:
1558    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
1559      the path) containing a `MetaGraphDef`.
1560    clear_devices: Whether or not to clear the device field for an `Operation`
1561      or `Tensor` during import.
1562    import_scope: Optional `string`. Name scope to add. Only used when
1563      initializing from protocol buffer.
1564    **kwargs: Optional keyed arguments.
1565
1566  Returns:
1567    A saver constructed from `saver_def` in `MetaGraphDef` or None.
1568
1569    A None value is returned if no variables exist in the `MetaGraphDef`
1570    (i.e., there are no variables to restore).
1571
1572  Raises:
1573    RuntimeError: If called with eager execution enabled.
1574
1575  @compatibility(eager)
1576  Exporting/importing meta graphs is not supported. No graph exists when eager
1577  execution is enabled.
1578  @end_compatibility
1579  """  # pylint: disable=g-doc-exception
1580  return _import_meta_graph_with_return_elements(meta_graph_or_file,
1581                                                 clear_devices, import_scope,
1582                                                 **kwargs)[0]
1583
1584
1585def _import_meta_graph_with_return_elements(meta_graph_or_file,
1586                                            clear_devices=False,
1587                                            import_scope=None,
1588                                            return_elements=None,
1589                                            **kwargs):
1590  """Import MetaGraph, and return both a saver and returned elements."""
1591  if context.executing_eagerly():
1592    raise RuntimeError("Exporting/importing meta graphs is not supported when "
1593                       "eager execution is enabled. No graph exists when eager "
1594                       "execution is enabled.")
1595  if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
1596    meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file)
1597  else:
1598    meta_graph_def = meta_graph_or_file
1599
1600  imported_vars, imported_return_elements = (
1601      meta_graph.import_scoped_meta_graph_with_return_elements(
1602          meta_graph_def,
1603          clear_devices=clear_devices,
1604          import_scope=import_scope,
1605          return_elements=return_elements,
1606          **kwargs))
1607
1608  saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
1609                                                 imported_vars)
1610  return saver, imported_return_elements
1611
1612
1613def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
1614                                           imported_vars):
1615  """Return a saver for restoring variable values to an imported MetaGraph."""
1616  if meta_graph_def.HasField("saver_def"):
1617    # Infer the scope that is prepended by `import_scoped_meta_graph`.
1618    scope = import_scope
1619    var_names = list(imported_vars.keys())
1620    if var_names:
1621      sample_key = var_names[0]
1622      sample_var = imported_vars[sample_key]
1623      scope = sample_var.name[:-len(sample_key)]
1624
1625    return Saver(saver_def=meta_graph_def.saver_def, name=scope)
1626  else:
1627    if variables._all_saveable_objects(scope=import_scope):  # pylint: disable=protected-access
1628      # Return the default saver instance for all graph variables.
1629      return Saver()
1630    else:
1631      # If no graph variables exist, then a Saver cannot be constructed.
1632      logging.info("Saver not created because there are no variables in the"
1633                   " graph to restore")
1634      return None
1635
1636
1637@tf_export(v1=["train.export_meta_graph"])
1638def export_meta_graph(filename=None,
1639                      meta_info_def=None,
1640                      graph_def=None,
1641                      saver_def=None,
1642                      collection_list=None,
1643                      as_text=False,
1644                      graph=None,
1645                      export_scope=None,
1646                      clear_devices=False,
1647                      clear_extraneous_savers=False,
1648                      strip_default_attrs=False,
1649                      save_debug_info=False,
1650                      **kwargs):
1651  # pylint: disable=line-too-long
1652  """Returns `MetaGraphDef` proto.
1653
1654  Optionally writes it to filename.
1655
1656  This function exports the graph, saver, and collection objects into
1657  `MetaGraphDef` protocol buffer with the intention of it being imported
1658  at a later time or location to restart training, run inference, or be
1659  a subgraph.
1660
1661  Args:
1662    filename: Optional filename including the path for writing the generated
1663      `MetaGraphDef` protocol buffer.
1664    meta_info_def: `MetaInfoDef` protocol buffer.
1665    graph_def: `GraphDef` protocol buffer.
1666    saver_def: `SaverDef` protocol buffer.
1667    collection_list: List of string keys to collect.
1668    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
1669    graph: The `Graph` to export. If `None`, use the default graph.
1670    export_scope: Optional `string`. Name scope under which to extract the
1671      subgraph. The scope name will be striped from the node definitions for
1672      easy import later into new name scopes. If `None`, the whole graph is
1673      exported. graph_def and export_scope cannot both be specified.
1674    clear_devices: Whether or not to clear the device field for an `Operation`
1675      or `Tensor` during export.
1676    clear_extraneous_savers: Remove any Saver-related information from the graph
1677      (both Save/Restore ops and SaverDefs) that are not associated with the
1678      provided SaverDef.
1679    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1680      removed from the NodeDefs. For a detailed guide, see [Stripping
1681      Default-Valued
1682      Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1683    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1684      which in the same directory of filename and with `_debug` added before the
1685      file extend.
1686    **kwargs: Optional keyed arguments.
1687
1688  Returns:
1689    A `MetaGraphDef` proto.
1690
1691  Raises:
1692    ValueError: When the `GraphDef` is larger than 2GB.
1693    RuntimeError: If called with eager execution enabled.
1694
1695  @compatibility(eager)
1696  Exporting/importing meta graphs is not supported unless both `graph_def` and
1697  `graph` are provided. No graph exists when eager execution is enabled.
1698  @end_compatibility
1699  """
1700  # pylint: enable=line-too-long
1701  if context.executing_eagerly() and not (graph_def is not None and
1702                                          graph is not None):
1703    raise RuntimeError("Exporting/importing meta graphs is not supported when "
1704                       "eager execution is enabled. No graph exists when eager "
1705                       "execution is enabled.")
1706  meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
1707      filename=filename,
1708      meta_info_def=meta_info_def,
1709      graph_def=graph_def,
1710      saver_def=saver_def,
1711      collection_list=collection_list,
1712      as_text=as_text,
1713      graph=graph,
1714      export_scope=export_scope,
1715      clear_devices=clear_devices,
1716      clear_extraneous_savers=clear_extraneous_savers,
1717      strip_default_attrs=strip_default_attrs,
1718      save_debug_info=save_debug_info,
1719      **kwargs)
1720  return meta_graph_def
1721
1722
1723def _wrap_restore_error_with_msg(err, extra_verbiage):
1724  err_msg = ("Restoring from checkpoint failed. This is most likely "
1725             "due to {} from the checkpoint. Please ensure that you "
1726             "have not altered the graph expected based on the checkpoint. "
1727             "Original error:\n\n{}").format(extra_verbiage, err.message)
1728  return err.__class__(err.node_def, err.op, err_msg)
1729
1730
1731ops.register_proto_function(
1732    ops.GraphKeys.SAVERS,
1733    proto_type=saver_pb2.SaverDef,
1734    to_proto=Saver.to_proto,
1735    from_proto=Saver.from_proto)
1736
1737
1738def object_graph_key_mapping(checkpoint_path):
1739  """Return name to key mappings from the checkpoint.
1740
1741  Args:
1742    checkpoint_path: string, path to object-based checkpoint
1743
1744  Returns:
1745    Dictionary mapping tensor names to checkpoint keys.
1746  """
1747  reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
1748  object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
1749  object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
1750  object_graph_proto.ParseFromString(object_graph_string)
1751  names_to_keys = {}
1752  for node in object_graph_proto.nodes:
1753    for attribute in node.attributes:
1754      names_to_keys[attribute.full_name] = attribute.checkpoint_key
1755  return names_to_keys
1756
1757
1758def saver_from_object_based_checkpoint(checkpoint_path,
1759                                       var_list=None,
1760                                       builder=None,
1761                                       names_to_keys=None,
1762                                       cached_saver=None):
1763  """Return a `Saver` which reads from an object-based checkpoint.
1764
1765  This function validates that all variables in the variables list are remapped
1766  in the object-based checkpoint (or `names_to_keys` dict if provided). A
1767  saver will be created with the list of remapped variables.
1768
1769  The `cached_saver` argument allows the user to pass in a previously created
1770  saver, so multiple `saver.restore()` calls don't pollute the graph when graph
1771  building. This assumes that keys are consistent, meaning that the
1772    1) `checkpoint_path` checkpoint, and
1773    2) checkpoint used to create the `cached_saver`
1774  are the same type of object-based checkpoint. If this argument is set, this
1775  function will simply validate that all variables have been remapped by the
1776  checkpoint at `checkpoint_path`.
1777
1778  Note that in general, `tf.train.Checkpoint` should be used to restore/save an
1779  object-based checkpoint.
1780
1781  Args:
1782    checkpoint_path: string, path to object-based checkpoint
1783    var_list: list of `Variables` that appear in the checkpoint. If `None`,
1784      `var_list` will be set to all saveable objects.
1785    builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder`
1786      will be created.
1787    names_to_keys: dict mapping string tensor names to checkpoint keys. If
1788      `None`, this dict will be generated from the checkpoint file.
1789    cached_saver: Cached `Saver` object with remapped variables.
1790
1791  Returns:
1792    `Saver` with remapped variables for reading from an object-based checkpoint.
1793
1794  Raises:
1795    ValueError if the checkpoint provided is not an object-based checkpoint.
1796    NotFoundError: If one of the variables in `var_list` can not be found in the
1797      checkpoint. This could mean the checkpoint or `names_to_keys` mapping is
1798      missing the variable.
1799  """
1800  if names_to_keys is None:
1801    try:
1802      names_to_keys = object_graph_key_mapping(checkpoint_path)
1803    except errors.NotFoundError:
1804      raise ValueError("Checkpoint in %s not an object-based checkpoint." %
1805                       checkpoint_path)
1806  if var_list is None:
1807    var_list = variables._all_saveable_objects()  # pylint: disable=protected-access
1808  if builder is None:
1809    builder = BulkSaverBuilder()
1810
1811  saveables = saveable_object_util.validate_and_slice_inputs(var_list)
1812  current_names = set()
1813  for saveable in saveables:
1814    for spec in saveable.specs:
1815      current_names.add(spec.name)
1816  previous_names = set(names_to_keys.keys())
1817  missing_names = current_names - previous_names
1818  if missing_names:
1819    extra_names = previous_names - current_names
1820    intersecting_names = previous_names.intersection(current_names)
1821    raise errors.NotFoundError(
1822        None,
1823        None,
1824        message=(
1825            "\n\nExisting variables not in the checkpoint: %s\n\n"
1826            "Variables names when this checkpoint was written which don't "
1827            "exist now: %s\n\n"
1828            "(%d variable name(s) did match)\n\n"
1829            "Could not find some variables in the checkpoint (see names "
1830            "above). Saver was attempting to load an object-based checkpoint "
1831            "(saved using tf.train.Checkpoint or tf.keras.Model.save_weights) "
1832            "using variable names. If the checkpoint was written with eager "
1833            "execution enabled, it's possible that variable names have "
1834            "changed (for example missing a '_1' suffix). It's also "
1835            "possible that there are new variables which did not exist "
1836            "when the checkpoint was written. You can construct a "
1837            "Saver(var_list=...) with only the variables which previously "
1838            "existed, and if variable names have changed you may need to "
1839            "make this a dictionary with the old names as keys. If you're "
1840            "using an Estimator, you'll need to return a tf.train.Saver "
1841            "inside a tf.train.Scaffold from your model_fn.") %
1842        (", ".join(sorted(missing_names)), ", ".join(
1843            sorted(extra_names)), len(intersecting_names)))
1844  for saveable in saveables:
1845    for spec in saveable.specs:
1846      spec.name = names_to_keys[spec.name]
1847  if cached_saver is None:
1848    return Saver(saveables)
1849  return cached_saver
1850