xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/checkpoint_view.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1"""Manages a Checkpoint View."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16import collections
17
18from tensorflow.core.protobuf import trackable_object_graph_pb2
19from tensorflow.python.checkpoint import trackable_view
20from tensorflow.python.framework import errors_impl
21from tensorflow.python.platform import tf_logging as logging
22from tensorflow.python.trackable import base
23from tensorflow.python.training import py_checkpoint_reader
24from tensorflow.python.util.tf_export import tf_export
25
26
27@tf_export("train.CheckpointView", v1=[])
28class CheckpointView(object):
29  """Gathers and serializes a checkpoint view.
30
31  This is for loading specific portions of a module from a
32  checkpoint, and be able to compare two modules by matching components.
33
34  Example usage:
35
36  >>> class SimpleModule(tf.Module):
37  ...   def __init__(self, name=None):
38  ...     super().__init__(name=name)
39  ...     self.a_var = tf.Variable(5.0)
40  ...     self.b_var = tf.Variable(4.0)
41  ...     self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
42
43  >>> root = SimpleModule(name="root")
44  >>> root.leaf = SimpleModule(name="leaf")
45  >>> ckpt = tf.train.Checkpoint(root)
46  >>> save_path = ckpt.save('/tmp/tf_ckpts')
47  >>> checkpoint_view = tf.train.CheckpointView(save_path)
48
49  Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the dictionary
50  of all children directly linked to the checkpoint root.
51
52  >>> for name, node_id in checkpoint_view.children(0).items():
53  ...   print(f"- name: '{name}', node_id: {node_id}")
54  - name: 'a_var', node_id: 1
55  - name: 'b_var', node_id: 2
56  - name: 'vars', node_id: 3
57  - name: 'leaf', node_id: 4
58  - name: 'root', node_id: 0
59  - name: 'save_counter', node_id: 5
60
61  """
62
63  def __init__(self, save_path):
64    """Configure the checkpoint view.
65
66    Args:
67      save_path: The path to the checkpoint.
68
69    Raises:
70      ValueError: If the save_path does not lead to a TF2 checkpoint.
71    """
72
73    reader = py_checkpoint_reader.NewCheckpointReader(save_path)
74    try:
75      object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
76    except errors_impl.NotFoundError as not_found_error:
77      raise ValueError(
78          f"The specified checkpoint \"{save_path}\" does not appear to be "
79          "object-based (saved with TF2) since it is missing the key "
80          f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the "
81          "TF1 name-based saver and does not contain an object dependency graph."
82      ) from not_found_error
83    object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
84    object_graph_proto.ParseFromString(object_graph_string)
85    self._object_graph_proto = object_graph_proto
86
87  def children(self, node_id):
88    """Returns all child trackables attached to obj.
89
90    Args:
91      node_id: Id of the node to return its children.
92
93    Returns:
94      Dictionary of all children attached to the object with name to node_id.
95    """
96    return {
97        child.local_name: child.node_id
98        for child in self._object_graph_proto.nodes[node_id].children
99    }
100
101  def descendants(self):
102    """Returns a list of trackables by node_id attached to obj."""
103
104    all_nodes = []
105    to_visit = collections.deque([0])
106    all_nodes.append(0)
107    while to_visit:
108      node_id = to_visit.popleft()
109      obj = self._object_graph_proto.nodes[node_id]
110      for child in obj.children:
111        if child.node_id not in all_nodes:
112          all_nodes.append(child.node_id)
113          to_visit.append(child.node_id)
114    return all_nodes
115
116  def match(self, obj):
117    """Returns all matching trackables between CheckpointView and Trackable.
118
119    Matching trackables represents trackables with the same name and position in
120    graph.
121
122    Args:
123      obj: `Trackable` root.
124
125    Returns:
126      Dictionary containing all overlapping trackables that maps `node_id` to
127      `Trackable`.
128
129    Example usage:
130
131    >>> class SimpleModule(tf.Module):
132    ...   def __init__(self, name=None):
133    ...     super().__init__(name=name)
134    ...     self.a_var = tf.Variable(5.0)
135    ...     self.b_var = tf.Variable(4.0)
136    ...     self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
137
138    >>> root = SimpleModule(name="root")
139    >>> leaf = root.leaf = SimpleModule(name="leaf")
140    >>> leaf.leaf3 = tf.Variable(6.0, name="leaf3")
141    >>> leaf.leaf4 = tf.Variable(7.0, name="leaf4")
142    >>> ckpt = tf.train.Checkpoint(root)
143    >>> save_path = ckpt.save('/tmp/tf_ckpts')
144    >>> checkpoint_view = tf.train.CheckpointView(save_path)
145
146    >>> root2 = SimpleModule(name="root")
147    >>> leaf2 = root2.leaf2 = SimpleModule(name="leaf2")
148    >>> leaf2.leaf3 = tf.Variable(6.0)
149    >>> leaf2.leaf4 = tf.Variable(7.0)
150
151    Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the
152    dictionary of all children directly linked to the checkpoint root.
153
154    >>> checkpoint_view_match = checkpoint_view.match(root2).items()
155    >>> for item in checkpoint_view_match:
156    ...   print(item)
157    (0, ...)
158    (1, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>)
159    (2, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>)
160    (3, ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32,
161    numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]))
162    (6, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>)
163    (7, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>)
164
165    """
166    if not isinstance(obj, base.Trackable):
167      raise ValueError(f"Expected a Trackable, got {obj} of type {type(obj)}.")
168
169    overlapping_nodes = {}
170    # Root node is always matched.
171    overlapping_nodes[0] = obj
172
173    # Queue of tuples of node_id and trackable.
174    to_visit = collections.deque([(0, obj)])
175    visited = set()
176    view = trackable_view.TrackableView(obj)
177    while to_visit:
178      current_node_id, current_trackable = to_visit.popleft()
179      trackable_children = view.children(current_trackable)
180      for child_name, child_node_id in self.children(current_node_id).items():
181        if child_node_id in visited or child_node_id == 0:
182          continue
183        if child_name in trackable_children:
184          current_assignment = overlapping_nodes.get(child_node_id)
185          if current_assignment is None:
186            overlapping_nodes[child_node_id] = trackable_children[child_name]
187            to_visit.append((child_node_id, trackable_children[child_name]))
188          else:
189            # The object was already mapped for this checkpoint load, which
190            # means we don't need to do anything besides check that the mapping
191            # is consistent (if the dependency DAG is not a tree then there are
192            # multiple paths to the same object).
193            if current_assignment is not trackable_children[child_name]:
194              logging.warning(
195                  "Inconsistent references when matching the checkpoint into "
196                  "this object graph. The referenced objects are: "
197                  f"({current_assignment} and "
198                  f"{trackable_children[child_name]}).")
199      visited.add(current_node_id)
200    return overlapping_nodes
201