1"""Manages a graph of Trackable objects.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16import copy 17import weakref 18 19from tensorflow.python.checkpoint import save_util_v1 20from tensorflow.python.checkpoint import trackable_view 21from tensorflow.python.trackable import base 22from tensorflow.python.util.tf_export import tf_export 23 24 25@tf_export("__internal__.tracking.ObjectGraphView", v1=[]) 26class ObjectGraphView(trackable_view.TrackableView): 27 """Gathers and serializes an object graph.""" 28 29 def __init__(self, root, attached_dependencies=None): 30 """Configure the graph view. 31 32 Args: 33 root: A `Trackable` object whose variables (including the variables of 34 dependencies, recursively) should be saved. May be a weak reference. 35 attached_dependencies: List of dependencies to attach to the root object. 36 Used when saving a Checkpoint with a defined root object. To avoid 37 reference cycles, this should use the WeakTrackableReference class. 38 """ 39 trackable_view.TrackableView.__init__(self, root) 40 # ObjectGraphView should never contain a strong reference to root, since it 41 # may result in a cycle: 42 # root -> deferred dependencies -> CheckpointPosition 43 # -> CheckpointRestoreCoordinator -> ObjectGraphView -> root 44 self._root_ref = (root if isinstance(root, weakref.ref) 45 else weakref.ref(root)) 46 self._attached_dependencies = attached_dependencies 47 48 def __deepcopy__(self, memo): 49 # By default, weak references are not copied, which leads to surprising 50 # deepcopy behavior. To fix, we first we copy the object itself, then we 51 # make a weak reference to the copy. 52 strong_root = self._root_ref() 53 if strong_root is not None: 54 strong_copy = copy.deepcopy(strong_root, memo) 55 memo[id(self._root_ref)] = weakref.ref(strong_copy) 56 # super() does not have a __deepcopy__, so we need to re-implement it 57 copied = super().__new__(type(self)) 58 memo[id(self)] = copied 59 for key, value in vars(self).items(): 60 setattr(copied, key, copy.deepcopy(value, memo)) 61 return copied 62 63 def list_children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs): 64 """Returns list of all child trackables attached to obj. 65 66 Args: 67 obj: A `Trackable` object. 68 save_type: A string, can be 'savedmodel' or 'checkpoint'. 69 **kwargs: kwargs to use when retrieving the object's children. 70 71 Returns: 72 List of all children attached to the object. 73 """ 74 children = [] 75 for name, ref in super(ObjectGraphView, 76 self).children(obj, save_type, **kwargs).items(): 77 children.append(base.TrackableReference(name, ref)) 78 79 # GraphView objects may define children of the root object that are not 80 # actually attached, e.g. a Checkpoint object's save_counter. 81 if obj is self.root and self._attached_dependencies: 82 children.extend(self._attached_dependencies) 83 return children 84 85 def children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs): 86 """Returns all child trackables attached to obj. 87 88 Args: 89 obj: A `Trackable` object. 90 save_type: A string, can be 'savedmodel' or 'checkpoint'. 91 **kwargs: kwargs to use when retrieving the object's children. 92 93 Returns: 94 Dictionary of all children attached to the object with name to trackable. 95 """ 96 children = {} 97 for name, ref in self.list_children(obj, **kwargs): 98 children[name] = ref 99 return children 100 101 @property 102 def attached_dependencies(self): 103 """Returns list of dependencies that should be saved in the checkpoint. 104 105 These dependencies are not tracked by root, but are in the checkpoint. 106 This is defined when the user creates a Checkpoint with both root and kwargs 107 set. 108 109 Returns: 110 A list of TrackableReferences. 111 """ 112 return self._attached_dependencies 113 114 @property 115 def root(self): 116 if isinstance(self._root_ref, weakref.ref): 117 derefed = self._root_ref() 118 assert derefed is not None 119 return derefed 120 else: 121 return self._root_ref 122 123 def breadth_first_traversal(self): 124 return self._breadth_first_traversal() 125 126 def _breadth_first_traversal(self): 127 """Find shortest paths to all dependencies of self.root.""" 128 return super(ObjectGraphView, self)._descendants_with_paths() 129 130 def serialize_object_graph(self, saveables_cache=None): 131 """Determine checkpoint keys for variables and build a serialized graph. 132 133 Non-slot variables are keyed based on a shortest path from the root saveable 134 to the object which owns the variable (i.e. the one which called 135 `Trackable._add_variable` to create it). 136 137 Slot variables are keyed based on a shortest path to the variable being 138 slotted for, a shortest path to their optimizer, and the slot name. 139 140 Args: 141 saveables_cache: An optional cache storing previously created 142 SaveableObjects created for each Trackable. Maps Trackables to a 143 dictionary of attribute names to Trackable. 144 145 Returns: 146 A tuple of (named_variables, object_graph_proto, feed_additions): 147 named_variables: A dictionary mapping names to variable objects. 148 object_graph_proto: A TrackableObjectGraph protocol buffer 149 containing the serialized object graph and variable references. 150 feed_additions: A dictionary mapping from Tensors to values which should 151 be fed when saving. 152 153 Raises: 154 ValueError: If there are invalid characters in an optimizer's slot names. 155 """ 156 named_saveable_objects, object_graph_proto, feed_additions, _ = ( 157 save_util_v1.serialize_object_graph_with_registered_savers( 158 self, saveables_cache)) 159 return named_saveable_objects, object_graph_proto, feed_additions 160 161 def frozen_saveable_objects(self, 162 object_map=None, 163 to_graph=None, 164 call_with_mapped_captures=None): 165 """Creates SaveableObjects with the current object graph frozen.""" 166 return save_util_v1.frozen_saveables_and_savers( 167 self, object_map, to_graph, call_with_mapped_captures)[0] 168