1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools to work with name-based checkpoints. 16 17While some of these symbols also work with the TF2 object-based checkpoints, 18they are not recommended for TF2. Please check `tensorflow/python/checkpoint` 19for newer utilities built to work with TF2 checkpoints. 20""" 21 22from collections import abc 23import os 24import time 25 26from tensorflow.python.checkpoint import checkpoint_management 27from tensorflow.python.distribute import distribution_strategy_context 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import io_ops 30from tensorflow.python.ops import resource_variable_ops 31from tensorflow.python.ops import variable_scope as vs 32from tensorflow.python.ops import variables 33from tensorflow.python.platform import gfile 34from tensorflow.python.platform import tf_logging as logging 35from tensorflow.python.training import py_checkpoint_reader 36from tensorflow.python.training.saving import saveable_object_util 37from tensorflow.python.util.tf_export import tf_export 38 39 40__all__ = [ 41 "load_checkpoint", "load_variable", "list_variables", 42 "checkpoints_iterator", "init_from_checkpoint" 43] 44 45 46@tf_export("train.load_checkpoint") 47def load_checkpoint(ckpt_dir_or_file): 48 """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`. 49 50 If `ckpt_dir_or_file` resolves to a directory with multiple checkpoints, 51 reader for the latest checkpoint is returned. 52 53 Args: 54 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint 55 file. 56 57 Returns: 58 `CheckpointReader` object. 59 60 Raises: 61 ValueError: If `ckpt_dir_or_file` resolves to a directory with no 62 checkpoints. 63 """ 64 filename = _get_checkpoint_filename(ckpt_dir_or_file) 65 if filename is None: 66 raise ValueError("Couldn't find 'checkpoint' file or checkpoints in " 67 "given directory %s" % ckpt_dir_or_file) 68 return py_checkpoint_reader.NewCheckpointReader(filename) 69 70 71@tf_export("train.load_variable") 72def load_variable(ckpt_dir_or_file, name): 73 """Returns the tensor value of the given variable in the checkpoint. 74 75 Args: 76 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 77 name: Name of the variable to return. 78 79 Returns: 80 A numpy `ndarray` with a copy of the value of this variable. 81 """ 82 # TODO(b/29227106): Fix this in the right place and remove this. 83 if name.endswith(":0"): 84 name = name[:-2] 85 reader = load_checkpoint(ckpt_dir_or_file) 86 return reader.get_tensor(name) 87 88 89@tf_export("train.list_variables") 90def list_variables(ckpt_dir_or_file): 91 """Lists the checkpoint keys and shapes of variables in a checkpoint. 92 93 Checkpoint keys are paths in a checkpoint graph. 94 95 Example usage: 96 97 ```python 98 import tensorflow as tf 99 import os 100 ckpt_directory = "/tmp/training_checkpoints/ckpt" 101 ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model) 102 manager = tf.train.CheckpointManager(ckpt, ckpt_directory, max_to_keep=3) 103 train_and_checkpoint(model, manager) 104 tf.train.list_variables(manager.latest_checkpoint) 105 ``` 106 107 Args: 108 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 109 110 Returns: 111 List of tuples `(key, shape)`. 112 """ 113 reader = load_checkpoint(ckpt_dir_or_file) 114 variable_map = reader.get_variable_to_shape_map() 115 names = sorted(variable_map.keys()) 116 result = [] 117 for name in names: 118 result.append((name, variable_map[name])) 119 return result 120 121 122def wait_for_new_checkpoint(checkpoint_dir, 123 last_checkpoint=None, 124 seconds_to_sleep=1, 125 timeout=None): 126 """Waits until a new checkpoint file is found. 127 128 Args: 129 checkpoint_dir: The directory in which checkpoints are saved. 130 last_checkpoint: The last checkpoint path used or `None` if we're expecting 131 a checkpoint for the first time. 132 seconds_to_sleep: The number of seconds to sleep for before looking for a 133 new checkpoint. 134 timeout: The maximum number of seconds to wait. If left as `None`, then the 135 process will wait indefinitely. 136 137 Returns: 138 a new checkpoint path, or None if the timeout was reached. 139 """ 140 logging.info("Waiting for new checkpoint at %s", checkpoint_dir) 141 stop_time = time.time() + timeout if timeout is not None else None 142 while True: 143 checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir) 144 if checkpoint_path is None or checkpoint_path == last_checkpoint: 145 if stop_time is not None and time.time() + seconds_to_sleep > stop_time: 146 return None 147 time.sleep(seconds_to_sleep) 148 else: 149 logging.info("Found new checkpoint at %s", checkpoint_path) 150 return checkpoint_path 151 152 153@tf_export("train.checkpoints_iterator") 154def checkpoints_iterator(checkpoint_dir, 155 min_interval_secs=0, 156 timeout=None, 157 timeout_fn=None): 158 """Continuously yield new checkpoint files as they appear. 159 160 The iterator only checks for new checkpoints when control flow has been 161 reverted to it. This means it can miss checkpoints if your code takes longer 162 to run between iterations than `min_interval_secs` or the interval at which 163 new checkpoints are written. 164 165 The `timeout` argument is the maximum number of seconds to block waiting for 166 a new checkpoint. It is used in combination with the `timeout_fn` as 167 follows: 168 169 * If the timeout expires and no `timeout_fn` was specified, the iterator 170 stops yielding. 171 * If a `timeout_fn` was specified, that function is called and if it returns 172 a true boolean value the iterator stops yielding. 173 * If the function returns a false boolean value then the iterator resumes the 174 wait for new checkpoints. At this point the timeout logic applies again. 175 176 This behavior gives control to callers on what to do if checkpoints do not 177 come fast enough or stop being generated. For example, if callers have a way 178 to detect that the training has stopped and know that no new checkpoints 179 will be generated, they can provide a `timeout_fn` that returns `True` when 180 the training has stopped. If they know that the training is still going on 181 they return `False` instead. 182 183 Args: 184 checkpoint_dir: The directory in which checkpoints are saved. 185 min_interval_secs: The minimum number of seconds between yielding 186 checkpoints. 187 timeout: The maximum number of seconds to wait between checkpoints. If left 188 as `None`, then the process will wait indefinitely. 189 timeout_fn: Optional function to call after a timeout. If the function 190 returns True, then it means that no new checkpoints will be generated and 191 the iterator will exit. The function is called with no arguments. 192 193 Yields: 194 String paths to latest checkpoint files as they arrive. 195 """ 196 checkpoint_path = None 197 while True: 198 new_checkpoint_path = wait_for_new_checkpoint( 199 checkpoint_dir, checkpoint_path, timeout=timeout) 200 if new_checkpoint_path is None: 201 if not timeout_fn: 202 # timed out 203 logging.info("Timed-out waiting for a checkpoint.") 204 return 205 if timeout_fn(): 206 # The timeout_fn indicated that we are truly done. 207 return 208 else: 209 # The timeout_fn indicated that more checkpoints may come. 210 continue 211 start = time.time() 212 checkpoint_path = new_checkpoint_path 213 yield checkpoint_path 214 time_to_next_eval = start + min_interval_secs - time.time() 215 if time_to_next_eval > 0: 216 time.sleep(time_to_next_eval) 217 218 219@tf_export(v1=["train.init_from_checkpoint"]) 220def init_from_checkpoint(ckpt_dir_or_file, assignment_map): 221 """Replaces `tf.Variable` initializers so they load from a checkpoint file. 222 223 @compatibility(TF2) 224 `tf.compat.v1.train.init_from_checkpoint` is not recommended for restoring 225 variable values in TF2. 226 227 To restore checkpoints in TF2, please use 228 `tf.keras.Model.load_weights` or `tf.train.Checkpoint.restore`. These APIs use 229 use an [object-based method of checkpointing] 230 (https://www.tensorflow.org/guide/checkpoint#loading_mechanics), while 231 `tf.compat.v1.init_from_checkpoint` relies on a more-fragile variable-name 232 based method of checkpointing. There is no object-based equivalent of 233 `init_from_checkpoint` in TF2. 234 235 Please re-write your checkpoints immediately using the object-based APIs, 236 see [migration guide] 237 (https://www.tensorflow.org/guide/migrate#checkpoint_compatibility) for more 238 details. 239 240 You can load a name-based checkpoint written by `tf.compat.v1.train.Saver` 241 using `tf.train.Checkpoint.restore` or `tf.keras.Model.load_weights`. However, 242 you may have to change the names of the variables in your model to match the 243 variable names in the name-based checkpoint, which can be viewed with 244 `tf.train.list_variables(path)`. 245 246 Another option is to create an `assignment_map` that maps the name of the 247 variables in the name-based checkpoint to the variables in your model, eg: 248 ``` 249 { 250 'sequential/dense/bias': model.variables[0], 251 'sequential/dense/kernel': model.variables[1] 252 } 253 ``` 254 and use `tf.compat.v1.train.init_from_checkpoint(path, assignment_map)` to 255 restore the name-based checkpoint. 256 257 After restoring, re-encode your checkpoint using `tf.train.Checkpoint.save` 258 or `tf.keras.Model.save_weights`. 259 260 @end_compatibility 261 262 Values are not loaded immediately, but when the initializer is run 263 (typically by running a `tf.compat.v1.global_variables_initializer` op). 264 265 Note: This overrides default initialization ops of specified variables and 266 redefines dtype. 267 268 Assignment map supports following syntax: 269 270 * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in 271 current `scope_name` from `checkpoint_scope_name` with matching tensor 272 names. 273 * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` - 274 will initialize `scope_name/variable_name` variable 275 from `checkpoint_scope_name/some_other_variable`. 276 * `'scope_variable_name': variable` - will initialize given `tf.Variable` 277 object with tensor 'scope_variable_name' from the checkpoint. 278 * `'scope_variable_name': list(variable)` - will initialize list of 279 partitioned variables with tensor 'scope_variable_name' from the checkpoint. 280 * `'/': 'scope_name/'` - will load all variables in current `scope_name` from 281 checkpoint's root (e.g. no scope). 282 283 Supports loading into partitioned variables, which are represented as 284 `'<variable>/part_<part #>'`. 285 286 Assignment map can be a dict, or a list of pairs. The latter is 287 necessary to initialize multiple variables in the current graph from 288 the same variable in the checkpoint. 289 290 Example: 291 292 ```python 293 294 # Say, '/tmp/model.ckpt' has the following tensors: 295 # -- name='old_scope_1/var1', shape=[20, 2] 296 # -- name='old_scope_1/var2', shape=[50, 4] 297 # -- name='old_scope_2/var3', shape=[100, 100] 298 299 # Create new model's variables 300 with tf.compat.v1.variable_scope('new_scope_1'): 301 var1 = tf.compat.v1.get_variable('var1', shape=[20, 2], 302 initializer=tf.compat.v1.zeros_initializer()) 303 with tf.compat.v1.variable_scope('new_scope_2'): 304 var2 = tf.compat.v1.get_variable('var2', shape=[50, 4], 305 initializer=tf.compat.v1.zeros_initializer()) 306 # Partition into 5 variables along the first axis. 307 var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100], 308 initializer=tf.compat.v1.zeros_initializer(), 309 partitioner=lambda shape, dtype: [5, 1]) 310 311 # Initialize all variables in `new_scope_1` from `old_scope_1`. 312 init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1/'}) 313 314 # Use names to specify which variables to initialize from checkpoint. 315 init_from_checkpoint('/tmp/model.ckpt', 316 {'old_scope_1/var1': 'new_scope_1/var1', 317 'old_scope_1/var2': 'new_scope_2/var2'}) 318 319 # Or use tf.Variable objects to identify what to initialize. 320 init_from_checkpoint('/tmp/model.ckpt', 321 {'old_scope_1/var1': var1, 322 'old_scope_1/var2': var2}) 323 324 # Initialize partitioned variables using variable's name 325 init_from_checkpoint('/tmp/model.ckpt', 326 {'old_scope_2/var3': 'new_scope_2/var3'}) 327 328 # Or specify the list of tf.Variable objects. 329 init_from_checkpoint('/tmp/model.ckpt', 330 {'old_scope_2/var3': var3._get_variable_list()}) 331 332 ``` 333 334 Args: 335 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 336 assignment_map: Dict, or a list of key-value pairs, where keys are names 337 of the variables in the checkpoint and values are current variables or 338 names of current variables (in default graph). 339 340 Raises: 341 ValueError: If missing variables in current graph, or if missing 342 checkpoints or tensors in checkpoints. 343 344 """ 345 init_from_checkpoint_fn = lambda _: _init_from_checkpoint( 346 ckpt_dir_or_file, assignment_map) 347 if distribution_strategy_context.get_cross_replica_context(): 348 init_from_checkpoint_fn(None) 349 else: 350 distribution_strategy_context.get_replica_context().merge_call( 351 init_from_checkpoint_fn) 352 353 354def _init_from_checkpoint(ckpt_dir_or_file, assignment_map): 355 """See `init_from_checkpoint` for documentation.""" 356 ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file) 357 reader = load_checkpoint(ckpt_dir_or_file) 358 variable_map = reader.get_variable_to_shape_map() 359 if isinstance(assignment_map, abc.Mapping): 360 assignment_map = assignment_map.items() 361 362 # We only want to sort by tensor names. 363 sort_key = lambda pair: pair[0] 364 365 for tensor_name_in_ckpt, current_var_or_name in sorted( 366 assignment_map, key=sort_key): 367 var = None 368 # Check if this is Variable object or list of Variable objects (in case of 369 # partitioned variables). 370 if _is_variable(current_var_or_name) or ( 371 isinstance(current_var_or_name, list) 372 and all(_is_variable(v) for v in current_var_or_name)): 373 var = current_var_or_name 374 else: 375 store_vars = vs._get_default_variable_store()._vars # pylint:disable=protected-access 376 # Check if this variable is in var_store. 377 var = store_vars.get(current_var_or_name, None) 378 # Also check if variable is partitioned as list. 379 if var is None: 380 var = _collect_partitioned_variable(current_var_or_name, store_vars) 381 if var is not None: 382 # If 1 to 1 mapping was provided, find variable in the checkpoint. 383 if tensor_name_in_ckpt not in variable_map: 384 raise ValueError("Tensor %s is not found in %s checkpoint %s" % ( 385 tensor_name_in_ckpt, ckpt_dir_or_file, variable_map 386 )) 387 if _is_variable(var): 388 # Additional at-call-time checks. 389 if not var.get_shape().is_compatible_with( 390 variable_map[tensor_name_in_ckpt]): 391 raise ValueError( 392 "Shape of variable %s (%s) doesn't match with shape of " 393 "tensor %s (%s) from checkpoint reader." % ( 394 var.name, str(var.get_shape()), 395 tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt]) 396 )) 397 var_name = var.name 398 else: 399 var_name = ",".join(v.name for v in var) 400 _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt) 401 logging.debug("Initialize variable %s from checkpoint %s with %s", 402 var_name, ckpt_dir_or_file, tensor_name_in_ckpt) 403 else: 404 scopes = "" 405 # TODO(vihanjain): Support list of 'current_var_or_name' here. 406 if "/" in current_var_or_name: 407 scopes = current_var_or_name[:current_var_or_name.rindex("/")] 408 if not tensor_name_in_ckpt.endswith("/"): 409 raise ValueError( 410 "Assignment map with scope only name {} should map to scope only " 411 "{}. Should be 'scope/': 'other_scope/'.".format( 412 scopes, tensor_name_in_ckpt)) 413 # If scope to scope mapping was provided, find all variables in the scope 414 # and create variable to variable mapping. 415 scope_variables = set() 416 for var_name in store_vars: 417 if not scopes or var_name.startswith(scopes + "/"): 418 # Consume /part_ if partitioned variable. 419 if "/part_" in var_name: 420 var_name = var_name[:var_name.index("/part_")] 421 scope_variables.add(var_name) 422 for var_name in sorted(scope_variables): 423 # Lookup name with specified prefix and suffix from current variable. 424 # If tensor_name given is '/' (root), don't use it for full name. 425 full_tensor_name = var_name[len(scopes):] 426 if current_var_or_name != "/": 427 full_tensor_name = full_tensor_name[1:] 428 if tensor_name_in_ckpt != "/": 429 full_tensor_name = tensor_name_in_ckpt + full_tensor_name 430 # Remove trailing '/', if any, in the full_tensor_name 431 if full_tensor_name.endswith("/"): 432 full_tensor_name = full_tensor_name[:-1] 433 if full_tensor_name not in variable_map: 434 raise ValueError( 435 "Tensor %s (%s in %s) is not found in %s checkpoint" % ( 436 full_tensor_name, var_name[len(scopes) + 1:], 437 tensor_name_in_ckpt, ckpt_dir_or_file 438 )) 439 var = store_vars.get(var_name, None) 440 if var is None: 441 var = _collect_partitioned_variable(var_name, store_vars) 442 _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name) 443 logging.debug("Initialize variable %s from checkpoint %s with %s", 444 var_name, ckpt_dir_or_file, full_tensor_name) 445 446 447def _get_checkpoint_filename(ckpt_dir_or_file): 448 """Returns checkpoint filename given directory or specific checkpoint file.""" 449 if isinstance(ckpt_dir_or_file, os.PathLike): 450 ckpt_dir_or_file = os.fspath(ckpt_dir_or_file) 451 if gfile.IsDirectory(ckpt_dir_or_file): 452 return checkpoint_management.latest_checkpoint(ckpt_dir_or_file) 453 return ckpt_dir_or_file 454 455 456def _set_checkpoint_initializer(variable, 457 ckpt_file, 458 tensor_name, 459 slice_spec, 460 name="checkpoint_initializer"): 461 """Overrides given variable's initialization op. 462 463 Sets variable initializer to assign op that initializes variable from tensor's 464 value in the checkpoint. 465 466 Args: 467 variable: `tf.Variable` object. 468 ckpt_file: string, full path of the checkpoint. 469 tensor_name: Name of the tensor to load from the checkpoint. 470 slice_spec: Slice specification for loading partitioned tensors. 471 name: Name of the operation. 472 """ 473 base_type = variable.dtype.base_dtype 474 # Do not colocate with variable since RestoreV2 op only runs on CPU and 475 # colocation will force variable (and other ops that colocate with variable) 476 # to be on CPU as well. It is okay to place the variable's initializer op on 477 # CPU since it will only be run once at the start. 478 with ops.device(variable.device), ops.device("/cpu:0"): 479 restore_op = io_ops.restore_v2( 480 ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] 481 482 names_to_saveables = saveable_object_util.op_list_to_dict([variable]) 483 saveable_objects = [] 484 for name, op in names_to_saveables.items(): 485 for s in saveable_object_util.saveable_objects_for_op(op, name): 486 saveable_objects.append(s) 487 488 assert len(saveable_objects) == 1 # Should be only one variable. 489 init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) 490 491 # pylint:disable=protected-access 492 variable._initializer_op = init_op 493 restore_op.set_shape(variable.shape) 494 variable._initial_value = restore_op 495 # pylint:enable=protected-access 496 497 498def _set_variable_or_list_initializer(variable_or_list, ckpt_file, 499 tensor_name): 500 """Overrides initialization op of given variable or list of variables. 501 502 Calls `_set_checkpoint_initializer` for each variable in the given list of 503 variables. 504 505 Args: 506 variable_or_list: `tf.Variable` object or a list of `tf.Variable` objects. 507 ckpt_file: string, full path of the checkpoint. 508 tensor_name: Name of the tensor to load from the checkpoint. 509 510 Raises: 511 ValueError: if all objects in `variable_or_list` are not partitions of the 512 same large variable. 513 """ 514 if isinstance(variable_or_list, (list, tuple)): 515 # A set of slices. 516 slice_name = None 517 for v in variable_or_list: 518 slice_info = v._save_slice_info # pylint:disable=protected-access 519 if slice_name is None: 520 slice_name = slice_info.full_name 521 elif slice_name != slice_info.full_name: 522 raise ValueError("Slices must all be from the same tensor: %s != %s" % 523 (slice_name, slice_info.full_name)) 524 _set_checkpoint_initializer(v, ckpt_file, tensor_name, slice_info.spec) 525 else: 526 _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "") 527 528 529def _is_variable(x): 530 return (isinstance(x, variables.Variable) or 531 resource_variable_ops.is_resource_variable(x)) 532 533 534def _collect_partitioned_variable(name, all_vars): 535 """Returns list of `tf.Variable` that comprise the partitioned variable.""" 536 if name + "/part_0" in all_vars: 537 var = [] 538 i = 0 539 while name + "/part_%d" % i in all_vars: 540 var.append(all_vars[name + "/part_%d" % i]) 541 i += 1 542 return var 543 return None 544