xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/checkpoint_view_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for the checkpoint view."""
16
17import os
18
19from tensorflow.python.checkpoint import checkpoint as trackable_utils
20from tensorflow.python.checkpoint import checkpoint_view
21from tensorflow.python.eager import test
22from tensorflow.python.trackable import autotrackable
23
24
25class CheckpointViewTest(test.TestCase):
26
27  def test_children(self):
28    root = autotrackable.AutoTrackable()
29    root.leaf = autotrackable.AutoTrackable()
30    root_ckpt = trackable_utils.Checkpoint(root=root)
31    root_save_path = root_ckpt.save(
32        os.path.join(self.get_temp_dir(), "root_ckpt"))
33    current_name, node_id = next(
34        iter(
35            checkpoint_view.CheckpointView(root_save_path).children(0).items()))
36    self.assertEqual("leaf", current_name)
37    self.assertEqual(1, node_id)
38
39  def test_all_nodes(self):
40    root = autotrackable.AutoTrackable()
41    root.leaf = autotrackable.AutoTrackable()
42    root_ckpt = trackable_utils.Checkpoint(root=root)
43    root_save_path = root_ckpt.save(
44        os.path.join(self.get_temp_dir(), "root_ckpt"))
45    all_nodes = checkpoint_view.CheckpointView(root_save_path).descendants()
46    self.assertEqual(3, len(all_nodes))
47    self.assertEqual(0, all_nodes[0])
48    self.assertEqual(1, all_nodes[1])
49
50  def test_match(self):
51    root1 = autotrackable.AutoTrackable()
52    leaf1 = root1.leaf1 = autotrackable.AutoTrackable()
53    leaf2 = root1.leaf2 = autotrackable.AutoTrackable()
54    leaf1.leaf3 = autotrackable.AutoTrackable()
55    leaf1.leaf4 = autotrackable.AutoTrackable()
56    leaf2.leaf5 = autotrackable.AutoTrackable()
57    root_ckpt = trackable_utils.Checkpoint(root=root1)
58    root_save_path = root_ckpt.save(
59        os.path.join(self.get_temp_dir(), "root_ckpt"))
60
61    root2 = autotrackable.AutoTrackable()
62    leaf11 = root2.leaf1 = autotrackable.AutoTrackable()
63    leaf12 = root2.leaf2 = autotrackable.AutoTrackable()
64    leaf13 = leaf11.leaf3 = autotrackable.AutoTrackable()
65    leaf15 = leaf12.leaf5 = autotrackable.AutoTrackable()
66    matching_nodes = checkpoint_view.CheckpointView(root_save_path).match(root2)
67    self.assertDictEqual(matching_nodes, {
68        0: root2,
69        1: leaf11,
70        2: leaf12,
71        4: leaf13,
72        6: leaf15
73    })
74
75  def test_match_overlapping_nodes(self):
76    root1 = autotrackable.AutoTrackable()
77    root1.a = root1.b = autotrackable.AutoTrackable()
78    root_ckpt = trackable_utils.Checkpoint(root=root1)
79    root_save_path = root_ckpt.save(
80        os.path.join(self.get_temp_dir(), "root_ckpt"))
81
82    root2 = autotrackable.AutoTrackable()
83    a1 = root2.a = autotrackable.AutoTrackable()
84    root2.b = autotrackable.AutoTrackable()
85    with self.assertLogs(level="WARNING") as logs:
86      matching_nodes = checkpoint_view.CheckpointView(root_save_path).match(
87          root2)
88    self.assertDictEqual(
89        matching_nodes,
90        {
91            0: root2,
92            1: a1,
93            # Only the first element at the same position will be matched.
94        })
95    expected_message = (
96        "Inconsistent references when matching the checkpoint into this object"
97        " graph.")
98    self.assertIn(expected_message, logs.output[0])
99
100if __name__ == "__main__":
101  test.main()
102