xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/registration/registration_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests the registration functions.
16
17For integration tests that use save and load functions, see
18registration_saving_test.py.
19"""
20
21from absl.testing import parameterized
22
23from tensorflow.python.eager import test
24from tensorflow.python.saved_model import registration
25from tensorflow.python.trackable import base
26
27
28@registration.register_serializable()
29class RegisteredClass(base.Trackable):
30  pass
31
32
33@registration.register_serializable(name="Subclass")
34class RegisteredSubclass(RegisteredClass):
35  pass
36
37
38@registration.register_serializable(package="testing")
39class CustomPackage(base.Trackable):
40  pass
41
42
43@registration.register_serializable(package="testing", name="name")
44class CustomPackageAndName(base.Trackable):
45  pass
46
47
48class SerializableRegistrationTest(test.TestCase, parameterized.TestCase):
49
50  @parameterized.parameters([
51      (RegisteredClass, "Custom.RegisteredClass"),
52      (RegisteredSubclass, "Custom.Subclass"),
53      (CustomPackage, "testing.CustomPackage"),
54      (CustomPackageAndName, "testing.name"),
55  ])
56  def test_registration(self, expected_cls, expected_name):
57    obj = expected_cls()
58    self.assertEqual(registration.get_registered_class_name(obj), expected_name)
59    self.assertIs(
60        registration.get_registered_class(expected_name), expected_cls)
61
62  def test_get_invalid_name(self):
63    self.assertIsNone(registration.get_registered_class("invalid name"))
64
65  def test_get_unregistered_class(self):
66
67    class NotRegistered(base.Trackable):
68      pass
69
70    no_register = NotRegistered
71    self.assertIsNone(registration.get_registered_class_name(no_register))
72
73  def test_duplicate_registration(self):
74
75    @registration.register_serializable()
76    class Duplicate(base.Trackable):
77      pass
78
79    dup = Duplicate()
80    self.assertEqual(
81        registration.get_registered_class_name(dup), "Custom.Duplicate")
82    # Registrations with different names are ok.
83    registration.register_serializable(package="duplicate")(Duplicate)
84    # Registrations are checked in reverse order.
85    self.assertEqual(
86        registration.get_registered_class_name(dup), "duplicate.Duplicate")
87    # Both names should resolve to the same class.
88    self.assertIs(
89        registration.get_registered_class("Custom.Duplicate"), Duplicate)
90    self.assertIs(
91        registration.get_registered_class("duplicate.Duplicate"), Duplicate)
92
93    # Registrations of the same name fails
94    with self.assertRaisesRegex(ValueError, "already been registered"):
95      registration.register_serializable(
96          package="testing", name="CustomPackage")(
97              Duplicate)
98
99  def test_register_non_class_fails(self):
100    obj = RegisteredClass()
101    with self.assertRaisesRegex(TypeError, "must be a class"):
102      registration.register_serializable()(obj)
103
104  def test_register_bad_predicate_fails(self):
105    with self.assertRaisesRegex(TypeError, "must be callable"):
106      registration.register_serializable(predicate=0)(RegisteredClass)
107
108  def test_predicate(self):
109
110    class Predicate(base.Trackable):
111
112      def __init__(self, register_this):
113        self.register_this = register_this
114
115    registration.register_serializable(
116        name="RegisterThisOnlyTrue",
117        predicate=lambda x: isinstance(x, Predicate) and x.register_this)(
118            Predicate)
119
120    a = Predicate(True)
121    b = Predicate(False)
122    self.assertEqual(
123        registration.get_registered_class_name(a),
124        "Custom.RegisterThisOnlyTrue")
125    self.assertIsNone(registration.get_registered_class_name(b))
126
127    registration.register_serializable(
128        name="RegisterAllPredicate",
129        predicate=lambda x: isinstance(x, Predicate))(
130            Predicate)
131
132    self.assertEqual(
133        registration.get_registered_class_name(a),
134        "Custom.RegisterAllPredicate")
135    self.assertEqual(
136        registration.get_registered_class_name(b),
137        "Custom.RegisterAllPredicate")
138
139
140class CheckpointSaverRegistrationTest(test.TestCase):
141
142  def test_invalid_registration(self):
143    with self.assertRaisesRegex(TypeError, "must be string"):
144      registration.register_checkpoint_saver(
145          package=None,
146          name="test",
147          predicate=lambda: None,
148          save_fn=lambda: None,
149          restore_fn=lambda: None)
150    with self.assertRaisesRegex(TypeError, "must be string"):
151      registration.register_checkpoint_saver(
152          name=None,
153          predicate=lambda: None,
154          save_fn=lambda: None,
155          restore_fn=lambda: None)
156    with self.assertRaisesRegex(ValueError,
157                                "Invalid registered checkpoint saver."):
158      registration.register_checkpoint_saver(
159          package="package",
160          name="t/est",
161          predicate=lambda: None,
162          save_fn=lambda: None,
163          restore_fn=lambda: None)
164    with self.assertRaisesRegex(ValueError,
165                                "Invalid registered checkpoint saver."):
166      registration.register_checkpoint_saver(
167          package="package",
168          name="t/est",
169          predicate=lambda: None,
170          save_fn=lambda: None,
171          restore_fn=lambda: None)
172    with self.assertRaisesRegex(
173        TypeError,
174        "The predicate registered to a checkpoint saver must be callable"
175    ):
176      registration.register_checkpoint_saver(
177          name="test",
178          predicate=None,
179          save_fn=lambda: None,
180          restore_fn=lambda: None)
181    with self.assertRaisesRegex(TypeError, "The save_fn must be callable"):
182      registration.register_checkpoint_saver(
183          name="test",
184          predicate=lambda: None,
185          save_fn=None,
186          restore_fn=lambda: None)
187    with self.assertRaisesRegex(TypeError, "The restore_fn must be callable"):
188      registration.register_checkpoint_saver(
189          name="test",
190          predicate=lambda: None,
191          save_fn=lambda: None,
192          restore_fn=None)
193
194  def test_registration(self):
195    registration.register_checkpoint_saver(
196        package="Testing",
197        name="test_predicate",
198        predicate=lambda x: hasattr(x, "check_attr"),
199        save_fn=lambda: "save",
200        restore_fn=lambda: "restore")
201    x = base.Trackable()
202    self.assertIsNone(registration.get_registered_saver_name(x))
203
204    x.check_attr = 1
205    saver_name = registration.get_registered_saver_name(x)
206    self.assertEqual(saver_name, "Testing.test_predicate")
207
208    self.assertEqual(registration.get_save_function(saver_name)(), "save")
209    self.assertEqual(registration.get_restore_function(saver_name)(), "restore")
210
211    registration.validate_restore_function(x, "Testing.test_predicate")
212    with self.assertRaisesRegex(ValueError, "saver cannot be found"):
213      registration.validate_restore_function(x, "Invalid.name")
214    x2 = base.Trackable()
215    with self.assertRaisesRegex(ValueError, "saver cannot be used"):
216      registration.validate_restore_function(x2, "Testing.test_predicate")
217
218
219if __name__ == "__main__":
220  test.main()
221