1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Save and restore variables. 18 19Symbols in this file are deprecated. See replacements in 20tensorflow/python/training/trackable and tensorflow/python/training/saving. 21""" 22import collections 23import glob 24import os.path 25import threading 26import time 27 28import numpy as np 29from tensorflow.core.protobuf import meta_graph_pb2 30from tensorflow.core.protobuf import saver_pb2 31from tensorflow.core.protobuf import trackable_object_graph_pb2 32from tensorflow.python.checkpoint import checkpoint_management 33from tensorflow.python.client import session 34from tensorflow.python.eager import context 35from tensorflow.python.framework import constant_op 36from tensorflow.python.framework import device as pydev 37from tensorflow.python.framework import errors 38from tensorflow.python.framework import meta_graph 39from tensorflow.python.framework import ops 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import control_flow_ops 42from tensorflow.python.ops import gen_io_ops 43from tensorflow.python.ops import io_ops 44from tensorflow.python.ops import string_ops 45from tensorflow.python.ops import variables 46from tensorflow.python.platform import gfile 47from tensorflow.python.platform import tf_logging as logging 48from tensorflow.python.saved_model.pywrap_saved_model import metrics 49from tensorflow.python.trackable import base as trackable 50from tensorflow.python.training import py_checkpoint_reader 51from tensorflow.python.training import training_util 52from tensorflow.python.training.saving import saveable_object 53from tensorflow.python.training.saving import saveable_object_util 54from tensorflow.python.util import compat 55from tensorflow.python.util.tf_export import tf_export 56 57# TODO(allenl): Remove these aliases once all users are migrated off. 58get_checkpoint_state = checkpoint_management.get_checkpoint_state 59update_checkpoint_state = checkpoint_management.update_checkpoint_state 60generate_checkpoint_state_proto = ( 61 checkpoint_management.generate_checkpoint_state_proto) 62latest_checkpoint = checkpoint_management.latest_checkpoint 63checkpoint_exists = checkpoint_management.checkpoint_exists 64get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes 65remove_checkpoint = checkpoint_management.remove_checkpoint 66 67# Captures the timestamp of the first Saver object instantiation or end of a 68# save operation. Can be accessed by multiple Saver instances. 69_END_TIME_OF_LAST_WRITE = None 70_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock() 71 72# API label for cell name used in checkpoint metrics. 73_SAVER_LABEL = "saver_v1" 74 75 76def _get_duration_microseconds(start_time_seconds, end_time_seconds): 77 if end_time_seconds < start_time_seconds: 78 # Avoid returning negative value in case of clock skew. 79 return 0 80 return round((end_time_seconds - start_time_seconds) * 1000000) 81 82 83def _get_checkpoint_size(prefix): 84 """Calculates filesize of checkpoint based on prefix.""" 85 size = 0 86 # Gather all files beginning with prefix (.index plus sharded data files). 87 files = glob.glob("{}*".format(prefix)) 88 for file in files: 89 # Use TensorFlow's C++ FileSystem API. 90 size += metrics.CalculateFileSize(file) 91 return size 92 93 94class BaseSaverBuilder: 95 """Base class for Savers. 96 97 Can be extended to create different Ops. 98 """ 99 100 SaveSpec = saveable_object.SaveSpec 101 SaveableObject = saveable_object.SaveableObject 102 103 # Aliases for code which was moved but still has lots of users. 104 VariableSaveable = saveable_object_util.ReferenceVariableSaveable 105 ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable 106 107 def __init__(self, write_version=saver_pb2.SaverDef.V2): 108 self._write_version = write_version 109 110 def save_op(self, filename_tensor, saveables): 111 """Create an Op to save 'saveables'. 112 113 This is intended to be overridden by subclasses that want to generate 114 different Ops. 115 116 Args: 117 filename_tensor: String Tensor. 118 saveables: A list of BaseSaverBuilder.SaveableObject objects. 119 120 Returns: 121 An Operation that save the variables. 122 123 Raises: 124 RuntimeError: (implementation detail) if "self._write_version" is an 125 unexpected value. 126 """ 127 # pylint: disable=protected-access 128 tensor_names = [] 129 tensors = [] 130 tensor_slices = [] 131 for saveable in saveables: 132 for spec in saveable.specs: 133 tensor_names.append(spec.name) 134 tensors.append(spec.tensor) 135 tensor_slices.append(spec.slice_spec) 136 if self._write_version == saver_pb2.SaverDef.V1: 137 return io_ops._save( 138 filename=filename_tensor, 139 tensor_names=tensor_names, 140 tensors=tensors, 141 tensor_slices=tensor_slices) 142 elif self._write_version == saver_pb2.SaverDef.V2: 143 # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix 144 # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>". 145 return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices, 146 tensors) 147 else: 148 raise RuntimeError("Unexpected write_version: " + self._write_version) 149 150 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 151 restore_sequentially): 152 """Restore all tensors contained in saveables. 153 154 By default, this issues separate calls to `restore_op` for each saveable. 155 Subclasses may override to load multiple saveables in a single call. 156 157 Args: 158 filename_tensor: String Tensor. 159 saveables: List of BaseSaverBuilder.SaveableObject objects. 160 preferred_shard: Int. Shard to open first when loading a sharded file. 161 restore_sequentially: Unused. Bool. If true, each restore is sequential. 162 163 Returns: 164 A list of Tensors resulting from reading 'saveable' from 165 'filename'. 166 167 """ 168 del restore_sequentially 169 all_tensors = [] 170 for saveable in saveables: 171 if saveable.device: 172 device = saveable_object_util.set_cpu0(saveable.device) 173 else: 174 device = None 175 with ops.device(device): 176 all_tensors.extend( 177 self.restore_op(filename_tensor, saveable, preferred_shard)) 178 return all_tensors 179 180 # pylint: disable=unused-argument 181 def restore_op(self, filename_tensor, saveable, preferred_shard): 182 """Create ops to restore 'saveable'. 183 184 This is intended to be overridden by subclasses that want to generate 185 different Ops. 186 187 Args: 188 filename_tensor: String Tensor. 189 saveable: A BaseSaverBuilder.SaveableObject object. 190 preferred_shard: Int. Shard to open first when loading a sharded file. 191 192 Returns: 193 A list of Tensors resulting from reading 'saveable' from 194 'filename'. 195 """ 196 # pylint: disable=protected-access 197 tensors = [] 198 for spec in saveable.specs: 199 tensors.append( 200 io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec], 201 [spec.dtype])[0]) 202 203 return tensors 204 205 # pylint: enable=unused-argument 206 207 def sharded_filename(self, filename_tensor, shard, num_shards): 208 """Append sharding information to a filename. 209 210 Args: 211 filename_tensor: A string tensor. 212 shard: Integer. The shard for the filename. 213 num_shards: An int Tensor for the number of shards. 214 215 Returns: 216 A string tensor. 217 """ 218 return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards) 219 220 def _AddSaveOps(self, filename_tensor, saveables): 221 """Add ops to save variables that are on the same shard. 222 223 Args: 224 filename_tensor: String Tensor. 225 saveables: A list of SaveableObject objects. 226 227 Returns: 228 A tensor with the filename used to save. 229 """ 230 save = self.save_op(filename_tensor, saveables) 231 return control_flow_ops.with_dependencies([save], filename_tensor) 232 233 def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device): 234 """Add ops to save the params per shard, for the V2 format. 235 236 Note that the sharded save procedure for the V2 format is different from 237 V1: there is a special "merge" step that merges the small metadata produced 238 from each device. 239 240 Args: 241 checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A FILENAME*, 242 but as a prefix of a V2 checkpoint; 243 per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as 244 returned by _GroupByDevices(). 245 246 Returns: 247 An op to save the variables, which, when evaluated, returns the prefix 248 "<user-fed prefix>" only and does not include the sharded spec suffix. 249 """ 250 # IMPLEMENTATION DETAILS: most clients should skip. 251 # 252 # Suffix for any well-formed "checkpoint_prefix", when sharded. 253 # Transformations: 254 # * Users pass in "save_path" in save() and restore(). Say "myckpt". 255 # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>. 256 # * If checkpoint_prefix is a S3 bucket path ".part" is appended to it 257 # * Otherwise _temp/part is appended which is normalized relative to the OS 258 # Example: 259 # During runtime, a temporary directory is first created, which contains 260 # files 261 # 262 # <train dir>/myckpt_temp/ 263 # part-?????-of-?????{.index, .data-00000-of-00001} 264 # 265 # Before .save() finishes, they will be (hopefully, atomically) renamed to 266 # 267 # <train dir>/ 268 # myckpt{.index, .data-?????-of-?????} 269 # 270 # Filesystems with eventual consistency (such as S3), don't need a 271 # temporary location. Using a temporary directory in those cases might 272 # cause situations where files are not available during copy. 273 # 274 # Users only need to interact with the user-specified prefix, which is 275 # "<train dir>/myckpt" in this case. Save() and Restore() work with the 276 # prefix directly, instead of any physical pathname. (On failure and 277 # subsequent restore, an outdated and orphaned temporary directory can be 278 # safely removed.) 279 with ops.device("CPU"): 280 _SHARDED_SUFFIX = array_ops.where( 281 string_ops.regex_full_match(checkpoint_prefix, "^s3://.*"), 282 constant_op.constant(".part"), 283 constant_op.constant(os.path.normpath("_temp/part"))) 284 tmp_checkpoint_prefix = string_ops.string_join( 285 [checkpoint_prefix, _SHARDED_SUFFIX]) 286 287 num_shards = len(per_device) 288 sharded_saves = [] 289 sharded_prefixes = [] 290 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 291 last_device = None 292 for shard, (device, saveables) in enumerate(per_device): 293 last_device = device 294 with ops.device(saveable_object_util.set_cpu0(device)): 295 sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard, 296 num_shards_tensor) 297 sharded_prefixes.append(sharded_filename) 298 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 299 300 with ops.control_dependencies([x.op for x in sharded_saves]): 301 # Co-locates the merge step with the last device. 302 with ops.device(saveable_object_util.set_cpu0(last_device)): 303 # V2 format write path consists of a metadata merge step. Once merged, 304 # attempts to delete the temporary directory, "<user-fed prefix>_temp". 305 merge_step = gen_io_ops.merge_v2_checkpoints( 306 sharded_prefixes, checkpoint_prefix, delete_old_dirs=True) 307 with ops.control_dependencies([merge_step]): 308 # Returns the prefix "<user-fed prefix>" only. DOES NOT include the 309 # sharded spec suffix. 310 return array_ops.identity(checkpoint_prefix) 311 312 def _AddShardedSaveOps(self, filename_tensor, per_device): 313 """Add ops to save the params per shard. 314 315 Args: 316 filename_tensor: a scalar String Tensor. 317 per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as 318 returned by _GroupByDevices(). 319 320 Returns: 321 An op to save the variables. 322 """ 323 if self._write_version == saver_pb2.SaverDef.V2: 324 return self._AddShardedSaveOpsForV2(filename_tensor, per_device) 325 326 num_shards = len(per_device) 327 sharded_saves = [] 328 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 329 for shard, (device, saveables) in enumerate(per_device): 330 with ops.device(device): 331 sharded_filename = self.sharded_filename(filename_tensor, shard, 332 num_shards_tensor) 333 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 334 # Return the sharded name for the save path. 335 with ops.control_dependencies([x.op for x in sharded_saves]): 336 return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor) 337 338 def _AddRestoreOps(self, 339 filename_tensor, 340 saveables, 341 restore_sequentially, 342 reshape, 343 preferred_shard=-1, 344 name="restore_all"): 345 """Add operations to restore saveables. 346 347 Args: 348 filename_tensor: Tensor for the path of the file to load. 349 saveables: A list of SaveableObject objects. 350 restore_sequentially: True if we want to restore variables sequentially 351 within a shard. 352 reshape: True if we want to reshape loaded tensors to the shape of the 353 corresponding variable. 354 preferred_shard: Shard to open first when loading a sharded file. 355 name: Name for the returned op. 356 357 Returns: 358 An Operation that restores the variables. 359 """ 360 all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard, 361 restore_sequentially) 362 363 assign_ops = [] 364 idx = 0 365 # Load and optionally reshape on the CPU, as string tensors are not 366 # available on the GPU. 367 # TODO(touts): Re-enable restore on GPU when we can support annotating 368 # string tensors as "HostMemory" inputs. 369 for saveable in saveables: 370 shapes = None 371 if reshape: 372 # Compute the shapes, let the restore op decide if and how to do 373 # the reshape. 374 shapes = [] 375 for spec in saveable.specs: 376 v = spec.tensor 377 shape = v.get_shape() 378 if not shape.is_fully_defined(): 379 shape = array_ops.shape(v) 380 shapes.append(shape) 381 saveable_tensors = all_tensors[idx:idx + len(saveable.specs)] 382 idx += len(saveable.specs) 383 assign_ops.append(saveable.restore(saveable_tensors, shapes)) 384 385 # Create a Noop that has control dependencies from all the updates. 386 return control_flow_ops.group(*assign_ops, name=name) 387 388 def _AddShardedRestoreOps(self, filename_tensor, per_device, 389 restore_sequentially, reshape): 390 """Add Ops to restore variables from multiple devices. 391 392 Args: 393 filename_tensor: Tensor for the path of the file to load. 394 per_device: A list of (device, SaveableObject) pairs, as returned by 395 _GroupByDevices(). 396 restore_sequentially: True if we want to restore variables sequentially 397 within a shard. 398 reshape: True if we want to reshape loaded tensors to the shape of the 399 corresponding variable. 400 401 Returns: 402 An Operation that restores the variables. 403 """ 404 sharded_restores = [] 405 for shard, (device, saveables) in enumerate(per_device): 406 with ops.device(device): 407 sharded_restores.append( 408 self._AddRestoreOps( 409 filename_tensor, 410 saveables, 411 restore_sequentially, 412 reshape, 413 preferred_shard=shard, 414 name="restore_shard")) 415 return control_flow_ops.group(*sharded_restores, name="restore_all") 416 417 def _GroupByDevices(self, saveables): 418 """Group Variable tensor slices per device. 419 420 TODO(touts): Make sure that all the devices found are on different 421 job/replica/task/cpu|gpu. It would be bad if 2 were on the same device. 422 It can happen if the devices are unspecified. 423 424 Args: 425 saveables: A list of BaseSaverBuilder.SaveableObject objects. 426 427 Returns: 428 A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples. 429 The list is sorted by ascending device_name. 430 431 Raises: 432 ValueError: If the tensors of a saveable are on different devices. 433 """ 434 per_device = collections.defaultdict(lambda: []) 435 for saveable in saveables: 436 canonical_device = set( 437 pydev.canonical_name(spec.device) for spec in saveable.specs) 438 if len(canonical_device) != 1: 439 raise ValueError("All tensors of a saveable object must be " 440 "on the same device: %s" % saveable.name) 441 per_device[canonical_device.pop()].append(saveable) 442 return sorted(per_device.items(), key=lambda t: t[0]) 443 444 def build(self, 445 names_to_saveables, 446 reshape=False, 447 sharded=False, 448 max_to_keep=5, 449 keep_checkpoint_every_n_hours=10000.0, 450 name=None, 451 restore_sequentially=False, 452 filename="model"): 453 """Builds save/restore graph nodes or runs save/restore in eager mode. 454 455 Args: 456 names_to_saveables: A dictionary mapping name to a Variable or 457 SaveableObject. Each name will be associated with the corresponding 458 variable in the checkpoint. 459 reshape: If True, allow restoring parameters from a checkpoint that where 460 the parameters have a different shape. This is only needed when you try 461 to restore from a Dist-Belief checkpoint, and only some times. 462 sharded: If True, shard the checkpoints, one per device that has Variable 463 nodes. 464 max_to_keep: Maximum number of checkpoints to keep. As new checkpoints 465 are created, old ones are deleted. If None or 0, no checkpoints are 466 deleted from the filesystem but only the last one is kept in the 467 `checkpoint` file. Presently the number is only roughly enforced. For 468 example in case of restarts more than max_to_keep checkpoints may be 469 kept. 470 keep_checkpoint_every_n_hours: How often checkpoints should be kept. 471 Defaults to 10,000 hours. 472 name: String. Optional name to use as a prefix when adding operations. 473 restore_sequentially: A Bool, which if true, causes restore of different 474 variables to happen sequentially within each device. 475 filename: If known at graph construction time, filename used for variable 476 loading/saving. If None, then the default name "model" will be used. 477 478 Returns: 479 A SaverDef proto. 480 481 Raises: 482 TypeError: If 'names_to_saveables' is not a dictionary mapping string 483 keys to variable Tensors. 484 ValueError: If any of the keys or values in 'names_to_saveables' is not 485 unique. 486 """ 487 return self._build_internal( 488 names_to_saveables=names_to_saveables, 489 reshape=reshape, 490 sharded=sharded, 491 max_to_keep=max_to_keep, 492 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 493 name=name, 494 restore_sequentially=restore_sequentially, 495 filename=filename) 496 497 def _build_internal(self, 498 names_to_saveables, 499 reshape=False, 500 sharded=False, 501 max_to_keep=5, 502 keep_checkpoint_every_n_hours=10000.0, 503 name=None, 504 restore_sequentially=False, 505 filename="model", 506 build_save=True, 507 build_restore=True): 508 """build() with option to only perform save and restore.""" 509 if not context.executing_eagerly() and (not build_save or 510 not build_restore): 511 raise ValueError("save and restore operations need to be built together " 512 " when eager execution is not enabled.") 513 514 saveables = saveable_object_util.validate_and_slice_inputs( 515 names_to_saveables) 516 if max_to_keep is None: 517 max_to_keep = 0 518 519 with ops.name_scope(name, "save", 520 [saveable.op for saveable in saveables]) as name: 521 # Add a placeholder string tensor for the filename. 522 filename_tensor = array_ops.placeholder_with_default( 523 filename or "model", shape=(), name="filename") 524 # Keep the name "Const" for backwards compatibility. 525 filename_tensor = array_ops.placeholder_with_default( 526 filename_tensor, shape=(), name="Const") 527 528 # Add the save ops. 529 if sharded: 530 per_device = self._GroupByDevices(saveables) 531 if build_save: 532 save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) 533 if build_restore: 534 restore_op = self._AddShardedRestoreOps(filename_tensor, per_device, 535 restore_sequentially, reshape) 536 else: 537 if build_save: 538 save_tensor = self._AddSaveOps(filename_tensor, saveables) 539 if build_restore: 540 restore_op = self._AddRestoreOps(filename_tensor, saveables, 541 restore_sequentially, reshape) 542 543 # In the following use case, it's possible to have restore_ops be called 544 # something else: 545 # - Build inference graph and export a meta_graph. 546 # - Import the inference meta_graph 547 # - Extend the inference graph to a train graph. 548 # - Export a new meta_graph. 549 # Now the second restore_op will be called "restore_all_1". 550 # As such, comment out the assert for now until we know whether supporting 551 # such usage model makes sense. 552 # 553 # assert restore_op.name.endswith("restore_all"), restore_op.name 554 if context.executing_eagerly(): 555 # Store the tensor values to the tensor_names. 556 save_tensor_name = save_tensor.numpy() if build_save else "" 557 return saver_pb2.SaverDef( 558 filename_tensor_name=filename_tensor.numpy(), 559 save_tensor_name=save_tensor_name, 560 restore_op_name="", 561 max_to_keep=max_to_keep, 562 sharded=sharded, 563 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 564 version=self._write_version) 565 else: 566 graph = ops.get_default_graph() 567 # Do some sanity checking on collections containing 568 # PartitionedVariables. If a saved collection has a PartitionedVariable, 569 # the GraphDef needs to include concat ops to get the value (or there'll 570 # be a lookup error on load). 571 check_collection_list = graph.get_all_collection_keys() 572 for collection_type in check_collection_list: 573 for element in graph.get_collection(collection_type): 574 if isinstance(element, variables.PartitionedVariable): 575 try: 576 graph.get_operation_by_name(element.name) 577 except KeyError: 578 # Create a concat op for this PartitionedVariable. The user may 579 # not need it, but we'll try looking it up on MetaGraph restore 580 # since it's in a collection. 581 element.as_tensor() 582 return saver_pb2.SaverDef( 583 filename_tensor_name=filename_tensor.name, 584 save_tensor_name=save_tensor.name, 585 restore_op_name=restore_op.name, 586 max_to_keep=max_to_keep, 587 sharded=sharded, 588 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 589 version=self._write_version) 590 591 592class BulkSaverBuilder(BaseSaverBuilder): 593 """SaverBuilder with support for bulk restoring multiple saveables.""" 594 595 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 596 restore_sequentially): 597 598 # Ignored: bulk restore is internally sequential. 599 del restore_sequentially 600 restore_specs = [] 601 for saveable in saveables: 602 for spec in saveable.specs: 603 restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) 604 605 names, slices, dtypes = zip(*restore_specs) 606 # Load all tensors onto CPU 0 for compatibility with existing code. 607 with ops.device("cpu:0"): 608 return io_ops.restore_v2(filename_tensor, names, slices, dtypes) 609 610 611def _get_saver_or_default(): 612 """Returns the saver from SAVERS collection, or creates a default one. 613 614 This method is used by other members of the training module, such as 615 `Scaffold`, or `CheckpointSaverHook`. 616 617 Returns: 618 `Saver`. 619 620 Raises: 621 RuntimeError: If the SAVERS collection already has more than one items. 622 """ 623 collection_key = ops.GraphKeys.SAVERS 624 savers = ops.get_collection(collection_key) 625 if savers: 626 if len(savers) > 1: 627 raise RuntimeError( 628 "More than one item in collection {}. " 629 "Please indicate which one to use by passing it to the constructor." 630 .format(collection_key)) 631 return savers[0] 632 saver = Saver(sharded=True, allow_empty=True) 633 if saver is not None: 634 ops.add_to_collection(collection_key, saver) 635 return saver 636 637 638@tf_export(v1=["train.Saver"]) 639class Saver: 640 # pylint: disable=line-too-long 641 """Saves and restores variables. 642 643 @compatibility(TF2) 644 `tf.compat.v1.train.Saver` is not supported for saving and restoring 645 checkpoints in TF2. Please switch to `tf.train.Checkpoint` or 646 `tf.keras.Model.save_weights`, which perform a more robust [object-based 647 saving](https://www.tensorflow.org/guide/checkpoint#loading_mechanics). 648 649 ### How to Rewrite Checkpoints 650 651 Please rewrite your checkpoints immediately using the object-based checkpoint 652 APIs. 653 654 You can load a name-based checkpoint written by `tf.compat.v1.train.Saver` 655 using `tf.train.Checkpoint.restore` or `tf.keras.Model.load_weights`. However, 656 you may have to change the names of the variables in your model to match the 657 variable names in the name-based checkpoint, which can be viewed with 658 `tf.train.list_variables(path)`. 659 660 Another option is to create an `assignment_map` that maps the name of the 661 variables in the name-based checkpoint to the variables in your model, eg: 662 ``` 663 { 664 'sequential/dense/bias': model.variables[0], 665 'sequential/dense/kernel': model.variables[1] 666 } 667 ``` 668 and use `tf.compat.v1.train.init_from_checkpoint(path, assignment_map)` to 669 restore the name-based checkpoint. 670 671 After restoring, re-encode your checkpoint 672 using `tf.train.Checkpoint.save` or `tf.keras.Model.save_weights`. 673 674 See the [Checkpoint compatibility]( 675 https://www.tensorflow.org/guide/migrate#checkpoint_compatibility) 676 section of the migration guide for more details. 677 678 679 ### Checkpoint Management in TF2 680 681 Use `tf.train.CheckpointManager` to manage checkpoints in TF2. 682 `tf.train.CheckpointManager` offers equivalent `keep_checkpoint_every_n_hours` 683 and `max_to_keep` parameters. 684 685 To recover the latest checkpoint, 686 687 ``` 688 checkpoint = tf.train.Checkpoint(model) 689 manager = tf.train.CheckpointManager(checkpoint) 690 status = checkpoint.restore(manager.latest_checkpoint) 691 ``` 692 693 `tf.train.CheckpointManager` also writes a [`CheckpointState` proto] 694 (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/checkpoint_state.proto) 695 which contains the timestamp when each checkpoint was created. 696 697 ### Writing `MetaGraphDef`s in TF2 698 699 To replace, `tf.compat.v1.train.Saver.save(write_meta_graph=True)`, use 700 `tf.saved_model.save` to write the `MetaGraphDef` (which is contained in 701 `saved_model.pb`). 702 703 @end_compatibility 704 705 See [Variables](https://tensorflow.org/guide/variables) 706 for an overview of variables, saving and restoring. 707 708 The `Saver` class adds ops to save and restore variables to and from 709 *checkpoints*. It also provides convenience methods to run these ops. 710 711 Checkpoints are binary files in a proprietary format which map variable names 712 to tensor values. The best way to examine the contents of a checkpoint is to 713 load it using a `Saver`. 714 715 Savers can automatically number checkpoint filenames with a provided counter. 716 This lets you keep multiple checkpoints at different steps while training a 717 model. For example you can number the checkpoint filenames with the training 718 step number. To avoid filling up disks, savers manage checkpoint files 719 automatically. For example, they can keep only the N most recent files, or 720 one checkpoint for every N hours of training. 721 722 You number checkpoint filenames by passing a value to the optional 723 `global_step` argument to `save()`: 724 725 ```python 726 saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0' 727 ... 728 saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000' 729 ``` 730 731 Additionally, optional arguments to the `Saver()` constructor let you control 732 the proliferation of checkpoint files on disk: 733 734 * `max_to_keep` indicates the maximum number of recent checkpoint files to 735 keep. As new files are created, older files are deleted. If None or 0, 736 no checkpoints are deleted from the filesystem but only the last one is 737 kept in the `checkpoint` file. Defaults to 5 (that is, the 5 most recent 738 checkpoint files are kept.) 739 740 * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent 741 `max_to_keep` checkpoint files, you might want to keep one checkpoint file 742 for every N hours of training. This can be useful if you want to later 743 analyze how a model progressed during a long training session. For 744 example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep 745 one checkpoint file for every 2 hours of training. The default value of 746 10,000 hours effectively disables the feature. 747 748 Note that you still have to call the `save()` method to save the model. 749 Passing these arguments to the constructor will not save variables 750 automatically for you. 751 752 A training program that saves regularly looks like: 753 754 ```python 755 ... 756 # Create a saver. 757 saver = tf.compat.v1.train.Saver(...variables...) 758 # Launch the graph and train, saving the model every 1,000 steps. 759 sess = tf.compat.v1.Session() 760 for step in range(1000000): 761 sess.run(..training_op..) 762 if step % 1000 == 0: 763 # Append the step number to the checkpoint name: 764 saver.save(sess, 'my-model', global_step=step) 765 ``` 766 767 In addition to checkpoint files, savers keep a protocol buffer on disk with 768 the list of recent checkpoints. This is used to manage numbered checkpoint 769 files and by `latest_checkpoint()`, which makes it easy to discover the path 770 to the most recent checkpoint. That protocol buffer is stored in a file named 771 'checkpoint' next to the checkpoint files. 772 773 If you create several savers, you can specify a different filename for the 774 protocol buffer file in the call to `save()`. 775 """ 776 777 # pylint: enable=line-too-long 778 779 def __init__(self, 780 var_list=None, 781 reshape=False, 782 sharded=False, 783 max_to_keep=5, 784 keep_checkpoint_every_n_hours=10000.0, 785 name=None, 786 restore_sequentially=False, 787 saver_def=None, 788 builder=None, 789 defer_build=False, 790 allow_empty=False, 791 write_version=saver_pb2.SaverDef.V2, 792 pad_step_number=False, 793 save_relative_paths=False, 794 filename=None): 795 """Creates a `Saver`. 796 797 The constructor adds ops to save and restore variables. 798 799 `var_list` specifies the variables that will be saved and restored. It can 800 be passed as a `dict` or a list: 801 802 * A `dict` of names to variables: The keys are the names that will be 803 used to save or restore the variables in the checkpoint files. 804 * A list of variables: The variables will be keyed with their op name in 805 the checkpoint files. 806 807 For example: 808 809 ```python 810 v1 = tf.Variable(..., name='v1') 811 v2 = tf.Variable(..., name='v2') 812 813 # Pass the variables as a dict: 814 saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2}) 815 816 # Or pass them as a list. 817 saver = tf.compat.v1.train.Saver([v1, v2]) 818 # Passing a list is equivalent to passing a dict with the variable op names 819 # as keys: 820 saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]}) 821 ``` 822 823 Note: the newer `AutoTrackable` API is not supported by `Saver`. In this 824 case, the `tf.train.Checkpoint` class should be used. 825 826 The optional `reshape` argument, if `True`, allows restoring a variable from 827 a save file where the variable had a different shape, but the same number 828 of elements and type. This is useful if you have reshaped a variable and 829 want to reload it from an older checkpoint. 830 831 The optional `sharded` argument, if `True`, instructs the saver to shard 832 checkpoints per device. 833 834 Args: 835 var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping 836 names to `SaveableObject`s. If `None`, defaults to the list of all 837 saveable objects. 838 reshape: If `True`, allows restoring parameters from a checkpoint where 839 the variables have a different shape. 840 sharded: If `True`, shard the checkpoints, one per device. 841 max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5. 842 keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to 843 10,000 hours. 844 name: String. Optional name to use as a prefix when adding operations. 845 restore_sequentially: A `Bool`, which if true, causes restore of different 846 variables to happen sequentially within each device. This can lower 847 memory usage when restoring very large models. 848 saver_def: Optional `SaverDef` proto to use instead of running the 849 builder. This is only useful for specialty code that wants to recreate a 850 `Saver` object for a previously built `Graph` that had a `Saver`. The 851 `saver_def` proto should be the one returned by the `as_saver_def()` 852 call of the `Saver` that was created for that `Graph`. 853 builder: Optional `SaverBuilder` to use if a `saver_def` was not provided. 854 Defaults to `BulkSaverBuilder()`. 855 defer_build: If `True`, defer adding the save and restore ops to the 856 `build()` call. In that case `build()` should be called before 857 finalizing the graph or using the saver. 858 allow_empty: If `False` (default) raise an error if there are no variables 859 in the graph. Otherwise, construct the saver anyway and make it a no-op. 860 write_version: controls what format to use when saving checkpoints. It 861 also affects certain filepath matching logic. The V2 format is the 862 recommended choice: it is much more optimized than V1 in terms of memory 863 required and latency incurred during restore. Regardless of this flag, 864 the Saver is able to restore from both V2 and V1 checkpoints. 865 pad_step_number: if True, pads the global step number in the checkpoint 866 filepaths to some fixed width (8 by default). This is turned off by 867 default. 868 save_relative_paths: If `True`, will write relative paths to the 869 checkpoint state file. This is needed if the user wants to copy the 870 checkpoint directory and reload from the copied directory. 871 filename: If known at graph construction time, filename used for variable 872 loading/saving. 873 874 Raises: 875 TypeError: If `var_list` is invalid. 876 ValueError: If any of the keys or values in `var_list` are not unique. 877 RuntimeError: If eager execution is enabled and`var_list` does not specify 878 a list of variables to save. 879 880 @compatibility(eager) 881 When eager execution is enabled, `var_list` must specify a `list` or `dict` 882 of variables to save. Otherwise, a `RuntimeError` will be raised. 883 884 Although Saver works in some cases when executing eagerly, it is 885 fragile. Please switch to `tf.train.Checkpoint` or 886 `tf.keras.Model.save_weights`, which perform a more robust object-based 887 saving. These APIs will load checkpoints written by `Saver`. 888 @end_compatibility 889 """ 890 global _END_TIME_OF_LAST_WRITE 891 with _END_TIME_OF_LAST_WRITE_LOCK: 892 if _END_TIME_OF_LAST_WRITE is None: 893 _END_TIME_OF_LAST_WRITE = time.time() 894 895 if defer_build and var_list: 896 raise ValueError( 897 "If `var_list` is provided then build cannot be deferred. " 898 "Either set defer_build=False or var_list=None.") 899 if context.executing_eagerly(): 900 logging.warning( 901 "Saver is deprecated, please switch to tf.train.Checkpoint or " 902 "tf.keras.Model.save_weights for training checkpoints. When " 903 "executing eagerly variables do not necessarily have unique names, " 904 "and so the variable.name-based lookups Saver performs are " 905 "error-prone.") 906 if var_list is None: 907 raise RuntimeError( 908 "When eager execution is enabled, `var_list` must specify a list " 909 "or dict of variables to save") 910 self._var_list = var_list 911 self._reshape = reshape 912 self._sharded = sharded 913 self._max_to_keep = max_to_keep 914 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 915 self._name = name 916 self._restore_sequentially = restore_sequentially 917 self.saver_def = saver_def 918 self._builder = builder 919 self._is_built = False 920 self._allow_empty = allow_empty 921 self._is_empty = None 922 self._write_version = write_version 923 self._pad_step_number = pad_step_number 924 self._filename = filename 925 self._last_checkpoints = [] 926 self._checkpoints_to_be_deleted = [] 927 if context.executing_eagerly(): 928 self._next_checkpoint_time = ( 929 time.time() + self._keep_checkpoint_every_n_hours * 3600) 930 elif not defer_build: 931 self.build() 932 if self.saver_def: 933 self._check_saver_def() 934 self._write_version = self.saver_def.version 935 self._save_relative_paths = save_relative_paths 936 # For compatibility with object-based checkpoints, we may build a second 937 # Saver to read the renamed keys. 938 self._object_restore_saver = None 939 940 def build(self): 941 if context.executing_eagerly(): 942 raise RuntimeError("Use save/restore instead of build in eager mode.") 943 self._build(self._filename, build_save=True, build_restore=True) 944 945 def _build_eager(self, checkpoint_path, build_save, build_restore): 946 self._build( 947 checkpoint_path, build_save=build_save, build_restore=build_restore) 948 949 def _build(self, checkpoint_path, build_save, build_restore): 950 """Builds saver_def.""" 951 if not context.executing_eagerly(): 952 if self._is_built: 953 return 954 self._is_built = True 955 956 if not self.saver_def or context.executing_eagerly(): 957 if self._builder is None: 958 self._builder = BulkSaverBuilder(self._write_version) 959 960 if self._var_list is None: 961 # pylint: disable=protected-access 962 self._var_list = variables._all_saveable_objects() 963 if not self._var_list: 964 if self._allow_empty: 965 self._is_empty = True 966 return 967 else: 968 raise ValueError("No variables to save") 969 self._is_empty = False 970 971 self.saver_def = self._builder._build_internal( # pylint: disable=protected-access 972 self._var_list, 973 reshape=self._reshape, 974 sharded=self._sharded, 975 max_to_keep=self._max_to_keep, 976 keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours, 977 name=self._name, 978 restore_sequentially=self._restore_sequentially, 979 filename=checkpoint_path, 980 build_save=build_save, 981 build_restore=build_restore) 982 elif self.saver_def and self._name: 983 # Since self._name is used as a name_scope by builder(), we are 984 # overloading the use of this field to represent the "import_scope" as 985 # well. 986 self.saver_def.filename_tensor_name = ops.prepend_name_scope( 987 self.saver_def.filename_tensor_name, self._name) 988 self.saver_def.save_tensor_name = ops.prepend_name_scope( 989 self.saver_def.save_tensor_name, self._name) 990 self.saver_def.restore_op_name = ops.prepend_name_scope( 991 self.saver_def.restore_op_name, self._name) 992 993 self._check_saver_def() 994 if not context.executing_eagerly(): 995 # Updates next checkpoint time. 996 # Set in __init__ when executing eagerly. 997 self._next_checkpoint_time = ( 998 time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) 999 1000 def _check_saver_def(self): 1001 if not isinstance(self.saver_def, saver_pb2.SaverDef): 1002 raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" % 1003 self.saver_def) 1004 if not context.executing_eagerly(): 1005 if not self.saver_def.save_tensor_name: 1006 raise ValueError("saver_def must specify the save_tensor_name: %s" % 1007 str(self.saver_def)) 1008 if not self.saver_def.restore_op_name: 1009 raise ValueError("saver_def must specify the restore_op_name: %s" % 1010 str(self.saver_def)) 1011 1012 def _CheckpointFilename(self, p): 1013 """Returns the checkpoint filename given a `(filename, time)` pair. 1014 1015 Args: 1016 p: (filename, time) pair. 1017 1018 Returns: 1019 Checkpoint file name. 1020 """ 1021 name, _ = p 1022 return name 1023 1024 def _RecordLastCheckpoint(self, latest_save_path): 1025 """Manages the list of the latest checkpoints.""" 1026 if not self.saver_def.max_to_keep: 1027 return 1028 # Remove first from list if the same name was used before. 1029 for p in self._last_checkpoints: 1030 if latest_save_path == self._CheckpointFilename(p): 1031 self._last_checkpoints.remove(p) 1032 # Append new path to list 1033 self._last_checkpoints.append((latest_save_path, time.time())) 1034 1035 # If more than max_to_keep, remove oldest. 1036 if len(self._last_checkpoints) > self.saver_def.max_to_keep: 1037 self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0)) 1038 1039 def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"): 1040 """Deletes old checkpoints if necessary. 1041 1042 `self._checkpoints_to_be_deleted` is going to contain checkpoints that are 1043 over `max_to_keep`. They are going to be deleted. If 1044 `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint 1045 every `N` hours. For example, if `N` is 0.5, an additional checkpoint is 1046 kept for every 0.5 hours of training; if `N` is 10, an additional 1047 checkpoint is kept for every 10 hours of training. 1048 1049 Args: 1050 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 1051 """ 1052 if self._checkpoints_to_be_deleted: 1053 p = self._checkpoints_to_be_deleted.pop(0) 1054 # Do not delete the file if we keep_checkpoint_every_n_hours is set and we 1055 # have reached N hours of training. 1056 should_keep = p[1] > self._next_checkpoint_time 1057 if should_keep: 1058 self._next_checkpoint_time += ( 1059 self.saver_def.keep_checkpoint_every_n_hours * 3600) 1060 return 1061 1062 # Otherwise delete the files. 1063 try: 1064 checkpoint_management.remove_checkpoint( 1065 self._CheckpointFilename(p), self.saver_def.version, 1066 meta_graph_suffix) 1067 except Exception as e: # pylint: disable=broad-except 1068 logging.warning("Ignoring: %s", str(e)) 1069 1070 def as_saver_def(self): 1071 """Generates a `SaverDef` representation of this saver. 1072 1073 Returns: 1074 A `SaverDef` proto. 1075 """ 1076 return self.saver_def 1077 1078 def to_proto(self, export_scope=None): 1079 """Converts this `Saver` to a `SaverDef` protocol buffer. 1080 1081 Args: 1082 export_scope: Optional `string`. Name scope to remove. 1083 1084 Returns: 1085 A `SaverDef` protocol buffer. 1086 """ 1087 if export_scope is None: 1088 return self.saver_def 1089 1090 if not (self.saver_def.filename_tensor_name.startswith(export_scope) and 1091 self.saver_def.save_tensor_name.startswith(export_scope) and 1092 self.saver_def.restore_op_name.startswith(export_scope)): 1093 return None 1094 1095 saver_def = saver_pb2.SaverDef() 1096 saver_def.CopyFrom(self.saver_def) 1097 saver_def.filename_tensor_name = ops.strip_name_scope( 1098 saver_def.filename_tensor_name, export_scope) 1099 saver_def.save_tensor_name = ops.strip_name_scope( 1100 saver_def.save_tensor_name, export_scope) 1101 saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name, 1102 export_scope) 1103 return saver_def 1104 1105 @staticmethod 1106 def from_proto(saver_def, import_scope=None): 1107 """Returns a `Saver` object created from `saver_def`. 1108 1109 Args: 1110 saver_def: a `SaverDef` protocol buffer. 1111 import_scope: Optional `string`. Name scope to use. 1112 1113 Returns: 1114 A `Saver` built from saver_def. 1115 """ 1116 return Saver(saver_def=saver_def, name=import_scope) 1117 1118 @property 1119 def last_checkpoints(self): 1120 """List of not-yet-deleted checkpoint filenames. 1121 1122 You can pass any of the returned values to `restore()`. 1123 1124 Returns: 1125 A list of checkpoint filenames, sorted from oldest to newest. 1126 """ 1127 return list(self._CheckpointFilename(p) for p in self._last_checkpoints) 1128 1129 def set_last_checkpoints(self, last_checkpoints): 1130 """DEPRECATED: Use set_last_checkpoints_with_time. 1131 1132 Sets the list of old checkpoint filenames. 1133 1134 Args: 1135 last_checkpoints: A list of checkpoint filenames. 1136 1137 Raises: 1138 AssertionError: If last_checkpoints is not a list. 1139 """ 1140 assert isinstance(last_checkpoints, list) 1141 # We use a timestamp of +inf so that this checkpoint will never be 1142 # deleted. This is both safe and backwards compatible to a previous 1143 # version of the code which used s[1] as the "timestamp". 1144 self._last_checkpoints = [(s, np.inf) for s in last_checkpoints] 1145 1146 def set_last_checkpoints_with_time(self, last_checkpoints_with_time): 1147 """Sets the list of old checkpoint filenames and timestamps. 1148 1149 Args: 1150 last_checkpoints_with_time: A list of tuples of checkpoint filenames and 1151 timestamps. 1152 1153 Raises: 1154 AssertionError: If last_checkpoints_with_time is not a list. 1155 """ 1156 assert isinstance(last_checkpoints_with_time, list) 1157 self._last_checkpoints = last_checkpoints_with_time 1158 1159 def recover_last_checkpoints(self, checkpoint_paths): 1160 """Recovers the internal saver state after a crash. 1161 1162 This method is useful for recovering the "self._last_checkpoints" state. 1163 1164 Globs for the checkpoints pointed to by `checkpoint_paths`. If the files 1165 exist, use their mtime as the checkpoint timestamp. 1166 1167 Args: 1168 checkpoint_paths: a list of checkpoint paths. 1169 """ 1170 checkpoints_with_mtimes = [] 1171 for checkpoint_path in checkpoint_paths: 1172 try: 1173 mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path]) 1174 except errors.NotFoundError: 1175 # It's fine if some other thread/process is deleting some older 1176 # checkpoint concurrently. 1177 continue 1178 if mtime: 1179 checkpoints_with_mtimes.append((checkpoint_path, mtime[0])) 1180 self.set_last_checkpoints_with_time(checkpoints_with_mtimes) 1181 1182 def save(self, 1183 sess, 1184 save_path, 1185 global_step=None, 1186 latest_filename=None, 1187 meta_graph_suffix="meta", 1188 write_meta_graph=True, 1189 write_state=True, 1190 strip_default_attrs=False, 1191 save_debug_info=False): 1192 # pylint: disable=line-too-long 1193 """Saves variables. 1194 1195 This method runs the ops added by the constructor for saving variables. 1196 It requires a session in which the graph was launched. The variables to 1197 save must also have been initialized. 1198 1199 The method returns the path prefix of the newly created checkpoint files. 1200 This string can be passed directly to a call to `restore()`. 1201 1202 Args: 1203 sess: A Session to use to save the variables. 1204 save_path: String. Prefix of filenames created for the checkpoint. 1205 global_step: If provided the global step number is appended to `save_path` 1206 to create the checkpoint filenames. The optional argument can be a 1207 `Tensor`, a `Tensor` name or an integer. 1208 latest_filename: Optional name for the protocol buffer file that will 1209 contains the list of most recent checkpoints. That file, kept in the 1210 same directory as the checkpoint files, is automatically managed by the 1211 saver to keep track of recent checkpoints. Defaults to 'checkpoint'. 1212 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 1213 write_meta_graph: `Boolean` indicating whether or not to write the meta 1214 graph file. 1215 write_state: `Boolean` indicating whether or not to write the 1216 `CheckpointStateProto`. 1217 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1218 removed from the NodeDefs. For a detailed guide, see [Stripping 1219 Default-Valued 1220 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1221 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1222 which in the same directory of save_path and with `_debug` added before 1223 the file extension. This is only enabled when `write_meta_graph` is 1224 `True` 1225 1226 Returns: 1227 A string: path prefix used for the checkpoint files. If the saver is 1228 sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' 1229 is the number of shards created. 1230 If the saver is empty, returns None. 1231 1232 Raises: 1233 TypeError: If `sess` is not a `Session`. 1234 ValueError: If `latest_filename` contains path components, or if it 1235 collides with `save_path`. 1236 RuntimeError: If save and restore ops weren't built. 1237 """ 1238 # pylint: enable=line-too-long 1239 start_time = time.time() 1240 if not self._is_built and not context.executing_eagerly(): 1241 raise RuntimeError( 1242 "`build()` should be called before save if defer_build==True") 1243 if latest_filename is None: 1244 latest_filename = "checkpoint" 1245 if self._write_version != saver_pb2.SaverDef.V2: 1246 logging.warning("*******************************************************") 1247 logging.warning("TensorFlow's V1 checkpoint format has been deprecated.") 1248 logging.warning("Consider switching to the more efficient V2 format:") 1249 logging.warning(" `tf.train.Saver(write_version=tf.train.SaverDef.V2)`") 1250 logging.warning("now on by default.") 1251 logging.warning("*******************************************************") 1252 1253 if os.path.split(latest_filename)[0]: 1254 raise ValueError("'latest_filename' must not contain path components") 1255 1256 save_path = compat.as_str(save_path) 1257 if global_step is not None: 1258 if not isinstance(global_step, compat.integral_types): 1259 global_step = training_util.global_step(sess, global_step) 1260 checkpoint_file = "%s-%d" % (save_path, global_step) 1261 if self._pad_step_number: 1262 # Zero-pads the step numbers, so that they are sorted when listed. 1263 checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step)) 1264 else: 1265 checkpoint_file = save_path 1266 if os.path.basename(save_path) == latest_filename and not self._sharded: 1267 # Guard against collision between data file and checkpoint state file. 1268 raise ValueError( 1269 "'latest_filename' collides with 'save_path': '%s' and '%s'" % 1270 (latest_filename, save_path)) 1271 1272 if (not context.executing_eagerly() and 1273 not isinstance(sess, session.SessionInterface)): 1274 raise TypeError("'sess' must be a Session; %s" % sess) 1275 1276 save_path_parent = os.path.dirname(save_path) 1277 if not self._is_empty: 1278 try: 1279 if context.executing_eagerly(): 1280 self._build_eager( 1281 checkpoint_file, build_save=True, build_restore=False) 1282 model_checkpoint_path = self.saver_def.save_tensor_name 1283 else: 1284 model_checkpoint_path = sess.run( 1285 self.saver_def.save_tensor_name, 1286 {self.saver_def.filename_tensor_name: checkpoint_file}) 1287 1288 model_checkpoint_path = compat.as_str(model_checkpoint_path) 1289 if write_state: 1290 self._RecordLastCheckpoint(model_checkpoint_path) 1291 checkpoint_management.update_checkpoint_state_internal( 1292 save_dir=save_path_parent, 1293 model_checkpoint_path=model_checkpoint_path, 1294 all_model_checkpoint_paths=self.last_checkpoints, 1295 latest_filename=latest_filename, 1296 save_relative_paths=self._save_relative_paths) 1297 self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix) 1298 except (errors.FailedPreconditionError, errors.NotFoundError) as exc: 1299 if not gfile.IsDirectory(save_path_parent): 1300 exc = ValueError( 1301 "Parent directory of {} doesn't exist, can't save.".format( 1302 save_path)) 1303 raise exc 1304 1305 end_time = time.time() 1306 metrics.AddCheckpointWriteDuration( 1307 api_label=_SAVER_LABEL, 1308 microseconds=_get_duration_microseconds(start_time, end_time)) 1309 global _END_TIME_OF_LAST_WRITE 1310 with _END_TIME_OF_LAST_WRITE_LOCK: 1311 metrics.AddTrainingTimeSaved( 1312 api_label=_SAVER_LABEL, 1313 microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE, 1314 end_time)) 1315 _END_TIME_OF_LAST_WRITE = end_time 1316 1317 if write_meta_graph: 1318 meta_graph_filename = checkpoint_management.meta_graph_filename( 1319 checkpoint_file, meta_graph_suffix=meta_graph_suffix) 1320 if not context.executing_eagerly(): 1321 with sess.graph.as_default(): 1322 self.export_meta_graph( 1323 meta_graph_filename, 1324 strip_default_attrs=strip_default_attrs, 1325 save_debug_info=save_debug_info) 1326 1327 if self._is_empty: 1328 return None 1329 else: 1330 metrics.RecordCheckpointSize( 1331 api_label=_SAVER_LABEL, 1332 filesize=_get_checkpoint_size(model_checkpoint_path)) 1333 return model_checkpoint_path 1334 1335 def export_meta_graph(self, 1336 filename=None, 1337 collection_list=None, 1338 as_text=False, 1339 export_scope=None, 1340 clear_devices=False, 1341 clear_extraneous_savers=False, 1342 strip_default_attrs=False, 1343 save_debug_info=False): 1344 # pylint: disable=line-too-long 1345 """Writes `MetaGraphDef` to save_path/filename. 1346 1347 Args: 1348 filename: Optional meta_graph filename including the path. 1349 collection_list: List of string keys to collect. 1350 as_text: If `True`, writes the meta_graph as an ASCII proto. 1351 export_scope: Optional `string`. Name scope to remove. 1352 clear_devices: Whether or not to clear the device field for an `Operation` 1353 or `Tensor` during export. 1354 clear_extraneous_savers: Remove any Saver-related information from the 1355 graph (both Save/Restore ops and SaverDefs) that are not associated with 1356 this Saver. 1357 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1358 removed from the NodeDefs. For a detailed guide, see [Stripping 1359 Default-Valued 1360 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1361 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1362 which in the same directory of filename and with `_debug` added before 1363 the file extension. 1364 1365 Returns: 1366 A `MetaGraphDef` proto. 1367 """ 1368 # pylint: enable=line-too-long 1369 return export_meta_graph( 1370 filename=filename, 1371 graph_def=ops.get_default_graph().as_graph_def(add_shapes=True), 1372 saver_def=self.saver_def, 1373 collection_list=collection_list, 1374 as_text=as_text, 1375 export_scope=export_scope, 1376 clear_devices=clear_devices, 1377 clear_extraneous_savers=clear_extraneous_savers, 1378 strip_default_attrs=strip_default_attrs, 1379 save_debug_info=save_debug_info) 1380 1381 def restore(self, sess, save_path): 1382 """Restores previously saved variables. 1383 1384 This method runs the ops added by the constructor for restoring variables. 1385 It requires a session in which the graph was launched. The variables to 1386 restore do not have to have been initialized, as restoring is itself a way 1387 to initialize variables. 1388 1389 The `save_path` argument is typically a value previously returned from a 1390 `save()` call, or a call to `latest_checkpoint()`. 1391 1392 Args: 1393 sess: A `Session` to use to restore the parameters. None in eager mode. 1394 save_path: Path where parameters were previously saved. 1395 1396 Raises: 1397 ValueError: If save_path is None or not a valid checkpoint. 1398 """ 1399 start_time = time.time() 1400 if self._is_empty: 1401 return 1402 if save_path is None: 1403 raise ValueError("Can't load save_path when it is None.") 1404 1405 checkpoint_prefix = compat.as_text(save_path) 1406 if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix): 1407 raise ValueError("The passed save_path is not a valid checkpoint: " + 1408 checkpoint_prefix) 1409 1410 logging.info("Restoring parameters from %s", checkpoint_prefix) 1411 try: 1412 if context.executing_eagerly(): 1413 self._build_eager(save_path, build_save=False, build_restore=True) 1414 else: 1415 sess.run(self.saver_def.restore_op_name, 1416 {self.saver_def.filename_tensor_name: save_path}) 1417 except errors.NotFoundError as err: 1418 # There are three common conditions that might cause this error: 1419 # 0. The file is missing. We ignore here, as this is checked above. 1420 # 1. This is an object-based checkpoint trying name-based loading. 1421 # 2. The graph has been altered and a variable or other name is missing. 1422 1423 # 1. The checkpoint would not be loaded successfully as is. Try to parse 1424 # it as an object-based checkpoint. 1425 try: 1426 names_to_keys = object_graph_key_mapping(save_path) 1427 except errors.NotFoundError: 1428 # 2. This is not an object-based checkpoint, which likely means there 1429 # is a graph mismatch. Re-raise the original error with 1430 # a helpful message (b/110263146) 1431 raise _wrap_restore_error_with_msg( 1432 err, "a Variable name or other graph key that is missing") 1433 1434 # This is an object-based checkpoint. We'll print a warning and then do 1435 # the restore. 1436 logging.warning( 1437 "Restoring an object-based checkpoint using a name-based saver. This " 1438 "may be somewhat fragile, and will re-build the Saver. Instead, " 1439 "consider loading object-based checkpoints using " 1440 "tf.train.Checkpoint().") 1441 self._object_restore_saver = saver_from_object_based_checkpoint( 1442 checkpoint_path=save_path, 1443 var_list=self._var_list, 1444 builder=self._builder, 1445 names_to_keys=names_to_keys, 1446 cached_saver=self._object_restore_saver) 1447 self._object_restore_saver.restore(sess=sess, save_path=save_path) 1448 except errors.InvalidArgumentError as err: 1449 # There is a mismatch between the graph and the checkpoint being loaded. 1450 # We add a more reasonable error message here to help users (b/110263146) 1451 raise _wrap_restore_error_with_msg( 1452 err, "a mismatch between the current graph and the graph") 1453 metrics.AddCheckpointReadDuration( 1454 api_label=_SAVER_LABEL, 1455 microseconds=_get_duration_microseconds(start_time, time.time())) 1456 1457 @staticmethod 1458 def _add_collection_def(meta_graph_def, key, export_scope=None): 1459 """Adds a collection to MetaGraphDef protocol buffer. 1460 1461 Args: 1462 meta_graph_def: MetaGraphDef protocol buffer. 1463 key: One of the GraphKeys or user-defined string. 1464 export_scope: Optional `string`. Name scope to remove. 1465 """ 1466 meta_graph.add_collection_def( 1467 meta_graph_def, key, export_scope=export_scope) 1468 1469 1470@tf_export(v1=["train.import_meta_graph"]) 1471def import_meta_graph(meta_graph_or_file, 1472 clear_devices=False, 1473 import_scope=None, 1474 **kwargs): 1475 """Recreates a Graph saved in a `MetaGraphDef` proto. 1476 1477 This function takes a `MetaGraphDef` protocol buffer as input. If 1478 the argument is a file containing a `MetaGraphDef` protocol buffer , 1479 it constructs a protocol buffer from the file content. The function 1480 then adds all the nodes from the `graph_def` field to the 1481 current graph, recreates all the collections, and returns a saver 1482 constructed from the `saver_def` field. 1483 1484 In combination with `export_meta_graph()`, this function can be used to 1485 1486 * Serialize a graph along with other Python objects such as `QueueRunner`, 1487 `Variable` into a `MetaGraphDef`. 1488 1489 * Restart training from a saved graph and checkpoints. 1490 1491 * Run inference from a saved graph and checkpoints. 1492 1493 ```Python 1494 ... 1495 # Create a saver. 1496 saver = tf.compat.v1.train.Saver(...variables...) 1497 # Remember the training_op we want to run by adding it to a collection. 1498 tf.compat.v1.add_to_collection('train_op', train_op) 1499 sess = tf.compat.v1.Session() 1500 for step in range(1000000): 1501 sess.run(train_op) 1502 if step % 1000 == 0: 1503 # Saves checkpoint, which by default also exports a meta_graph 1504 # named 'my-model-global_step.meta'. 1505 saver.save(sess, 'my-model', global_step=step) 1506 ``` 1507 1508 Later we can continue training from this saved `meta_graph` without building 1509 the model from scratch. 1510 1511 ```Python 1512 with tf.Session() as sess: 1513 new_saver = 1514 tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') 1515 new_saver.restore(sess, 'my-save-dir/my-model-10000') 1516 # tf.get_collection() returns a list. In this example we only want 1517 # the first one. 1518 train_op = tf.get_collection('train_op')[0] 1519 for step in range(1000000): 1520 sess.run(train_op) 1521 ``` 1522 1523 NOTE: Restarting training from saved `meta_graph` only works if the 1524 device assignments have not changed. 1525 1526 Example: 1527 Variables, placeholders, and independent operations can also be stored, as 1528 shown in the following example. 1529 1530 ```Python 1531 # Saving contents and operations. 1532 v1 = tf.placeholder(tf.float32, name="v1") 1533 v2 = tf.placeholder(tf.float32, name="v2") 1534 v3 = tf.math.multiply(v1, v2) 1535 vx = tf.Variable(10.0, name="vx") 1536 v4 = tf.add(v3, vx, name="v4") 1537 saver = tf.train.Saver([vx]) 1538 sess = tf.Session() 1539 sess.run(tf.global_variables_initializer()) 1540 sess.run(vx.assign(tf.add(vx, vx))) 1541 result = sess.run(v4, feed_dict={v1:12.0, v2:3.3}) 1542 print(result) 1543 saver.save(sess, "./model_ex1") 1544 ``` 1545 1546 Later this model can be restored and contents loaded. 1547 1548 ```Python 1549 # Restoring variables and running operations. 1550 saver = tf.train.import_meta_graph("./model_ex1.meta") 1551 sess = tf.Session() 1552 saver.restore(sess, "./model_ex1") 1553 result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) 1554 print(result) 1555 ``` 1556 1557 Args: 1558 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 1559 the path) containing a `MetaGraphDef`. 1560 clear_devices: Whether or not to clear the device field for an `Operation` 1561 or `Tensor` during import. 1562 import_scope: Optional `string`. Name scope to add. Only used when 1563 initializing from protocol buffer. 1564 **kwargs: Optional keyed arguments. 1565 1566 Returns: 1567 A saver constructed from `saver_def` in `MetaGraphDef` or None. 1568 1569 A None value is returned if no variables exist in the `MetaGraphDef` 1570 (i.e., there are no variables to restore). 1571 1572 Raises: 1573 RuntimeError: If called with eager execution enabled. 1574 1575 @compatibility(eager) 1576 Exporting/importing meta graphs is not supported. No graph exists when eager 1577 execution is enabled. 1578 @end_compatibility 1579 """ # pylint: disable=g-doc-exception 1580 return _import_meta_graph_with_return_elements(meta_graph_or_file, 1581 clear_devices, import_scope, 1582 **kwargs)[0] 1583 1584 1585def _import_meta_graph_with_return_elements(meta_graph_or_file, 1586 clear_devices=False, 1587 import_scope=None, 1588 return_elements=None, 1589 **kwargs): 1590 """Import MetaGraph, and return both a saver and returned elements.""" 1591 if context.executing_eagerly(): 1592 raise RuntimeError("Exporting/importing meta graphs is not supported when " 1593 "eager execution is enabled. No graph exists when eager " 1594 "execution is enabled.") 1595 if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): 1596 meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file) 1597 else: 1598 meta_graph_def = meta_graph_or_file 1599 1600 imported_vars, imported_return_elements = ( 1601 meta_graph.import_scoped_meta_graph_with_return_elements( 1602 meta_graph_def, 1603 clear_devices=clear_devices, 1604 import_scope=import_scope, 1605 return_elements=return_elements, 1606 **kwargs)) 1607 1608 saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope, 1609 imported_vars) 1610 return saver, imported_return_elements 1611 1612 1613def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope, 1614 imported_vars): 1615 """Return a saver for restoring variable values to an imported MetaGraph.""" 1616 if meta_graph_def.HasField("saver_def"): 1617 # Infer the scope that is prepended by `import_scoped_meta_graph`. 1618 scope = import_scope 1619 var_names = list(imported_vars.keys()) 1620 if var_names: 1621 sample_key = var_names[0] 1622 sample_var = imported_vars[sample_key] 1623 scope = sample_var.name[:-len(sample_key)] 1624 1625 return Saver(saver_def=meta_graph_def.saver_def, name=scope) 1626 else: 1627 if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access 1628 # Return the default saver instance for all graph variables. 1629 return Saver() 1630 else: 1631 # If no graph variables exist, then a Saver cannot be constructed. 1632 logging.info("Saver not created because there are no variables in the" 1633 " graph to restore") 1634 return None 1635 1636 1637@tf_export(v1=["train.export_meta_graph"]) 1638def export_meta_graph(filename=None, 1639 meta_info_def=None, 1640 graph_def=None, 1641 saver_def=None, 1642 collection_list=None, 1643 as_text=False, 1644 graph=None, 1645 export_scope=None, 1646 clear_devices=False, 1647 clear_extraneous_savers=False, 1648 strip_default_attrs=False, 1649 save_debug_info=False, 1650 **kwargs): 1651 # pylint: disable=line-too-long 1652 """Returns `MetaGraphDef` proto. 1653 1654 Optionally writes it to filename. 1655 1656 This function exports the graph, saver, and collection objects into 1657 `MetaGraphDef` protocol buffer with the intention of it being imported 1658 at a later time or location to restart training, run inference, or be 1659 a subgraph. 1660 1661 Args: 1662 filename: Optional filename including the path for writing the generated 1663 `MetaGraphDef` protocol buffer. 1664 meta_info_def: `MetaInfoDef` protocol buffer. 1665 graph_def: `GraphDef` protocol buffer. 1666 saver_def: `SaverDef` protocol buffer. 1667 collection_list: List of string keys to collect. 1668 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. 1669 graph: The `Graph` to export. If `None`, use the default graph. 1670 export_scope: Optional `string`. Name scope under which to extract the 1671 subgraph. The scope name will be striped from the node definitions for 1672 easy import later into new name scopes. If `None`, the whole graph is 1673 exported. graph_def and export_scope cannot both be specified. 1674 clear_devices: Whether or not to clear the device field for an `Operation` 1675 or `Tensor` during export. 1676 clear_extraneous_savers: Remove any Saver-related information from the graph 1677 (both Save/Restore ops and SaverDefs) that are not associated with the 1678 provided SaverDef. 1679 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1680 removed from the NodeDefs. For a detailed guide, see [Stripping 1681 Default-Valued 1682 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1683 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 1684 which in the same directory of filename and with `_debug` added before the 1685 file extend. 1686 **kwargs: Optional keyed arguments. 1687 1688 Returns: 1689 A `MetaGraphDef` proto. 1690 1691 Raises: 1692 ValueError: When the `GraphDef` is larger than 2GB. 1693 RuntimeError: If called with eager execution enabled. 1694 1695 @compatibility(eager) 1696 Exporting/importing meta graphs is not supported unless both `graph_def` and 1697 `graph` are provided. No graph exists when eager execution is enabled. 1698 @end_compatibility 1699 """ 1700 # pylint: enable=line-too-long 1701 if context.executing_eagerly() and not (graph_def is not None and 1702 graph is not None): 1703 raise RuntimeError("Exporting/importing meta graphs is not supported when " 1704 "eager execution is enabled. No graph exists when eager " 1705 "execution is enabled.") 1706 meta_graph_def, _ = meta_graph.export_scoped_meta_graph( 1707 filename=filename, 1708 meta_info_def=meta_info_def, 1709 graph_def=graph_def, 1710 saver_def=saver_def, 1711 collection_list=collection_list, 1712 as_text=as_text, 1713 graph=graph, 1714 export_scope=export_scope, 1715 clear_devices=clear_devices, 1716 clear_extraneous_savers=clear_extraneous_savers, 1717 strip_default_attrs=strip_default_attrs, 1718 save_debug_info=save_debug_info, 1719 **kwargs) 1720 return meta_graph_def 1721 1722 1723def _wrap_restore_error_with_msg(err, extra_verbiage): 1724 err_msg = ("Restoring from checkpoint failed. This is most likely " 1725 "due to {} from the checkpoint. Please ensure that you " 1726 "have not altered the graph expected based on the checkpoint. " 1727 "Original error:\n\n{}").format(extra_verbiage, err.message) 1728 return err.__class__(err.node_def, err.op, err_msg) 1729 1730 1731ops.register_proto_function( 1732 ops.GraphKeys.SAVERS, 1733 proto_type=saver_pb2.SaverDef, 1734 to_proto=Saver.to_proto, 1735 from_proto=Saver.from_proto) 1736 1737 1738def object_graph_key_mapping(checkpoint_path): 1739 """Return name to key mappings from the checkpoint. 1740 1741 Args: 1742 checkpoint_path: string, path to object-based checkpoint 1743 1744 Returns: 1745 Dictionary mapping tensor names to checkpoint keys. 1746 """ 1747 reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path) 1748 object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY) 1749 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 1750 object_graph_proto.ParseFromString(object_graph_string) 1751 names_to_keys = {} 1752 for node in object_graph_proto.nodes: 1753 for attribute in node.attributes: 1754 names_to_keys[attribute.full_name] = attribute.checkpoint_key 1755 return names_to_keys 1756 1757 1758def saver_from_object_based_checkpoint(checkpoint_path, 1759 var_list=None, 1760 builder=None, 1761 names_to_keys=None, 1762 cached_saver=None): 1763 """Return a `Saver` which reads from an object-based checkpoint. 1764 1765 This function validates that all variables in the variables list are remapped 1766 in the object-based checkpoint (or `names_to_keys` dict if provided). A 1767 saver will be created with the list of remapped variables. 1768 1769 The `cached_saver` argument allows the user to pass in a previously created 1770 saver, so multiple `saver.restore()` calls don't pollute the graph when graph 1771 building. This assumes that keys are consistent, meaning that the 1772 1) `checkpoint_path` checkpoint, and 1773 2) checkpoint used to create the `cached_saver` 1774 are the same type of object-based checkpoint. If this argument is set, this 1775 function will simply validate that all variables have been remapped by the 1776 checkpoint at `checkpoint_path`. 1777 1778 Note that in general, `tf.train.Checkpoint` should be used to restore/save an 1779 object-based checkpoint. 1780 1781 Args: 1782 checkpoint_path: string, path to object-based checkpoint 1783 var_list: list of `Variables` that appear in the checkpoint. If `None`, 1784 `var_list` will be set to all saveable objects. 1785 builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder` 1786 will be created. 1787 names_to_keys: dict mapping string tensor names to checkpoint keys. If 1788 `None`, this dict will be generated from the checkpoint file. 1789 cached_saver: Cached `Saver` object with remapped variables. 1790 1791 Returns: 1792 `Saver` with remapped variables for reading from an object-based checkpoint. 1793 1794 Raises: 1795 ValueError if the checkpoint provided is not an object-based checkpoint. 1796 NotFoundError: If one of the variables in `var_list` can not be found in the 1797 checkpoint. This could mean the checkpoint or `names_to_keys` mapping is 1798 missing the variable. 1799 """ 1800 if names_to_keys is None: 1801 try: 1802 names_to_keys = object_graph_key_mapping(checkpoint_path) 1803 except errors.NotFoundError: 1804 raise ValueError("Checkpoint in %s not an object-based checkpoint." % 1805 checkpoint_path) 1806 if var_list is None: 1807 var_list = variables._all_saveable_objects() # pylint: disable=protected-access 1808 if builder is None: 1809 builder = BulkSaverBuilder() 1810 1811 saveables = saveable_object_util.validate_and_slice_inputs(var_list) 1812 current_names = set() 1813 for saveable in saveables: 1814 for spec in saveable.specs: 1815 current_names.add(spec.name) 1816 previous_names = set(names_to_keys.keys()) 1817 missing_names = current_names - previous_names 1818 if missing_names: 1819 extra_names = previous_names - current_names 1820 intersecting_names = previous_names.intersection(current_names) 1821 raise errors.NotFoundError( 1822 None, 1823 None, 1824 message=( 1825 "\n\nExisting variables not in the checkpoint: %s\n\n" 1826 "Variables names when this checkpoint was written which don't " 1827 "exist now: %s\n\n" 1828 "(%d variable name(s) did match)\n\n" 1829 "Could not find some variables in the checkpoint (see names " 1830 "above). Saver was attempting to load an object-based checkpoint " 1831 "(saved using tf.train.Checkpoint or tf.keras.Model.save_weights) " 1832 "using variable names. If the checkpoint was written with eager " 1833 "execution enabled, it's possible that variable names have " 1834 "changed (for example missing a '_1' suffix). It's also " 1835 "possible that there are new variables which did not exist " 1836 "when the checkpoint was written. You can construct a " 1837 "Saver(var_list=...) with only the variables which previously " 1838 "existed, and if variable names have changed you may need to " 1839 "make this a dictionary with the old names as keys. If you're " 1840 "using an Estimator, you'll need to return a tf.train.Saver " 1841 "inside a tf.train.Scaffold from your model_fn.") % 1842 (", ".join(sorted(missing_names)), ", ".join( 1843 sorted(extra_names)), len(intersecting_names))) 1844 for saveable in saveables: 1845 for spec in saveable.specs: 1846 spec.name = names_to_keys[spec.name] 1847 if cached_saver is None: 1848 return Saver(saveables) 1849 return cached_saver 1850