xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/graph_view.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1"""Manages a graph of Trackable objects."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16import copy
17import weakref
18
19from tensorflow.python.checkpoint import save_util_v1
20from tensorflow.python.checkpoint import trackable_view
21from tensorflow.python.trackable import base
22from tensorflow.python.util.tf_export import tf_export
23
24
25@tf_export("__internal__.tracking.ObjectGraphView", v1=[])
26class ObjectGraphView(trackable_view.TrackableView):
27  """Gathers and serializes an object graph."""
28
29  def __init__(self, root, attached_dependencies=None):
30    """Configure the graph view.
31
32    Args:
33      root: A `Trackable` object whose variables (including the variables of
34        dependencies, recursively) should be saved. May be a weak reference.
35      attached_dependencies: List of dependencies to attach to the root object.
36        Used when saving a Checkpoint with a defined root object. To avoid
37        reference cycles, this should use the WeakTrackableReference class.
38    """
39    trackable_view.TrackableView.__init__(self, root)
40    # ObjectGraphView should never contain a strong reference to root, since it
41    # may result in a cycle:
42    #   root -> deferred dependencies -> CheckpointPosition
43    #   -> CheckpointRestoreCoordinator -> ObjectGraphView -> root
44    self._root_ref = (root if isinstance(root, weakref.ref)
45                      else weakref.ref(root))
46    self._attached_dependencies = attached_dependencies
47
48  def __deepcopy__(self, memo):
49    # By default, weak references are not copied, which leads to surprising
50    # deepcopy behavior. To fix, we first we copy the object itself, then we
51    # make a weak reference to the copy.
52    strong_root = self._root_ref()
53    if strong_root is not None:
54      strong_copy = copy.deepcopy(strong_root, memo)
55      memo[id(self._root_ref)] = weakref.ref(strong_copy)
56    # super() does not have a __deepcopy__, so we need to re-implement it
57    copied = super().__new__(type(self))
58    memo[id(self)] = copied
59    for key, value in vars(self).items():
60      setattr(copied, key, copy.deepcopy(value, memo))
61    return copied
62
63  def list_children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
64    """Returns list of all child trackables attached to obj.
65
66    Args:
67      obj: A `Trackable` object.
68      save_type: A string, can be 'savedmodel' or 'checkpoint'.
69      **kwargs: kwargs to use when retrieving the object's children.
70
71    Returns:
72      List of all children attached to the object.
73    """
74    children = []
75    for name, ref in super(ObjectGraphView,
76                           self).children(obj, save_type, **kwargs).items():
77      children.append(base.TrackableReference(name, ref))
78
79    # GraphView objects may define children of the root object that are not
80    # actually attached, e.g. a Checkpoint object's save_counter.
81    if obj is self.root and self._attached_dependencies:
82      children.extend(self._attached_dependencies)
83    return children
84
85  def children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
86    """Returns all child trackables attached to obj.
87
88    Args:
89      obj: A `Trackable` object.
90      save_type: A string, can be 'savedmodel' or 'checkpoint'.
91      **kwargs: kwargs to use when retrieving the object's children.
92
93    Returns:
94      Dictionary of all children attached to the object with name to trackable.
95    """
96    children = {}
97    for name, ref in self.list_children(obj, **kwargs):
98      children[name] = ref
99    return children
100
101  @property
102  def attached_dependencies(self):
103    """Returns list of dependencies that should be saved in the checkpoint.
104
105    These dependencies are not tracked by root, but are in the checkpoint.
106    This is defined when the user creates a Checkpoint with both root and kwargs
107    set.
108
109    Returns:
110      A list of TrackableReferences.
111    """
112    return self._attached_dependencies
113
114  @property
115  def root(self):
116    if isinstance(self._root_ref, weakref.ref):
117      derefed = self._root_ref()
118      assert derefed is not None
119      return derefed
120    else:
121      return self._root_ref
122
123  def breadth_first_traversal(self):
124    return self._breadth_first_traversal()
125
126  def _breadth_first_traversal(self):
127    """Find shortest paths to all dependencies of self.root."""
128    return super(ObjectGraphView, self)._descendants_with_paths()
129
130  def serialize_object_graph(self, saveables_cache=None):
131    """Determine checkpoint keys for variables and build a serialized graph.
132
133    Non-slot variables are keyed based on a shortest path from the root saveable
134    to the object which owns the variable (i.e. the one which called
135    `Trackable._add_variable` to create it).
136
137    Slot variables are keyed based on a shortest path to the variable being
138    slotted for, a shortest path to their optimizer, and the slot name.
139
140    Args:
141      saveables_cache: An optional cache storing previously created
142        SaveableObjects created for each Trackable. Maps Trackables to a
143        dictionary of attribute names to Trackable.
144
145    Returns:
146      A tuple of (named_variables, object_graph_proto, feed_additions):
147        named_variables: A dictionary mapping names to variable objects.
148        object_graph_proto: A TrackableObjectGraph protocol buffer
149          containing the serialized object graph and variable references.
150        feed_additions: A dictionary mapping from Tensors to values which should
151          be fed when saving.
152
153    Raises:
154      ValueError: If there are invalid characters in an optimizer's slot names.
155    """
156    named_saveable_objects, object_graph_proto, feed_additions, _ = (
157        save_util_v1.serialize_object_graph_with_registered_savers(
158            self, saveables_cache))
159    return named_saveable_objects, object_graph_proto, feed_additions
160
161  def frozen_saveable_objects(self,
162                              object_map=None,
163                              to_graph=None,
164                              call_with_mapped_captures=None):
165    """Creates SaveableObjects with the current object graph frozen."""
166    return save_util_v1.frozen_saveables_and_savers(
167        self, object_map, to_graph, call_with_mapped_captures)[0]
168