1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Extracts tensors for checkpointing while updating a TrackableObjectGraph. 16 17This is labelled "v1" because the methods here use SaveableObject, which will 18soon be deprecated. 19""" 20 21import collections 22 23from tensorflow.core.protobuf import trackable_object_graph_pb2 24from tensorflow.python.checkpoint import saveable_compat 25from tensorflow.python.checkpoint import util 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.saved_model import registration 30from tensorflow.python.trackable import base 31from tensorflow.python.trackable import python_state 32from tensorflow.python.trackable import trackable_utils 33from tensorflow.python.training.saving import saveable_object as saveable_object_lib 34from tensorflow.python.training.saving import saveable_object_util 35from tensorflow.python.util import object_identity 36 37# Factory and related info used to build a SaveableObject that saves a Trackable 38# to checkpoint. 39_CheckpointFactoryData = collections.namedtuple( 40 "_CheckpointFactoryData", ["factory", "name", "checkpoint_key"]) 41 42 43def get_checkpoint_factories_and_keys(object_names, object_map=None): 44 """Gets a map of saveable factories and corresponding checkpoint keys. 45 46 Args: 47 object_names: a dictionary that maps `Trackable` objects to auto-generated 48 string names. 49 object_map: a dictionary mapping `Trackable` to copied `Trackable` objects. 50 The copied objects are generated from `Trackable._map_resources()` which 51 copies the object into another graph. Generally only resource objects 52 (e.g. Variables, Tables) will be in this map. 53 54 Returns: 55 A tuple of ( 56 Dictionary mapping trackable -> list of _CheckpointFactoryData, 57 Dictionary mapping registered saver name -> {object name -> trackable}) 58 """ 59 checkpoint_factory_map = object_identity.ObjectIdentityDictionary() 60 unmapped_registered_savers = collections.defaultdict(dict) 61 for trackable, object_name in object_names.items(): 62 # object_to_save is only used to retrieve the saving functionality. For keys 63 # and other data, use the original `trackable`. 64 object_to_save = util.get_mapped_trackable(trackable, object_map) 65 66 saver_name = registration.get_registered_saver_name(object_to_save) 67 if saver_name: 68 # Add the original trackable instead of `object_to_save` to the returned 69 # dict because the original is needed for writing the object proto. 70 unmapped_registered_savers[saver_name][object_name] = trackable 71 else: 72 checkpoint_factory_map[trackable] = [] 73 for name, saveable_factory in ( 74 saveable_object_util.saveable_objects_from_trackable( 75 object_to_save).items()): # pylint: disable=protected-access 76 # Retrieve the legacy saveable name (for compatibility purposes during 77 # SaveableObject deprecation) 78 79 key_suffix = saveable_compat.get_saveable_name(object_to_save) or name 80 checkpoint_key = trackable_utils.checkpoint_key(object_name, key_suffix) 81 82 if not saveable_compat.force_checkpoint_conversion_enabled(): 83 # Make sure the set the name as the legacy saveable name if there 84 # is one (only when checkpoint conversion is diabled) 85 name = key_suffix 86 87 checkpoint_factory_map[trackable].append( 88 _CheckpointFactoryData( 89 factory=saveable_factory, 90 name=name, 91 checkpoint_key=checkpoint_key)) 92 return checkpoint_factory_map, unmapped_registered_savers 93 94 95def _add_attributes_to_object_graph(trackable_objects, object_graph_proto, 96 node_ids, object_names, object_map, 97 call_with_mapped_captures, saveables_cache): 98 """Create saveables/savers and corresponding protos in the object graph.""" 99 # The loop below creates TrackableObject protos in the TrackableObjectGraph, 100 # which are filled in the `_add_attributes_to_object_graph_for_*` methods. 101 for checkpoint_id, (trackable, unused_object_proto) in enumerate( 102 zip(trackable_objects, object_graph_proto.nodes)): 103 assert node_ids[trackable] == checkpoint_id 104 105 checkpoint_factory_map, unmapped_registered_savers = ( 106 get_checkpoint_factories_and_keys(object_names, object_map)) 107 108 # Add attributes, which describe what values are saved in checkpoint for 109 # this trackable. 110 registered_savers = _add_attributes_to_object_graph_for_registered_savers( 111 unmapped_registered_savers, object_graph_proto, node_ids, object_map) 112 named_saveable_objects, feed_additions = ( 113 _add_attributes_to_object_graph_for_saveable_objects( 114 checkpoint_factory_map, object_graph_proto, node_ids, object_map, 115 call_with_mapped_captures, saveables_cache)) 116 return named_saveable_objects, feed_additions, registered_savers 117 118 119def _add_attributes_to_object_graph_for_registered_savers( 120 unmapped_registered_savers, object_graph_proto, node_ids, object_map): 121 """Fills the object graph proto with data about the registered savers.""" 122 registered_savers = collections.defaultdict(dict) 123 for saver_name, trackables in unmapped_registered_savers.items(): 124 for object_name, trackable in trackables.items(): 125 object_proto = object_graph_proto.nodes[node_ids[trackable]] 126 object_proto.registered_saver.name = saver_name 127 object_proto.registered_saver.object_name = object_name 128 129 object_to_save = util.get_mapped_trackable(trackable, object_map) 130 registered_savers[saver_name][object_name] = object_to_save 131 return registered_savers 132 133 134def _add_attributes_to_object_graph_for_saveable_objects( 135 checkpoint_factory_map, object_graph_proto, node_ids, object_map, 136 call_with_mapped_captures, saveables_cache): 137 """Create SaveableObjects and corresponding SerializedTensor protos.""" 138 named_saveable_objects = [] 139 if saveables_cache is None: 140 # No SaveableObject caching. Either we're executing eagerly, or building a 141 # static save which is specialized to the current Python state. 142 feed_additions = None 143 else: 144 # If we are caching SaveableObjects, we need to build up a feed_dict with 145 # functions computing volatile Python state to be saved with the 146 # checkpoint. 147 feed_additions = {} 148 for trackable, factory_data_list in checkpoint_factory_map.items(): 149 object_proto = object_graph_proto.nodes[node_ids[trackable]] 150 object_to_save = util.get_mapped_trackable(trackable, object_map) 151 if saveables_cache is not None: 152 cached_attributes = saveables_cache.setdefault(object_to_save, {}) 153 else: 154 cached_attributes = None 155 156 for factory_data in factory_data_list: 157 name = factory_data.name 158 key = factory_data.checkpoint_key 159 saveable_factory = factory_data.factory 160 161 # See if we can skip saving this checkpoint key. 162 saveables = cached_attributes.get(name) if cached_attributes else None 163 if saveables is not None: 164 for saveable in saveables: 165 if key not in saveable.name: 166 # The checkpoint key for this SaveableObject is different. We 167 # need to re-create it. 168 saveables = None 169 del cached_attributes[name] 170 break 171 172 if saveables is None: 173 if callable(saveable_factory): 174 maybe_saveable = saveable_object_util.create_saveable_object( 175 name, key, saveable_factory, call_with_mapped_captures) 176 else: 177 maybe_saveable = saveable_factory 178 if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): 179 saveables = (maybe_saveable,) 180 else: 181 saveables = tuple( 182 saveable_object_util.saveable_objects_for_op( 183 op=maybe_saveable, name=key)) 184 for saveable in saveables: 185 if key not in saveable.name: 186 raise AssertionError( 187 f"The object {trackable} produced a SaveableObject with name " 188 f"'{saveable.name}' for attribute '{name}'. Expected a name" 189 f" containing '{key}'.") 190 if cached_attributes is not None: 191 cached_attributes[name] = saveables 192 193 if isinstance(object_to_save, python_state.PythonState): 194 assert len(saveables) == 1 195 saveable = saveables[0] 196 197 if feed_additions is None: 198 assert saveables_cache is None 199 # If we're not caching saveables, then we're either executing 200 # eagerly or building a static save/restore (e.g. for a 201 # SavedModel). In either case, we should embed the current Python 202 # state in the graph rather than relying on a feed dict. 203 saveables = (saveable.freeze(),) 204 else: 205 feed_additions.update(saveable.feed_dict_additions()) 206 named_saveable_objects.extend(saveables) 207 208 # Update the object proto. 209 # For updated Trackables that override serialize_to_tensors, add an 210 # attribute for each tensor that is serialized. 211 # For Trackables that have SaveableObjects or a legacy saveable name, 212 # add a single attribute to the proto. 213 if (isinstance(saveables[0], saveable_object_util.TrackableSaveable) and 214 (saveable_compat.force_checkpoint_conversion_enabled() or 215 saveable_compat.get_saveable_name(object_to_save) is None)): 216 for local_name, local_key in ( 217 saveables[0].get_proto_names_and_checkpoint_keys()): 218 object_proto.attributes.add( 219 name=local_name, 220 checkpoint_key=local_key, 221 full_name=util.get_full_name(object_to_save)) 222 else: 223 object_proto.attributes.add( 224 name=name, 225 checkpoint_key=key, 226 full_name=util.get_full_name(object_to_save)) 227 228 return named_saveable_objects, feed_additions 229 230 231def _fill_object_graph_proto(graph_view, 232 trackable_objects, 233 node_ids, 234 slot_variables): 235 """Name non-slot `Trackable`s and add them to `object_graph_proto`.""" 236 object_graph_proto = trackable_object_graph_pb2.TrackableObjectGraph() 237 for checkpoint_id, trackable in enumerate(trackable_objects): 238 assert node_ids[trackable] == checkpoint_id 239 object_proto = object_graph_proto.nodes.add() 240 object_proto.slot_variables.extend(slot_variables.get(trackable, ())) 241 for child in graph_view.list_children(trackable): 242 child_proto = object_proto.children.add() 243 child_proto.node_id = node_ids[child.ref] 244 child_proto.local_name = child.name 245 return object_graph_proto 246 247 248def serialize_gathered_objects(graph_view, 249 object_map=None, 250 call_with_mapped_captures=None, 251 saveables_cache=None): 252 """Create SaveableObjects and protos for gathered objects.""" 253 trackable_objects, node_paths = graph_view.breadth_first_traversal() 254 object_names = object_identity.ObjectIdentityDictionary() 255 for obj, path in node_paths.items(): 256 object_names[obj] = trackable_utils.object_path_to_string(path) 257 node_ids = object_identity.ObjectIdentityDictionary() 258 for node_id, node in enumerate(trackable_objects): 259 node_ids[node] = node_id 260 slot_variables = util.serialize_slot_variables( 261 trackable_objects=trackable_objects, 262 node_ids=node_ids, 263 object_names=object_names) 264 object_graph_proto = _fill_object_graph_proto( 265 graph_view=graph_view, 266 trackable_objects=trackable_objects, 267 node_ids=node_ids, 268 slot_variables=slot_variables) 269 named_saveable_objects, feed_additions, registered_savers = ( 270 _add_attributes_to_object_graph( 271 trackable_objects=trackable_objects, 272 object_graph_proto=object_graph_proto, 273 node_ids=node_ids, 274 object_names=object_names, 275 object_map=object_map, 276 call_with_mapped_captures=call_with_mapped_captures, 277 saveables_cache=saveables_cache)) 278 # Gather all trackables that have checkpoint values or descendants with 279 # checkpoint values, and add that info to the proto. 280 util.add_checkpoint_values_check(object_graph_proto) 281 return (named_saveable_objects, object_graph_proto, feed_additions, 282 registered_savers) 283 284 285def serialize_object_graph_with_registered_savers(graph_view, saveables_cache): 286 """Determine checkpoint keys for variables and build a serialized graph.""" 287 return serialize_gathered_objects(graph_view, saveables_cache=saveables_cache) 288 289 290def frozen_saveables_and_savers(graph_view, 291 object_map=None, 292 to_graph=None, 293 call_with_mapped_captures=None, 294 saveables_cache=None): 295 """Generates SaveableObjects and registered savers in the frozen graph.""" 296 if to_graph: 297 target_context = to_graph.as_default 298 else: 299 target_context = ops.NullContextmanager 300 with target_context(): 301 named_saveable_objects, graph_proto, _, registered_savers = ( 302 serialize_gathered_objects(graph_view, object_map, 303 call_with_mapped_captures, saveables_cache)) 304 with ops.device("/cpu:0"): 305 object_graph_tensor = constant_op.constant( 306 graph_proto.SerializeToString(), dtype=dtypes.string) 307 named_saveable_objects.append( 308 base.NoRestoreSaveable( 309 tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY)) 310 return named_saveable_objects, registered_savers 311