xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/save_util_v1.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Extracts tensors for checkpointing while updating a TrackableObjectGraph.
16
17This is labelled "v1" because the methods here use SaveableObject, which will
18soon be deprecated.
19"""
20
21import collections
22
23from tensorflow.core.protobuf import trackable_object_graph_pb2
24from tensorflow.python.checkpoint import saveable_compat
25from tensorflow.python.checkpoint import util
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.saved_model import registration
30from tensorflow.python.trackable import base
31from tensorflow.python.trackable import python_state
32from tensorflow.python.trackable import trackable_utils
33from tensorflow.python.training.saving import saveable_object as saveable_object_lib
34from tensorflow.python.training.saving import saveable_object_util
35from tensorflow.python.util import object_identity
36
37# Factory and related info used to build a SaveableObject that saves a Trackable
38# to checkpoint.
39_CheckpointFactoryData = collections.namedtuple(
40    "_CheckpointFactoryData", ["factory", "name", "checkpoint_key"])
41
42
43def get_checkpoint_factories_and_keys(object_names, object_map=None):
44  """Gets a map of saveable factories and corresponding checkpoint keys.
45
46  Args:
47    object_names: a dictionary that maps `Trackable` objects to auto-generated
48      string names.
49    object_map: a dictionary mapping `Trackable` to copied `Trackable` objects.
50      The copied objects are generated from `Trackable._map_resources()` which
51      copies the object into another graph. Generally only resource objects
52      (e.g. Variables, Tables) will be in this map.
53
54  Returns:
55    A tuple of (
56      Dictionary mapping trackable -> list of _CheckpointFactoryData,
57      Dictionary mapping registered saver name -> {object name -> trackable})
58  """
59  checkpoint_factory_map = object_identity.ObjectIdentityDictionary()
60  unmapped_registered_savers = collections.defaultdict(dict)
61  for trackable, object_name in object_names.items():
62    # object_to_save is only used to retrieve the saving functionality. For keys
63    # and other data, use the original `trackable`.
64    object_to_save = util.get_mapped_trackable(trackable, object_map)
65
66    saver_name = registration.get_registered_saver_name(object_to_save)
67    if saver_name:
68      # Add the original trackable instead of `object_to_save` to the returned
69      # dict because the original is needed for writing the object proto.
70      unmapped_registered_savers[saver_name][object_name] = trackable
71    else:
72      checkpoint_factory_map[trackable] = []
73      for name, saveable_factory in (
74          saveable_object_util.saveable_objects_from_trackable(
75              object_to_save).items()):  # pylint: disable=protected-access
76        # Retrieve the legacy saveable name (for compatibility purposes during
77        # SaveableObject deprecation)
78
79        key_suffix = saveable_compat.get_saveable_name(object_to_save) or name
80        checkpoint_key = trackable_utils.checkpoint_key(object_name, key_suffix)
81
82        if not saveable_compat.force_checkpoint_conversion_enabled():
83          # Make sure the set the name as the legacy saveable name if there
84          # is one (only when checkpoint conversion is diabled)
85          name = key_suffix
86
87        checkpoint_factory_map[trackable].append(
88            _CheckpointFactoryData(
89                factory=saveable_factory,
90                name=name,
91                checkpoint_key=checkpoint_key))
92  return checkpoint_factory_map, unmapped_registered_savers
93
94
95def _add_attributes_to_object_graph(trackable_objects, object_graph_proto,
96                                    node_ids, object_names, object_map,
97                                    call_with_mapped_captures, saveables_cache):
98  """Create saveables/savers and corresponding protos in the object graph."""
99  # The loop below creates TrackableObject protos in the TrackableObjectGraph,
100  # which are filled in the `_add_attributes_to_object_graph_for_*` methods.
101  for checkpoint_id, (trackable, unused_object_proto) in enumerate(
102      zip(trackable_objects, object_graph_proto.nodes)):
103    assert node_ids[trackable] == checkpoint_id
104
105  checkpoint_factory_map, unmapped_registered_savers = (
106      get_checkpoint_factories_and_keys(object_names, object_map))
107
108  # Add attributes, which describe what values are saved in checkpoint for
109  # this trackable.
110  registered_savers = _add_attributes_to_object_graph_for_registered_savers(
111      unmapped_registered_savers, object_graph_proto, node_ids, object_map)
112  named_saveable_objects, feed_additions = (
113      _add_attributes_to_object_graph_for_saveable_objects(
114          checkpoint_factory_map, object_graph_proto, node_ids, object_map,
115          call_with_mapped_captures, saveables_cache))
116  return named_saveable_objects, feed_additions, registered_savers
117
118
119def _add_attributes_to_object_graph_for_registered_savers(
120    unmapped_registered_savers, object_graph_proto, node_ids, object_map):
121  """Fills the object graph proto with data about the registered savers."""
122  registered_savers = collections.defaultdict(dict)
123  for saver_name, trackables in unmapped_registered_savers.items():
124    for object_name, trackable in trackables.items():
125      object_proto = object_graph_proto.nodes[node_ids[trackable]]
126      object_proto.registered_saver.name = saver_name
127      object_proto.registered_saver.object_name = object_name
128
129      object_to_save = util.get_mapped_trackable(trackable, object_map)
130      registered_savers[saver_name][object_name] = object_to_save
131  return registered_savers
132
133
134def _add_attributes_to_object_graph_for_saveable_objects(
135    checkpoint_factory_map, object_graph_proto, node_ids, object_map,
136    call_with_mapped_captures, saveables_cache):
137  """Create SaveableObjects and corresponding SerializedTensor protos."""
138  named_saveable_objects = []
139  if saveables_cache is None:
140    # No SaveableObject caching. Either we're executing eagerly, or building a
141    # static save which is specialized to the current Python state.
142    feed_additions = None
143  else:
144    # If we are caching SaveableObjects, we need to build up a feed_dict with
145    # functions computing volatile Python state to be saved with the
146    # checkpoint.
147    feed_additions = {}
148  for trackable, factory_data_list in checkpoint_factory_map.items():
149    object_proto = object_graph_proto.nodes[node_ids[trackable]]
150    object_to_save = util.get_mapped_trackable(trackable, object_map)
151    if saveables_cache is not None:
152      cached_attributes = saveables_cache.setdefault(object_to_save, {})
153    else:
154      cached_attributes = None
155
156    for factory_data in factory_data_list:
157      name = factory_data.name
158      key = factory_data.checkpoint_key
159      saveable_factory = factory_data.factory
160
161      # See if we can skip saving this checkpoint key.
162      saveables = cached_attributes.get(name) if cached_attributes else None
163      if saveables is not None:
164        for saveable in saveables:
165          if key not in saveable.name:
166            # The checkpoint key for this SaveableObject is different. We
167            # need to re-create it.
168            saveables = None
169            del cached_attributes[name]
170            break
171
172      if saveables is None:
173        if callable(saveable_factory):
174          maybe_saveable = saveable_object_util.create_saveable_object(
175              name, key, saveable_factory, call_with_mapped_captures)
176        else:
177          maybe_saveable = saveable_factory
178        if isinstance(maybe_saveable, saveable_object_lib.SaveableObject):
179          saveables = (maybe_saveable,)
180        else:
181          saveables = tuple(
182              saveable_object_util.saveable_objects_for_op(
183                  op=maybe_saveable, name=key))
184        for saveable in saveables:
185          if key not in saveable.name:
186            raise AssertionError(
187                f"The object {trackable} produced a SaveableObject with name "
188                f"'{saveable.name}' for attribute '{name}'. Expected a name"
189                f" containing '{key}'.")
190        if cached_attributes is not None:
191          cached_attributes[name] = saveables
192
193      if isinstance(object_to_save, python_state.PythonState):
194        assert len(saveables) == 1
195        saveable = saveables[0]
196
197        if feed_additions is None:
198          assert saveables_cache is None
199          # If we're not caching saveables, then we're either executing
200          # eagerly or building a static save/restore (e.g. for a
201          # SavedModel). In either case, we should embed the current Python
202          # state in the graph rather than relying on a feed dict.
203          saveables = (saveable.freeze(),)
204        else:
205          feed_additions.update(saveable.feed_dict_additions())
206      named_saveable_objects.extend(saveables)
207
208      # Update the object proto.
209      # For updated Trackables that override serialize_to_tensors, add an
210      # attribute for each tensor that is serialized.
211      # For Trackables that have SaveableObjects or a legacy saveable name,
212      # add a single attribute to the proto.
213      if (isinstance(saveables[0], saveable_object_util.TrackableSaveable) and
214          (saveable_compat.force_checkpoint_conversion_enabled() or
215           saveable_compat.get_saveable_name(object_to_save) is None)):
216        for local_name, local_key in (
217            saveables[0].get_proto_names_and_checkpoint_keys()):
218          object_proto.attributes.add(
219              name=local_name,
220              checkpoint_key=local_key,
221              full_name=util.get_full_name(object_to_save))
222      else:
223        object_proto.attributes.add(
224            name=name,
225            checkpoint_key=key,
226            full_name=util.get_full_name(object_to_save))
227
228  return named_saveable_objects, feed_additions
229
230
231def _fill_object_graph_proto(graph_view,
232                             trackable_objects,
233                             node_ids,
234                             slot_variables):
235  """Name non-slot `Trackable`s and add them to `object_graph_proto`."""
236  object_graph_proto = trackable_object_graph_pb2.TrackableObjectGraph()
237  for checkpoint_id, trackable in enumerate(trackable_objects):
238    assert node_ids[trackable] == checkpoint_id
239    object_proto = object_graph_proto.nodes.add()
240    object_proto.slot_variables.extend(slot_variables.get(trackable, ()))
241    for child in graph_view.list_children(trackable):
242      child_proto = object_proto.children.add()
243      child_proto.node_id = node_ids[child.ref]
244      child_proto.local_name = child.name
245  return object_graph_proto
246
247
248def serialize_gathered_objects(graph_view,
249                               object_map=None,
250                               call_with_mapped_captures=None,
251                               saveables_cache=None):
252  """Create SaveableObjects and protos for gathered objects."""
253  trackable_objects, node_paths = graph_view.breadth_first_traversal()
254  object_names = object_identity.ObjectIdentityDictionary()
255  for obj, path in node_paths.items():
256    object_names[obj] = trackable_utils.object_path_to_string(path)
257  node_ids = object_identity.ObjectIdentityDictionary()
258  for node_id, node in enumerate(trackable_objects):
259    node_ids[node] = node_id
260  slot_variables = util.serialize_slot_variables(
261      trackable_objects=trackable_objects,
262      node_ids=node_ids,
263      object_names=object_names)
264  object_graph_proto = _fill_object_graph_proto(
265      graph_view=graph_view,
266      trackable_objects=trackable_objects,
267      node_ids=node_ids,
268      slot_variables=slot_variables)
269  named_saveable_objects, feed_additions, registered_savers = (
270      _add_attributes_to_object_graph(
271          trackable_objects=trackable_objects,
272          object_graph_proto=object_graph_proto,
273          node_ids=node_ids,
274          object_names=object_names,
275          object_map=object_map,
276          call_with_mapped_captures=call_with_mapped_captures,
277          saveables_cache=saveables_cache))
278  # Gather all trackables that have checkpoint values or descendants with
279  # checkpoint values, and add that info to the proto.
280  util.add_checkpoint_values_check(object_graph_proto)
281  return (named_saveable_objects, object_graph_proto, feed_additions,
282          registered_savers)
283
284
285def serialize_object_graph_with_registered_savers(graph_view, saveables_cache):
286  """Determine checkpoint keys for variables and build a serialized graph."""
287  return serialize_gathered_objects(graph_view, saveables_cache=saveables_cache)
288
289
290def frozen_saveables_and_savers(graph_view,
291                                object_map=None,
292                                to_graph=None,
293                                call_with_mapped_captures=None,
294                                saveables_cache=None):
295  """Generates SaveableObjects and registered savers in the frozen graph."""
296  if to_graph:
297    target_context = to_graph.as_default
298  else:
299    target_context = ops.NullContextmanager
300  with target_context():
301    named_saveable_objects, graph_proto, _, registered_savers = (
302        serialize_gathered_objects(graph_view, object_map,
303                                   call_with_mapped_captures, saveables_cache))
304    with ops.device("/cpu:0"):
305      object_graph_tensor = constant_op.constant(
306          graph_proto.SerializeToString(), dtype=dtypes.string)
307    named_saveable_objects.append(
308        base.NoRestoreSaveable(
309            tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY))
310  return named_saveable_objects, registered_savers
311