1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Import a trackable object from a SavedModel.""" 16 17import collections 18import functools 19import os 20import sys 21 22from tensorflow.core.protobuf import graph_debug_info_pb2 23from tensorflow.python.checkpoint import checkpoint 24from tensorflow.python.checkpoint import checkpoint_options 25from tensorflow.python.checkpoint import graph_view 26from tensorflow.python.checkpoint import restore 27from tensorflow.python.distribute import distribute_utils 28from tensorflow.python.distribute import distribution_strategy_context as ds_context 29from tensorflow.python.distribute import values_util 30from tensorflow.python.eager import context 31from tensorflow.python.eager import function 32from tensorflow.python.eager import function_saved_model_utils 33from tensorflow.python.framework import config 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import ops 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import control_flow_ops 40from tensorflow.python.ops import lookup_ops 41from tensorflow.python.ops import resource_variable_ops 42from tensorflow.python.ops import variables 43from tensorflow.python.saved_model import function_deserialization 44from tensorflow.python.saved_model import load_options 45from tensorflow.python.saved_model import load_v1_in_v2 46from tensorflow.python.saved_model import loader_impl 47from tensorflow.python.saved_model import registration 48from tensorflow.python.saved_model import revived_types 49from tensorflow.python.saved_model import utils_impl as saved_model_utils 50from tensorflow.python.saved_model.pywrap_saved_model import metrics 51from tensorflow.python.trackable import asset 52from tensorflow.python.trackable import autotrackable 53from tensorflow.python.trackable import base 54from tensorflow.python.trackable import data_structures 55from tensorflow.python.trackable import resource 56from tensorflow.python.trackable import trackable_utils 57from tensorflow.python.training.saving import saveable_object_util 58from tensorflow.python.util import nest 59from tensorflow.python.util.tf_export import tf_export 60 61# API label for SavedModel metrics. 62_LOAD_V2_LABEL = "load_v2" 63# Built-in registrations use the "oneof kind" field in the SavedObject proto, 64# instead of "registered_name" field. The "kind" field has almost the same 65# functionality as the registered_name, but only contains built-in TensorFlow 66# types (like variable, functions, assets). 67_BUILT_IN_REGISTRATIONS = { 68 "asset": asset.Asset, 69 "resource": resource.RestoredResource, 70 "constant": function_saved_model_utils.TrackableConstant} 71 72 73def _unused_handle(): 74 """Returns a placeholder as a handle that is not supposed to be accessed.""" 75 error_message = ("Trying to access a placeholder that is not supposed to be " 76 "executed. This means you are executing a graph generated " 77 "from the cross-replica context in an in-replica context.") 78 save_error_message = ( 79 "It seems that you are trying to save a " 80 "tf.types.experimental.ConcreteFunction that involves a distributed " 81 "model, and the model contains parts that are loaded form a SavedModel. " 82 "It's not supported to save such tf.types.experimental.ConcreteFunction. " 83 "Try saving a tf.function with input_signature instead, and file a bug if" 84 " there are still issues.") 85 86 assert_op = control_flow_ops.Assert( 87 array_ops.placeholder_with_default(False, shape=()), [error_message]) 88 if (not context.executing_eagerly() 89 ) and ops.get_default_graph().building_function: 90 ops.get_default_graph().mark_as_unsaveable(save_error_message) 91 92 with ops.control_dependencies([assert_op]): 93 return array_ops.placeholder(dtype=dtypes.resource) 94 95 96class _WrapperFunction(function.ConcreteFunction): 97 """A class wraps a concrete function to handle different distributed contexts. 98 99 The reason for wrapping a concrete function is because the _captured_inputs 100 fields used for in-replica context and cross-replica context are different. 101 When `load()` is called from within a tf.distribute.strategy scope, the 102 captured inputs are distributed variables. When using these distributed 103 variables during calling the function, we need different approaches when it is 104 in-replica and when it is not in-replica. When it is in replica, naturally we 105 should use the corresponding component of the distributed variable; when it is 106 not in-replica, calling the function should mean that it is constructing a 107 graph that is not actually going to be used. A typical use case is when 108 constructing a functional model. In this case, return a placeholder with a 109 control dependency to ensure that is never accessed. 110 """ 111 112 def __init__(self, concrete_function): 113 # Shallow copy the concrete_function 114 self.__dict__.update(vars(concrete_function)) 115 116 def _call_flat(self, args, captured_inputs, cancellation_manager=None): 117 118 def get_handle(x): 119 return x.handle if distribute_utils.is_distributed_variable(x) else x 120 121 def get_unused_handle(x): 122 return _unused_handle() if distribute_utils.is_distributed_variable(x) \ 123 else x 124 125 if (ds_context.get_replica_context() is not None or 126 values_util.is_saving_non_distributed()): 127 # If we're in the replica context or are saving a non-distributed version 128 # of the model, we resolve the captured variables to the corresponding 129 # resource handle. In both situation we call var.handle, but it has 130 # different behavior. In the replica context, var.handle resolves the 131 # replica local variable handle if the variable is replicated. When saving 132 # a non-distributed version of the model, var.handle resolves to the 133 # primary variable handle, since we only save one copy of a replicated 134 # variable. 135 captured_inputs = list(map(get_handle, captured_inputs)) 136 else: # cross-replica context 137 captured_inputs = list(map(get_unused_handle, captured_inputs)) 138 return super(_WrapperFunction, self)._call_flat(args, captured_inputs, 139 cancellation_manager) 140 141 142class Loader(object): 143 """Helper class to load an object-based SavedModel.""" 144 145 def __init__(self, object_graph_proto, saved_model_proto, export_dir, 146 ckpt_options, save_options, filters): 147 meta_graph = saved_model_proto.meta_graphs[0] 148 self._asset_file_def = meta_graph.asset_file_def 149 self._operation_attributes = { 150 node.name: node.attr for node in meta_graph.graph_def.node} 151 self._proto = object_graph_proto 152 self._export_dir = export_dir 153 self._concrete_functions = ( 154 function_deserialization.load_function_def_library( 155 library=meta_graph.graph_def.library, 156 saved_object_graph=self._proto, 157 wrapper_function=_WrapperFunction)) 158 # Store a set of all concrete functions that have been set up with 159 # captures. 160 self._restored_concrete_functions = set() 161 self._checkpoint_options = ckpt_options 162 self._save_options = save_options 163 164 self._pretty_printer = checkpoint.ObjectGraphProtoPrettyPrinter(self._proto) 165 166 # Stores user-defined node_filters argument. 167 self._node_filters = filters 168 # Stores map of string paths to integers. 169 self._node_path_to_id = self._convert_node_paths_to_ints() 170 self._loaded_nodes = {} 171 if isinstance(filters, dict): 172 # If node_filters is a dict, then the values may contain already created 173 # trackable objects. In this case, create a dictionary mapping node IDs to 174 # the already created nodes. This dict will be updated in 175 # `_retrieve_all_filtered_nodes` with tracked children. 176 for node_path, node in filters.items(): 177 if isinstance(node, tuple): 178 self._loaded_nodes[self._node_path_to_id[node_path]] = node 179 else: 180 self._loaded_nodes[self._node_path_to_id[node_path]] = (node, setattr) 181 182 # Get a list of all integer node ids to load, or None if all nodes should be 183 # loaded. This list includes ids of child nodes. 184 self._filtered_nodes = self._retrieve_all_filtered_nodes() 185 186 # Order all nodes or filtered nodes using the dependencies. 187 self._ordered_node_ids = self._generate_ordered_node_ids() 188 189 self._load_all() 190 191 if not save_options.experimental_skip_checkpoint: 192 self._restore_checkpoint() 193 for node in self._nodes: 194 if isinstance(node, resource.CapturableResource): 195 init_op = node._initialize() # pylint: disable=protected-access 196 if not context.executing_eagerly(): 197 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 198 199 def _convert_node_paths_to_ints(self): 200 """Maps all string node paths in node_filters to the int node ids.""" 201 if self._node_filters is None: 202 return None 203 path_to_int = {} 204 for node_id in self._node_filters: 205 int_node_id = None 206 if isinstance(node_id, str): 207 node_path = node_id.split(".") 208 if node_path[0] != "root": 209 raise ValueError( 210 "When passing string identifiers to node_filters, the first name" 211 f" must be root. Received {node_path[0]}.") 212 int_node_id = 0 213 for n, name in enumerate(node_path[1:]): 214 int_node_id = self._find_node_child( 215 int_node_id, name, ".".join(node_path[:n+2])) 216 path_to_int[node_id] = int_node_id 217 else: 218 raise TypeError("Elements in node_filters must be strings.") 219 return path_to_int 220 221 def _retrieve_all_filtered_nodes(self): 222 """Traverses through the object graph to get the IDs of all nodes to load. 223 224 As a side-effect, if node_filters is a dictionary that contains already- 225 created objects, then the children tracked by those objects will be 226 added to node_filters. 227 228 Returns: 229 List of all nodes to load, or None if all nodes should be loaded. 230 231 """ 232 if self._node_filters is None: 233 return None # All nodes should be loaded. 234 235 all_filtered_nodes = set() 236 nodes_to_visit = list(self._node_filters) 237 238 while nodes_to_visit: 239 node_path = nodes_to_visit.pop(0) 240 node_id = self._node_path_to_id[node_path] 241 if node_id in all_filtered_nodes: 242 continue 243 all_filtered_nodes.add(node_id) 244 245 node, setter = self._loaded_nodes.get(node_id, (None, None)) 246 if node is not None: 247 if not isinstance(node, base.Trackable): 248 raise TypeError( 249 "Error when processing dictionary values passed to nodes_to_load." 250 f"Object at {node_path} is expected to be a checkpointable (i.e. " 251 "'trackable') TensorFlow object (e.g. tf.Variable, tf.Module or " 252 "Keras layer).") 253 node._maybe_initialize_trackable() # pylint: disable=protected-access 254 255 for reference in self._proto.nodes[node_id].children: 256 child_object, _ = self._loaded_nodes.get( 257 reference.node_id, (None, None)) 258 259 # See if node already tracks the child reference, in which case add the 260 # child to the loaded_nodes dict. 261 if child_object is None and node is not None: 262 child_object = node._lookup_dependency(reference.local_name) # pylint: disable=protected-access 263 if isinstance(child_object, data_structures.TrackableDataStructure): 264 # Make setattr a noop to avoid overwriting already existing data 265 # structures. 266 setter = lambda *args: None 267 268 self._loaded_nodes[reference.node_id] = (child_object, setter) 269 270 child_path = "{}.{}".format(node_path, reference.local_name) 271 self._node_path_to_id[child_path] = reference.node_id 272 nodes_to_visit.append(child_path) 273 274 if 0 in all_filtered_nodes: 275 return None 276 return all_filtered_nodes 277 278 def _find_node_child(self, node_id, child_name, path): 279 for reference in self._proto.nodes[node_id].children: 280 if reference.local_name == child_name: 281 return reference.node_id 282 raise ValueError(f"Unable to find node {path}.") 283 284 def _load_all(self): 285 """Loads all nodes and functions from the SavedModel and their edges.""" 286 self._load_nodes() 287 self._load_edges() 288 289 # Set up concrete functions that aren't part of the object graph 290 # (e.g. gradient functions) 291 self._setup_remaining_functions() 292 self._load_checkpoint_save_and_restore_functions() 293 294 def _load_checkpoint_save_and_restore_functions(self): 295 """Restores the checkpoint-related save/restore functions to all nodes.""" 296 for node_id, proto in self._iter_all_nodes(): 297 node = self.get(node_id) 298 if proto.saveable_objects.keys() == { 299 trackable_utils.SERIALIZE_TO_TENSORS_NAME}: 300 # Restore Trackable serialize- and restore-from-tensor functions. 301 assert len(proto.saveable_objects) == 1 302 saveable_object_proto = next(iter(proto.saveable_objects.values())) 303 save_fn_id = saveable_object_proto.save_function 304 restore_fn_id = saveable_object_proto.restore_function 305 node._serialize_to_tensors = self.get(save_fn_id) # pylint: disable=protected-access 306 node._restore_from_tensors = self.get(restore_fn_id) # pylint: disable=protected-access 307 else: 308 # Restore legacy SaveableObject functions. 309 saveable_fn_by_name = {} 310 for name, saveable_object_proto in proto.saveable_objects.items(): 311 save_fn_id = saveable_object_proto.save_function 312 restore_fn_id = saveable_object_proto.restore_function 313 saveable_fn_by_name[name] = (self.get(save_fn_id), 314 self.get(restore_fn_id)) 315 316 node._self_saveable_object_factories = ( # pylint: disable=protected-access 317 saveable_object_util.recreate_saveable_objects(saveable_fn_by_name)) 318 319 def _load_edges(self): 320 """Adds edges from objects to other objects and functions.""" 321 for node_id, object_proto in self._iter_all_nodes(): 322 self._add_object_graph_edges(object_proto, node_id) 323 324 # If root object isn't loaded, then create edges from the root for 325 # checkpoint compatibility. 326 if self._filtered_nodes is not None and 0 not in self._filtered_nodes: 327 root = self.get(0) 328 for node_path in self._node_filters: 329 loaded_node = self._nodes[self._node_path_to_id[node_path]] 330 path = node_path.split(".") 331 current_node = root 332 for name in path[1:-1]: 333 if not hasattr(current_node, name): 334 setattr(current_node, name, self._recreate_base_user_object()[0]) 335 current_node = getattr(current_node, name) 336 if not hasattr(current_node, path[-1]): 337 setattr(current_node, path[-1], loaded_node) 338 339 def _add_object_graph_edges(self, proto, node_id): 340 """Adds edges from an object to its children.""" 341 obj = self._nodes[node_id] 342 setter = self._node_setters[node_id] 343 344 for reference in proto.children: 345 setter(obj, reference.local_name, self._nodes[reference.node_id]) 346 # Note: if an object has an attribute `__call__` add a class method 347 # that allows `obj()` syntax to work. This is done per-instance to 348 # allow `callable` to be used to find out if an object is callable. 349 if reference.local_name == "__call__" and not callable(obj): 350 setattr(type(obj), "__call__", _call_attribute) 351 352 def _setup_remaining_functions(self): 353 concrete_function_names = sorted(self._proto.concrete_functions.keys()) 354 for name in concrete_function_names: 355 if name in self._restored_concrete_functions: 356 continue 357 self._setup_function_captures(name, self._nodes) 358 359 def _setup_function_captures(self, concrete_function_name, nodes): 360 """Setup captures and variables in a restored function.""" 361 if concrete_function_name in self._restored_concrete_functions: 362 return 363 self._restored_concrete_functions.add(concrete_function_name) 364 concrete_function = self._concrete_functions[concrete_function_name] 365 proto = self._proto.concrete_functions[concrete_function_name] 366 inputs = [nodes[node_id] for node_id in proto.bound_inputs] 367 function_saved_model_utils.restore_captures(concrete_function, inputs) 368 369 def _initialize_loaded_nodes(self): 370 nodes = {} 371 node_setters = {} 372 for node_id, (node, setter) in self._loaded_nodes.items(): 373 nodes[node_id] = node 374 node_setters[node_id] = setter 375 return nodes, node_setters 376 377 def _get_node_dependencies(self, proto): 378 """Returns a dictionary of all dependencies of an object. 379 380 Args: 381 proto: A SavedObject proto. 382 383 Returns: 384 Dict mapping string dependency name *or* int node id to the node id. 385 The int node id key is used for mapping function captures. 386 """ 387 dependencies = {ref.local_name: ref.node_id for ref in proto.dependencies} 388 kind = proto.WhichOneof("kind") 389 if kind == "function": 390 concrete_functions = proto.function.concrete_functions 391 for fn_name in concrete_functions: 392 for bound_input in self._proto.concrete_functions[fn_name].bound_inputs: 393 dependencies[bound_input] = bound_input 394 elif kind == "bare_concrete_function": 395 fn_name = proto.bare_concrete_function.concrete_function_name 396 for bound_input in self._proto.concrete_functions[fn_name].bound_inputs: 397 dependencies[bound_input] = bound_input 398 elif kind == "resource": 399 # Make sure that the resource creator is listed as a dependency. 400 for child in proto.children: 401 if child.local_name == "_create_resource": 402 dependencies["_create_resource"] = child.node_id 403 return dependencies 404 405 def _generate_ordered_node_ids(self): 406 """Orders the node ids so that dependencies appear first.""" 407 if self._filtered_nodes is None: 408 unordered_ids = range(len(self._proto.nodes)) 409 else: 410 unordered_ids = list(self._filtered_nodes) 411 412 # Maps node ids -> list of dependencies (ids of other nodes that must be 413 # loaded before it). 414 dependency_map = collections.defaultdict(list) 415 for node_id in unordered_ids: 416 deps = dependency_map[node_id] 417 if self._loaded_nodes.get(node_id) is not None: 418 # Deps are only used if the node has not been created. 419 continue 420 proto = self._proto.nodes[node_id] 421 for dep in set(self._get_node_dependencies(proto).values()): 422 deps.append(dep) 423 if self._filtered_nodes is not None and dep not in self._filtered_nodes: 424 raise ValueError( 425 "Unable to partially load SavedModel since the specified filter " 426 "does not include all required objects for loading (e.g. " 427 "variables used in functions or deserialization dependencies). " 428 "Please include this path in the filter: " 429 f"{self._pretty_printer.node_names[dep]}") 430 431 # Add optimizer slot variable to dependency map. 432 prev_slot = None 433 for slot_variable_proto in proto.slot_variables: 434 slot_variable_node_id = slot_variable_proto.slot_variable_node_id 435 # The optimizer and original variable must be created before the slot 436 # variable, since the slot variable is generated using the Optimizer's 437 # add_slot API. 438 slot_deps = dependency_map[slot_variable_node_id] 439 slot_deps.append(node_id) 440 slot_deps.append(slot_variable_proto.original_variable_node_id) 441 442 if prev_slot is not None: 443 # Add previous slot to deps so that the optimizer slot variables are 444 # added in order. The ordering is needed because the slot name and 445 # variable are both added to ordered lists, which are exposed to the 446 # user via `Optimizer.get_slot_names()` and `Optimizer.weights`. 447 # TODO(kathywu): Maybe enforce some sort of deterministic ordering in 448 # `order_by_dependency` to avoid doing this? 449 slot_deps.append(prev_slot) 450 prev_slot = slot_variable_node_id 451 try: 452 return list(trackable_utils.order_by_dependency(dependency_map)) 453 except trackable_utils.CyclicDependencyError: 454 # This should not happen since there is already a validation for cycles 455 # when saving, but raise an error just in case. 456 raise ValueError("Encountered a cycle in the deserialization dependencies" 457 "in the SavedModel. This is extremely unexpected, please" 458 "file a bug and make sure you are not manually modifying" 459 " the SavedModel.") 460 461 def _iter_all_nodes(self): 462 for node_id in self._ordered_node_ids: 463 yield node_id, self._proto.nodes[node_id] 464 465 def _load_nodes(self): 466 """Load all saved objects.""" 467 # `nodes` maps from node ids to recreated objects 468 # `node_setters` maps from node ids to setter functions 469 # (same signature as setattr) for setting children. 470 nodes, node_setters = self._initialize_loaded_nodes() 471 472 # Figure out which objects are slot variables. These objects are created 473 # with Optimizer.add_slot rather than _recreate_variable. 474 # Maps slot node id -> optimizer node id, SlotVariableReference proto 475 slot_variable_node_ids = {} 476 477 for node_id, proto in self._iter_all_nodes(): 478 for slot_variable_proto in proto.slot_variables: 479 slot_variable_node_id = slot_variable_proto.slot_variable_node_id 480 slot_variable_node_ids[slot_variable_node_id] = (node_id, 481 slot_variable_proto) 482 483 # Re-create everything. 484 for node_id, proto in self._iter_all_nodes(): 485 if nodes.get(node_id) is not None: 486 continue 487 elif node_id in slot_variable_node_ids: 488 # Use the public Optimizer interface when creating slot variables. 489 optimizer_node_id, slot_variable_proto = slot_variable_node_ids[node_id] 490 optimizer_object = nodes[optimizer_node_id] 491 optimized_variable = nodes[ 492 slot_variable_proto.original_variable_node_id] 493 slot_variable = optimizer_object.add_slot( 494 var=optimized_variable, 495 slot_name=slot_variable_proto.slot_name) 496 nodes[slot_variable_proto.slot_variable_node_id] = slot_variable 497 node_setters[slot_variable_proto.slot_variable_node_id] = setattr 498 else: 499 node, setter = self._recreate(proto, node_id, nodes) 500 nodes[node_id] = node 501 node_setters[node_id] = setter 502 503 # If root object is not loaded, add a dummy root object for checkpoint 504 # compatibility. 505 if 0 not in nodes: 506 nodes[0] = self._recreate_base_user_object()[0] 507 508 self._nodes = [nodes.get(node_id) 509 for node_id in range(len(self._proto.nodes))] 510 self._node_setters = node_setters 511 512 def _restore_checkpoint(self): 513 """Load state from checkpoint into the deserialized objects.""" 514 variables_path = saved_model_utils.get_variables_path(self._export_dir) 515 # TODO(b/205010730): Clean use of private methods of TrackableSaver. 516 # pylint: disable=protected-access 517 saver = checkpoint.TrackableSaver(graph_view.ObjectGraphView(self.get(0))) 518 with ops.device("CPU"): 519 saver._file_prefix_placeholder = constant_op.constant(variables_path) 520 if self._save_options.allow_partial_checkpoint: 521 load_status = saver.restore(variables_path, 522 self._checkpoint_options).expect_partial() 523 load_status.assert_nontrivial_match() 524 else: 525 load_status = saver.restore(variables_path, self._checkpoint_options) 526 load_status.assert_existing_objects_matched() 527 ckpt = load_status._checkpoint 528 529 if not context.executing_eagerly(): 530 # When running in eager mode, the `restore` call above has already run and 531 # restored the state of trackables, and calling `position.restore_ops()` 532 # would re-run the restore. In graph mode, that will return a cached list 533 # of ops that must run to restore the object on that position. We have to 534 # wire them in the initializers of the objects so that they get 535 # initialized properly when using common practices (e.g. the ones used by 536 # ManagedSession) without further user action. 537 for object_id, obj in dict(ckpt.object_by_proto_id).items(): 538 position = restore.CheckpointPosition(checkpoint=ckpt, 539 proto_id=object_id) 540 registered_saver = position.get_registered_saver_name() 541 if registered_saver: 542 raise NotImplementedError( 543 "Loading a SavedModel that uses registered checkpoint saver is " 544 f"not supported in graph mode. The loaded object {obj} uses the " 545 f"saver registered with the name {registered_saver}.") 546 547 restore_ops = position.restore_ops() 548 if restore_ops: 549 if resource_variable_ops.is_resource_variable(obj): 550 if len(restore_ops) == 1: 551 obj._initializer_op = restore_ops[0] 552 else: 553 obj._initializer_op = control_flow_ops.group(*restore_ops) 554 elif isinstance(obj, lookup_ops.LookupInterface): 555 # We don't need to check for eager execution here, since this code 556 # path should only be taken if we are restoring in graph mode. 557 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops) 558 else: 559 raise NotImplementedError( 560 f"Unable to restore state of object {obj} from the checkpoint.") 561 562 def adjust_debug_info_func_names(self, debug_info): 563 """Rewrite func names in the debug info by using the concrete func names.""" 564 output_debug_info = graph_debug_info_pb2.GraphDebugInfo() 565 output_debug_info.files[:] = debug_info.files 566 for key in debug_info.traces: 567 node, func = key.split("@") 568 new_func = "" 569 if func in self._concrete_functions: 570 new_func = self._concrete_functions[func].function_def.signature.name 571 output_debug_info.traces[node + "@" + new_func].CopyFrom( 572 debug_info.traces[key]) 573 return output_debug_info 574 575 def get(self, node_id): 576 if isinstance(node_id, str): 577 node_id = self._node_path_to_id[node_id] 578 return self._nodes[node_id] 579 580 def _recreate(self, proto, node_id, nodes): 581 """Creates a Python object from a SavedObject protocol buffer. 582 583 Args: 584 proto: a SavedObject proto 585 node_id: int, the index of this object in the SavedObjectGraph node list. 586 nodes: dict mapping int node_ids -> created objects. 587 588 Returns: 589 The recreated object, and the set-attribute function for reconnecting 590 the trackable children. 591 """ 592 registered_class = registration.get_registered_class(proto.registered_name) 593 if registered_class is None: 594 registered_class = _BUILT_IN_REGISTRATIONS.get(proto.WhichOneof("kind")) 595 596 dependencies = {} 597 for key, dep_node_id in self._get_node_dependencies(proto).items(): 598 dependencies[key] = nodes[dep_node_id] 599 600 if registered_class: 601 obj = registered_class._deserialize_from_proto( # pylint: disable=protected-access 602 proto=proto.serialized_user_proto, 603 object_proto=proto, 604 dependencies=dependencies, 605 export_dir=self._export_dir, 606 asset_file_def=self._asset_file_def, 607 operation_attributes=self._operation_attributes) 608 if isinstance(obj, base.Trackable): 609 setter = type(obj)._add_trackable_child # pylint: disable=protected-access 610 else: 611 # Returned object may be non-Trackable (e.g. when restoring captures). 612 setter = setattr 613 return obj, setter 614 else: 615 return self._recreate_default(proto, node_id, dependencies) 616 617 def _recreate_default(self, proto, node_id, deps): 618 """Creates a Python object from a SavedObject protocol buffer.""" 619 factory = { 620 "user_object": ( 621 lambda: self._recreate_user_object(proto.user_object, node_id)), 622 "function": lambda: self._recreate_function(proto.function, deps), 623 "bare_concrete_function": functools.partial( 624 self._recreate_bare_concrete_function, 625 proto=proto.bare_concrete_function, dependencies=deps), 626 "variable": lambda: self._recreate_variable(proto.variable), 627 "captured_tensor": functools.partial( 628 self._get_tensor_from_fn, proto.captured_tensor), 629 } 630 kind = proto.WhichOneof("kind") 631 if kind not in factory: 632 raise ValueError(f"Unknown SavedObject type: {kind}. Expected one of " 633 f"{list(factory.keys())}.") 634 return factory[kind]() 635 636 def _recreate_user_object(self, proto, node_id): 637 """Instantiates a SavedUserObject.""" 638 looked_up = revived_types.deserialize(proto) 639 if looked_up is None: 640 return self._recreate_base_user_object(proto, node_id) 641 return looked_up 642 643 def _recreate_base_user_object(self, proto=None, node_id=None): 644 del proto, node_id 645 # Note: each user object has its own class. This allows making each one 646 # individually callable by adding a `__call__` method to the classes of 647 # the objects instances that have a `__call__` property. 648 649 class _UserObject(autotrackable.AutoTrackable): 650 pass 651 652 return _UserObject(), setattr 653 654 def _recreate_function(self, proto, dependencies): 655 fn = function_deserialization.recreate_function( 656 proto, self._concrete_functions) 657 for name in proto.concrete_functions: 658 self._setup_function_captures(name, dependencies) 659 return fn, setattr 660 661 def _recreate_bare_concrete_function(self, proto, dependencies): 662 fn = function_deserialization.setup_bare_concrete_function( 663 proto, self._concrete_functions) 664 self._setup_function_captures(proto.concrete_function_name, dependencies) 665 return fn, setattr 666 667 def _recreate_variable(self, proto): 668 name = proto.name if proto.name else None 669 if name is not None: 670 dbg_name = name 671 else: 672 dbg_name = "<variable loaded from saved model>" 673 synchronization, aggregation, trainable = ( 674 variables.validate_synchronization_aggregation_trainable( 675 proto.synchronization, proto.aggregation, proto.trainable, 676 name=dbg_name)) 677 678 def uninitialized_variable_creator(next_creator, **kwargs): 679 """A variable creator that creates uninitialized variables.""" 680 del next_creator 681 return resource_variable_ops.UninitializedVariable(**kwargs) 682 683 # Create a variable_creator_scope that creates uninitialized variables with 684 # a lower priority such that a potential distributed variable_creator_scope 685 # can take precedence. 686 with ops.get_default_graph()._variable_creator_scope( # pylint: disable=protected-access 687 uninitialized_variable_creator, 688 priority=50): 689 saved_device = proto.device 690 load_with_device = ( 691 self._save_options.experimental_variable_policy 692 ._save_variable_devices() and config.get_soft_device_placement() and 693 saved_device) 694 if load_with_device: 695 with ops.device(saved_device): 696 return variables.Variable( 697 shape=proto.shape, 698 dtype=proto.dtype, 699 name=name, 700 trainable=trainable, 701 synchronization=synchronization, 702 aggregation=aggregation), setattr 703 else: 704 return variables.Variable( 705 shape=proto.shape, 706 dtype=proto.dtype, 707 name=name, 708 trainable=trainable, 709 synchronization=synchronization, 710 aggregation=aggregation), setattr 711 712 def _get_tensor_from_fn(self, proto): 713 outer_graph = self._concrete_functions[proto.concrete_function].graph 714 captured_tensor = outer_graph.get_tensor_by_name(proto.name) 715 return captured_tensor, setattr 716 717 718def _call_attribute(instance, *args, **kwargs): 719 return instance.__call__(*args, **kwargs) 720 721 722@tf_export("saved_model.load", v1=["saved_model.load_v2"]) 723def load(export_dir, tags=None, options=None): 724 """Load a SavedModel from `export_dir`. 725 726 Signatures associated with the SavedModel are available as functions: 727 728 ```python 729 imported = tf.saved_model.load(path) 730 f = imported.signatures["serving_default"] 731 print(f(x=tf.constant([[1.]]))) 732 ``` 733 734 Objects exported with `tf.saved_model.save` additionally have trackable 735 objects and functions assigned to attributes: 736 737 ```python 738 exported = tf.train.Checkpoint(v=tf.Variable(3.)) 739 exported.f = tf.function( 740 lambda x: exported.v * x, 741 input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) 742 tf.saved_model.save(exported, path) 743 imported = tf.saved_model.load(path) 744 assert 3. == imported.v.numpy() 745 assert 6. == imported.f(x=tf.constant(2.)).numpy() 746 ``` 747 748 _Loading Keras models_ 749 750 Keras models are trackable, so they can be saved to SavedModel. The object 751 returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have 752 `.fit`, `.predict`, etc. methods). A few attributes and functions are still 753 available: `.variables`, `.trainable_variables` and `.__call__`. 754 755 ```python 756 model = tf.keras.Model(...) 757 tf.saved_model.save(model, path) 758 imported = tf.saved_model.load(path) 759 outputs = imported(inputs) 760 ``` 761 762 Use `tf.keras.models.load_model` to restore the Keras model. 763 764 _Importing SavedModels from TensorFlow 1.x_ 765 766 SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat 767 graph instead of `tf.function` objects. These SavedModels will be loaded with 768 the following attributes: 769 770 * `.signatures`: A dictionary mapping signature names to functions. 771 * `.prune(feeds, fetches) `: A method which allows you to extract 772 functions for new subgraphs. This is equivalent to importing the SavedModel 773 and naming feeds and fetches in a Session from TensorFlow 1.x. 774 775 ```python 776 imported = tf.saved_model.load(path_to_v1_saved_model) 777 pruned = imported.prune("x:0", "out:0") 778 pruned(tf.ones([])) 779 ``` 780 781 See `tf.compat.v1.wrap_function` for details. 782 * `.variables`: A list of imported variables. 783 * `.graph`: The whole imported graph. 784 * `.restore(save_path)`: A function that restores variables from a checkpoint 785 saved from `tf.compat.v1.Saver`. 786 787 _Consuming SavedModels asynchronously_ 788 789 When consuming SavedModels asynchronously (the producer is a separate 790 process), the SavedModel directory will appear before all files have been 791 written, and `tf.saved_model.load` will fail if pointed at an incomplete 792 SavedModel. Rather than checking for the directory, check for 793 "saved_model_dir/saved_model.pb". This file is written atomically as the last 794 `tf.saved_model.save` file operation. 795 796 Args: 797 export_dir: The SavedModel directory to load from. 798 tags: A tag or sequence of tags identifying the MetaGraph to load. Optional 799 if the SavedModel contains a single MetaGraph, as for those exported from 800 `tf.saved_model.save`. 801 options: `tf.saved_model.LoadOptions` object that specifies options for 802 loading. 803 804 Returns: 805 A trackable object with a `signatures` attribute mapping from signature 806 keys to functions. If the SavedModel was exported by `tf.saved_model.save`, 807 it also points to trackable objects, functions, debug info which it has been 808 saved. 809 810 Raises: 811 ValueError: If `tags` don't match a MetaGraph in the SavedModel. 812 """ 813 if isinstance(export_dir, os.PathLike): 814 export_dir = os.fspath(export_dir) 815 result = load_partial(export_dir, None, tags, options)["root"] 816 return result 817 818 819@tf_export("__internal__.saved_model.load_partial", v1=[]) 820def load_partial(export_dir, filters, tags=None, options=None): 821 """Partially load a SavedModel (saved from V2). 822 823 Similar to `tf.saved_model.load`, but with an additional argument that 824 lets you specify which nodes to load. 825 `tf.saved_model.load_partial(export_dir, ["root"])` and 826 `tf.saved_model.load(export_dir)` are equivalent. 827 828 Note: This only works for SavedModels saved with TensorFlow V2 from 829 `tf.saved_model.save` or Keras. This will not load SavedModels save from 830 the Estimator API. 831 832 In Tensorflow V2, SavedModel stores the **object graph** of the saved object. 833 The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras 834 layers, etc.) and edges that are the name of the attributes connecting the 835 objects. 836 837 *Example 1* 838 839 ``` 840 model = tf.Module() 841 model.child_layer = tf.Module() 842 model.child_layer.v = tf.Variable(5.) 843 tf.saved_model.save(model, '/tmp/model') 844 loaded = tf.__internal__.saved_model.load_partial( 845 ... '/tmp/model', 846 ... ['root.child_layer', 'root.child_layer.v']) 847 loaded['root.child_layer'].v.numpy() 848 5. 849 loaded['root.child_layer'].v is loaded['root.child_layer.v'] 850 True 851 852 *Example 2* 853 model = tf.Module() 854 model.child_layer = tf.Module() 855 model.child_layer.v = tf.Variable(5.) 856 >>> 857 tf.saved_model.save(model, '/tmp/model') 858 # Create a variable 859 new_variable = tf.Variable(0.) 860 loaded = tf.__internal__.saved_model.load_partial( 861 ... '/tmp/model', 862 ... {'root.child_layer': None, 'root.child_layer.v': new_variable}) 863 loaded['root.child_layer'].v.numpy() 864 5. 865 new_variable.numpy() 866 5. 867 ``` 868 869 **Loading under different distribution strategies** 870 You can load different parts of the model under different distribution 871 strategies. Note that this is very experimental so use with care. 872 873 ``` 874 model = tf.Module() 875 model.layer_1 = tf.Module() 876 model.layer_1.v = tf.Variable(5.) 877 model.layer_2 = tf.Module() 878 model.layer_2.v = tf.Variable(7.) 879 tf.saved_model.save(model, '/tmp/model') 880 # Load with no strategy 881 loaded = tf.__internal__.saved_model.load_partial( 882 ... '/tmp/model', 883 ... ['root.layer_1']) 884 loaded['root.layer_1'].v 885 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0> 886 strategy = tf.distribute.MirroredStrategy() 887 with strategy.scope(): 888 ... loaded2 = tf.__internal__.saved_model.load_partial( 889 ... '/tmp/model', 890 ... ['root.layer_2']) 891 loaded2['root.layer_2'].v 892 MirroredVariable:{ 893 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0> 894 } 895 ``` 896 897 Args: 898 export_dir: The SavedModel directory to load from. 899 filters: A list or dictionary where each element or key is a string 900 path to nodes that should be loaded. Node paths consist of all the child 901 attribute names to reach that node in the form: `root.{attribute_name}`. 902 The loader will load all of the specified nodes and their recursive 903 descendants. When this option is defined, the loader will return a 904 dictionary mapping the node paths to the loaded objects. 905 tags: A tag or sequence of tags identifying the MetaGraph to load. Optional 906 if the SavedModel contains a single MetaGraph, as for those exported from 907 `tf.saved_model.save`. 908 options: `tf.saved_model.LoadOptions` object that specifies options for 909 loading. 910 911 Returns: 912 A dictionary mapping node paths from the filter to loaded objects. 913 """ 914 options = options or load_options.LoadOptions() 915 if tags is not None and not isinstance(tags, set): 916 # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered 917 # sequences for nest.flatten, so we put those through as-is. 918 tags = nest.flatten(tags) 919 saved_model_proto, debug_info = ( 920 loader_impl.parse_saved_model_with_debug_info(export_dir)) 921 922 if (len(saved_model_proto.meta_graphs) == 1 and 923 saved_model_proto.meta_graphs[0].HasField("object_graph_def")): 924 metrics.IncrementReadApi(_LOAD_V2_LABEL) 925 meta_graph_def = saved_model_proto.meta_graphs[0] 926 # tensor_content field contains raw bytes in litle endian format 927 # which causes problems when loaded on big-endian systems 928 # requiring byteswap 929 if sys.byteorder == "big": 930 saved_model_utils.swap_function_tensor_content(meta_graph_def, "little", 931 "big") 932 if (tags is not None 933 and set(tags) != set(meta_graph_def.meta_info_def.tags)): 934 raise ValueError( 935 f"Got an incompatible argument to `tags`: {tags}. The SavedModel at " 936 f"{export_dir} has one MetaGraph with tags " 937 f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, " 938 "pass 'None', or pass matching tags.") 939 object_graph_proto = meta_graph_def.object_graph_def 940 941 ckpt_options = checkpoint_options.CheckpointOptions( 942 experimental_io_device=options.experimental_io_device) 943 with ops.init_scope(): 944 try: 945 loader = Loader(object_graph_proto, saved_model_proto, export_dir, 946 ckpt_options, options, filters) 947 except errors.NotFoundError as err: 948 raise FileNotFoundError( 949 str(err) + "\n You may be trying to load on a different device " 950 "from the computational device. Consider setting the " 951 "`experimental_io_device` option in `tf.saved_model.LoadOptions` " 952 "to the io_device such as '/job:localhost'.") 953 root = loader.get(0) 954 root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info) 955 root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version 956 root.tensorflow_git_version = ( 957 meta_graph_def.meta_info_def.tensorflow_git_version) 958 metrics.IncrementRead(write_version="2") 959 else: 960 if filters: 961 raise ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" 962 " version) cannot be loaded with node filters.") 963 with ops.init_scope(): 964 root = load_v1_in_v2.load(export_dir, tags) 965 root.graph_debug_info = debug_info 966 967 if filters: 968 return {node_id: loader.get(node_id) for node_id in filters} 969 else: 970 return {"root": root} 971