1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""User-defined ExtensionType classes.""" 16 17import abc 18import typing 19import typing_extensions 20 21from tensorflow.python.framework import composite_tensor 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import extension_type_field 24from tensorflow.python.framework import immutable_dict 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.framework import type_spec 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import composite_tensor_ops 31from tensorflow.python.ops import gen_math_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.saved_model import nested_structure_coder 34from tensorflow.python.util import nest 35from tensorflow.python.util import tf_decorator 36from tensorflow.python.util import tf_inspect 37from tensorflow.python.util.tf_export import tf_export 38 39# Attribute used to keep track of when we're inside a user-defined constructor 40# (in which case the fields of `self` may be modified). 41_IN_CONSTRUCTOR = '_tf_extension_type_in_constructor' 42 43_MUTABLE_KERAS_PROPERTIES = [ 44 # Keras uses _keras_mask property to pass the mask around 45 '_keras_mask', 46] 47 48 49# ============================================================================== 50# Utility functions 51# ============================================================================== 52def _create_object_from_type_and_dict(cls, obj_dict): 53 """Creates an object, bypassing the constructor. 54 55 Creates an object of type `cls`, whose `__dict__` is updated to contain 56 `obj_dict`. 57 58 Args: 59 cls: The type of the new object. 60 obj_dict: A `Mapping` that should be used to initialize the new object's 61 `__dict__`. 62 63 Returns: 64 An object of type `cls`. 65 """ 66 value = object.__new__(cls) 67 value.__dict__.update(obj_dict) 68 return value 69 70 71# ============================================================================== 72# Metaclass for tf.ExtensionType 73# ============================================================================== 74class ExtensionTypeMetaclass(abc.ABCMeta): 75 """Metaclass for tf.ExtensionType types.""" 76 77 def __init__(cls, name, bases, namespace): 78 # Don't transform base classes that are part of the framework -- only 79 # transform user classes. We identify classes that are part of the 80 # framework by setting '_tf_extension_type_do_not_transform_this_class=True' 81 # in the class definition. (Note: we check for this in the class namespace, 82 # so it is *not* ineherited.) 83 if not namespace.get('_tf_extension_type_do_not_transform_this_class', 84 False): 85 _check_field_annotations(cls) 86 _add_extension_type_constructor(cls) 87 _add_type_spec(cls) 88 super(ExtensionTypeMetaclass, cls).__init__(name, bases, namespace) 89 90 91# ============================================================================== 92# Base class for user-defined types 93# ============================================================================== 94@tf_export('experimental.ExtensionType') 95class ExtensionType( 96 composite_tensor.CompositeTensor, metaclass=ExtensionTypeMetaclass): 97 """Base class for TensorFlow `ExtensionType` classes. 98 99 Tensorflow `ExtensionType` classes are specialized Python classes that can be 100 used transparently with TensorFlow -- e.g., they can be used with ops 101 such as `tf.cond` or `tf.while_loop` and used as inputs or outputs for 102 `tf.function` and Keras layers. 103 104 New `ExtensionType` classes are defined by creating a subclass of 105 `tf.ExtensionType` that 106 contains type annotations for all instance variables. The following type 107 annotations are supported: 108 109 Type | Example 110 -------------------- | -------------------------------------------- 111 Python integers | `i: int` 112 Python floats | `f: float` 113 Python strings | `s: str` 114 Python booleans | `b: bool` 115 Python None | `n: None` 116 Tensors | `t: tf.Tensor` 117 Composite Tensors | `rt: tf.RaggedTensor` 118 Extension Types | `m: MyMaskedTensor` 119 Tensor shapes | `shape: tf.TensorShape` 120 Tensor dtypes | `dtype: tf.DType` 121 Type unions | `length: typing.Union[int, float]` 122 Tuples | `params: typing.Tuple[int, float, int, int]` 123 Tuples w/ Ellipsis | `lengths: typing.Tuple[int, ...]` 124 Mappings | `tags: typing.Mapping[str, str]` 125 126 Fields annotated with `typing.Mapping` will be stored using an immutable 127 mapping type. 128 129 ExtensionType values are immutable -- i.e., once constructed, you can not 130 modify or delete any of their instance members. 131 132 ### Examples 133 134 >>> class MaskedTensor(ExtensionType): 135 ... values: tf.Tensor 136 ... mask: tf.Tensor 137 138 >>> class Toy(ExtensionType): 139 ... name: str 140 ... price: ops.Tensor 141 ... features: typing.Mapping[str, tf.Tensor] 142 143 >>> class ToyStore(ExtensionType): 144 ... name: str 145 ... toys: typing.Tuple[Toy, ...] 146 """ 147 148 # Let the metaclass know that it should *not* transform this class (since 149 # this class is part of the ExtensionType framework, and not a user class). 150 _tf_extension_type_do_not_transform_this_class = True 151 152 def __init__(self, *args, **kwargs): 153 if type(self) is ExtensionType: # pylint: disable=unidiomatic-typecheck 154 raise AssertionError('Cannot create an instance of ExtensionType ' 155 'because ExtensionType is an abstract base class.') 156 157 # This class variable is used to cache the return value for 158 # _tf_extension_type_fields. 159 _tf_extension_type_cached_fields = None 160 161 @classmethod 162 def _tf_extension_type_fields(cls): # pylint: disable=no-self-argument 163 """An ordered list describing the fields of this ExtensionType. 164 165 Returns: 166 A list of `ExtensionTypeField` objects. Forward references are resolved 167 if possible, or left unresolved otherwise. 168 """ 169 if '_tf_extension_type_cached_fields' in cls.__dict__: # do not inherit. 170 return cls._tf_extension_type_cached_fields 171 172 try: 173 # Using include_extras=False will replace all Annotated[T, ...] with T. 174 # The typing_extensions module is used since this is only supported in 175 # Python 3.9. 176 type_hints = typing_extensions.get_type_hints(cls, include_extras=False) 177 ok_to_cache = True # all forward references have been resolved. 178 except (NameError, AttributeError): 179 # Unresolved forward reference -- gather type hints manually. 180 # * NameError comes from an annotation like `Foo` where class 181 # `Foo` hasn't been defined yet. 182 # * AttributeError comes from an annotation like `foo.Bar`, where 183 # the module `foo` exists but `Bar` hasn't been defined yet. 184 # Note: If a user attempts to instantiate a `ExtensionType` type that 185 # still has unresolved forward references (e.g., because of a typo or a 186 # missing import), then the constructor will raise an exception. 187 type_hints = {} 188 for base in reversed(cls.__mro__): 189 type_hints.update(base.__dict__.get('__annotations__', {})) 190 ok_to_cache = False 191 192 fields = [] 193 for (name, value_type) in type_hints.items(): 194 default = getattr(cls, name, 195 extension_type_field.ExtensionTypeField.NO_DEFAULT) 196 fields.append( 197 extension_type_field.ExtensionTypeField(name, value_type, default)) 198 fields = tuple(fields) 199 200 if ok_to_cache: 201 cls._tf_extension_type_cached_fields = fields 202 203 return fields 204 205 @classmethod 206 def _tf_extension_type_has_field(cls, name): 207 return any(name == field.name for field in cls._tf_extension_type_fields()) 208 209 def _tf_extension_type_convert_fields(self): 210 extension_type_field.convert_fields(self._tf_extension_type_fields(), 211 self.__dict__) 212 213 def __repr__(self): 214 fields = ', '.join([ 215 f'{field.name}={getattr(self, field.name)!r}' 216 for field in self._tf_extension_type_fields() 217 ]) 218 return f'{type(self).__qualname__}({fields})' 219 220 def __setattr__(self, name, value): 221 if (name in _MUTABLE_KERAS_PROPERTIES or 222 (hasattr(self, _IN_CONSTRUCTOR) and 223 self._tf_extension_type_has_field(name))): 224 self.__dict__[name] = value 225 else: 226 raise AttributeError(f'Cannot mutate attribute `{name}` ' 227 f'outside the custom constructor of ExtensionType.') 228 229 def __delattr__(self, name): 230 if (name in _MUTABLE_KERAS_PROPERTIES or 231 (hasattr(self, _IN_CONSTRUCTOR) and 232 self._tf_extension_type_has_field(name))): 233 del self.__dict__[name] 234 else: 235 raise AttributeError(f'Cannot mutate attribute `{name}` ' 236 f'outside the custom constructor of ExtensionType.') 237 238 def __getattr__(self, name): 239 if name in _MUTABLE_KERAS_PROPERTIES: 240 return object.__getattribute__(self, name) 241 if '_tf_extension_type_packed_variant' in self.__dict__: 242 # Note: it's *not* ok to cache the results of unpack() here. In 243 # particular, it would be nice if we could do something like 244 # `self.__dict__.update(unpack(self).__dict__)`, but that (potentially) 245 # violates an invariant required by the `cond` operation. E.g., if we had 246 # `tf.cond(lambda: x.foo, lambda: x.bar)`, then tensor `x.bar` used in the 247 # "else" branch would be created by an op in the "then" branch (when 248 # looking up `x.foo`); and that's not allowed. 249 return getattr(unpack(self), name) 250 251 raise AttributeError( 252 f'{type(self).__name__!r} object has no attribute {name!r}') 253 254 def __eq__(self, other): 255 if type(self) is not type(other): 256 return False 257 258 if self._type_spec != other._type_spec: 259 return False 260 261 self_tensors = nest.flatten(self, expand_composites=True) 262 other_tensors = nest.flatten(other, expand_composites=True) 263 if len(self_tensors) != len(other_tensors): 264 return False 265 conditions = [] 266 for t1, t2 in zip(self_tensors, other_tensors): 267 conditions.append( 268 math_ops.reduce_all( 269 gen_math_ops.equal( 270 array_ops.shape(t1), 271 array_ops.shape(t2), 272 incompatible_shape_error=False))) 273 # Explicitly check shape (values that have different shapes but broadcast 274 # to the same value are considered non-equal). 275 conditions.append( 276 math_ops.reduce_all( 277 gen_math_ops.equal(t1, t2, incompatible_shape_error=False))) 278 return math_ops.reduce_all(array_ops.stack(conditions)) 279 280 def __ne__(self, other): 281 eq = self.__eq__(other) 282 if isinstance(eq, ops.Tensor): 283 return math_ops.logical_not(eq) 284 else: 285 return not eq 286 287 def __validate__(self): 288 """Perform post-construction validation.""" 289 290 # This instance variable is used to cache the value for the _type_spec 291 # property. 292 _tf_extension_type_cached_type_spec = None 293 294 @property 295 def _type_spec(self): # CompositeTensor API. 296 # Note: the TypeSpec contains all static (non-tensor) data from `self`. 297 if self._tf_extension_type_cached_type_spec is None: 298 assert not is_packed(self) # Packed version always caches TypeSpec. 299 self.__dict__[ 300 '_tf_extension_type_cached_type_spec'] = self.Spec.from_value(self) 301 return self._tf_extension_type_cached_type_spec 302 303 304def pack(value): 305 """Returns a copy of `value` with fields packed in a single Variant. 306 307 Args: 308 value: An `ExtensionType` object. 309 310 Returns: 311 An `ExtensionType` object. 312 """ 313 if is_packed(value): 314 return value 315 316 spec = value._type_spec._tf_extension_type_with_packed(True) # pylint: disable=protected-access 317 try: 318 variant = composite_tensor_ops.composite_tensor_to_variants(value) 319 except nested_structure_coder.NotEncodableError as e: 320 # Note: the only time `_TypeSpecCodec.can_encode` returns False is if the 321 # named type is not registered. The default error message would simply 322 # tell the user that there is no encoder for the object, so we provide 323 # a more useful message letting them know how to register the type. 324 raise ValueError('ExtensionTypes must have a __name__ field in order ' 325 'to be packed.') from e 326 327 return _create_object_from_type_and_dict( 328 type(value), { 329 '_tf_extension_type_cached_type_spec': spec, 330 '_tf_extension_type_packed_variant': variant, 331 }) 332 333 334def unpack(value): 335 """Returns a copy of `value` with individual fields stored in __dict__. 336 337 Args: 338 value: An `ExtensionType` object. 339 340 Returns: 341 An `ExtensionType` object. 342 """ 343 if not is_packed(value): 344 return value 345 346 # pylint: disable=protected-access 347 variant = value._tf_extension_type_packed_variant 348 spec = value._tf_extension_type_cached_type_spec 349 spec = spec._tf_extension_type_with_packed(False) 350 return composite_tensor_ops.composite_tensor_from_variant(variant, spec) 351 352 353def is_packed(value): 354 """Returns true if `value`'s fields are packed in a single Variant.""" 355 if not isinstance(value, ExtensionType): 356 raise ValueError(f'Expected `value` to be an object of type ExtensionType,' 357 f'got an instance of {type(value)}.') 358 return '_tf_extension_type_packed_variant' in value.__dict__ 359 360 361# ============================================================================== 362# Base class for the tf.ExtensionType TypeSpecs 363# ============================================================================== 364# TODO(b/184565242) Support customizing type relaxation for tracing. 365# TODO(b/184565242) Support conversion to/from FullType. 366# TODO(b/195884675) Support batch and unbatch. 367 368 369class ExtensionTypeSpec(type_spec.TypeSpec): 370 """Base class for tf.ExtensionType TypeSpec.""" 371 372 def _serialize(self): # TypeSpec API. 373 # Use a tuple of (name, value) pairs, to ensure we preserve field ordering. 374 fields = [f.name for f in self._tf_extension_type_fields()] 375 if self._tf_extension_type_is_packed: 376 fields.append('_tf_extension_type_is_packed') 377 return tuple( 378 (f, _change_nested_mappings_to(self.__dict__[f], dict)) for f in fields) 379 380 @classmethod 381 def _deserialize(cls, state): # TypeSpec API. 382 state = _change_nested_mappings_to(state, immutable_dict.ImmutableDict) 383 return _create_object_from_type_and_dict(cls, state) 384 385 def __reduce__(self): 386 # Use value_type instead of spec_type, as spec_type is a nested class. 387 # Pickle support of nested class requries Pickle protocol version 4, which 388 # is not enabled by default until py 3.8. 389 # 390 # https://www.python.org/dev/peps/pep-3154/#serializing-more-lookupable-objects 391 # https://docs.python.org/3/library/pickle.html#pickle.DEFAULT_PROTOCOL 392 return _deserialize_for_reduce, (self.value_type, self._serialize()) 393 394 def _to_components(self, value): # TypeSpec API. 395 if self._tf_extension_type_is_packed: 396 return value._tf_extension_type_packed_variant # pylint: disable=protected-access 397 398 tensor_or_composite = (ops.Tensor, composite_tensor.CompositeTensor) 399 # Retireve fields by the order of spec dict to preserve field ordering. This 400 # is needed as nest.flatten would sort dictionary entries by key. 401 value_tuple = tuple(value.__dict__[key] for key in self.__dict__) 402 return tuple( 403 x for x in nest.flatten(value_tuple) 404 if isinstance(x, tensor_or_composite)) 405 406 def _from_components(self, components): # TypeSpec API. 407 if self._tf_extension_type_is_packed: 408 return _create_object_from_type_and_dict( 409 self.value_type, { 410 '_tf_extension_type_cached_type_spec': self, 411 '_tf_extension_type_packed_variant': components 412 }) 413 414 spec_tuple = tuple(self.__dict__.values()) 415 components_iter = iter(components) 416 flat = [ 417 next(components_iter) if isinstance(x, type_spec.TypeSpec) else x 418 for x in nest.flatten(spec_tuple) 419 ] 420 if list(components_iter): 421 raise ValueError( 422 'Cannot build an ExtensionType instance from components ' 423 'because more components are provided than the number expected ' 424 'by the type spec.') 425 value_tuple = nest.pack_sequence_as(spec_tuple, flat) 426 fields = dict(zip(self.__dict__.keys(), value_tuple)) 427 428 # Build the new value. Bypass the constructor (__init__), in case the user 429 # who defined the ExtensionType used a custom constructor. 430 return _create_object_from_type_and_dict(self.value_type, fields) 431 432 @property 433 def _component_specs(self): # TypeSpec API. 434 if self._tf_extension_type_is_packed: 435 return tensor_spec.TensorSpec((), dtypes.variant) 436 437 components = [] 438 439 def push_if_type_spec(x): 440 if isinstance(x, type_spec.TypeSpec): 441 components.append(x) 442 443 nest.map_structure(push_if_type_spec, tuple(self.__dict__.values())) 444 return tuple(components) 445 446 @classmethod 447 def from_value(cls, value): 448 cached_spec = getattr(value, '_tf_extension_type_cached_type_spec', None) 449 if cached_spec is not None: 450 return cached_spec 451 452 value_fields = value.__dict__ 453 spec_fields = nest.map_structure(_replace_tensor_with_spec, value_fields) 454 spec_fields.pop('_tf_extension_type_cached_fields', None) 455 return _create_object_from_type_and_dict(cls, spec_fields) 456 457 def __setattr__(self, name, value): 458 if (hasattr(self, _IN_CONSTRUCTOR) and 459 self._tf_extension_type_has_field(name)): 460 self.__dict__[name] = value 461 else: 462 raise AttributeError( 463 f'Cannot mutate attribute `{name}` ' 464 f'outside the custom constructor of ExtensionTypeSpec.') 465 466 def __delattr__(self, name): 467 if (hasattr(self, _IN_CONSTRUCTOR) and 468 self._tf_extension_type_has_field(name)): 469 del self.__dict__[name] 470 else: 471 raise AttributeError( 472 f'Cannot mutate attribute `{name}` ' 473 f'outside the custom constructor of ExtensionTypeSpec.') 474 475 def __validate__(self): 476 """Perform post-construction validation.""" 477 478 @classmethod 479 def _tf_extension_type_fields(cls): 480 return cls.value_type._tf_extension_type_fields() # pylint: disable=protected-access 481 482 @classmethod 483 def _tf_extension_type_has_field(cls, name): 484 return any(name == field.name for field in cls._tf_extension_type_fields()) 485 486 def _tf_extension_type_convert_fields(self): 487 extension_type_field.convert_fields_for_spec( 488 self._tf_extension_type_fields(), self.__dict__) 489 490 def __repr__(self): 491 fields = ', '.join([f'{k}={v!r}' for (k, v) in self._serialize()]) 492 return f'{type(self).__qualname__}({fields})' 493 494 _tf_extension_type_is_packed = False 495 496 def _tf_extension_type_with_packed(self, value): 497 """Returns a copy of this `TypeSpec` with `packed=value`. 498 499 Args: 500 value: A boolean value. 501 502 Returns: 503 A copy of `self` with `_tf_extension_type_is_packed=value`. 504 """ 505 copy = _create_object_from_type_and_dict(type(self), self.__dict__) 506 copy.__dict__['_tf_extension_type_is_packed'] = value 507 return copy 508 509 510@tf_export('experimental.ExtensionTypeBatchEncoder') 511class ExtensionTypeBatchEncoder(type_spec.TypeSpecBatchEncoder): 512 """Class used to encode and decode extension type values for batching. 513 514 In order to be batched and unbatched by APIs such as `tf.data.Dataset`, 515 `tf.keras`, and `tf.map_fn`, extension type values must be encoded as a list 516 of `tf.Tensor`s, where stacking, unstacking, or concatenating these encoded 517 tensors and then decoding the result must be equivalent to stacking, 518 unstacking, or concatenating the original values. `ExtensionTypeBatchEncoder`s 519 are responsible for implementing this encoding. 520 521 The default `ExtensionTypeBatchEncoder` that is used by 522 `BatchableExtensionType` assumes that extension type values can be stacked, 523 unstacked, or concatenated by simply stacking, unstacking, or concatenating 524 every nested `Tensor`, `ExtensionType`, `CompositeTensor`, and `TensorShape` 525 field. 526 527 Extension types where this is not the case will need to override 528 `__batch_encoder__` with a custom encoder that overrides the `batch`, 529 `unbatch`, `encode`, and `decode` methods. E.g.: 530 531 >>> class CustomBatchEncoder(ExtensionTypeBatchEncoder): 532 ... pass # Override batch(), unbatch(), encode(), and decode(). 533 534 >>> class CustomType(BatchableExtensionType): 535 ... x: tf.Tensor 536 ... y: tf.Tensor 537 ... shape: tf.TensorShape 538 ... __batch_encoder__ = CustomBatchEncoder() 539 540 For example, `tf.RaggedTensor` and `tf.SparseTensor` both use custom batch 541 encodings which define ops to "box" and "unbox" individual values into 542 `tf.variant` tensors. 543 """ 544 545 def batch(self, spec, batch_size): 546 """Returns the TypeSpec representing a batch of values described by `spec`. 547 548 The default definition returns a `TypeSpec` that is equal to `spec`, except 549 that an outer axis with size `batch_size` is added to every nested 550 `TypeSpec` and `TensorShape` field. Subclasses may override this default 551 definition, when necessary. 552 553 Args: 554 spec: The `TypeSpec` for an individual value. 555 batch_size: An `int` indicating the number of values that are batched 556 together, or `None` if the batch size is not known. 557 558 Returns: 559 A `TypeSpec` for a batch of values. 560 """ 561 562 def batch_field(f): 563 if isinstance(f, type_spec.BatchableTypeSpec): 564 return f.__batch_encoder__.batch(f, batch_size) 565 elif isinstance(f, tensor_shape.TensorShape): 566 return [batch_size] + f 567 else: 568 return f 569 570 fields = tuple(spec.__dict__.items()) 571 batched_fields = nest.map_structure(batch_field, fields) 572 return _create_object_from_type_and_dict(type(spec), batched_fields) 573 574 def unbatch(self, spec): 575 """Returns the TypeSpec for a single unbatched element in `spec`. 576 577 The default definition returns a `TypeSpec` that is equal to `spec`, except 578 that the outermost axis is removed from every nested `TypeSpec`, and 579 `TensorShape` field. Subclasses may override this default definition, when 580 necessary. 581 582 Args: 583 spec: The `TypeSpec` for a batch of values. 584 585 Returns: 586 A `TypeSpec` for an individual value. 587 """ 588 589 def unbatch_field(f): 590 if isinstance(f, type_spec.BatchableTypeSpec): 591 return f.__batch_encoder__.unbatch(f) 592 elif isinstance(f, tensor_shape.TensorShape): 593 return f[1:] 594 else: 595 return f 596 597 fields = tuple(spec.__dict__.items()) 598 unbatched_fields = nest.map_structure(unbatch_field, fields) 599 return _create_object_from_type_and_dict(type(spec), unbatched_fields) 600 601 def encode(self, spec, value, minimum_rank=0): 602 """Encodes `value` as a nest of batchable Tensors or CompositeTensors. 603 604 The default definition returns a flat tuple of all the `Tensor`s, 605 `CompositeTensor`s, and `ExtensionType`s from a depth-first traversal of 606 `value`'s fields. Subclasses may override this default definition, when 607 necessary. 608 609 Args: 610 spec: The TypeSpec of the value to encode. 611 value: A value compatible with `spec`. 612 minimum_rank: The minimum rank for the returned Tensors, CompositeTensors, 613 and ExtensionType values. This can be used to ensure that the encoded 614 values can be unbatched this number of times. If `minimum_rank>0`, 615 then `t.shape[:minimum_rank]` must be compatible for all values `t` 616 returned by `encode`. 617 618 Returns: 619 A nest (as defined by `tf.nest`) of `tf.Tensor`s, batchable 620 `tf.CompositeTensor`s, or `tf.ExtensionType`s. Stacking, unstacking, or 621 concatenating these encoded values and then decoding the result must be 622 equivalent to stacking, unstacking, or concatenating the original values. 623 """ 624 return spec._to_components(value) # pylint: disable=protected-access 625 626 def decode(self, spec, encoded_value): 627 """Decodes `value` from a batchable tensor encoding. 628 629 See `encode` for a description of the default encoding. Subclasses may 630 override this default definition, when necessary. 631 632 Args: 633 spec: The TypeSpec for the result value. If encoded values with spec `s` 634 were batched, then `spec` should be `s.batch(batch_size)`; or if encoded 635 values with spec `s` were unbatched, then `spec` should be 636 `s.unbatch()`. 637 encoded_value: A nest of values returned by `encode`; or a nest of 638 values that was formed by stacking, unstacking, or concatenating the 639 corresponding elements of values returned by `encode`. 640 641 Returns: 642 A value compatible with `type_spec`. 643 """ 644 return spec._from_components(encoded_value) # pylint: disable=protected-access 645 646 def encoding_specs(self, spec): 647 """Returns a list of `TensorSpec`(s) describing the encoding for `spec`. 648 649 See `encode` for a description of the default encoding. Subclasses may 650 override this default definition, when necessary. 651 652 Args: 653 spec: The TypeSpec whose encoding should be described. 654 655 Returns: 656 A nest (as defined by `tf.nest) of `tf.TypeSpec`, describing the values 657 that are returned by `self.encode(spec, ...)`. All TypeSpecs in this 658 nest must be batchable. 659 """ 660 return spec._component_specs # pylint: disable=protected-access 661 662 663class BatchableExtensionTypeSpec(ExtensionTypeSpec, 664 type_spec.BatchableTypeSpec): 665 """Base class for TypeSpecs for BatchableExtensionTypes.""" 666 667 __batch_encoder__ = ExtensionTypeBatchEncoder() 668 669 def _batch(self, batch_size): 670 return self.__batch_encoder__.batch(self, batch_size) 671 672 def _unbatch(self): 673 return self.__batch_encoder__.unbatch(self) 674 675 def _to_tensor_list(self, value): 676 return type_spec.batchable_to_tensor_list(self, value) 677 678 def _to_batched_tensor_list(self, value): 679 return type_spec.batchable_to_tensor_list(self, value, minimum_rank=1) 680 681 def _from_compatible_tensor_list(self, tensor_list): 682 return type_spec.batchable_from_tensor_list(self, tensor_list) 683 684 @property 685 def _flat_tensor_specs(self): 686 return type_spec.get_batchable_flat_tensor_specs(self) 687 688 689@tf_export('experimental.BatchableExtensionType') 690class BatchableExtensionType(ExtensionType): 691 """An ExtensionType that can be batched and unbatched. 692 693 `BatchableExtensionType`s can be used with APIs that require batching or 694 unbatching, including `Keras`, `tf.data.Dataset`, and `tf.map_fn`. E.g.: 695 696 >>> class Vehicle(tf.experimental.BatchableExtensionType): 697 ... top_speed: tf.Tensor 698 ... mpg: tf.Tensor 699 >>> batch = Vehicle([120, 150, 80], [30, 40, 12]) 700 >>> tf.map_fn(lambda vehicle: vehicle.top_speed * vehicle.mpg, batch, 701 ... fn_output_signature=tf.int32).numpy() 702 array([3600, 6000, 960], dtype=int32) 703 704 An `ExtensionTypeBatchEncoder` is used by these APIs to encode `ExtensionType` 705 values. The default encoder assumes that values can be stacked, unstacked, or 706 concatenated by simply stacking, unstacking, or concatenating every nested 707 `Tensor`, `ExtensionType`, `CompositeTensor`, or `TensorShape` field. 708 Extension types where this is not the case will need to override 709 `__batch_encoder__` with a custom `ExtensionTypeBatchEncoder`. See 710 `tf.experimental.ExtensionTypeBatchEncoder` for more details. 711 """ 712 # Let the metaclass know that it should *not* transform this class (since 713 # this class is part of the ExtensionType framework, and not a user class). 714 _tf_extension_type_do_not_transform_this_class = True 715 716 717# For Pickle __reduce__ protocol: 718def _deserialize_for_reduce(value_type, serialization): 719 return value_type.Spec._deserialize(serialization) # pylint: disable=protected-access 720 721 722def _replace_tensor_with_spec(value): 723 if isinstance(value, ops.Tensor): 724 # Note: we intentionally exclude `value.name` from the `TensorSpec`. 725 return tensor_spec.TensorSpec(value.shape, value.dtype) 726 if hasattr(value, '_type_spec'): 727 return value._type_spec # pylint: disable=protected-access 728 return value 729 730 731def _change_nested_mappings_to(value, new_type): 732 """Recursively replace mappings with `new_type`.""" 733 if isinstance(value, (dict, immutable_dict.ImmutableDict)): 734 return new_type([(k, _change_nested_mappings_to(v, new_type)) 735 for (k, v) in value.items()]) 736 elif isinstance(value, tuple): 737 return tuple(_change_nested_mappings_to(elt, new_type) for elt in value) 738 else: 739 return value 740 741 742# ============================================================================== 743# Helper methods for tf.ExtensionTypeMetaclass 744# ============================================================================== 745 746 747def _check_field_annotations(cls): 748 """Validates the field annotations for tf.ExtensionType subclass `cls`.""" 749 annotations = getattr(cls, '__annotations__', {}) 750 751 # Check that no fields use reserved names. 752 for name, value in cls.__dict__.items(): 753 if name == 'Spec': 754 if not isinstance(value, type): 755 raise ValueError(f'{cls.__qualname__}.Spec must be a nested class; ' 756 f'got {value}.') 757 if (value.__bases__ != (type_spec.TypeSpec,) and value.__bases__ != 758 (object,)): 759 raise ValueError(f'{cls.__qualname__}.Spec must be directly subclassed ' 760 'from tf.TypeSpec.') 761 elif extension_type_field.ExtensionTypeField.is_reserved_name(name): 762 raise ValueError(f'The field annotations for {cls.__name__} are ' 763 f"invalid. Field '{name}' is reserved.") 764 for name in annotations: 765 if extension_type_field.ExtensionTypeField.is_reserved_name(name): 766 raise ValueError(f'The field annotations for {cls.__name__} are ' 767 f"invalid. Field '{name}' is reserved.") 768 769 # Check that all fields have type annotaitons. 770 for (key, value) in cls.__dict__.items(): 771 if not (key in annotations or callable(value) or key.startswith('_abc_') or 772 key == '_tf_extension_type_fields' or 773 key.startswith('__') and key.endswith('__') or 774 isinstance(value, (property, classmethod, staticmethod))): 775 raise ValueError(f'The field annotations for {cls.__name__} are ' 776 f'invalid. Field {key} is missing a type annotation.') 777 778 779def _add_extension_type_constructor(cls): 780 """Creates a constructor for a ExtensionType or ExtensionTypeSpec subclass.""" 781 if '__init__' in cls.__dict__: 782 _wrap_user_constructor(cls) 783 else: 784 _build_extension_type_constructor(cls) 785 786 787def _wrap_user_constructor(cls): 788 """Wraps a user-defined constructor for tf.ExtensionType subclass `cls`.""" 789 user_constructor = cls.__init__ 790 791 def wrapped_init(self, *args, **kwargs): 792 self.__dict__[_IN_CONSTRUCTOR] = True 793 user_constructor(self, *args, **kwargs) 794 del self.__dict__[_IN_CONSTRUCTOR] 795 796 self._tf_extension_type_convert_fields() # pylint: disable=protected-access 797 self.__validate__() 798 799 cls.__init__ = tf_decorator.make_decorator(user_constructor, wrapped_init) 800 801 802_NO_DEFAULT = extension_type_field.ExtensionTypeField.NO_DEFAULT 803 804 805# TODO(b/184565242) Consider using the templating system from autograph here. 806def _build_extension_type_constructor(cls): 807 """Builds a constructor for tf.ExtensionType subclass `cls`.""" 808 fields = cls._tf_extension_type_fields() # pylint: disable=protected-access 809 810 # Mark any no-default fields that follow default fields as keyword_only. 811 got_default = False 812 keyword_only_start = len(fields) 813 for i in range(len(fields)): 814 if got_default: 815 if fields[i].default is _NO_DEFAULT: 816 keyword_only_start = i 817 break 818 elif fields[i].default is not _NO_DEFAULT: 819 got_default = True 820 821 params = [] 822 for i, field in enumerate(fields): 823 if i < keyword_only_start: 824 kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD 825 else: 826 kind = tf_inspect.Parameter.KEYWORD_ONLY 827 if field.default is _NO_DEFAULT: 828 default = tf_inspect.Parameter.empty 829 else: 830 default = field.default 831 params.append( 832 tf_inspect.Parameter( 833 field.name, kind, default=default, annotation=field.value_type)) 834 835 signature = tf_inspect.Signature(params, return_annotation=cls.__name__) 836 837 def __init__(self, *args, **kwargs): # pylint: disable=invalid-name 838 bound_args = signature.bind(*args, **kwargs) 839 bound_args.apply_defaults() 840 self.__dict__.update(bound_args.arguments) 841 self._tf_extension_type_convert_fields() # pylint: disable=protected-access 842 self.__validate__() 843 844 # __signature__ is supported by some inspection/documentation tools 845 # (but note: typing.get_type_hints does not respect __signature__). 846 __init__.__signature__ = tf_inspect.Signature( 847 [ 848 tf_inspect.Parameter('self', 849 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD) 850 ] + params, 851 return_annotation=cls) 852 853 cls.__init__ = __init__ 854 855 856def _build_spec_constructor(cls): 857 """Builds a constructor for ExtensionTypeSpec subclass `cls`.""" 858 params = [] 859 kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD 860 for field in cls._tf_extension_type_fields(): # pylint: disable=protected-access 861 params.append(tf_inspect.Parameter(field.name, kind)) 862 863 signature = tf_inspect.Signature(params, return_annotation=cls.__name__) 864 865 def __init__(self, *args, **kwargs): # pylint: disable=invalid-name 866 bound_args = signature.bind(*args, **kwargs) 867 bound_args.apply_defaults() 868 self.__dict__.update(bound_args.arguments) 869 self._tf_extension_type_convert_fields() # pylint: disable=protected-access 870 self.__validate__() 871 872 # __signature__ is supported by some inspection/documentation tools. 873 __init__.__signature__ = tf_inspect.Signature( 874 [ 875 tf_inspect.Parameter('self', 876 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD) 877 ] + params, 878 return_annotation=cls) 879 880 cls.__init__ = __init__ 881 882 883def _add_type_spec(cls): 884 """Creates a nested TypeSpec class for tf.ExtensionType subclass `cls`.""" 885 spec_name = cls.__name__ + '.Spec' 886 spec_qualname = cls.__qualname__ + '.Spec' 887 888 # Set __module__ explicitly as a dynamic created class has module='abc' 889 # by default. 890 spec_dict = {'value_type': cls, '__module__': cls.__module__} 891 892 # Copy user-supplied customizations into the TypeSpec. 893 user_spec = cls.__dict__.get('Spec', None) 894 if user_spec is not None: 895 for (name, value) in user_spec.__dict__.items(): 896 if extension_type_field.ExtensionTypeField.is_reserved_name(name): 897 raise ValueError(f'TypeSpec {spec_qualname} uses reserved ' 898 f"name '{name}'.") 899 if cls._tf_extension_type_has_field(name): # pylint: disable=protected-access 900 raise ValueError(f"TypeSpec {spec_qualname} defines a variable '{name}'" 901 f' which shadows a field in {cls.__qualname__}') 902 if name in ('__module__', '__dict__', '__weakref__'): 903 continue 904 905 spec_dict[name] = value 906 907 if issubclass(cls, BatchableExtensionType): 908 type_spec_base = BatchableExtensionTypeSpec 909 if hasattr(cls, 910 '__batch_encoder__') and '__batch_encoder__' not in spec_dict: 911 spec_dict['__batch_encoder__'] = cls.__batch_encoder__ 912 else: 913 type_spec_base = ExtensionTypeSpec 914 if hasattr(cls, '__batch_encoder__') or '__batch_encoder__' in spec_dict: 915 raise ValueError('__batch_encoder__ should only be defined for ' 916 'BatchableExtensionType classes.') 917 918 # Build the TypeSpec and store it as a nested class inside `cls`. 919 spec = type(spec_name, (type_spec_base,), spec_dict) 920 spec.__qualname__ = spec_qualname 921 setattr(cls, 'Spec', spec) 922 923 # Build a constructor for the TypeSpec class. 924 if '__init__' in spec.__dict__: 925 _wrap_user_constructor(spec) 926 else: 927 _build_spec_constructor(spec) 928 929 cls.__abstractmethods__ -= {'_type_spec'} 930 931 # If the user included an explicit `__name__` attribute, then use that to 932 # register the TypeSpec (so it can be used in SavedModel signatures). 933 if '__name__' in cls.__dict__: 934 type_spec.register(cls.__dict__['__name__'] + '.Spec')(spec) 935 936 937# ============================================================================== 938# Anonymous ExtensionType 939# ============================================================================== 940class AnonymousExtensionType(ExtensionType): 941 """Fallback used to decode `tf.ExtensionType` when the original type is unavailable. 942 943 When a SavedModel is serialized, the signatures of any functions in the 944 SavedModel can include `tf.ExtensionType` subclasses. These subclasses are 945 usually 946 registered, so they can be restored when the SavedModel is loaded. However, 947 if a SavedModel is loaded without first registering the ExtensionType types in 948 its 949 signature, then the SavedModel will fall back to using the 950 `AnonymousExtensionType` 951 type instead. 952 953 If necessary, `AnonymousExtensionType` objects can be converted to a concrete 954 `tf.ExtensionType` subclass (and vice versa) using `reinterpret`. 955 """ 956 957 # Let the metaclass know that it should *not* transform this class (since 958 # this class is part of the ExtensionType framework, and not a user class). 959 _tf_extension_type_do_not_transform_this_class = True 960 961 def __init__(self, **fields): 962 for name in fields: 963 if (extension_type_field.ExtensionTypeField.is_reserved_name(name) or 964 (name.startswith('__') and name.endswith('__'))): 965 raise ValueError( 966 f'Reserved field name {name} was encountered ' 967 f'when trying to instantiate an AnonymousExtensionType.') 968 fields = [(k, _convert_anonymous_fields(v)) for (k, v) in fields.items()] 969 self.__dict__.update(fields) 970 self._tf_extension_type_convert_fields() 971 super().__init__() 972 973 @classmethod 974 def _tf_extension_type_fields(cls): 975 return [ 976 extension_type_field.ExtensionTypeField(name, None) 977 for name in cls.__dict__ 978 if not extension_type_field.ExtensionTypeField.is_reserved_name(name) 979 ] 980 981 def __setattr__(self, name, value): 982 raise AttributeError(f'Cannot set attribute `{name}`. ' 983 f'AnonymousExtensionType instances are immutable.') 984 985 def __delattr__(self, name): 986 raise AttributeError(f'Cannot delete attribute `{name}`. ' 987 f'AnonymousExtensionType instances are immutable.') 988 989 def _tf_extension_type_convert_fields(self): 990 fields = [(k, _convert_anonymous_fields(v)) 991 for (k, v) in self.__dict__.items() 992 if not extension_type_field.ExtensionTypeField.is_reserved_name(k) 993 ] 994 self.__dict__.update(fields) 995 996 def __repr__(self): 997 fields = [ 998 f'{k}={v!r}' for (k, v) in self.__dict__.items() 999 if not extension_type_field.ExtensionTypeField.is_reserved_name(k) 1000 ] 1001 return f'AnonymousExtensionType({", ".join(fields)})' 1002 1003 _tf_extension_type_cached_type_spec = None 1004 1005 @property 1006 def _type_spec(self): # CompositeTensor API. 1007 # Note: the TypeSpec contains all static (non-tensor) data from `self`. 1008 if self._tf_extension_type_cached_type_spec is None: 1009 spec = AnonymousExtensionTypeSpec.from_value(self) 1010 self.__dict__['_tf_extension_type_cached_type_spec'] = spec 1011 return self._tf_extension_type_cached_type_spec 1012 1013 1014@type_spec.register('tf.AnonymousExtensionType.Spec') 1015class AnonymousExtensionTypeSpec(ExtensionTypeSpec): 1016 """TypeSpec for AnonymousExtensionType.""" 1017 1018 def __init__(self, **fields): 1019 for name in fields: 1020 if (extension_type_field.ExtensionTypeField.is_reserved_name(name) or 1021 (name.startswith('__') and name.endswith('__'))): 1022 raise ValueError( 1023 f'Reserved field name {name} was encountered ' 1024 f'when trying to instantiate an AnonymousExtensionTypeSpec.') 1025 fields = [(k, _convert_anonymous_fields(v, for_spec=True)) 1026 for (k, v) in fields.items()] 1027 self.__dict__.update(fields) 1028 super().__init__() 1029 1030 value_type = AnonymousExtensionType # TypeSpec API. 1031 1032 def _serialize(self): # TypeSpec API. 1033 return tuple( 1034 (name, _change_nested_mappings_to(value, dict)) 1035 for (name, value) in self.__dict__.items() 1036 if not extension_type_field.ExtensionTypeField.is_reserved_name(name)) 1037 1038 def __setattr__(self, name, value): 1039 raise AttributeError(f'Cannot set attribute `{name}`. ' 1040 f'AnonymousExtensionTypeSpec instances are immutable.') 1041 1042 def __delattr__(self, name): 1043 raise AttributeError(f'Cannot delete attribute `{name}`. ' 1044 f'AnonymousExtensionTypeSpec instances are immutable.') 1045 1046 1047def _convert_anonymous_fields(value, for_spec=False): 1048 """Type-checks and converts `value` for inclusion in an AnonymousExtensionType.""" 1049 if isinstance(value, (int, float, bool, str, bytes, type(None), dtypes.DType, 1050 tensor_shape.TensorShape)): 1051 return value 1052 1053 if isinstance(value, tuple): 1054 return tuple(_convert_anonymous_fields(v, for_spec) for v in value) 1055 1056 if isinstance(value, typing.Mapping): 1057 return immutable_dict.ImmutableDict([ 1058 (_convert_anonymous_fields(k, for_spec), 1059 _convert_anonymous_fields(v, for_spec)) for (k, v) in value.items() 1060 ]) 1061 1062 if (isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)) and 1063 not for_spec): 1064 return value 1065 1066 if isinstance(value, type_spec.TypeSpec) and for_spec: 1067 return value 1068 1069 raise ValueError(f'Cannot convert anonymous fields from ' 1070 f'an unsupported `value` argument: {value!r}.') 1071 1072 1073# ============================================================================== 1074# reinterpret 1075# ============================================================================== 1076def reinterpret(value, new_type): 1077 """Converts a given `ExtensionType` to a new type with compatible fields. 1078 1079 In particular, this can be used to convert a concrete subclass of 1080 `ExtensionType` to an `AnonymousExtensionType`, or vice versa. When 1081 converting to a non-anonymous ExtensionType, field values are type-checked to 1082 ensure they are consistent with `new_type`'s type annotations, and validated 1083 with `new_type.__validate__`. 1084 1085 Args: 1086 value: An instance of a subclass of `tf.ExtensionType` 1087 new_type: A subclass of `tf.ExtensionType` 1088 1089 Returns: 1090 An instance of `new_type`, whose fields are copied from `value`. 1091 """ 1092 if not isinstance(value, ExtensionType): 1093 raise ValueError( 1094 f'reinterpret expects `value` to be a tf.ExtensionType instance; ' 1095 f'got {value!r}') 1096 if not (isinstance(new_type, type) and issubclass(new_type, ExtensionType)): 1097 raise ValueError( 1098 f'reinterpret expects `new_type` to be a subclass of tf.ExtensionType; ' 1099 f'got {new_type!r}') 1100 1101 fields = [ 1102 item for item in value.__dict__.items() 1103 if not extension_type_field.ExtensionTypeField.is_reserved_name(item[0]) 1104 ] 1105 new_value = _create_object_from_type_and_dict(new_type, fields) 1106 new_value._tf_extension_type_convert_fields() # pylint: disable=protected-access 1107 new_value.__validate__() 1108 return new_value 1109