1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Saves and restore variables inside traced @tf.functions.""" 16 17from tensorflow.core.protobuf import saver_pb2 18from tensorflow.python.checkpoint import checkpoint_options 19from tensorflow.python.eager import context 20from tensorflow.python.eager import def_function 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_spec 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import gen_io_ops 28from tensorflow.python.ops import io_ops 29from tensorflow.python.ops import string_ops 30from tensorflow.python.saved_model import registration 31from tensorflow.python.training.saving import saveable_object 32from tensorflow.python.training.saving import saveable_object_util 33from tensorflow.python.util import nest 34 35 36class _SingleDeviceSaver(object): 37 """Saves and restores checkpoints from the current device.""" 38 39 __slots__ = ["_tensor_slice_dict"] 40 41 def __init__(self, tensor_slice_dict): 42 """Specify a list of `SaveableObject`s to save and restore. 43 44 Args: 45 tensor_slice_dict: A dict mapping checkpoint key -> slice_spec -> tensor. 46 """ 47 self._tensor_slice_dict = tensor_slice_dict 48 49 def save(self, file_prefix, options=None): 50 """Save the saveable objects to a checkpoint with `file_prefix`. 51 52 Args: 53 file_prefix: A string or scalar string Tensor containing the prefix to 54 save under. 55 options: Optional `CheckpointOptions` object. 56 Returns: 57 An `Operation`, or None when executing eagerly. 58 """ 59 options = options or checkpoint_options.CheckpointOptions() 60 tensor_names = [] 61 tensors = [] 62 slice_specs = [] 63 for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): 64 for slice_spec, tensor in tensor_slices.items(): 65 if isinstance(tensor, saveable_object.SaveSpec): 66 tensor_value = tensor.tensor 67 # A tensor value of `None` indicates that this SaveableObject gets 68 # recorded in the object graph, but that no value is saved in the 69 # checkpoint. 70 if tensor_value is not None: 71 tensor_names.append(tensor.name) 72 tensors.append(tensor_value) 73 slice_specs.append(tensor.slice_spec) 74 else: 75 tensor_names.append(checkpoint_key) 76 tensors.append(tensor) 77 slice_specs.append(slice_spec) 78 save_device = options.experimental_io_device or "cpu:0" 79 with ops.device(save_device): 80 return io_ops.save_v2(file_prefix, tensor_names, slice_specs, tensors) 81 82 def restore(self, file_prefix, options=None): 83 """Restore the saveable objects from a checkpoint with `file_prefix`. 84 85 Args: 86 file_prefix: A string or scalar string Tensor containing the prefix for 87 files to read from. 88 options: Optional `CheckpointOptions` object. 89 90 Returns: 91 A restored tensor dict (maps checkpoint_key -> slice_spec -> tensor). 92 """ 93 options = options or checkpoint_options.CheckpointOptions() 94 tensor_names = [] 95 tensor_dtypes = [] 96 slice_specs = [] 97 98 for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): 99 for slice_spec, tensor in tensor_slices.items(): 100 tensor_dtypes.append(tensor.dtype) 101 if isinstance(tensor, saveable_object.SaveSpec): 102 slice_specs.append(tensor.slice_spec) 103 tensor_names.append(tensor.name) 104 else: 105 slice_specs.append(slice_spec) 106 tensor_names.append(checkpoint_key) 107 108 restore_device = options.experimental_io_device or "cpu:0" 109 with ops.device(restore_device): 110 restored_tensors = io_ops.restore_v2( 111 file_prefix, tensor_names, slice_specs, tensor_dtypes) 112 113 restored_tensor_dict = {} 114 for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): 115 for slice_spec in tensor_slices: 116 restored_tensor = restored_tensors.pop(0) 117 restored_tensor_dict.setdefault(checkpoint_key, {})[slice_spec] = ( 118 restored_tensor) 119 return restored_tensor_dict 120 121 122def sharded_filename(filename_tensor, shard, num_shards): 123 """Append sharding information to a filename. 124 125 Args: 126 filename_tensor: A string tensor. 127 shard: Integer. The shard for the filename. 128 num_shards: An int Tensor for the number of shards. 129 130 Returns: 131 A string tensor. 132 """ 133 return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards) 134 135 136def registered_saver_filename(filename_tensor, saver_name): 137 return string_ops.string_join( 138 [filename_tensor, constant_op.constant(f"-{saver_name}")]) 139 140 141def _get_mapped_registered_save_fn(fn, trackables, call_with_mapped_captures): 142 """Converts the function to a python or tf.function with a single file arg.""" 143 144 def save_fn(file_prefix): 145 return fn(trackables=trackables, file_prefix=file_prefix) 146 if call_with_mapped_captures is None: 147 return save_fn 148 else: 149 tf_fn = def_function.function(save_fn, autograph=False) 150 concrete = tf_fn.get_concrete_function( 151 file_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string)) 152 153 def save_fn_with_replaced_captures(file_prefix): 154 return call_with_mapped_captures(concrete, [file_prefix]) 155 156 return save_fn_with_replaced_captures 157 158 159def _get_mapped_registered_restore_fn(fn, trackables, 160 call_with_mapped_captures): 161 """Converts the function to a python or tf.function with a single file arg.""" 162 163 def restore_fn(merged_prefix): 164 return fn(trackables=trackables, merged_prefix=merged_prefix) 165 if call_with_mapped_captures is None: 166 return restore_fn 167 else: 168 tf_fn = def_function.function(restore_fn, autograph=False) 169 concrete = tf_fn.get_concrete_function( 170 merged_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string)) 171 172 def restore_fn_with_replaced_captures(merged_prefix): 173 return call_with_mapped_captures(concrete, [merged_prefix]) 174 175 return restore_fn_with_replaced_captures 176 177 178class MultiDeviceSaver(object): 179 """Saves checkpoints directly from multiple devices. 180 181 Note that this is a low-level utility which stores Tensors in the keys 182 specified by `SaveableObject`s. Higher-level utilities for object-based 183 checkpointing are built on top of it. 184 """ 185 186 def __init__(self, 187 saveable_objects, 188 registered_savers=None, 189 call_with_mapped_captures=None): 190 """Specify a list of `SaveableObject`s to save and restore. 191 192 Args: 193 saveable_objects: A list of `SaveableObject`s. 194 Objects extending `SaveableObject` will be saved and restored. 195 registered_savers: A dictionary mapping `registration.RegisteredSaver` 196 namedtuples to a dictionary of named Trackables. The keys of the 197 Trackable dictionary are string names that uniquely identify the 198 Trackable in the checkpoint. 199 call_with_mapped_captures: TODO 200 """ 201 saveable_objects = list(saveable_objects) 202 203 # Keep these two data structures so that we can map restored tensors to 204 # the Trackable restore functions. 205 self._keys_to_restore_fn = {} 206 self._restore_fn_to_keys = {} 207 208 # Extract serialized tensors and separate by device. 209 tensors_by_device = {} # device -> checkpoint key -> (slice_spec ->) tensor 210 for saveable in saveable_objects: 211 tensor_dict = saveable_object_util.saveable_object_to_tensor_dict( 212 [saveable]) 213 restore_fn = saveable_object_util.saveable_object_to_restore_fn( 214 [saveable]) 215 216 # Divide tensor_dict by device. 217 for checkpoint_key, maybe_tensor in tensor_dict.items(): 218 if not isinstance(maybe_tensor, dict): 219 # Make sure that maybe_tensor is structured as {slice_spec -> tensor}. 220 maybe_tensor = {"": maybe_tensor} 221 222 for slice_spec, tensor in maybe_tensor.items(): 223 if (checkpoint_key, slice_spec) in self._keys_to_restore_fn: 224 raise ValueError( 225 "Recieved multiple tensors with the same checkpoint key and " 226 "slice spec. This is invalid because one will overwrite the " 227 "other in the checkpoint. This indicates a bug in the " 228 "Checkpoint key-generation.") 229 self._keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn 230 self._restore_fn_to_keys.setdefault(restore_fn, []).append( 231 (checkpoint_key, slice_spec)) 232 233 host_device = saveable_object_util.set_cpu0(tensor.device) 234 (tensors_by_device 235 .setdefault(host_device, {}) 236 .setdefault(checkpoint_key, {})[slice_spec]) = tensor 237 self._single_device_savers = { 238 device: _SingleDeviceSaver(tensor_slice_dict) 239 for device, tensor_slice_dict in tensors_by_device.items()} 240 241 self._registered_savers = {} 242 if registered_savers: 243 for registered_name, trackables in registered_savers.items(): 244 save_fn = _get_mapped_registered_save_fn( 245 registration.get_save_function(registered_name), 246 trackables, call_with_mapped_captures) 247 restore_fn = _get_mapped_registered_restore_fn( 248 registration.get_restore_function(registered_name), 249 trackables, call_with_mapped_captures) 250 self._registered_savers[registered_name] = (save_fn, restore_fn) 251 252 def to_proto(self): 253 """Serializes to a SaverDef referencing the current graph.""" 254 filename_tensor = array_ops.placeholder( 255 shape=[], dtype=dtypes.string, name="saver_filename") 256 save_tensor = self._traced_save(filename_tensor) 257 restore_op = self._traced_restore(filename_tensor).op 258 return saver_pb2.SaverDef( 259 filename_tensor_name=filename_tensor.name, 260 save_tensor_name=save_tensor.name, 261 restore_op_name=restore_op.name, 262 version=saver_pb2.SaverDef.V2) 263 264 @def_function.function( 265 input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),), 266 autograph=False) 267 def _traced_save(self, file_prefix): 268 save_op = self.save(file_prefix) 269 with ops.device("cpu:0"): 270 with ops.control_dependencies([save_op]): 271 return array_ops.identity(file_prefix) 272 273 @def_function.function( 274 input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),), 275 autograph=False) 276 def _traced_restore(self, file_prefix): 277 restore_ops = self.restore(file_prefix) 278 with ops.device("cpu:0"): 279 with ops.control_dependencies(restore_ops.values()): 280 return array_ops.identity(file_prefix) 281 282 def save(self, file_prefix, options=None): 283 """Save the saveable objects to a checkpoint with `file_prefix`. 284 285 Args: 286 file_prefix: A string or scalar string Tensor containing the prefix to 287 save under. 288 options: Optional `CheckpointOptions` object. 289 Returns: 290 An `Operation`, or None when executing eagerly. 291 """ 292 options = options or checkpoint_options.CheckpointOptions() 293 294 # IMPLEMENTATION DETAILS: most clients should skip. 295 # 296 # Suffix for any well-formed "checkpoint_prefix", when sharded. 297 # Transformations: 298 # * Users pass in "save_path" in save() and restore(). Say "myckpt". 299 # * checkpoint_prefix gets fed <save_path><sharded_suffix>. 300 # 301 # Example: 302 # During runtime, a temporary directory is first created, which contains 303 # files 304 # 305 # <train dir>/myckpt_temp/ 306 # part-?????-of-?????{.index, .data-00000-of-00001} 307 # 308 # Before .save() finishes, they will be (hopefully, atomically) renamed to 309 # 310 # <train dir>/ 311 # myckpt{.index, .data-?????-of-?????} 312 # 313 # Filesystems with eventual consistency (such as S3), don't need a 314 # temporary location. Using a temporary directory in those cases might 315 # cause situations where files are not available during copy. 316 # 317 # Users only need to interact with the user-specified prefix, which is 318 # "<train dir>/myckpt" in this case. Save() and Restore() work with the 319 # prefix directly, instead of any physical pathname. (On failure and 320 # subsequent restore, an outdated and orphaned temporary directory can be 321 # safely removed.) 322 with ops.device("CPU"): 323 sharded_suffix = array_ops.where( 324 string_ops.regex_full_match(file_prefix, "^s3://.*"), 325 constant_op.constant(".part"), 326 constant_op.constant("_temp/part")) 327 tmp_checkpoint_prefix = string_ops.string_join( 328 [file_prefix, sharded_suffix]) 329 registered_paths = { 330 saver_name: registered_saver_filename(file_prefix, saver_name) 331 for saver_name in self._registered_savers 332 } 333 334 def save_fn(): 335 saved_prefixes = [] 336 # Save with the registered savers. These run before default savers due to 337 # the API contract. 338 for saver_name, (save_fn, _) in self._registered_savers.items(): 339 maybe_saved_prefixes = save_fn(registered_paths[saver_name]) 340 if maybe_saved_prefixes is not None: 341 flattened_saved_prefixes = nest.flatten(maybe_saved_prefixes) 342 if not all( 343 tensor_util.is_tf_type(x) and x.dtype == dtypes.string 344 for x in flattened_saved_prefixes): 345 raise ValueError( 346 "Registered saver must return a (maybe empty) list of " 347 f"string type tensors. Got {maybe_saved_prefixes}.") 348 saved_prefixes.extend(flattened_saved_prefixes) 349 350 # (Default saver) Save with single device savers. 351 num_shards = len(self._single_device_savers) 352 sharded_saves = [] 353 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 354 last_device = None 355 for shard, (device, saver) in enumerate( 356 sorted(self._single_device_savers.items())): 357 last_device = device 358 with ops.device(saveable_object_util.set_cpu0(device)): 359 shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, 360 num_shards_tensor) 361 saved_prefixes.append(shard_prefix) 362 with ops.device(device): 363 # _SingleDeviceSaver will use the CPU device when necessary, but 364 # initial read operations should be placed on the SaveableObject's 365 # device. 366 sharded_saves.append(saver.save(shard_prefix, options)) 367 368 with ops.control_dependencies(sharded_saves): 369 # Merge on the io_device if specified, otherwise co-locates the merge op 370 # with the last device used. 371 merge_device = ( 372 options.experimental_io_device or 373 saveable_object_util.set_cpu0(last_device)) 374 with ops.device(merge_device): 375 # V2 format write path consists of a metadata merge step. Once 376 # merged, attempts to delete the temporary directory, 377 # "<user-fed prefix>_temp". 378 return gen_io_ops.merge_v2_checkpoints( 379 saved_prefixes, file_prefix, delete_old_dirs=True) 380 381 # Since this will causes a function re-trace on each save, limit this to the 382 # cases where it is needed: eager and when there are multiple tasks/single 383 # device savers. Note that the retrace is needed to ensure we pickup the 384 # latest values of options like experimental_io_device. 385 if context.executing_eagerly() and len(self._single_device_savers) > 1: 386 # Explicitly place the identity op on the first device. 387 @def_function.function(jit_compile=False) 388 def tf_function_save(): 389 save_fn() 390 tf_function_save() 391 else: 392 return save_fn() 393 394 def restore(self, file_prefix, options=None): 395 """Restore the saveable objects from a checkpoint with `file_prefix`. 396 397 Args: 398 file_prefix: A string or scalar string Tensor containing the prefix for 399 files to read from. 400 options: Optional `CheckpointOptions` object. 401 402 Returns: 403 When not run eagerly or when saving on a single device, returns a 404 dictionary mapping from SaveableObject names to restore operations; 405 otherwise, returns an empty dict. 406 """ 407 options = options or checkpoint_options.CheckpointOptions() 408 409 def restore_fn(): 410 restore_fn_inputs = {} 411 restore_fn_input_count = { 412 fn: len(keys) for fn, keys in self._restore_fn_to_keys.items()} 413 414 restore_ops = {} 415 # Sort by device name to avoid propagating non-deterministic dictionary 416 # ordering in some Python versions. 417 for device, saver in sorted(self._single_device_savers.items()): 418 with ops.device(device): 419 # Load values from checkpoint 420 restored_tensor_dict = saver.restore(file_prefix, options) 421 422 # Map restored tensors to the corresponding restore_fn, and see if all 423 # inputs have all been loaded. Call `restore_fn` if that is the case. 424 for checkpoint_key, slice_and_tensor in restored_tensor_dict.items(): 425 for slice_spec, tensor in slice_and_tensor.items(): 426 restore_fn = self._keys_to_restore_fn[(checkpoint_key, 427 slice_spec)] 428 (restore_fn_inputs 429 .setdefault(restore_fn, {}) 430 .setdefault(checkpoint_key, {})[slice_spec]) = tensor 431 restore_fn_input_count[restore_fn] -= 1 432 433 if restore_fn_input_count[restore_fn] == 0: 434 ret = restore_fn(restore_fn_inputs[restore_fn]) 435 if isinstance(ret, dict): 436 restore_ops.update(ret) 437 # Run registered restore methods after the default restore ops. 438 for _, (_, restore_fn) in self._registered_savers.items(): 439 restore_fn(file_prefix) 440 return restore_ops 441 442 restore_device = options.experimental_io_device or "cpu:0" 443 444 # Since this will causes a function re-trace on each restore, limit this to 445 # cases where it is needed: eager and when there are multiple tasks/single 446 # device savers. Note that the retrace is needed to ensure we pickup the 447 # latest values of options like experimental_io_device. 448 if context.executing_eagerly() and (len(self._single_device_savers) > 1 or 449 options.experimental_io_device): 450 @def_function.function(jit_compile=False) 451 def tf_function_restore(): 452 restore_fn() 453 return {} 454 455 with ops.device(restore_device): 456 restore_ops = tf_function_restore() 457 else: 458 restore_ops = restore_fn() 459 460 return restore_ops 461