xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/load.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Import a trackable object from a SavedModel."""
16
17import collections
18import functools
19import os
20import sys
21
22from tensorflow.core.protobuf import graph_debug_info_pb2
23from tensorflow.python.checkpoint import checkpoint
24from tensorflow.python.checkpoint import checkpoint_options
25from tensorflow.python.checkpoint import graph_view
26from tensorflow.python.checkpoint import restore
27from tensorflow.python.distribute import distribute_utils
28from tensorflow.python.distribute import distribution_strategy_context as ds_context
29from tensorflow.python.distribute import values_util
30from tensorflow.python.eager import context
31from tensorflow.python.eager import function
32from tensorflow.python.eager import function_saved_model_utils
33from tensorflow.python.framework import config
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import ops
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import control_flow_ops
40from tensorflow.python.ops import lookup_ops
41from tensorflow.python.ops import resource_variable_ops
42from tensorflow.python.ops import variables
43from tensorflow.python.saved_model import function_deserialization
44from tensorflow.python.saved_model import load_options
45from tensorflow.python.saved_model import load_v1_in_v2
46from tensorflow.python.saved_model import loader_impl
47from tensorflow.python.saved_model import registration
48from tensorflow.python.saved_model import revived_types
49from tensorflow.python.saved_model import utils_impl as saved_model_utils
50from tensorflow.python.saved_model.pywrap_saved_model import metrics
51from tensorflow.python.trackable import asset
52from tensorflow.python.trackable import autotrackable
53from tensorflow.python.trackable import base
54from tensorflow.python.trackable import data_structures
55from tensorflow.python.trackable import resource
56from tensorflow.python.trackable import trackable_utils
57from tensorflow.python.training.saving import saveable_object_util
58from tensorflow.python.util import nest
59from tensorflow.python.util.tf_export import tf_export
60
61# API label for SavedModel metrics.
62_LOAD_V2_LABEL = "load_v2"
63# Built-in registrations use the "oneof kind" field in the SavedObject proto,
64# instead of "registered_name" field. The "kind" field has almost the same
65# functionality as the registered_name, but only contains built-in TensorFlow
66# types (like variable, functions, assets).
67_BUILT_IN_REGISTRATIONS = {
68    "asset": asset.Asset,
69    "resource": resource.RestoredResource,
70    "constant": function_saved_model_utils.TrackableConstant}
71
72
73def _unused_handle():
74  """Returns a placeholder as a handle that is not supposed to be accessed."""
75  error_message = ("Trying to access a placeholder that is not supposed to be "
76                   "executed. This means you are executing a graph generated "
77                   "from the cross-replica context in an in-replica context.")
78  save_error_message = (
79      "It seems that you are trying to save a "
80      "tf.types.experimental.ConcreteFunction that involves a distributed "
81      "model, and the model contains parts that are loaded form a SavedModel. "
82      "It's not supported to save such tf.types.experimental.ConcreteFunction. "
83      "Try saving a tf.function with input_signature instead, and file a bug if"
84      " there are still issues.")
85
86  assert_op = control_flow_ops.Assert(
87      array_ops.placeholder_with_default(False, shape=()), [error_message])
88  if (not context.executing_eagerly()
89     ) and ops.get_default_graph().building_function:
90    ops.get_default_graph().mark_as_unsaveable(save_error_message)
91
92  with ops.control_dependencies([assert_op]):
93    return array_ops.placeholder(dtype=dtypes.resource)
94
95
96class _WrapperFunction(function.ConcreteFunction):
97  """A class wraps a concrete function to handle different distributed contexts.
98
99  The reason for wrapping a concrete function is because the _captured_inputs
100  fields used for in-replica context and cross-replica context are different.
101  When `load()` is called from within a tf.distribute.strategy scope, the
102  captured inputs are distributed variables. When using these distributed
103  variables during calling the function, we need different approaches when it is
104  in-replica and when it is not in-replica. When it is in replica, naturally we
105  should use the corresponding component of the distributed variable; when it is
106  not in-replica, calling the function should mean that it is constructing a
107  graph that is not actually going to be used. A typical use case is when
108  constructing a functional model. In this case, return a placeholder with a
109  control dependency to ensure that is never accessed.
110  """
111
112  def __init__(self, concrete_function):
113    # Shallow copy the concrete_function
114    self.__dict__.update(vars(concrete_function))
115
116  def _call_flat(self, args, captured_inputs, cancellation_manager=None):
117
118    def get_handle(x):
119      return x.handle if distribute_utils.is_distributed_variable(x) else x
120
121    def get_unused_handle(x):
122      return _unused_handle() if distribute_utils.is_distributed_variable(x)   \
123          else x
124
125    if (ds_context.get_replica_context() is not None or
126        values_util.is_saving_non_distributed()):
127      # If we're in the replica context or are saving a non-distributed version
128      # of the model, we resolve the captured variables to the corresponding
129      # resource handle. In both situation we call var.handle, but it has
130      # different behavior. In the replica context, var.handle resolves the
131      # replica local variable handle if the variable is replicated. When saving
132      # a non-distributed version of the model, var.handle resolves to the
133      # primary variable handle, since we only save one copy of a replicated
134      # variable.
135      captured_inputs = list(map(get_handle, captured_inputs))
136    else:  # cross-replica context
137      captured_inputs = list(map(get_unused_handle, captured_inputs))
138    return super(_WrapperFunction, self)._call_flat(args, captured_inputs,
139                                                    cancellation_manager)
140
141
142class Loader(object):
143  """Helper class to load an object-based SavedModel."""
144
145  def __init__(self, object_graph_proto, saved_model_proto, export_dir,
146               ckpt_options, save_options, filters):
147    meta_graph = saved_model_proto.meta_graphs[0]
148    self._asset_file_def = meta_graph.asset_file_def
149    self._operation_attributes = {
150        node.name: node.attr for node in meta_graph.graph_def.node}
151    self._proto = object_graph_proto
152    self._export_dir = export_dir
153    self._concrete_functions = (
154        function_deserialization.load_function_def_library(
155            library=meta_graph.graph_def.library,
156            saved_object_graph=self._proto,
157            wrapper_function=_WrapperFunction))
158    # Store a set of all concrete functions that have been set up with
159    # captures.
160    self._restored_concrete_functions = set()
161    self._checkpoint_options = ckpt_options
162    self._save_options = save_options
163
164    self._pretty_printer = checkpoint.ObjectGraphProtoPrettyPrinter(self._proto)
165
166    # Stores user-defined node_filters argument.
167    self._node_filters = filters
168    # Stores map of string paths to integers.
169    self._node_path_to_id = self._convert_node_paths_to_ints()
170    self._loaded_nodes = {}
171    if isinstance(filters, dict):
172      # If node_filters is a dict, then the values may contain already created
173      # trackable objects. In this case, create a dictionary mapping node IDs to
174      # the already created nodes. This dict will be updated in
175      # `_retrieve_all_filtered_nodes` with tracked children.
176      for node_path, node in filters.items():
177        if isinstance(node, tuple):
178          self._loaded_nodes[self._node_path_to_id[node_path]] = node
179        else:
180          self._loaded_nodes[self._node_path_to_id[node_path]] = (node, setattr)
181
182    # Get a list of all integer node ids to load, or None if all nodes should be
183    # loaded. This list includes ids of child nodes.
184    self._filtered_nodes = self._retrieve_all_filtered_nodes()
185
186    # Order all nodes or filtered nodes using the dependencies.
187    self._ordered_node_ids = self._generate_ordered_node_ids()
188
189    self._load_all()
190
191    if not save_options.experimental_skip_checkpoint:
192      self._restore_checkpoint()
193    for node in self._nodes:
194      if isinstance(node, resource.CapturableResource):
195        init_op = node._initialize()  # pylint: disable=protected-access
196        if not context.executing_eagerly():
197          ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
198
199  def _convert_node_paths_to_ints(self):
200    """Maps all string node paths in node_filters to the int node ids."""
201    if self._node_filters is None:
202      return None
203    path_to_int = {}
204    for node_id in self._node_filters:
205      int_node_id = None
206      if isinstance(node_id, str):
207        node_path = node_id.split(".")
208        if node_path[0] != "root":
209          raise ValueError(
210              "When passing string identifiers to node_filters, the first name"
211              f" must be root. Received {node_path[0]}.")
212        int_node_id = 0
213        for n, name in enumerate(node_path[1:]):
214          int_node_id = self._find_node_child(
215              int_node_id, name, ".".join(node_path[:n+2]))
216        path_to_int[node_id] = int_node_id
217      else:
218        raise TypeError("Elements in node_filters must be strings.")
219    return path_to_int
220
221  def _retrieve_all_filtered_nodes(self):
222    """Traverses through the object graph to get the IDs of all nodes to load.
223
224    As a side-effect, if node_filters is a dictionary that contains already-
225    created objects, then the children tracked by those objects will be
226    added to node_filters.
227
228    Returns:
229      List of all nodes to load, or None if all nodes should be loaded.
230
231    """
232    if self._node_filters is None:
233      return None  # All nodes should be loaded.
234
235    all_filtered_nodes = set()
236    nodes_to_visit = list(self._node_filters)
237
238    while nodes_to_visit:
239      node_path = nodes_to_visit.pop(0)
240      node_id = self._node_path_to_id[node_path]
241      if node_id in all_filtered_nodes:
242        continue
243      all_filtered_nodes.add(node_id)
244
245      node, setter = self._loaded_nodes.get(node_id, (None, None))
246      if node is not None:
247        if not isinstance(node, base.Trackable):
248          raise TypeError(
249              "Error when processing dictionary values passed to nodes_to_load."
250              f"Object at {node_path} is expected to be a checkpointable (i.e. "
251              "'trackable') TensorFlow object (e.g. tf.Variable, tf.Module or "
252              "Keras layer).")
253        node._maybe_initialize_trackable()  # pylint: disable=protected-access
254
255      for reference in self._proto.nodes[node_id].children:
256        child_object, _ = self._loaded_nodes.get(
257            reference.node_id, (None, None))
258
259        # See if node already tracks the child reference, in which case add the
260        # child to the loaded_nodes dict.
261        if child_object is None and node is not None:
262          child_object = node._lookup_dependency(reference.local_name)  # pylint: disable=protected-access
263          if isinstance(child_object, data_structures.TrackableDataStructure):
264            # Make setattr a noop to avoid overwriting already existing data
265            # structures.
266            setter = lambda *args: None
267
268            self._loaded_nodes[reference.node_id] = (child_object, setter)
269
270        child_path = "{}.{}".format(node_path, reference.local_name)
271        self._node_path_to_id[child_path] = reference.node_id
272        nodes_to_visit.append(child_path)
273
274    if 0 in all_filtered_nodes:
275      return None
276    return all_filtered_nodes
277
278  def _find_node_child(self, node_id, child_name, path):
279    for reference in self._proto.nodes[node_id].children:
280      if reference.local_name == child_name:
281        return reference.node_id
282    raise ValueError(f"Unable to find node {path}.")
283
284  def _load_all(self):
285    """Loads all nodes and functions from the SavedModel and their edges."""
286    self._load_nodes()
287    self._load_edges()
288
289    # Set up concrete functions that aren't part of the object graph
290    # (e.g. gradient functions)
291    self._setup_remaining_functions()
292    self._load_checkpoint_save_and_restore_functions()
293
294  def _load_checkpoint_save_and_restore_functions(self):
295    """Restores the checkpoint-related save/restore functions to all nodes."""
296    for node_id, proto in self._iter_all_nodes():
297      node = self.get(node_id)
298      if proto.saveable_objects.keys() == {
299          trackable_utils.SERIALIZE_TO_TENSORS_NAME}:
300        # Restore Trackable serialize- and restore-from-tensor functions.
301        assert len(proto.saveable_objects) == 1
302        saveable_object_proto = next(iter(proto.saveable_objects.values()))
303        save_fn_id = saveable_object_proto.save_function
304        restore_fn_id = saveable_object_proto.restore_function
305        node._serialize_to_tensors = self.get(save_fn_id)  # pylint: disable=protected-access
306        node._restore_from_tensors = self.get(restore_fn_id)  # pylint: disable=protected-access
307      else:
308        # Restore legacy SaveableObject functions.
309        saveable_fn_by_name = {}
310        for name, saveable_object_proto in proto.saveable_objects.items():
311          save_fn_id = saveable_object_proto.save_function
312          restore_fn_id = saveable_object_proto.restore_function
313          saveable_fn_by_name[name] = (self.get(save_fn_id),
314                                       self.get(restore_fn_id))
315
316        node._self_saveable_object_factories = (  # pylint: disable=protected-access
317            saveable_object_util.recreate_saveable_objects(saveable_fn_by_name))
318
319  def _load_edges(self):
320    """Adds edges from objects to other objects and functions."""
321    for node_id, object_proto in self._iter_all_nodes():
322      self._add_object_graph_edges(object_proto, node_id)
323
324    # If root object isn't loaded, then create edges from the root for
325    # checkpoint compatibility.
326    if self._filtered_nodes is not None and 0 not in self._filtered_nodes:
327      root = self.get(0)
328      for node_path in self._node_filters:
329        loaded_node = self._nodes[self._node_path_to_id[node_path]]
330        path = node_path.split(".")
331        current_node = root
332        for name in path[1:-1]:
333          if not hasattr(current_node, name):
334            setattr(current_node, name, self._recreate_base_user_object()[0])
335          current_node = getattr(current_node, name)
336        if not hasattr(current_node, path[-1]):
337          setattr(current_node, path[-1], loaded_node)
338
339  def _add_object_graph_edges(self, proto, node_id):
340    """Adds edges from an object to its children."""
341    obj = self._nodes[node_id]
342    setter = self._node_setters[node_id]
343
344    for reference in proto.children:
345      setter(obj, reference.local_name, self._nodes[reference.node_id])
346      # Note: if an object has an attribute `__call__` add a class method
347      # that allows `obj()` syntax to work. This is done per-instance to
348      # allow `callable` to be used to find out if an object is callable.
349      if reference.local_name == "__call__" and not callable(obj):
350        setattr(type(obj), "__call__", _call_attribute)
351
352  def _setup_remaining_functions(self):
353    concrete_function_names = sorted(self._proto.concrete_functions.keys())
354    for name in concrete_function_names:
355      if name in self._restored_concrete_functions:
356        continue
357      self._setup_function_captures(name, self._nodes)
358
359  def _setup_function_captures(self, concrete_function_name, nodes):
360    """Setup captures and variables in a restored function."""
361    if concrete_function_name in self._restored_concrete_functions:
362      return
363    self._restored_concrete_functions.add(concrete_function_name)
364    concrete_function = self._concrete_functions[concrete_function_name]
365    proto = self._proto.concrete_functions[concrete_function_name]
366    inputs = [nodes[node_id] for node_id in proto.bound_inputs]
367    function_saved_model_utils.restore_captures(concrete_function, inputs)
368
369  def _initialize_loaded_nodes(self):
370    nodes = {}
371    node_setters = {}
372    for node_id, (node, setter) in self._loaded_nodes.items():
373      nodes[node_id] = node
374      node_setters[node_id] = setter
375    return nodes, node_setters
376
377  def _get_node_dependencies(self, proto):
378    """Returns a dictionary of all dependencies of an object.
379
380    Args:
381      proto: A SavedObject proto.
382
383    Returns:
384      Dict mapping string dependency name *or* int node id to the node id.
385      The int node id key is used for mapping function captures.
386    """
387    dependencies = {ref.local_name: ref.node_id for ref in proto.dependencies}
388    kind = proto.WhichOneof("kind")
389    if kind == "function":
390      concrete_functions = proto.function.concrete_functions
391      for fn_name in concrete_functions:
392        for bound_input in self._proto.concrete_functions[fn_name].bound_inputs:
393          dependencies[bound_input] = bound_input
394    elif kind == "bare_concrete_function":
395      fn_name = proto.bare_concrete_function.concrete_function_name
396      for bound_input in self._proto.concrete_functions[fn_name].bound_inputs:
397        dependencies[bound_input] = bound_input
398    elif kind == "resource":
399      # Make sure that the resource creator is listed as a dependency.
400      for child in proto.children:
401        if child.local_name == "_create_resource":
402          dependencies["_create_resource"] = child.node_id
403    return dependencies
404
405  def _generate_ordered_node_ids(self):
406    """Orders the node ids so that dependencies appear first."""
407    if self._filtered_nodes is None:
408      unordered_ids = range(len(self._proto.nodes))
409    else:
410      unordered_ids = list(self._filtered_nodes)
411
412    # Maps node ids -> list of dependencies (ids of other nodes that must be
413    # loaded before it).
414    dependency_map = collections.defaultdict(list)
415    for node_id in unordered_ids:
416      deps = dependency_map[node_id]
417      if self._loaded_nodes.get(node_id) is not None:
418        # Deps are only used if the node has not been created.
419        continue
420      proto = self._proto.nodes[node_id]
421      for dep in set(self._get_node_dependencies(proto).values()):
422        deps.append(dep)
423        if self._filtered_nodes is not None and dep not in self._filtered_nodes:
424          raise ValueError(
425              "Unable to partially load SavedModel since the specified filter "
426              "does not include all required objects for loading (e.g. "
427              "variables used in functions or deserialization dependencies). "
428              "Please include this path in the filter: "
429              f"{self._pretty_printer.node_names[dep]}")
430
431      # Add optimizer slot variable to dependency map.
432      prev_slot = None
433      for slot_variable_proto in proto.slot_variables:
434        slot_variable_node_id = slot_variable_proto.slot_variable_node_id
435        # The optimizer and original variable must be created before the slot
436        # variable, since the slot variable is generated using the Optimizer's
437        # add_slot API.
438        slot_deps = dependency_map[slot_variable_node_id]
439        slot_deps.append(node_id)
440        slot_deps.append(slot_variable_proto.original_variable_node_id)
441
442        if prev_slot is not None:
443          # Add previous slot to deps so that the optimizer slot variables are
444          # added in order. The ordering is needed because the slot name and
445          # variable are both added to ordered lists, which are exposed to the
446          # user via `Optimizer.get_slot_names()` and `Optimizer.weights`.
447          # TODO(kathywu): Maybe enforce some sort of deterministic ordering in
448          # `order_by_dependency` to avoid doing this?
449          slot_deps.append(prev_slot)
450        prev_slot = slot_variable_node_id
451    try:
452      return list(trackable_utils.order_by_dependency(dependency_map))
453    except trackable_utils.CyclicDependencyError:
454      # This should not happen since there is already a validation for cycles
455      # when saving, but raise an error just in case.
456      raise ValueError("Encountered a cycle in the deserialization dependencies"
457                       "in the SavedModel. This is extremely unexpected, please"
458                       "file a bug and make sure you are not manually modifying"
459                       " the SavedModel.")
460
461  def _iter_all_nodes(self):
462    for node_id in self._ordered_node_ids:
463      yield node_id, self._proto.nodes[node_id]
464
465  def _load_nodes(self):
466    """Load all saved objects."""
467    # `nodes` maps from node ids to recreated objects
468    # `node_setters` maps from node ids to setter functions
469    # (same signature as setattr) for setting children.
470    nodes, node_setters = self._initialize_loaded_nodes()
471
472    # Figure out which objects are slot variables. These objects are created
473    # with Optimizer.add_slot rather than _recreate_variable.
474    # Maps slot node id -> optimizer node id, SlotVariableReference proto
475    slot_variable_node_ids = {}
476
477    for node_id, proto in self._iter_all_nodes():
478      for slot_variable_proto in proto.slot_variables:
479        slot_variable_node_id = slot_variable_proto.slot_variable_node_id
480        slot_variable_node_ids[slot_variable_node_id] = (node_id,
481                                                         slot_variable_proto)
482
483    # Re-create everything.
484    for node_id, proto in self._iter_all_nodes():
485      if nodes.get(node_id) is not None:
486        continue
487      elif node_id in slot_variable_node_ids:
488        # Use the public Optimizer interface when creating slot variables.
489        optimizer_node_id, slot_variable_proto = slot_variable_node_ids[node_id]
490        optimizer_object = nodes[optimizer_node_id]
491        optimized_variable = nodes[
492            slot_variable_proto.original_variable_node_id]
493        slot_variable = optimizer_object.add_slot(
494            var=optimized_variable,
495            slot_name=slot_variable_proto.slot_name)
496        nodes[slot_variable_proto.slot_variable_node_id] = slot_variable
497        node_setters[slot_variable_proto.slot_variable_node_id] = setattr
498      else:
499        node, setter = self._recreate(proto, node_id, nodes)
500        nodes[node_id] = node
501        node_setters[node_id] = setter
502
503    # If root object is not loaded, add a dummy root object for checkpoint
504    # compatibility.
505    if 0 not in nodes:
506      nodes[0] = self._recreate_base_user_object()[0]
507
508    self._nodes = [nodes.get(node_id)
509                   for node_id in range(len(self._proto.nodes))]
510    self._node_setters = node_setters
511
512  def _restore_checkpoint(self):
513    """Load state from checkpoint into the deserialized objects."""
514    variables_path = saved_model_utils.get_variables_path(self._export_dir)
515    # TODO(b/205010730): Clean use of private methods of TrackableSaver.
516    # pylint: disable=protected-access
517    saver = checkpoint.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
518    with ops.device("CPU"):
519      saver._file_prefix_placeholder = constant_op.constant(variables_path)
520    if self._save_options.allow_partial_checkpoint:
521      load_status = saver.restore(variables_path,
522                                  self._checkpoint_options).expect_partial()
523      load_status.assert_nontrivial_match()
524    else:
525      load_status = saver.restore(variables_path, self._checkpoint_options)
526      load_status.assert_existing_objects_matched()
527    ckpt = load_status._checkpoint
528
529    if not context.executing_eagerly():
530      # When running in eager mode, the `restore` call above has already run and
531      # restored the state of trackables, and calling `position.restore_ops()`
532      # would re-run the restore. In graph mode, that will return a cached list
533      # of ops that must run to restore the object on that position. We have to
534      # wire them in the initializers of the objects so that they get
535      # initialized properly when using common practices (e.g. the ones used by
536      # ManagedSession) without further user action.
537      for object_id, obj in dict(ckpt.object_by_proto_id).items():
538        position = restore.CheckpointPosition(checkpoint=ckpt,
539                                              proto_id=object_id)
540        registered_saver = position.get_registered_saver_name()
541        if registered_saver:
542          raise NotImplementedError(
543              "Loading a SavedModel that uses registered checkpoint saver is "
544              f"not supported in graph mode. The loaded object {obj} uses the "
545              f"saver registered with the name {registered_saver}.")
546
547        restore_ops = position.restore_ops()
548        if restore_ops:
549          if resource_variable_ops.is_resource_variable(obj):
550            if len(restore_ops) == 1:
551              obj._initializer_op = restore_ops[0]
552            else:
553              obj._initializer_op = control_flow_ops.group(*restore_ops)
554          elif isinstance(obj, lookup_ops.LookupInterface):
555            # We don't need to check for eager execution here, since this code
556            # path should only be taken if we are restoring in graph mode.
557            ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops)
558          else:
559            raise NotImplementedError(
560                f"Unable to restore state of object {obj} from the checkpoint.")
561
562  def adjust_debug_info_func_names(self, debug_info):
563    """Rewrite func names in the debug info by using the concrete func names."""
564    output_debug_info = graph_debug_info_pb2.GraphDebugInfo()
565    output_debug_info.files[:] = debug_info.files
566    for key in debug_info.traces:
567      node, func = key.split("@")
568      new_func = ""
569      if func in self._concrete_functions:
570        new_func = self._concrete_functions[func].function_def.signature.name
571      output_debug_info.traces[node + "@" + new_func].CopyFrom(
572          debug_info.traces[key])
573    return output_debug_info
574
575  def get(self, node_id):
576    if isinstance(node_id, str):
577      node_id = self._node_path_to_id[node_id]
578    return self._nodes[node_id]
579
580  def _recreate(self, proto, node_id, nodes):
581    """Creates a Python object from a SavedObject protocol buffer.
582
583    Args:
584      proto: a SavedObject proto
585      node_id: int, the index of this object in the SavedObjectGraph node list.
586      nodes: dict mapping int node_ids -> created objects.
587
588    Returns:
589      The recreated object, and the set-attribute function for reconnecting
590      the trackable children.
591    """
592    registered_class = registration.get_registered_class(proto.registered_name)
593    if registered_class is None:
594      registered_class = _BUILT_IN_REGISTRATIONS.get(proto.WhichOneof("kind"))
595
596    dependencies = {}
597    for key, dep_node_id in self._get_node_dependencies(proto).items():
598      dependencies[key] = nodes[dep_node_id]
599
600    if registered_class:
601      obj = registered_class._deserialize_from_proto(  # pylint: disable=protected-access
602          proto=proto.serialized_user_proto,
603          object_proto=proto,
604          dependencies=dependencies,
605          export_dir=self._export_dir,
606          asset_file_def=self._asset_file_def,
607          operation_attributes=self._operation_attributes)
608      if isinstance(obj, base.Trackable):
609        setter = type(obj)._add_trackable_child  # pylint: disable=protected-access
610      else:
611        # Returned object may be non-Trackable (e.g. when restoring captures).
612        setter = setattr
613      return obj, setter
614    else:
615      return self._recreate_default(proto, node_id, dependencies)
616
617  def _recreate_default(self, proto, node_id, deps):
618    """Creates a Python object from a SavedObject protocol buffer."""
619    factory = {
620        "user_object": (
621            lambda: self._recreate_user_object(proto.user_object, node_id)),
622        "function": lambda: self._recreate_function(proto.function, deps),
623        "bare_concrete_function": functools.partial(
624            self._recreate_bare_concrete_function,
625            proto=proto.bare_concrete_function, dependencies=deps),
626        "variable": lambda: self._recreate_variable(proto.variable),
627        "captured_tensor": functools.partial(
628            self._get_tensor_from_fn, proto.captured_tensor),
629    }
630    kind = proto.WhichOneof("kind")
631    if kind not in factory:
632      raise ValueError(f"Unknown SavedObject type: {kind}. Expected one of "
633                       f"{list(factory.keys())}.")
634    return factory[kind]()
635
636  def _recreate_user_object(self, proto, node_id):
637    """Instantiates a SavedUserObject."""
638    looked_up = revived_types.deserialize(proto)
639    if looked_up is None:
640      return self._recreate_base_user_object(proto, node_id)
641    return looked_up
642
643  def _recreate_base_user_object(self, proto=None, node_id=None):
644    del proto, node_id
645    # Note: each user object has its own class. This allows making each one
646    # individually callable by adding a `__call__` method to the classes of
647    # the objects instances that have a `__call__` property.
648
649    class _UserObject(autotrackable.AutoTrackable):
650      pass
651
652    return _UserObject(), setattr
653
654  def _recreate_function(self, proto, dependencies):
655    fn = function_deserialization.recreate_function(
656        proto, self._concrete_functions)
657    for name in proto.concrete_functions:
658      self._setup_function_captures(name, dependencies)
659    return fn, setattr
660
661  def _recreate_bare_concrete_function(self, proto, dependencies):
662    fn = function_deserialization.setup_bare_concrete_function(
663        proto, self._concrete_functions)
664    self._setup_function_captures(proto.concrete_function_name, dependencies)
665    return fn, setattr
666
667  def _recreate_variable(self, proto):
668    name = proto.name if proto.name else None
669    if name is not None:
670      dbg_name = name
671    else:
672      dbg_name = "<variable loaded from saved model>"
673    synchronization, aggregation, trainable = (
674        variables.validate_synchronization_aggregation_trainable(
675            proto.synchronization, proto.aggregation, proto.trainable,
676            name=dbg_name))
677
678    def uninitialized_variable_creator(next_creator, **kwargs):
679      """A variable creator that creates uninitialized variables."""
680      del next_creator
681      return resource_variable_ops.UninitializedVariable(**kwargs)
682
683    # Create a variable_creator_scope that creates uninitialized variables with
684    # a lower priority such that a potential distributed variable_creator_scope
685    # can take precedence.
686    with ops.get_default_graph()._variable_creator_scope(  # pylint: disable=protected-access
687        uninitialized_variable_creator,
688        priority=50):
689      saved_device = proto.device
690      load_with_device = (
691          self._save_options.experimental_variable_policy
692          ._save_variable_devices() and config.get_soft_device_placement() and
693          saved_device)
694      if load_with_device:
695        with ops.device(saved_device):
696          return variables.Variable(
697              shape=proto.shape,
698              dtype=proto.dtype,
699              name=name,
700              trainable=trainable,
701              synchronization=synchronization,
702              aggregation=aggregation), setattr
703      else:
704        return variables.Variable(
705            shape=proto.shape,
706            dtype=proto.dtype,
707            name=name,
708            trainable=trainable,
709            synchronization=synchronization,
710            aggregation=aggregation), setattr
711
712  def _get_tensor_from_fn(self, proto):
713    outer_graph = self._concrete_functions[proto.concrete_function].graph
714    captured_tensor = outer_graph.get_tensor_by_name(proto.name)
715    return captured_tensor, setattr
716
717
718def _call_attribute(instance, *args, **kwargs):
719  return instance.__call__(*args, **kwargs)
720
721
722@tf_export("saved_model.load", v1=["saved_model.load_v2"])
723def load(export_dir, tags=None, options=None):
724  """Load a SavedModel from `export_dir`.
725
726  Signatures associated with the SavedModel are available as functions:
727
728  ```python
729  imported = tf.saved_model.load(path)
730  f = imported.signatures["serving_default"]
731  print(f(x=tf.constant([[1.]])))
732  ```
733
734  Objects exported with `tf.saved_model.save` additionally have trackable
735  objects and functions assigned to attributes:
736
737  ```python
738  exported = tf.train.Checkpoint(v=tf.Variable(3.))
739  exported.f = tf.function(
740      lambda x: exported.v * x,
741      input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
742  tf.saved_model.save(exported, path)
743  imported = tf.saved_model.load(path)
744  assert 3. == imported.v.numpy()
745  assert 6. == imported.f(x=tf.constant(2.)).numpy()
746  ```
747
748  _Loading Keras models_
749
750  Keras models are trackable, so they can be saved to SavedModel. The object
751  returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have
752  `.fit`, `.predict`, etc. methods). A few attributes and functions are still
753  available: `.variables`, `.trainable_variables` and `.__call__`.
754
755  ```python
756  model = tf.keras.Model(...)
757  tf.saved_model.save(model, path)
758  imported = tf.saved_model.load(path)
759  outputs = imported(inputs)
760  ```
761
762  Use `tf.keras.models.load_model` to restore the Keras model.
763
764  _Importing SavedModels from TensorFlow 1.x_
765
766  SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
767  graph instead of `tf.function` objects. These SavedModels will be loaded with
768  the following attributes:
769
770  * `.signatures`: A dictionary mapping signature names to functions.
771  * `.prune(feeds, fetches) `: A method which allows you to extract
772    functions for new subgraphs. This is equivalent to importing the SavedModel
773    and naming feeds and fetches in a Session from TensorFlow 1.x.
774
775    ```python
776    imported = tf.saved_model.load(path_to_v1_saved_model)
777    pruned = imported.prune("x:0", "out:0")
778    pruned(tf.ones([]))
779    ```
780
781    See `tf.compat.v1.wrap_function` for details.
782  * `.variables`: A list of imported variables.
783  * `.graph`: The whole imported graph.
784  * `.restore(save_path)`: A function that restores variables from a checkpoint
785    saved from `tf.compat.v1.Saver`.
786
787  _Consuming SavedModels asynchronously_
788
789  When consuming SavedModels asynchronously (the producer is a separate
790  process), the SavedModel directory will appear before all files have been
791  written, and `tf.saved_model.load` will fail if pointed at an incomplete
792  SavedModel. Rather than checking for the directory, check for
793  "saved_model_dir/saved_model.pb". This file is written atomically as the last
794  `tf.saved_model.save` file operation.
795
796  Args:
797    export_dir: The SavedModel directory to load from.
798    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
799      if the SavedModel contains a single MetaGraph, as for those exported from
800      `tf.saved_model.save`.
801    options: `tf.saved_model.LoadOptions` object that specifies options for
802      loading.
803
804  Returns:
805    A trackable object with a `signatures` attribute mapping from signature
806    keys to functions. If the SavedModel was exported by `tf.saved_model.save`,
807    it also points to trackable objects, functions, debug info which it has been
808    saved.
809
810  Raises:
811    ValueError: If `tags` don't match a MetaGraph in the SavedModel.
812  """
813  if isinstance(export_dir, os.PathLike):
814    export_dir = os.fspath(export_dir)
815  result = load_partial(export_dir, None, tags, options)["root"]
816  return result
817
818
819@tf_export("__internal__.saved_model.load_partial", v1=[])
820def load_partial(export_dir, filters, tags=None, options=None):
821  """Partially load a SavedModel (saved from V2).
822
823  Similar to `tf.saved_model.load`, but with an additional argument that
824  lets you specify which nodes to load.
825  `tf.saved_model.load_partial(export_dir, ["root"])` and
826  `tf.saved_model.load(export_dir)` are equivalent.
827
828  Note: This only works for SavedModels saved with TensorFlow V2 from
829  `tf.saved_model.save` or Keras. This will not load SavedModels save from
830  the Estimator API.
831
832  In Tensorflow V2, SavedModel stores the **object graph** of the saved object.
833  The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras
834  layers, etc.) and edges that are the name of the attributes connecting the
835  objects.
836
837  *Example 1*
838
839  ```
840  model = tf.Module()
841  model.child_layer = tf.Module()
842  model.child_layer.v = tf.Variable(5.)
843  tf.saved_model.save(model, '/tmp/model')
844  loaded = tf.__internal__.saved_model.load_partial(
845  ...   '/tmp/model',
846  ...   ['root.child_layer', 'root.child_layer.v'])
847  loaded['root.child_layer'].v.numpy()
848  5.
849  loaded['root.child_layer'].v is loaded['root.child_layer.v']
850  True
851
852  *Example 2*
853  model = tf.Module()
854  model.child_layer = tf.Module()
855  model.child_layer.v = tf.Variable(5.)
856  >>>
857  tf.saved_model.save(model, '/tmp/model')
858  # Create a variable
859  new_variable = tf.Variable(0.)
860  loaded = tf.__internal__.saved_model.load_partial(
861  ...   '/tmp/model',
862  ...   {'root.child_layer': None, 'root.child_layer.v': new_variable})
863  loaded['root.child_layer'].v.numpy()
864  5.
865  new_variable.numpy()
866  5.
867  ```
868
869  **Loading under different distribution strategies**
870  You can load different parts of the model under different distribution
871  strategies. Note that this is very experimental so use with care.
872
873  ```
874  model = tf.Module()
875  model.layer_1 = tf.Module()
876  model.layer_1.v = tf.Variable(5.)
877  model.layer_2 = tf.Module()
878  model.layer_2.v = tf.Variable(7.)
879  tf.saved_model.save(model, '/tmp/model')
880  # Load with no strategy
881  loaded = tf.__internal__.saved_model.load_partial(
882  ...   '/tmp/model',
883  ...   ['root.layer_1'])
884  loaded['root.layer_1'].v
885  <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>
886  strategy = tf.distribute.MirroredStrategy()
887  with strategy.scope():
888  ...   loaded2 = tf.__internal__.saved_model.load_partial(
889  ...     '/tmp/model',
890  ...     ['root.layer_2'])
891  loaded2['root.layer_2'].v
892  MirroredVariable:{
893      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0>
894  }
895  ```
896
897  Args:
898    export_dir: The SavedModel directory to load from.
899    filters: A list or dictionary where each element or key is a string
900      path to nodes that should be loaded. Node paths consist of all the child
901      attribute names to reach that node in the form: `root.{attribute_name}`.
902      The loader will load all of the specified nodes and their recursive
903      descendants. When this option is defined, the loader will return a
904      dictionary mapping the node paths to the loaded objects.
905    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
906      if the SavedModel contains a single MetaGraph, as for those exported from
907      `tf.saved_model.save`.
908    options: `tf.saved_model.LoadOptions` object that specifies options for
909      loading.
910
911  Returns:
912    A dictionary mapping node paths from the filter to loaded objects.
913  """
914  options = options or load_options.LoadOptions()
915  if tags is not None and not isinstance(tags, set):
916    # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
917    # sequences for nest.flatten, so we put those through as-is.
918    tags = nest.flatten(tags)
919  saved_model_proto, debug_info = (
920      loader_impl.parse_saved_model_with_debug_info(export_dir))
921
922  if (len(saved_model_proto.meta_graphs) == 1 and
923      saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
924    metrics.IncrementReadApi(_LOAD_V2_LABEL)
925    meta_graph_def = saved_model_proto.meta_graphs[0]
926    # tensor_content field contains raw bytes in litle endian format
927    # which causes problems when loaded on big-endian systems
928    # requiring byteswap
929    if sys.byteorder == "big":
930      saved_model_utils.swap_function_tensor_content(meta_graph_def, "little",
931                                                     "big")
932    if (tags is not None
933        and set(tags) != set(meta_graph_def.meta_info_def.tags)):
934      raise ValueError(
935          f"Got an incompatible argument to `tags`: {tags}. The SavedModel at "
936          f"{export_dir} has one MetaGraph with tags "
937          f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, "
938          "pass 'None', or pass matching tags.")
939    object_graph_proto = meta_graph_def.object_graph_def
940
941    ckpt_options = checkpoint_options.CheckpointOptions(
942        experimental_io_device=options.experimental_io_device)
943    with ops.init_scope():
944      try:
945        loader = Loader(object_graph_proto, saved_model_proto, export_dir,
946                        ckpt_options, options, filters)
947      except errors.NotFoundError as err:
948        raise FileNotFoundError(
949            str(err) + "\n You may be trying to load on a different device "
950            "from the computational device. Consider setting the "
951            "`experimental_io_device` option in `tf.saved_model.LoadOptions` "
952            "to the io_device such as '/job:localhost'.")
953      root = loader.get(0)
954      root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info)
955    root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
956    root.tensorflow_git_version = (
957        meta_graph_def.meta_info_def.tensorflow_git_version)
958    metrics.IncrementRead(write_version="2")
959  else:
960    if filters:
961      raise ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any"
962                       " version) cannot be loaded with node filters.")
963    with ops.init_scope():
964      root = load_v1_in_v2.load(export_dir, tags)
965      root.graph_debug_info = debug_info
966
967  if filters:
968    return {node_id: loader.get(node_id) for node_id in filters}
969  else:
970    return {"root": root}
971