1"""Manages a Checkpoint View.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16import collections 17 18from tensorflow.core.protobuf import trackable_object_graph_pb2 19from tensorflow.python.checkpoint import trackable_view 20from tensorflow.python.framework import errors_impl 21from tensorflow.python.platform import tf_logging as logging 22from tensorflow.python.trackable import base 23from tensorflow.python.training import py_checkpoint_reader 24from tensorflow.python.util.tf_export import tf_export 25 26 27@tf_export("train.CheckpointView", v1=[]) 28class CheckpointView(object): 29 """Gathers and serializes a checkpoint view. 30 31 This is for loading specific portions of a module from a 32 checkpoint, and be able to compare two modules by matching components. 33 34 Example usage: 35 36 >>> class SimpleModule(tf.Module): 37 ... def __init__(self, name=None): 38 ... super().__init__(name=name) 39 ... self.a_var = tf.Variable(5.0) 40 ... self.b_var = tf.Variable(4.0) 41 ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)] 42 43 >>> root = SimpleModule(name="root") 44 >>> root.leaf = SimpleModule(name="leaf") 45 >>> ckpt = tf.train.Checkpoint(root) 46 >>> save_path = ckpt.save('/tmp/tf_ckpts') 47 >>> checkpoint_view = tf.train.CheckpointView(save_path) 48 49 Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the dictionary 50 of all children directly linked to the checkpoint root. 51 52 >>> for name, node_id in checkpoint_view.children(0).items(): 53 ... print(f"- name: '{name}', node_id: {node_id}") 54 - name: 'a_var', node_id: 1 55 - name: 'b_var', node_id: 2 56 - name: 'vars', node_id: 3 57 - name: 'leaf', node_id: 4 58 - name: 'root', node_id: 0 59 - name: 'save_counter', node_id: 5 60 61 """ 62 63 def __init__(self, save_path): 64 """Configure the checkpoint view. 65 66 Args: 67 save_path: The path to the checkpoint. 68 69 Raises: 70 ValueError: If the save_path does not lead to a TF2 checkpoint. 71 """ 72 73 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 74 try: 75 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) 76 except errors_impl.NotFoundError as not_found_error: 77 raise ValueError( 78 f"The specified checkpoint \"{save_path}\" does not appear to be " 79 "object-based (saved with TF2) since it is missing the key " 80 f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the " 81 "TF1 name-based saver and does not contain an object dependency graph." 82 ) from not_found_error 83 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 84 object_graph_proto.ParseFromString(object_graph_string) 85 self._object_graph_proto = object_graph_proto 86 87 def children(self, node_id): 88 """Returns all child trackables attached to obj. 89 90 Args: 91 node_id: Id of the node to return its children. 92 93 Returns: 94 Dictionary of all children attached to the object with name to node_id. 95 """ 96 return { 97 child.local_name: child.node_id 98 for child in self._object_graph_proto.nodes[node_id].children 99 } 100 101 def descendants(self): 102 """Returns a list of trackables by node_id attached to obj.""" 103 104 all_nodes = [] 105 to_visit = collections.deque([0]) 106 all_nodes.append(0) 107 while to_visit: 108 node_id = to_visit.popleft() 109 obj = self._object_graph_proto.nodes[node_id] 110 for child in obj.children: 111 if child.node_id not in all_nodes: 112 all_nodes.append(child.node_id) 113 to_visit.append(child.node_id) 114 return all_nodes 115 116 def match(self, obj): 117 """Returns all matching trackables between CheckpointView and Trackable. 118 119 Matching trackables represents trackables with the same name and position in 120 graph. 121 122 Args: 123 obj: `Trackable` root. 124 125 Returns: 126 Dictionary containing all overlapping trackables that maps `node_id` to 127 `Trackable`. 128 129 Example usage: 130 131 >>> class SimpleModule(tf.Module): 132 ... def __init__(self, name=None): 133 ... super().__init__(name=name) 134 ... self.a_var = tf.Variable(5.0) 135 ... self.b_var = tf.Variable(4.0) 136 ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)] 137 138 >>> root = SimpleModule(name="root") 139 >>> leaf = root.leaf = SimpleModule(name="leaf") 140 >>> leaf.leaf3 = tf.Variable(6.0, name="leaf3") 141 >>> leaf.leaf4 = tf.Variable(7.0, name="leaf4") 142 >>> ckpt = tf.train.Checkpoint(root) 143 >>> save_path = ckpt.save('/tmp/tf_ckpts') 144 >>> checkpoint_view = tf.train.CheckpointView(save_path) 145 146 >>> root2 = SimpleModule(name="root") 147 >>> leaf2 = root2.leaf2 = SimpleModule(name="leaf2") 148 >>> leaf2.leaf3 = tf.Variable(6.0) 149 >>> leaf2.leaf4 = tf.Variable(7.0) 150 151 Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the 152 dictionary of all children directly linked to the checkpoint root. 153 154 >>> checkpoint_view_match = checkpoint_view.match(root2).items() 155 >>> for item in checkpoint_view_match: 156 ... print(item) 157 (0, ...) 158 (1, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>) 159 (2, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>) 160 (3, ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, 161 numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>])) 162 (6, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>) 163 (7, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>) 164 165 """ 166 if not isinstance(obj, base.Trackable): 167 raise ValueError(f"Expected a Trackable, got {obj} of type {type(obj)}.") 168 169 overlapping_nodes = {} 170 # Root node is always matched. 171 overlapping_nodes[0] = obj 172 173 # Queue of tuples of node_id and trackable. 174 to_visit = collections.deque([(0, obj)]) 175 visited = set() 176 view = trackable_view.TrackableView(obj) 177 while to_visit: 178 current_node_id, current_trackable = to_visit.popleft() 179 trackable_children = view.children(current_trackable) 180 for child_name, child_node_id in self.children(current_node_id).items(): 181 if child_node_id in visited or child_node_id == 0: 182 continue 183 if child_name in trackable_children: 184 current_assignment = overlapping_nodes.get(child_node_id) 185 if current_assignment is None: 186 overlapping_nodes[child_node_id] = trackable_children[child_name] 187 to_visit.append((child_node_id, trackable_children[child_name])) 188 else: 189 # The object was already mapped for this checkpoint load, which 190 # means we don't need to do anything besides check that the mapping 191 # is consistent (if the dependency DAG is not a tree then there are 192 # multiple paths to the same object). 193 if current_assignment is not trackable_children[child_name]: 194 logging.warning( 195 "Inconsistent references when matching the checkpoint into " 196 "this object graph. The referenced objects are: " 197 f"({current_assignment} and " 198 f"{trackable_children[child_name]}).") 199 visited.add(current_node_id) 200 return overlapping_nodes 201