1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for the checkpoint view.""" 16 17import os 18 19from tensorflow.python.checkpoint import checkpoint as trackable_utils 20from tensorflow.python.checkpoint import checkpoint_view 21from tensorflow.python.eager import test 22from tensorflow.python.trackable import autotrackable 23 24 25class CheckpointViewTest(test.TestCase): 26 27 def test_children(self): 28 root = autotrackable.AutoTrackable() 29 root.leaf = autotrackable.AutoTrackable() 30 root_ckpt = trackable_utils.Checkpoint(root=root) 31 root_save_path = root_ckpt.save( 32 os.path.join(self.get_temp_dir(), "root_ckpt")) 33 current_name, node_id = next( 34 iter( 35 checkpoint_view.CheckpointView(root_save_path).children(0).items())) 36 self.assertEqual("leaf", current_name) 37 self.assertEqual(1, node_id) 38 39 def test_all_nodes(self): 40 root = autotrackable.AutoTrackable() 41 root.leaf = autotrackable.AutoTrackable() 42 root_ckpt = trackable_utils.Checkpoint(root=root) 43 root_save_path = root_ckpt.save( 44 os.path.join(self.get_temp_dir(), "root_ckpt")) 45 all_nodes = checkpoint_view.CheckpointView(root_save_path).descendants() 46 self.assertEqual(3, len(all_nodes)) 47 self.assertEqual(0, all_nodes[0]) 48 self.assertEqual(1, all_nodes[1]) 49 50 def test_match(self): 51 root1 = autotrackable.AutoTrackable() 52 leaf1 = root1.leaf1 = autotrackable.AutoTrackable() 53 leaf2 = root1.leaf2 = autotrackable.AutoTrackable() 54 leaf1.leaf3 = autotrackable.AutoTrackable() 55 leaf1.leaf4 = autotrackable.AutoTrackable() 56 leaf2.leaf5 = autotrackable.AutoTrackable() 57 root_ckpt = trackable_utils.Checkpoint(root=root1) 58 root_save_path = root_ckpt.save( 59 os.path.join(self.get_temp_dir(), "root_ckpt")) 60 61 root2 = autotrackable.AutoTrackable() 62 leaf11 = root2.leaf1 = autotrackable.AutoTrackable() 63 leaf12 = root2.leaf2 = autotrackable.AutoTrackable() 64 leaf13 = leaf11.leaf3 = autotrackable.AutoTrackable() 65 leaf15 = leaf12.leaf5 = autotrackable.AutoTrackable() 66 matching_nodes = checkpoint_view.CheckpointView(root_save_path).match(root2) 67 self.assertDictEqual(matching_nodes, { 68 0: root2, 69 1: leaf11, 70 2: leaf12, 71 4: leaf13, 72 6: leaf15 73 }) 74 75 def test_match_overlapping_nodes(self): 76 root1 = autotrackable.AutoTrackable() 77 root1.a = root1.b = autotrackable.AutoTrackable() 78 root_ckpt = trackable_utils.Checkpoint(root=root1) 79 root_save_path = root_ckpt.save( 80 os.path.join(self.get_temp_dir(), "root_ckpt")) 81 82 root2 = autotrackable.AutoTrackable() 83 a1 = root2.a = autotrackable.AutoTrackable() 84 root2.b = autotrackable.AutoTrackable() 85 with self.assertLogs(level="WARNING") as logs: 86 matching_nodes = checkpoint_view.CheckpointView(root_save_path).match( 87 root2) 88 self.assertDictEqual( 89 matching_nodes, 90 { 91 0: root2, 92 1: a1, 93 # Only the first element at the same position will be matched. 94 }) 95 expected_message = ( 96 "Inconsistent references when matching the checkpoint into this object" 97 " graph.") 98 self.assertIn(expected_message, logs.output[0]) 99 100if __name__ == "__main__": 101 test.main() 102