1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for the checkpoint/util.py.""" 16 17from tensorflow.python.checkpoint import graph_view 18from tensorflow.python.checkpoint import save_util_v1 19from tensorflow.python.eager import test 20from tensorflow.python.ops import variables 21from tensorflow.python.saved_model import registration 22from tensorflow.python.trackable import autotrackable 23from tensorflow.python.util import object_identity 24 25 26class TrackableWithRegisteredSaver(autotrackable.AutoTrackable): 27 pass 28 29 30registration.register_checkpoint_saver( 31 name="RegisteredSaver", 32 predicate=lambda x: isinstance(x, TrackableWithRegisteredSaver), 33 save_fn=lambda trackables, file_prefix: [], 34 restore_fn=lambda trackables, merged_prefix: None) 35 36 37class SerializationTest(test.TestCase): 38 39 def test_serialize_gathered_objects(self): 40 root = autotrackable.AutoTrackable() 41 root.v = variables.Variable(1.0) 42 root.registered = TrackableWithRegisteredSaver() 43 named_saveable_objects, _, _, registered_savers = ( 44 save_util_v1.serialize_gathered_objects( 45 graph_view.ObjectGraphView(root))) 46 47 self.assertLen(named_saveable_objects, 1) 48 self.assertIs(named_saveable_objects[0].op, root.v) 49 self.assertDictEqual( 50 {"Custom.RegisteredSaver": {"registered": root.registered}}, 51 registered_savers) 52 53 def test_serialize_gathered_objects_with_map(self): 54 root = autotrackable.AutoTrackable() 55 root.v = variables.Variable(1.0) 56 root.registered = TrackableWithRegisteredSaver() 57 58 copy_of_registered = TrackableWithRegisteredSaver() 59 copy_of_v = variables.Variable(1.0) 60 object_map = object_identity.ObjectIdentityDictionary() 61 object_map[root.registered] = copy_of_registered 62 object_map[root.v] = copy_of_v 63 64 named_saveable_objects, _, _, registered_savers = ( 65 save_util_v1.serialize_gathered_objects( 66 graph_view.ObjectGraphView(root), object_map)) 67 68 self.assertLen(named_saveable_objects, 1) 69 self.assertIsNot(named_saveable_objects[0].op, root.v) 70 self.assertIs(named_saveable_objects[0].op, copy_of_v) 71 72 ret_value = registered_savers["Custom.RegisteredSaver"]["registered"] 73 self.assertIsNot(root.registered, ret_value) 74 self.assertIs(copy_of_registered, ret_value) 75 76 77if __name__ == "__main__": 78 test.main() 79