xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/functional_saver.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Saves and restore variables inside traced @tf.functions."""
16
17from tensorflow.core.protobuf import saver_pb2
18from tensorflow.python.checkpoint import checkpoint_options
19from tensorflow.python.eager import context
20from tensorflow.python.eager import def_function
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_spec
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_io_ops
28from tensorflow.python.ops import io_ops
29from tensorflow.python.ops import string_ops
30from tensorflow.python.saved_model import registration
31from tensorflow.python.training.saving import saveable_object
32from tensorflow.python.training.saving import saveable_object_util
33from tensorflow.python.util import nest
34
35
36class _SingleDeviceSaver(object):
37  """Saves and restores checkpoints from the current device."""
38
39  __slots__ = ["_tensor_slice_dict"]
40
41  def __init__(self, tensor_slice_dict):
42    """Specify a list of `SaveableObject`s to save and restore.
43
44    Args:
45      tensor_slice_dict: A dict mapping checkpoint key -> slice_spec -> tensor.
46    """
47    self._tensor_slice_dict = tensor_slice_dict
48
49  def save(self, file_prefix, options=None):
50    """Save the saveable objects to a checkpoint with `file_prefix`.
51
52    Args:
53      file_prefix: A string or scalar string Tensor containing the prefix to
54        save under.
55      options: Optional `CheckpointOptions` object.
56    Returns:
57      An `Operation`, or None when executing eagerly.
58    """
59    options = options or checkpoint_options.CheckpointOptions()
60    tensor_names = []
61    tensors = []
62    slice_specs = []
63    for checkpoint_key, tensor_slices in self._tensor_slice_dict.items():
64      for slice_spec, tensor in tensor_slices.items():
65        if isinstance(tensor, saveable_object.SaveSpec):
66          tensor_value = tensor.tensor
67          # A tensor value of `None` indicates that this SaveableObject gets
68          # recorded in the object graph, but that no value is saved in the
69          # checkpoint.
70          if tensor_value is not None:
71            tensor_names.append(tensor.name)
72            tensors.append(tensor_value)
73            slice_specs.append(tensor.slice_spec)
74        else:
75          tensor_names.append(checkpoint_key)
76          tensors.append(tensor)
77          slice_specs.append(slice_spec)
78    save_device = options.experimental_io_device or "cpu:0"
79    with ops.device(save_device):
80      return io_ops.save_v2(file_prefix, tensor_names, slice_specs, tensors)
81
82  def restore(self, file_prefix, options=None):
83    """Restore the saveable objects from a checkpoint with `file_prefix`.
84
85    Args:
86      file_prefix: A string or scalar string Tensor containing the prefix for
87        files to read from.
88      options: Optional `CheckpointOptions` object.
89
90    Returns:
91      A restored tensor dict (maps checkpoint_key -> slice_spec -> tensor).
92    """
93    options = options or checkpoint_options.CheckpointOptions()
94    tensor_names = []
95    tensor_dtypes = []
96    slice_specs = []
97
98    for checkpoint_key, tensor_slices in self._tensor_slice_dict.items():
99      for slice_spec, tensor in tensor_slices.items():
100        tensor_dtypes.append(tensor.dtype)
101        if isinstance(tensor, saveable_object.SaveSpec):
102          slice_specs.append(tensor.slice_spec)
103          tensor_names.append(tensor.name)
104        else:
105          slice_specs.append(slice_spec)
106          tensor_names.append(checkpoint_key)
107
108    restore_device = options.experimental_io_device or "cpu:0"
109    with ops.device(restore_device):
110      restored_tensors = io_ops.restore_v2(
111          file_prefix, tensor_names, slice_specs, tensor_dtypes)
112
113    restored_tensor_dict = {}
114    for checkpoint_key, tensor_slices in self._tensor_slice_dict.items():
115      for slice_spec in tensor_slices:
116        restored_tensor = restored_tensors.pop(0)
117        restored_tensor_dict.setdefault(checkpoint_key, {})[slice_spec] = (
118            restored_tensor)
119    return restored_tensor_dict
120
121
122def sharded_filename(filename_tensor, shard, num_shards):
123  """Append sharding information to a filename.
124
125  Args:
126    filename_tensor: A string tensor.
127    shard: Integer.  The shard for the filename.
128    num_shards: An int Tensor for the number of shards.
129
130  Returns:
131    A string tensor.
132  """
133  return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
134
135
136def registered_saver_filename(filename_tensor, saver_name):
137  return string_ops.string_join(
138      [filename_tensor, constant_op.constant(f"-{saver_name}")])
139
140
141def _get_mapped_registered_save_fn(fn, trackables, call_with_mapped_captures):
142  """Converts the function to a python or tf.function with a single file arg."""
143
144  def save_fn(file_prefix):
145    return fn(trackables=trackables, file_prefix=file_prefix)
146  if call_with_mapped_captures is None:
147    return save_fn
148  else:
149    tf_fn = def_function.function(save_fn, autograph=False)
150    concrete = tf_fn.get_concrete_function(
151        file_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string))
152
153    def save_fn_with_replaced_captures(file_prefix):
154      return call_with_mapped_captures(concrete, [file_prefix])
155
156    return save_fn_with_replaced_captures
157
158
159def _get_mapped_registered_restore_fn(fn, trackables,
160                                      call_with_mapped_captures):
161  """Converts the function to a python or tf.function with a single file arg."""
162
163  def restore_fn(merged_prefix):
164    return fn(trackables=trackables, merged_prefix=merged_prefix)
165  if call_with_mapped_captures is None:
166    return restore_fn
167  else:
168    tf_fn = def_function.function(restore_fn, autograph=False)
169    concrete = tf_fn.get_concrete_function(
170        merged_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string))
171
172    def restore_fn_with_replaced_captures(merged_prefix):
173      return call_with_mapped_captures(concrete, [merged_prefix])
174
175    return restore_fn_with_replaced_captures
176
177
178class MultiDeviceSaver(object):
179  """Saves checkpoints directly from multiple devices.
180
181  Note that this is a low-level utility which stores Tensors in the keys
182  specified by `SaveableObject`s. Higher-level utilities for object-based
183  checkpointing are built on top of it.
184  """
185
186  def __init__(self,
187               saveable_objects,
188               registered_savers=None,
189               call_with_mapped_captures=None):
190    """Specify a list of `SaveableObject`s to save and restore.
191
192    Args:
193      saveable_objects: A list of `SaveableObject`s.
194        Objects extending `SaveableObject` will be saved and restored.
195      registered_savers: A dictionary mapping `registration.RegisteredSaver`
196        namedtuples to a dictionary of named Trackables. The keys of the
197        Trackable dictionary are string names that uniquely identify the
198        Trackable in the checkpoint.
199      call_with_mapped_captures: TODO
200    """
201    saveable_objects = list(saveable_objects)
202
203    # Keep these two data structures so that we can map restored tensors to
204    # the Trackable restore functions.
205    self._keys_to_restore_fn = {}
206    self._restore_fn_to_keys = {}
207
208    # Extract serialized tensors and separate by device.
209    tensors_by_device = {}  # device -> checkpoint key -> (slice_spec ->) tensor
210    for saveable in saveable_objects:
211      tensor_dict = saveable_object_util.saveable_object_to_tensor_dict(
212          [saveable])
213      restore_fn = saveable_object_util.saveable_object_to_restore_fn(
214          [saveable])
215
216      # Divide tensor_dict by device.
217      for checkpoint_key, maybe_tensor in tensor_dict.items():
218        if not isinstance(maybe_tensor, dict):
219          # Make sure that maybe_tensor is structured as {slice_spec -> tensor}.
220          maybe_tensor = {"": maybe_tensor}
221
222        for slice_spec, tensor in maybe_tensor.items():
223          if (checkpoint_key, slice_spec) in self._keys_to_restore_fn:
224            raise ValueError(
225                "Recieved multiple tensors with the same checkpoint key and "
226                "slice spec. This is invalid because one will overwrite the "
227                "other in the checkpoint. This indicates a bug in the "
228                "Checkpoint key-generation.")
229          self._keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn
230          self._restore_fn_to_keys.setdefault(restore_fn, []).append(
231              (checkpoint_key, slice_spec))
232
233          host_device = saveable_object_util.set_cpu0(tensor.device)
234          (tensors_by_device
235           .setdefault(host_device, {})
236           .setdefault(checkpoint_key, {})[slice_spec]) = tensor
237    self._single_device_savers = {
238        device: _SingleDeviceSaver(tensor_slice_dict)
239        for device, tensor_slice_dict in tensors_by_device.items()}
240
241    self._registered_savers = {}
242    if registered_savers:
243      for registered_name, trackables in registered_savers.items():
244        save_fn = _get_mapped_registered_save_fn(
245            registration.get_save_function(registered_name),
246            trackables, call_with_mapped_captures)
247        restore_fn = _get_mapped_registered_restore_fn(
248            registration.get_restore_function(registered_name),
249            trackables, call_with_mapped_captures)
250        self._registered_savers[registered_name] = (save_fn, restore_fn)
251
252  def to_proto(self):
253    """Serializes to a SaverDef referencing the current graph."""
254    filename_tensor = array_ops.placeholder(
255        shape=[], dtype=dtypes.string, name="saver_filename")
256    save_tensor = self._traced_save(filename_tensor)
257    restore_op = self._traced_restore(filename_tensor).op
258    return saver_pb2.SaverDef(
259        filename_tensor_name=filename_tensor.name,
260        save_tensor_name=save_tensor.name,
261        restore_op_name=restore_op.name,
262        version=saver_pb2.SaverDef.V2)
263
264  @def_function.function(
265      input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),),
266      autograph=False)
267  def _traced_save(self, file_prefix):
268    save_op = self.save(file_prefix)
269    with ops.device("cpu:0"):
270      with ops.control_dependencies([save_op]):
271        return array_ops.identity(file_prefix)
272
273  @def_function.function(
274      input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),),
275      autograph=False)
276  def _traced_restore(self, file_prefix):
277    restore_ops = self.restore(file_prefix)
278    with ops.device("cpu:0"):
279      with ops.control_dependencies(restore_ops.values()):
280        return array_ops.identity(file_prefix)
281
282  def save(self, file_prefix, options=None):
283    """Save the saveable objects to a checkpoint with `file_prefix`.
284
285    Args:
286      file_prefix: A string or scalar string Tensor containing the prefix to
287        save under.
288      options: Optional `CheckpointOptions` object.
289    Returns:
290      An `Operation`, or None when executing eagerly.
291    """
292    options = options or checkpoint_options.CheckpointOptions()
293
294    # IMPLEMENTATION DETAILS: most clients should skip.
295    #
296    # Suffix for any well-formed "checkpoint_prefix", when sharded.
297    # Transformations:
298    # * Users pass in "save_path" in save() and restore().  Say "myckpt".
299    # * checkpoint_prefix gets fed <save_path><sharded_suffix>.
300    #
301    # Example:
302    #   During runtime, a temporary directory is first created, which contains
303    #   files
304    #
305    #     <train dir>/myckpt_temp/
306    #        part-?????-of-?????{.index, .data-00000-of-00001}
307    #
308    #   Before .save() finishes, they will be (hopefully, atomically) renamed to
309    #
310    #     <train dir>/
311    #        myckpt{.index, .data-?????-of-?????}
312    #
313    #   Filesystems with eventual consistency (such as S3), don't need a
314    #   temporary location. Using a temporary directory in those cases might
315    #   cause situations where files are not available during copy.
316    #
317    # Users only need to interact with the user-specified prefix, which is
318    # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
319    # prefix directly, instead of any physical pathname.  (On failure and
320    # subsequent restore, an outdated and orphaned temporary directory can be
321    # safely removed.)
322    with ops.device("CPU"):
323      sharded_suffix = array_ops.where(
324          string_ops.regex_full_match(file_prefix, "^s3://.*"),
325          constant_op.constant(".part"),
326          constant_op.constant("_temp/part"))
327      tmp_checkpoint_prefix = string_ops.string_join(
328          [file_prefix, sharded_suffix])
329      registered_paths = {
330          saver_name: registered_saver_filename(file_prefix, saver_name)
331          for saver_name in self._registered_savers
332      }
333
334    def save_fn():
335      saved_prefixes = []
336      # Save with the registered savers. These run before default savers due to
337      # the API contract.
338      for saver_name, (save_fn, _) in self._registered_savers.items():
339        maybe_saved_prefixes = save_fn(registered_paths[saver_name])
340        if maybe_saved_prefixes is not None:
341          flattened_saved_prefixes = nest.flatten(maybe_saved_prefixes)
342          if not all(
343              tensor_util.is_tf_type(x) and x.dtype == dtypes.string
344              for x in flattened_saved_prefixes):
345            raise ValueError(
346                "Registered saver must return a (maybe empty) list of "
347                f"string type tensors. Got {maybe_saved_prefixes}.")
348          saved_prefixes.extend(flattened_saved_prefixes)
349
350      # (Default saver) Save with single device savers.
351      num_shards = len(self._single_device_savers)
352      sharded_saves = []
353      num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
354      last_device = None
355      for shard, (device, saver) in enumerate(
356          sorted(self._single_device_savers.items())):
357        last_device = device
358        with ops.device(saveable_object_util.set_cpu0(device)):
359          shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
360                                          num_shards_tensor)
361        saved_prefixes.append(shard_prefix)
362        with ops.device(device):
363          # _SingleDeviceSaver will use the CPU device when necessary, but
364          # initial read operations should be placed on the SaveableObject's
365          # device.
366          sharded_saves.append(saver.save(shard_prefix, options))
367
368      with ops.control_dependencies(sharded_saves):
369        # Merge on the io_device if specified, otherwise co-locates the merge op
370        # with the last device used.
371        merge_device = (
372            options.experimental_io_device or
373            saveable_object_util.set_cpu0(last_device))
374        with ops.device(merge_device):
375          # V2 format write path consists of a metadata merge step.  Once
376          # merged, attempts to delete the temporary directory,
377          # "<user-fed prefix>_temp".
378          return gen_io_ops.merge_v2_checkpoints(
379              saved_prefixes, file_prefix, delete_old_dirs=True)
380
381    # Since this will causes a function re-trace on each save, limit this to the
382    # cases where it is needed: eager and when there are multiple tasks/single
383    # device savers. Note that the retrace is needed to ensure we pickup the
384    # latest values of options like experimental_io_device.
385    if context.executing_eagerly() and len(self._single_device_savers) > 1:
386      # Explicitly place the identity op on the first device.
387      @def_function.function(jit_compile=False)
388      def tf_function_save():
389        save_fn()
390      tf_function_save()
391    else:
392      return save_fn()
393
394  def restore(self, file_prefix, options=None):
395    """Restore the saveable objects from a checkpoint with `file_prefix`.
396
397    Args:
398      file_prefix: A string or scalar string Tensor containing the prefix for
399        files to read from.
400      options: Optional `CheckpointOptions` object.
401
402    Returns:
403      When not run eagerly or when saving on a single device, returns a
404      dictionary mapping from SaveableObject names to restore operations;
405      otherwise, returns an empty dict.
406    """
407    options = options or checkpoint_options.CheckpointOptions()
408
409    def restore_fn():
410      restore_fn_inputs = {}
411      restore_fn_input_count = {
412          fn: len(keys) for fn, keys in self._restore_fn_to_keys.items()}
413
414      restore_ops = {}
415      # Sort by device name to avoid propagating non-deterministic dictionary
416      # ordering in some Python versions.
417      for device, saver in sorted(self._single_device_savers.items()):
418        with ops.device(device):
419          # Load values from checkpoint
420          restored_tensor_dict = saver.restore(file_prefix, options)
421
422          # Map restored tensors to the corresponding restore_fn, and see if all
423          # inputs have all been loaded. Call `restore_fn` if that is the case.
424          for checkpoint_key, slice_and_tensor in restored_tensor_dict.items():
425            for slice_spec, tensor in slice_and_tensor.items():
426              restore_fn = self._keys_to_restore_fn[(checkpoint_key,
427                                                     slice_spec)]
428              (restore_fn_inputs
429               .setdefault(restore_fn, {})
430               .setdefault(checkpoint_key, {})[slice_spec]) = tensor
431              restore_fn_input_count[restore_fn] -= 1
432
433              if restore_fn_input_count[restore_fn] == 0:
434                ret = restore_fn(restore_fn_inputs[restore_fn])
435                if isinstance(ret, dict):
436                  restore_ops.update(ret)
437      # Run registered restore methods after the default restore ops.
438      for _, (_, restore_fn) in self._registered_savers.items():
439        restore_fn(file_prefix)
440      return restore_ops
441
442    restore_device = options.experimental_io_device or "cpu:0"
443
444    # Since this will causes a function re-trace on each restore, limit this to
445    # cases where it is needed: eager and when there are multiple tasks/single
446    # device savers. Note that the retrace is needed to ensure we pickup the
447    # latest values of options like experimental_io_device.
448    if context.executing_eagerly() and (len(self._single_device_savers) > 1 or
449                                        options.experimental_io_device):
450      @def_function.function(jit_compile=False)
451      def tf_function_restore():
452        restore_fn()
453        return {}
454
455      with ops.device(restore_device):
456        restore_ops = tf_function_restore()
457    else:
458      restore_ops = restore_fn()
459
460    return restore_ops
461