1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests the registration functions. 16 17For integration tests that use save and load functions, see 18registration_saving_test.py. 19""" 20 21from absl.testing import parameterized 22 23from tensorflow.python.eager import test 24from tensorflow.python.saved_model import registration 25from tensorflow.python.trackable import base 26 27 28@registration.register_serializable() 29class RegisteredClass(base.Trackable): 30 pass 31 32 33@registration.register_serializable(name="Subclass") 34class RegisteredSubclass(RegisteredClass): 35 pass 36 37 38@registration.register_serializable(package="testing") 39class CustomPackage(base.Trackable): 40 pass 41 42 43@registration.register_serializable(package="testing", name="name") 44class CustomPackageAndName(base.Trackable): 45 pass 46 47 48class SerializableRegistrationTest(test.TestCase, parameterized.TestCase): 49 50 @parameterized.parameters([ 51 (RegisteredClass, "Custom.RegisteredClass"), 52 (RegisteredSubclass, "Custom.Subclass"), 53 (CustomPackage, "testing.CustomPackage"), 54 (CustomPackageAndName, "testing.name"), 55 ]) 56 def test_registration(self, expected_cls, expected_name): 57 obj = expected_cls() 58 self.assertEqual(registration.get_registered_class_name(obj), expected_name) 59 self.assertIs( 60 registration.get_registered_class(expected_name), expected_cls) 61 62 def test_get_invalid_name(self): 63 self.assertIsNone(registration.get_registered_class("invalid name")) 64 65 def test_get_unregistered_class(self): 66 67 class NotRegistered(base.Trackable): 68 pass 69 70 no_register = NotRegistered 71 self.assertIsNone(registration.get_registered_class_name(no_register)) 72 73 def test_duplicate_registration(self): 74 75 @registration.register_serializable() 76 class Duplicate(base.Trackable): 77 pass 78 79 dup = Duplicate() 80 self.assertEqual( 81 registration.get_registered_class_name(dup), "Custom.Duplicate") 82 # Registrations with different names are ok. 83 registration.register_serializable(package="duplicate")(Duplicate) 84 # Registrations are checked in reverse order. 85 self.assertEqual( 86 registration.get_registered_class_name(dup), "duplicate.Duplicate") 87 # Both names should resolve to the same class. 88 self.assertIs( 89 registration.get_registered_class("Custom.Duplicate"), Duplicate) 90 self.assertIs( 91 registration.get_registered_class("duplicate.Duplicate"), Duplicate) 92 93 # Registrations of the same name fails 94 with self.assertRaisesRegex(ValueError, "already been registered"): 95 registration.register_serializable( 96 package="testing", name="CustomPackage")( 97 Duplicate) 98 99 def test_register_non_class_fails(self): 100 obj = RegisteredClass() 101 with self.assertRaisesRegex(TypeError, "must be a class"): 102 registration.register_serializable()(obj) 103 104 def test_register_bad_predicate_fails(self): 105 with self.assertRaisesRegex(TypeError, "must be callable"): 106 registration.register_serializable(predicate=0)(RegisteredClass) 107 108 def test_predicate(self): 109 110 class Predicate(base.Trackable): 111 112 def __init__(self, register_this): 113 self.register_this = register_this 114 115 registration.register_serializable( 116 name="RegisterThisOnlyTrue", 117 predicate=lambda x: isinstance(x, Predicate) and x.register_this)( 118 Predicate) 119 120 a = Predicate(True) 121 b = Predicate(False) 122 self.assertEqual( 123 registration.get_registered_class_name(a), 124 "Custom.RegisterThisOnlyTrue") 125 self.assertIsNone(registration.get_registered_class_name(b)) 126 127 registration.register_serializable( 128 name="RegisterAllPredicate", 129 predicate=lambda x: isinstance(x, Predicate))( 130 Predicate) 131 132 self.assertEqual( 133 registration.get_registered_class_name(a), 134 "Custom.RegisterAllPredicate") 135 self.assertEqual( 136 registration.get_registered_class_name(b), 137 "Custom.RegisterAllPredicate") 138 139 140class CheckpointSaverRegistrationTest(test.TestCase): 141 142 def test_invalid_registration(self): 143 with self.assertRaisesRegex(TypeError, "must be string"): 144 registration.register_checkpoint_saver( 145 package=None, 146 name="test", 147 predicate=lambda: None, 148 save_fn=lambda: None, 149 restore_fn=lambda: None) 150 with self.assertRaisesRegex(TypeError, "must be string"): 151 registration.register_checkpoint_saver( 152 name=None, 153 predicate=lambda: None, 154 save_fn=lambda: None, 155 restore_fn=lambda: None) 156 with self.assertRaisesRegex(ValueError, 157 "Invalid registered checkpoint saver."): 158 registration.register_checkpoint_saver( 159 package="package", 160 name="t/est", 161 predicate=lambda: None, 162 save_fn=lambda: None, 163 restore_fn=lambda: None) 164 with self.assertRaisesRegex(ValueError, 165 "Invalid registered checkpoint saver."): 166 registration.register_checkpoint_saver( 167 package="package", 168 name="t/est", 169 predicate=lambda: None, 170 save_fn=lambda: None, 171 restore_fn=lambda: None) 172 with self.assertRaisesRegex( 173 TypeError, 174 "The predicate registered to a checkpoint saver must be callable" 175 ): 176 registration.register_checkpoint_saver( 177 name="test", 178 predicate=None, 179 save_fn=lambda: None, 180 restore_fn=lambda: None) 181 with self.assertRaisesRegex(TypeError, "The save_fn must be callable"): 182 registration.register_checkpoint_saver( 183 name="test", 184 predicate=lambda: None, 185 save_fn=None, 186 restore_fn=lambda: None) 187 with self.assertRaisesRegex(TypeError, "The restore_fn must be callable"): 188 registration.register_checkpoint_saver( 189 name="test", 190 predicate=lambda: None, 191 save_fn=lambda: None, 192 restore_fn=None) 193 194 def test_registration(self): 195 registration.register_checkpoint_saver( 196 package="Testing", 197 name="test_predicate", 198 predicate=lambda x: hasattr(x, "check_attr"), 199 save_fn=lambda: "save", 200 restore_fn=lambda: "restore") 201 x = base.Trackable() 202 self.assertIsNone(registration.get_registered_saver_name(x)) 203 204 x.check_attr = 1 205 saver_name = registration.get_registered_saver_name(x) 206 self.assertEqual(saver_name, "Testing.test_predicate") 207 208 self.assertEqual(registration.get_save_function(saver_name)(), "save") 209 self.assertEqual(registration.get_restore_function(saver_name)(), "restore") 210 211 registration.validate_restore_function(x, "Testing.test_predicate") 212 with self.assertRaisesRegex(ValueError, "saver cannot be found"): 213 registration.validate_restore_function(x, "Invalid.name") 214 x2 = base.Trackable() 215 with self.assertRaisesRegex(ValueError, "saver cannot be used"): 216 registration.validate_restore_function(x2, "Testing.test_predicate") 217 218 219if __name__ == "__main__": 220 test.main() 221