1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Serialization Registration for SavedModel. 16 17revived_types registration will be migrated to this infrastructure. 18 19See the Advanced saving section in go/savedmodel-configurability. 20This API is approved for TF internal use only. 21""" 22import collections 23import re 24 25from tensorflow.python.util import tf_inspect 26 27 28# Only allow valid file/directory characters 29_VALID_REGISTERED_NAME = re.compile(r"^[a-zA-Z0-9._-]+$") 30 31 32class _PredicateRegistry(object): 33 """Registry with predicate-based lookup. 34 35 See the documentation for `register_checkpoint_saver` and 36 `register_serializable` for reasons why predicates are required over a 37 class-based registry. 38 39 Since this class is used for global registries, each object must be registered 40 to unique names (an error is raised if there are naming conflicts). The lookup 41 searches the predicates in reverse order, so that later-registered predicates 42 are executed first. 43 """ 44 __slots__ = ("_registry_name", "_registered_map", "_registered_predicates", 45 "_registered_names") 46 47 def __init__(self, name): 48 self._registry_name = name 49 # Maps registered name -> object 50 self._registered_map = {} 51 # Maps registered name -> predicate 52 self._registered_predicates = {} 53 # Stores names in the order of registration 54 self._registered_names = [] 55 56 @property 57 def name(self): 58 return self._registry_name 59 60 def register(self, package, name, predicate, candidate): 61 """Registers a candidate object under the package, name and predicate.""" 62 if not isinstance(package, str) or not isinstance(name, str): 63 raise TypeError( 64 f"The package and name registered to a {self.name} must be strings, " 65 f"got: package={type(package)}, name={type(name)}") 66 if not callable(predicate): 67 raise TypeError( 68 f"The predicate registered to a {self.name} must be callable, " 69 f"got: {type(predicate)}") 70 registered_name = package + "." + name 71 if not _VALID_REGISTERED_NAME.match(registered_name): 72 raise ValueError( 73 f"Invalid registered {self.name}. Please check that the package and " 74 f"name follow the regex '{_VALID_REGISTERED_NAME.pattern}': " 75 f"(package='{package}', name='{name}')") 76 if registered_name in self._registered_map: 77 raise ValueError( 78 f"The name '{registered_name}' has already been registered to a " 79 f"{self.name}. Found: {self._registered_map[registered_name]}") 80 81 self._registered_map[registered_name] = candidate 82 self._registered_predicates[registered_name] = predicate 83 self._registered_names.append(registered_name) 84 85 def lookup(self, obj): 86 """Looks up the registered object using the predicate. 87 88 Args: 89 obj: Object to pass to each of the registered predicates to look up the 90 registered object. 91 Returns: 92 The object registered with the first passing predicate. 93 Raises: 94 LookupError if the object does not match any of the predicate functions. 95 """ 96 return self._registered_map[self.get_registered_name(obj)] 97 98 def name_lookup(self, registered_name): 99 """Looks up the registered object using the registered name.""" 100 try: 101 return self._registered_map[registered_name] 102 except KeyError: 103 raise LookupError(f"The {self.name} registry does not have name " 104 f"'{registered_name}' registered.") 105 106 def get_registered_name(self, obj): 107 for registered_name in reversed(self._registered_names): 108 predicate = self._registered_predicates[registered_name] 109 if predicate(obj): 110 return registered_name 111 raise LookupError(f"Could not find matching {self.name} for {type(obj)}.") 112 113 def get_predicate(self, registered_name): 114 try: 115 return self._registered_predicates[registered_name] 116 except KeyError: 117 raise LookupError(f"The {self.name} registry does not have name " 118 f"'{registered_name}' registered.") 119 120 def get_registrations(self): 121 return self._registered_predicates 122 123_class_registry = _PredicateRegistry("serializable class") 124_saver_registry = _PredicateRegistry("checkpoint saver") 125 126 127def get_registered_class_name(obj): 128 try: 129 return _class_registry.get_registered_name(obj) 130 except LookupError: 131 return None 132 133 134def get_registered_class(registered_name): 135 try: 136 return _class_registry.name_lookup(registered_name) 137 except LookupError: 138 return None 139 140 141def register_serializable(package="Custom", name=None, predicate=None): # pylint: disable=unused-argument 142 """Decorator for registering a serializable class. 143 144 THIS METHOD IS STILL EXPERIMENTAL AND MAY CHANGE AT ANY TIME. 145 146 Registered classes will be saved with a name generated by combining the 147 `package` and `name` arguments. When loading a SavedModel, modules saved with 148 this registered name will be created using the `_deserialize_from_proto` 149 method. 150 151 By default, only direct instances of the registered class will be saved/ 152 restored with the `serialize_from_proto`/`deserialize_from_proto` methods. To 153 extend the registration to subclasses, use the `predicate argument`: 154 155 ```python 156 class A(tf.Module): 157 pass 158 159 register_serializable( 160 package="Example", predicate=lambda obj: isinstance(obj, A))(A) 161 ``` 162 163 Args: 164 package: The package that this class belongs to. 165 name: The name to serialize this class under in this package. If None, the 166 class's name will be used. 167 predicate: An optional function that takes a single Trackable argument, and 168 determines whether that object should be serialized with this `package` 169 and `name`. The default predicate checks whether the object's type exactly 170 matches the registered class. Predicates are executed in the reverse order 171 that they are added (later registrations are checked first). 172 173 Returns: 174 A decorator that registers the decorated class with the passed names and 175 predicate. 176 """ 177 def decorator(arg): 178 """Registers a class with the serialization framework.""" 179 nonlocal predicate 180 if not tf_inspect.isclass(arg): 181 raise TypeError("Registered serializable must be a class: {}".format(arg)) 182 183 class_name = name if name is not None else arg.__name__ 184 if predicate is None: 185 predicate = lambda x: isinstance(x, arg) 186 _class_registry.register(package, class_name, predicate, arg) 187 return arg 188 189 return decorator 190 191 192RegisteredSaver = collections.namedtuple( 193 "RegisteredSaver", ["name", "predicate", "save_fn", "restore_fn"]) 194_REGISTERED_SAVERS = {} 195_REGISTERED_SAVER_NAMES = [] # Stores names in the order of registration 196 197 198def register_checkpoint_saver(package="Custom", 199 name=None, 200 predicate=None, 201 save_fn=None, 202 restore_fn=None, 203 strict_predicate_restore=True): 204 """Registers functions which checkpoints & restores objects with custom steps. 205 206 If you have a class that requires complicated coordination between multiple 207 objects when checkpointing, then you will need to register a custom saver 208 and restore function. An example of this is a custom Variable class that 209 splits the variable across different objects and devices, and needs to write 210 checkpoints that are compatible with different configurations of devices. 211 212 The registered save and restore functions are used in checkpoints and 213 SavedModel. 214 215 Please make sure you are familiar with the concepts in the [Checkpointing 216 guide](https://www.tensorflow.org/guide/checkpoint), and ops used to save the 217 V2 checkpoint format: 218 219 * io_ops.SaveV2 220 * io_ops.MergeV2Checkpoints 221 * io_ops.RestoreV2 222 223 **Predicate** 224 225 The predicate is a filter that will run on every `Trackable` object connected 226 to the root object. This function determines whether a `Trackable` should use 227 the registered functions. 228 229 Example: `lambda x: isinstance(x, CustomClass)` 230 231 **Custom save function** 232 233 This is how checkpoint saving works normally: 234 1. Gather all of the Trackables with saveable values. 235 2. For each Trackable, gather all of the saveable tensors. 236 3. Save checkpoint shards (grouping tensors by device) with SaveV2 237 4. Merge the shards with MergeCheckpointV2. This combines all of the shard's 238 metadata, and renames them to follow the standard shard pattern. 239 240 When a saver is registered, Trackables that pass the registered `predicate` 241 are automatically marked as having saveable values. Next, the custom save 242 function replaces steps 2 and 3 of the saving process. Finally, the shards 243 returned by the custom save function are merged with the other shards. 244 245 The save function takes in a dictionary of `Trackables` and a `file_prefix` 246 string. The function should save checkpoint shards using the SaveV2 op, and 247 list of the shard prefixes. SaveV2 is currently required to work a correctly, 248 because the code merges all of the returned shards, and the `restore_fn` will 249 only be given the prefix of the merged checkpoint. If you need to be able to 250 save and restore from unmerged shards, please file a feature request. 251 252 Specification and example of the save function: 253 254 ``` 255 def save_fn(trackables, file_prefix): 256 # trackables: A dictionary mapping unique string identifiers to trackables 257 # file_prefix: A unique file prefix generated using the registered name. 258 ... 259 # Gather the tensors to save. 260 ... 261 io_ops.SaveV2(file_prefix, tensor_names, shapes_and_slices, tensors) 262 return file_prefix # Returns a tensor or a list of string tensors 263 ``` 264 265 The save function is executed before the unregistered save ops. 266 267 **Custom restore function** 268 269 Normal checkpoint restore behavior: 270 1. Gather all of the Trackables that have saveable values. 271 2. For each Trackable, get the names of the desired tensors to extract from 272 the checkpoint. 273 3. Use RestoreV2 to read the saved values, and pass the restored tensors to 274 the corresponding Trackables. 275 276 The custom restore function replaces steps 2 and 3. 277 278 The restore function also takes a dictionary of `Trackables` and a 279 `merged_prefix` string. The `merged_prefix` is different from the 280 `file_prefix`, since it contains the renamed shard paths. To read from the 281 merged checkpoint, you must use `RestoreV2(merged_prefix, ...)`. 282 283 Specification: 284 285 ``` 286 def restore_fn(trackables, merged_prefix): 287 # trackables: A dictionary mapping unique string identifiers to Trackables 288 # merged_prefix: File prefix of the merged shard names. 289 290 restored_tensors = io_ops.restore_v2( 291 merged_prefix, tensor_names, shapes_and_slices, dtypes) 292 ... 293 # Restore the checkpoint values for the given Trackables. 294 ``` 295 296 The restore function is executed after the non-registered restore ops. 297 298 Args: 299 package: Optional, the package that this class belongs to. 300 name: (Required) The name of this saver, which is saved to the checkpoint. 301 When a checkpoint is restored, the name and package are used to find the 302 the matching restore function. The name and package are also used to 303 generate a unique file prefix that is passed to the save_fn. 304 predicate: (Required) A function that returns a boolean indicating whether a 305 `Trackable` object should be checkpointed with this function. Predicates 306 are executed in the reverse order that they are added (later registrations 307 are checked first). 308 save_fn: (Required) A function that takes a dictionary of trackables and a 309 file prefix as the arguments, writes the checkpoint shards for the given 310 Trackables, and returns the list of shard prefixes. 311 restore_fn: (Required) A function that takes a dictionary of trackables and 312 a file prefix as the arguments and restores the trackable values. 313 strict_predicate_restore: If this is `True` (default), then an error will be 314 raised if the predicate fails during checkpoint restoration. If this is 315 `True`, checkpoint restoration will skip running the restore function. 316 This value is generally set to `False` when the predicate does not pass on 317 the Trackables after being saved/loaded from SavedModel. 318 319 Raises: 320 ValueError: if the package and name are already registered. 321 """ 322 if not callable(save_fn): 323 raise TypeError(f"The save_fn must be callable, got: {type(save_fn)}") 324 if not callable(restore_fn): 325 raise TypeError(f"The restore_fn must be callable, got: {type(restore_fn)}") 326 327 _saver_registry.register(package, name, predicate, (save_fn, restore_fn, 328 strict_predicate_restore)) 329 330 331def get_registered_saver_name(trackable): 332 """Returns the name of the registered saver to use with Trackable.""" 333 try: 334 return _saver_registry.get_registered_name(trackable) 335 except LookupError: 336 return None 337 338 339def get_save_function(registered_name): 340 """Returns save function registered to name.""" 341 return _saver_registry.name_lookup(registered_name)[0] 342 343 344def get_restore_function(registered_name): 345 """Returns restore function registered to name.""" 346 return _saver_registry.name_lookup(registered_name)[1] 347 348 349def get_strict_predicate_restore(registered_name): 350 """Returns if the registered restore can be ignored if the predicate fails.""" 351 return _saver_registry.name_lookup(registered_name)[2] 352 353 354def validate_restore_function(trackable, registered_name): 355 """Validates whether the trackable can be restored with the saver. 356 357 When using a checkpoint saved with a registered saver, that same saver must 358 also be also registered when loading. The name of that saver is saved to the 359 checkpoint and set in the `registered_name` arg. 360 361 Args: 362 trackable: A `Trackable` object. 363 registered_name: String name of the expected registered saver. This argument 364 should be set using the name saved in a checkpoint. 365 366 Raises: 367 ValueError if the saver could not be found, or if the predicate associated 368 with the saver does not pass. 369 """ 370 try: 371 _saver_registry.name_lookup(registered_name) 372 except LookupError: 373 raise ValueError( 374 f"Error when restoring object {trackable} from checkpoint. This " 375 "object was saved using a registered saver named " 376 f"'{registered_name}', but this saver cannot be found in the " 377 "current context.") 378 if not _saver_registry.get_predicate(registered_name)(trackable): 379 raise ValueError( 380 f"Object {trackable} was saved with the registered saver named " 381 f"'{registered_name}'. However, this saver cannot be used to restore the " 382 "object because the predicate does not pass.") 383