xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/registration/registration.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Serialization Registration for SavedModel.
16
17revived_types registration will be migrated to this infrastructure.
18
19See the Advanced saving section in go/savedmodel-configurability.
20This API is approved for TF internal use only.
21"""
22import collections
23import re
24
25from tensorflow.python.util import tf_inspect
26
27
28# Only allow valid file/directory characters
29_VALID_REGISTERED_NAME = re.compile(r"^[a-zA-Z0-9._-]+$")
30
31
32class _PredicateRegistry(object):
33  """Registry with predicate-based lookup.
34
35  See the documentation for `register_checkpoint_saver` and
36  `register_serializable` for reasons why predicates are required over a
37  class-based registry.
38
39  Since this class is used for global registries, each object must be registered
40  to unique names (an error is raised if there are naming conflicts). The lookup
41  searches the predicates in reverse order, so that later-registered predicates
42  are executed first.
43  """
44  __slots__ = ("_registry_name", "_registered_map", "_registered_predicates",
45               "_registered_names")
46
47  def __init__(self, name):
48    self._registry_name = name
49    # Maps registered name -> object
50    self._registered_map = {}
51    # Maps registered name -> predicate
52    self._registered_predicates = {}
53    # Stores names in the order of registration
54    self._registered_names = []
55
56  @property
57  def name(self):
58    return self._registry_name
59
60  def register(self, package, name, predicate, candidate):
61    """Registers a candidate object under the package, name and predicate."""
62    if not isinstance(package, str) or not isinstance(name, str):
63      raise TypeError(
64          f"The package and name registered to a {self.name} must be strings, "
65          f"got: package={type(package)}, name={type(name)}")
66    if not callable(predicate):
67      raise TypeError(
68          f"The predicate registered to a {self.name} must be callable, "
69          f"got: {type(predicate)}")
70    registered_name = package + "." + name
71    if not _VALID_REGISTERED_NAME.match(registered_name):
72      raise ValueError(
73          f"Invalid registered {self.name}. Please check that the package and "
74          f"name follow the regex '{_VALID_REGISTERED_NAME.pattern}': "
75          f"(package='{package}', name='{name}')")
76    if registered_name in self._registered_map:
77      raise ValueError(
78          f"The name '{registered_name}' has already been registered to a "
79          f"{self.name}. Found: {self._registered_map[registered_name]}")
80
81    self._registered_map[registered_name] = candidate
82    self._registered_predicates[registered_name] = predicate
83    self._registered_names.append(registered_name)
84
85  def lookup(self, obj):
86    """Looks up the registered object using the predicate.
87
88    Args:
89      obj: Object to pass to each of the registered predicates to look up the
90        registered object.
91    Returns:
92      The object registered with the first passing predicate.
93    Raises:
94      LookupError if the object does not match any of the predicate functions.
95    """
96    return self._registered_map[self.get_registered_name(obj)]
97
98  def name_lookup(self, registered_name):
99    """Looks up the registered object using the registered name."""
100    try:
101      return self._registered_map[registered_name]
102    except KeyError:
103      raise LookupError(f"The {self.name} registry does not have name "
104                        f"'{registered_name}' registered.")
105
106  def get_registered_name(self, obj):
107    for registered_name in reversed(self._registered_names):
108      predicate = self._registered_predicates[registered_name]
109      if predicate(obj):
110        return registered_name
111    raise LookupError(f"Could not find matching {self.name} for {type(obj)}.")
112
113  def get_predicate(self, registered_name):
114    try:
115      return self._registered_predicates[registered_name]
116    except KeyError:
117      raise LookupError(f"The {self.name} registry does not have name "
118                        f"'{registered_name}' registered.")
119
120  def get_registrations(self):
121    return self._registered_predicates
122
123_class_registry = _PredicateRegistry("serializable class")
124_saver_registry = _PredicateRegistry("checkpoint saver")
125
126
127def get_registered_class_name(obj):
128  try:
129    return _class_registry.get_registered_name(obj)
130  except LookupError:
131    return None
132
133
134def get_registered_class(registered_name):
135  try:
136    return _class_registry.name_lookup(registered_name)
137  except LookupError:
138    return None
139
140
141def register_serializable(package="Custom", name=None, predicate=None):  # pylint: disable=unused-argument
142  """Decorator for registering a serializable class.
143
144  THIS METHOD IS STILL EXPERIMENTAL AND MAY CHANGE AT ANY TIME.
145
146  Registered classes will be saved with a name generated by combining the
147  `package` and `name` arguments. When loading a SavedModel, modules saved with
148  this registered name will be created using the `_deserialize_from_proto`
149  method.
150
151  By default, only direct instances of the registered class will be saved/
152  restored with the `serialize_from_proto`/`deserialize_from_proto` methods. To
153  extend the registration to subclasses, use the `predicate argument`:
154
155  ```python
156  class A(tf.Module):
157    pass
158
159  register_serializable(
160      package="Example", predicate=lambda obj: isinstance(obj, A))(A)
161  ```
162
163  Args:
164    package: The package that this class belongs to.
165    name: The name to serialize this class under in this package. If None, the
166      class's name will be used.
167    predicate: An optional function that takes a single Trackable argument, and
168      determines whether that object should be serialized with this `package`
169      and `name`. The default predicate checks whether the object's type exactly
170      matches the registered class. Predicates are executed in the reverse order
171      that they are added (later registrations are checked first).
172
173  Returns:
174    A decorator that registers the decorated class with the passed names and
175    predicate.
176  """
177  def decorator(arg):
178    """Registers a class with the serialization framework."""
179    nonlocal predicate
180    if not tf_inspect.isclass(arg):
181      raise TypeError("Registered serializable must be a class: {}".format(arg))
182
183    class_name = name if name is not None else arg.__name__
184    if predicate is None:
185      predicate = lambda x: isinstance(x, arg)
186    _class_registry.register(package, class_name, predicate, arg)
187    return arg
188
189  return decorator
190
191
192RegisteredSaver = collections.namedtuple(
193    "RegisteredSaver", ["name", "predicate", "save_fn", "restore_fn"])
194_REGISTERED_SAVERS = {}
195_REGISTERED_SAVER_NAMES = []  # Stores names in the order of registration
196
197
198def register_checkpoint_saver(package="Custom",
199                              name=None,
200                              predicate=None,
201                              save_fn=None,
202                              restore_fn=None,
203                              strict_predicate_restore=True):
204  """Registers functions which checkpoints & restores objects with custom steps.
205
206  If you have a class that requires complicated coordination between multiple
207  objects when checkpointing, then you will need to register a custom saver
208  and restore function. An example of this is a custom Variable class that
209  splits the variable across different objects and devices, and needs to write
210  checkpoints that are compatible with different configurations of devices.
211
212  The registered save and restore functions are used in checkpoints and
213  SavedModel.
214
215  Please make sure you are familiar with the concepts in the [Checkpointing
216  guide](https://www.tensorflow.org/guide/checkpoint), and ops used to save the
217  V2 checkpoint format:
218
219  * io_ops.SaveV2
220  * io_ops.MergeV2Checkpoints
221  * io_ops.RestoreV2
222
223  **Predicate**
224
225  The predicate is a filter that will run on every `Trackable` object connected
226  to the root object. This function determines whether a `Trackable` should use
227  the registered functions.
228
229  Example: `lambda x: isinstance(x, CustomClass)`
230
231  **Custom save function**
232
233  This is how checkpoint saving works normally:
234  1. Gather all of the Trackables with saveable values.
235  2. For each Trackable, gather all of the saveable tensors.
236  3. Save checkpoint shards (grouping tensors by device) with SaveV2
237  4. Merge the shards with MergeCheckpointV2. This combines all of the shard's
238     metadata, and renames them to follow the standard shard pattern.
239
240  When a saver is registered, Trackables that pass the registered `predicate`
241  are automatically marked as having saveable values. Next, the custom save
242  function replaces steps 2 and 3 of the saving process. Finally, the shards
243  returned by the custom save function are merged with the other shards.
244
245  The save function takes in a dictionary of `Trackables` and a `file_prefix`
246  string. The function should save checkpoint shards using the SaveV2 op, and
247  list of the shard prefixes. SaveV2 is currently required to work a correctly,
248  because the code merges all of the returned shards, and the `restore_fn` will
249  only be given the prefix of the merged checkpoint. If you need to be able to
250  save and restore from unmerged shards, please file a feature request.
251
252  Specification and example of the save function:
253
254  ```
255  def save_fn(trackables, file_prefix):
256    # trackables: A dictionary mapping unique string identifiers to trackables
257    # file_prefix: A unique file prefix generated using the registered name.
258    ...
259    # Gather the tensors to save.
260    ...
261    io_ops.SaveV2(file_prefix, tensor_names, shapes_and_slices, tensors)
262    return file_prefix  # Returns a tensor or a list of string tensors
263  ```
264
265  The save function is executed before the unregistered save ops.
266
267  **Custom restore function**
268
269  Normal checkpoint restore behavior:
270  1. Gather all of the Trackables that have saveable values.
271  2. For each Trackable, get the names of the desired tensors to extract from
272     the checkpoint.
273  3. Use RestoreV2 to read the saved values, and pass the restored tensors to
274     the corresponding Trackables.
275
276  The custom restore function replaces steps 2 and 3.
277
278  The restore function also takes a dictionary of `Trackables` and a
279  `merged_prefix` string. The `merged_prefix` is different from the
280  `file_prefix`, since it contains the renamed shard paths. To read from the
281  merged checkpoint, you must use `RestoreV2(merged_prefix, ...)`.
282
283  Specification:
284
285  ```
286  def restore_fn(trackables, merged_prefix):
287    # trackables: A dictionary mapping unique string identifiers to Trackables
288    # merged_prefix: File prefix of the merged shard names.
289
290    restored_tensors = io_ops.restore_v2(
291        merged_prefix, tensor_names, shapes_and_slices, dtypes)
292    ...
293    # Restore the checkpoint values for the given Trackables.
294  ```
295
296  The restore function is executed after the non-registered restore ops.
297
298  Args:
299    package: Optional, the package that this class belongs to.
300    name: (Required) The name of this saver, which is saved to the checkpoint.
301      When a checkpoint is restored, the name and package are used to find the
302      the matching restore function. The name and package are also used to
303      generate a unique file prefix that is passed to the save_fn.
304    predicate: (Required) A function that returns a boolean indicating whether a
305      `Trackable` object should be checkpointed with this function. Predicates
306      are executed in the reverse order that they are added (later registrations
307      are checked first).
308    save_fn: (Required) A function that takes a dictionary of trackables and a
309      file prefix as the arguments, writes the checkpoint shards for the given
310      Trackables, and returns the list of shard prefixes.
311    restore_fn: (Required) A function that takes a dictionary of trackables and
312      a file prefix as the arguments and restores the trackable values.
313    strict_predicate_restore: If this is `True` (default), then an error will be
314      raised if the predicate fails during checkpoint restoration. If this is
315      `True`, checkpoint restoration will skip running the restore function.
316      This value is generally set to `False` when the predicate does not pass on
317      the Trackables after being saved/loaded from SavedModel.
318
319  Raises:
320    ValueError: if the package and name are already registered.
321  """
322  if not callable(save_fn):
323    raise TypeError(f"The save_fn must be callable, got: {type(save_fn)}")
324  if not callable(restore_fn):
325    raise TypeError(f"The restore_fn must be callable, got: {type(restore_fn)}")
326
327  _saver_registry.register(package, name, predicate, (save_fn, restore_fn,
328                                                      strict_predicate_restore))
329
330
331def get_registered_saver_name(trackable):
332  """Returns the name of the registered saver to use with Trackable."""
333  try:
334    return _saver_registry.get_registered_name(trackable)
335  except LookupError:
336    return None
337
338
339def get_save_function(registered_name):
340  """Returns save function registered to name."""
341  return _saver_registry.name_lookup(registered_name)[0]
342
343
344def get_restore_function(registered_name):
345  """Returns restore function registered to name."""
346  return _saver_registry.name_lookup(registered_name)[1]
347
348
349def get_strict_predicate_restore(registered_name):
350  """Returns if the registered restore can be ignored if the predicate fails."""
351  return _saver_registry.name_lookup(registered_name)[2]
352
353
354def validate_restore_function(trackable, registered_name):
355  """Validates whether the trackable can be restored with the saver.
356
357  When using a checkpoint saved with a registered saver, that same saver must
358  also be also registered when loading. The name of that saver is saved to the
359  checkpoint and set in the `registered_name` arg.
360
361  Args:
362    trackable: A `Trackable` object.
363    registered_name: String name of the expected registered saver. This argument
364      should be set using the name saved in a checkpoint.
365
366  Raises:
367    ValueError if the saver could not be found, or if the predicate associated
368      with the saver does not pass.
369  """
370  try:
371    _saver_registry.name_lookup(registered_name)
372  except LookupError:
373    raise ValueError(
374        f"Error when restoring object {trackable} from checkpoint. This "
375        "object was saved using a registered saver named "
376        f"'{registered_name}', but this saver cannot be found in the "
377        "current context.")
378  if not _saver_registry.get_predicate(registered_name)(trackable):
379    raise ValueError(
380        f"Object {trackable} was saved with the registered saver named "
381        f"'{registered_name}'. However, this saver cannot be used to restore the "
382        "object because the predicate does not pass.")
383