xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/extension_type.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""User-defined ExtensionType classes."""
16
17import abc
18import typing
19import typing_extensions
20
21from tensorflow.python.framework import composite_tensor
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import extension_type_field
24from tensorflow.python.framework import immutable_dict
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.framework import type_spec
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import composite_tensor_ops
31from tensorflow.python.ops import gen_math_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.saved_model import nested_structure_coder
34from tensorflow.python.util import nest
35from tensorflow.python.util import tf_decorator
36from tensorflow.python.util import tf_inspect
37from tensorflow.python.util.tf_export import tf_export
38
39# Attribute used to keep track of when we're inside a user-defined constructor
40# (in which case the fields of `self` may be modified).
41_IN_CONSTRUCTOR = '_tf_extension_type_in_constructor'
42
43_MUTABLE_KERAS_PROPERTIES = [
44    # Keras uses _keras_mask property to pass the mask around
45    '_keras_mask',
46]
47
48
49# ==============================================================================
50# Utility functions
51# ==============================================================================
52def _create_object_from_type_and_dict(cls, obj_dict):
53  """Creates an object, bypassing the constructor.
54
55  Creates an object of type `cls`, whose `__dict__` is updated to contain
56  `obj_dict`.
57
58  Args:
59    cls: The type of the new object.
60    obj_dict: A `Mapping` that should be used to initialize the new object's
61      `__dict__`.
62
63  Returns:
64    An object of type `cls`.
65  """
66  value = object.__new__(cls)
67  value.__dict__.update(obj_dict)
68  return value
69
70
71# ==============================================================================
72# Metaclass for tf.ExtensionType
73# ==============================================================================
74class ExtensionTypeMetaclass(abc.ABCMeta):
75  """Metaclass for tf.ExtensionType types."""
76
77  def __init__(cls, name, bases, namespace):
78    # Don't transform base classes that are part of the framework -- only
79    # transform user classes.  We identify classes that are part of the
80    # framework by setting '_tf_extension_type_do_not_transform_this_class=True'
81    # in the class definition.  (Note: we check for this in the class namespace,
82    # so it is *not* ineherited.)
83    if not namespace.get('_tf_extension_type_do_not_transform_this_class',
84                         False):
85      _check_field_annotations(cls)
86      _add_extension_type_constructor(cls)
87      _add_type_spec(cls)
88    super(ExtensionTypeMetaclass, cls).__init__(name, bases, namespace)
89
90
91# ==============================================================================
92# Base class for user-defined types
93# ==============================================================================
94@tf_export('experimental.ExtensionType')
95class ExtensionType(
96    composite_tensor.CompositeTensor, metaclass=ExtensionTypeMetaclass):
97  """Base class for TensorFlow `ExtensionType` classes.
98
99  Tensorflow `ExtensionType` classes are specialized Python classes that can be
100  used transparently with TensorFlow -- e.g., they can be used with ops
101  such as `tf.cond` or `tf.while_loop` and used as inputs or outputs for
102  `tf.function` and Keras layers.
103
104  New `ExtensionType` classes are defined by creating a subclass of
105  `tf.ExtensionType` that
106  contains type annotations for all instance variables.  The following type
107  annotations are supported:
108
109  Type                 | Example
110  -------------------- | --------------------------------------------
111  Python integers      | `i: int`
112  Python floats        | `f: float`
113  Python strings       | `s: str`
114  Python booleans      | `b: bool`
115  Python None          | `n: None`
116  Tensors              | `t: tf.Tensor`
117  Composite Tensors    | `rt: tf.RaggedTensor`
118  Extension Types      | `m: MyMaskedTensor`
119  Tensor shapes        | `shape: tf.TensorShape`
120  Tensor dtypes        | `dtype: tf.DType`
121  Type unions          | `length: typing.Union[int, float]`
122  Tuples               | `params: typing.Tuple[int, float, int, int]`
123  Tuples w/ Ellipsis   | `lengths: typing.Tuple[int, ...]`
124  Mappings             | `tags: typing.Mapping[str, str]`
125
126  Fields annotated with `typing.Mapping` will be stored using an immutable
127  mapping type.
128
129  ExtensionType values are immutable -- i.e., once constructed, you can not
130  modify or delete any of their instance members.
131
132  ### Examples
133
134  >>> class MaskedTensor(ExtensionType):
135  ...   values: tf.Tensor
136  ...   mask: tf.Tensor
137
138  >>> class Toy(ExtensionType):
139  ...   name: str
140  ...   price: ops.Tensor
141  ...   features: typing.Mapping[str, tf.Tensor]
142
143  >>> class ToyStore(ExtensionType):
144  ...   name: str
145  ...   toys: typing.Tuple[Toy, ...]
146  """
147
148  # Let the metaclass know that it should *not* transform this class (since
149  # this class is part of the ExtensionType framework, and not a user class).
150  _tf_extension_type_do_not_transform_this_class = True
151
152  def __init__(self, *args, **kwargs):
153    if type(self) is ExtensionType:  # pylint: disable=unidiomatic-typecheck
154      raise AssertionError('Cannot create an instance of ExtensionType '
155                           'because ExtensionType is an abstract base class.')
156
157  # This class variable is used to cache the return value for
158  # _tf_extension_type_fields.
159  _tf_extension_type_cached_fields = None
160
161  @classmethod
162  def _tf_extension_type_fields(cls):  # pylint: disable=no-self-argument
163    """An ordered list describing the fields of this ExtensionType.
164
165    Returns:
166      A list of `ExtensionTypeField` objects.  Forward references are resolved
167      if possible, or left unresolved otherwise.
168    """
169    if '_tf_extension_type_cached_fields' in cls.__dict__:  # do not inherit.
170      return cls._tf_extension_type_cached_fields
171
172    try:
173      # Using include_extras=False will replace all Annotated[T, ...] with T.
174      # The typing_extensions module is used since this is only supported in
175      # Python 3.9.
176      type_hints = typing_extensions.get_type_hints(cls, include_extras=False)
177      ok_to_cache = True  # all forward references have been resolved.
178    except (NameError, AttributeError):
179      # Unresolved forward reference -- gather type hints manually.
180      # * NameError comes from an annotation like `Foo` where class
181      #   `Foo` hasn't been defined yet.
182      # * AttributeError comes from an annotation like `foo.Bar`, where
183      #   the module `foo` exists but `Bar` hasn't been defined yet.
184      # Note: If a user attempts to instantiate a `ExtensionType` type that
185      # still has unresolved forward references (e.g., because of a typo or a
186      # missing import), then the constructor will raise an exception.
187      type_hints = {}
188      for base in reversed(cls.__mro__):
189        type_hints.update(base.__dict__.get('__annotations__', {}))
190      ok_to_cache = False
191
192    fields = []
193    for (name, value_type) in type_hints.items():
194      default = getattr(cls, name,
195                        extension_type_field.ExtensionTypeField.NO_DEFAULT)
196      fields.append(
197          extension_type_field.ExtensionTypeField(name, value_type, default))
198    fields = tuple(fields)
199
200    if ok_to_cache:
201      cls._tf_extension_type_cached_fields = fields
202
203    return fields
204
205  @classmethod
206  def _tf_extension_type_has_field(cls, name):
207    return any(name == field.name for field in cls._tf_extension_type_fields())
208
209  def _tf_extension_type_convert_fields(self):
210    extension_type_field.convert_fields(self._tf_extension_type_fields(),
211                                        self.__dict__)
212
213  def __repr__(self):
214    fields = ', '.join([
215        f'{field.name}={getattr(self, field.name)!r}'
216        for field in self._tf_extension_type_fields()
217    ])
218    return f'{type(self).__qualname__}({fields})'
219
220  def __setattr__(self, name, value):
221    if (name in _MUTABLE_KERAS_PROPERTIES or
222        (hasattr(self, _IN_CONSTRUCTOR) and
223         self._tf_extension_type_has_field(name))):
224      self.__dict__[name] = value
225    else:
226      raise AttributeError(f'Cannot mutate attribute `{name}` '
227                           f'outside the custom constructor of ExtensionType.')
228
229  def __delattr__(self, name):
230    if (name in _MUTABLE_KERAS_PROPERTIES or
231        (hasattr(self, _IN_CONSTRUCTOR) and
232         self._tf_extension_type_has_field(name))):
233      del self.__dict__[name]
234    else:
235      raise AttributeError(f'Cannot mutate attribute `{name}` '
236                           f'outside the custom constructor of ExtensionType.')
237
238  def __getattr__(self, name):
239    if name in _MUTABLE_KERAS_PROPERTIES:
240      return object.__getattribute__(self, name)
241    if '_tf_extension_type_packed_variant' in self.__dict__:
242      # Note: it's *not* ok to cache the results of unpack() here.  In
243      # particular, it would be nice if we could do something like
244      # `self.__dict__.update(unpack(self).__dict__)`, but that (potentially)
245      # violates an invariant required by the `cond` operation.  E.g., if we had
246      # `tf.cond(lambda: x.foo, lambda: x.bar)`, then tensor `x.bar` used in the
247      # "else" branch would be created by an op in the "then" branch (when
248      # looking up `x.foo`); and that's not allowed.
249      return getattr(unpack(self), name)
250
251    raise AttributeError(
252        f'{type(self).__name__!r} object has no attribute {name!r}')
253
254  def __eq__(self, other):
255    if type(self) is not type(other):
256      return False
257
258    if self._type_spec != other._type_spec:
259      return False
260
261    self_tensors = nest.flatten(self, expand_composites=True)
262    other_tensors = nest.flatten(other, expand_composites=True)
263    if len(self_tensors) != len(other_tensors):
264      return False
265    conditions = []
266    for t1, t2 in zip(self_tensors, other_tensors):
267      conditions.append(
268          math_ops.reduce_all(
269              gen_math_ops.equal(
270                  array_ops.shape(t1),
271                  array_ops.shape(t2),
272                  incompatible_shape_error=False)))
273      # Explicitly check shape (values that have different shapes but broadcast
274      # to the same value are considered non-equal).
275      conditions.append(
276          math_ops.reduce_all(
277              gen_math_ops.equal(t1, t2, incompatible_shape_error=False)))
278    return math_ops.reduce_all(array_ops.stack(conditions))
279
280  def __ne__(self, other):
281    eq = self.__eq__(other)
282    if isinstance(eq, ops.Tensor):
283      return math_ops.logical_not(eq)
284    else:
285      return not eq
286
287  def __validate__(self):
288    """Perform post-construction validation."""
289
290  # This instance variable is used to cache the value for the _type_spec
291  # property.
292  _tf_extension_type_cached_type_spec = None
293
294  @property
295  def _type_spec(self):  # CompositeTensor API.
296    # Note: the TypeSpec contains all static (non-tensor) data from `self`.
297    if self._tf_extension_type_cached_type_spec is None:
298      assert not is_packed(self)  # Packed version always caches TypeSpec.
299      self.__dict__[
300          '_tf_extension_type_cached_type_spec'] = self.Spec.from_value(self)
301    return self._tf_extension_type_cached_type_spec
302
303
304def pack(value):
305  """Returns a copy of `value` with fields packed in a single Variant.
306
307  Args:
308    value: An `ExtensionType` object.
309
310  Returns:
311    An `ExtensionType` object.
312  """
313  if is_packed(value):
314    return value
315
316  spec = value._type_spec._tf_extension_type_with_packed(True)  # pylint: disable=protected-access
317  try:
318    variant = composite_tensor_ops.composite_tensor_to_variants(value)
319  except nested_structure_coder.NotEncodableError as e:
320    # Note: the only time `_TypeSpecCodec.can_encode` returns False is if the
321    # named type is not registered.  The default error message would simply
322    # tell the user that there is no encoder for the object, so we provide
323    # a more useful message letting them know how to register the type.
324    raise ValueError('ExtensionTypes must have a __name__ field in order '
325                     'to be packed.') from e
326
327  return _create_object_from_type_and_dict(
328      type(value), {
329          '_tf_extension_type_cached_type_spec': spec,
330          '_tf_extension_type_packed_variant': variant,
331      })
332
333
334def unpack(value):
335  """Returns a copy of `value` with individual fields stored in __dict__.
336
337  Args:
338    value: An `ExtensionType` object.
339
340  Returns:
341    An `ExtensionType` object.
342  """
343  if not is_packed(value):
344    return value
345
346  # pylint: disable=protected-access
347  variant = value._tf_extension_type_packed_variant
348  spec = value._tf_extension_type_cached_type_spec
349  spec = spec._tf_extension_type_with_packed(False)
350  return composite_tensor_ops.composite_tensor_from_variant(variant, spec)
351
352
353def is_packed(value):
354  """Returns true if `value`'s fields are packed in a single Variant."""
355  if not isinstance(value, ExtensionType):
356    raise ValueError(f'Expected `value` to be an object of type ExtensionType,'
357                     f'got an instance of {type(value)}.')
358  return '_tf_extension_type_packed_variant' in value.__dict__
359
360
361# ==============================================================================
362# Base class for the tf.ExtensionType TypeSpecs
363# ==============================================================================
364# TODO(b/184565242) Support customizing type relaxation for tracing.
365# TODO(b/184565242) Support conversion to/from FullType.
366# TODO(b/195884675) Support batch and unbatch.
367
368
369class ExtensionTypeSpec(type_spec.TypeSpec):
370  """Base class for tf.ExtensionType TypeSpec."""
371
372  def _serialize(self):  # TypeSpec API.
373    # Use a tuple of (name, value) pairs, to ensure we preserve field ordering.
374    fields = [f.name for f in self._tf_extension_type_fields()]
375    if self._tf_extension_type_is_packed:
376      fields.append('_tf_extension_type_is_packed')
377    return tuple(
378        (f, _change_nested_mappings_to(self.__dict__[f], dict)) for f in fields)
379
380  @classmethod
381  def _deserialize(cls, state):  # TypeSpec API.
382    state = _change_nested_mappings_to(state, immutable_dict.ImmutableDict)
383    return _create_object_from_type_and_dict(cls, state)
384
385  def __reduce__(self):
386    # Use value_type instead of spec_type, as spec_type is a nested class.
387    # Pickle support of nested class requries Pickle protocol version 4, which
388    # is not enabled by default until py 3.8.
389    #
390    # https://www.python.org/dev/peps/pep-3154/#serializing-more-lookupable-objects
391    # https://docs.python.org/3/library/pickle.html#pickle.DEFAULT_PROTOCOL
392    return _deserialize_for_reduce, (self.value_type, self._serialize())
393
394  def _to_components(self, value):  # TypeSpec API.
395    if self._tf_extension_type_is_packed:
396      return value._tf_extension_type_packed_variant  # pylint: disable=protected-access
397
398    tensor_or_composite = (ops.Tensor, composite_tensor.CompositeTensor)
399    # Retireve fields by the order of spec dict to preserve field ordering. This
400    # is needed as nest.flatten would sort dictionary entries by key.
401    value_tuple = tuple(value.__dict__[key] for key in self.__dict__)
402    return tuple(
403        x for x in nest.flatten(value_tuple)
404        if isinstance(x, tensor_or_composite))
405
406  def _from_components(self, components):  # TypeSpec API.
407    if self._tf_extension_type_is_packed:
408      return _create_object_from_type_and_dict(
409          self.value_type, {
410              '_tf_extension_type_cached_type_spec': self,
411              '_tf_extension_type_packed_variant': components
412          })
413
414    spec_tuple = tuple(self.__dict__.values())
415    components_iter = iter(components)
416    flat = [
417        next(components_iter) if isinstance(x, type_spec.TypeSpec) else x
418        for x in nest.flatten(spec_tuple)
419    ]
420    if list(components_iter):
421      raise ValueError(
422          'Cannot build an ExtensionType instance from components '
423          'because more components are provided than the number expected '
424          'by the type spec.')
425    value_tuple = nest.pack_sequence_as(spec_tuple, flat)
426    fields = dict(zip(self.__dict__.keys(), value_tuple))
427
428    # Build the new value.  Bypass the constructor (__init__), in case the user
429    # who defined the ExtensionType used a custom constructor.
430    return _create_object_from_type_and_dict(self.value_type, fields)
431
432  @property
433  def _component_specs(self):  # TypeSpec API.
434    if self._tf_extension_type_is_packed:
435      return tensor_spec.TensorSpec((), dtypes.variant)
436
437    components = []
438
439    def push_if_type_spec(x):
440      if isinstance(x, type_spec.TypeSpec):
441        components.append(x)
442
443    nest.map_structure(push_if_type_spec, tuple(self.__dict__.values()))
444    return tuple(components)
445
446  @classmethod
447  def from_value(cls, value):
448    cached_spec = getattr(value, '_tf_extension_type_cached_type_spec', None)
449    if cached_spec is not None:
450      return cached_spec
451
452    value_fields = value.__dict__
453    spec_fields = nest.map_structure(_replace_tensor_with_spec, value_fields)
454    spec_fields.pop('_tf_extension_type_cached_fields', None)
455    return _create_object_from_type_and_dict(cls, spec_fields)
456
457  def __setattr__(self, name, value):
458    if (hasattr(self, _IN_CONSTRUCTOR) and
459        self._tf_extension_type_has_field(name)):
460      self.__dict__[name] = value
461    else:
462      raise AttributeError(
463          f'Cannot mutate attribute `{name}` '
464          f'outside the custom constructor of ExtensionTypeSpec.')
465
466  def __delattr__(self, name):
467    if (hasattr(self, _IN_CONSTRUCTOR) and
468        self._tf_extension_type_has_field(name)):
469      del self.__dict__[name]
470    else:
471      raise AttributeError(
472          f'Cannot mutate attribute `{name}` '
473          f'outside the custom constructor of ExtensionTypeSpec.')
474
475  def __validate__(self):
476    """Perform post-construction validation."""
477
478  @classmethod
479  def _tf_extension_type_fields(cls):
480    return cls.value_type._tf_extension_type_fields()  # pylint: disable=protected-access
481
482  @classmethod
483  def _tf_extension_type_has_field(cls, name):
484    return any(name == field.name for field in cls._tf_extension_type_fields())
485
486  def _tf_extension_type_convert_fields(self):
487    extension_type_field.convert_fields_for_spec(
488        self._tf_extension_type_fields(), self.__dict__)
489
490  def __repr__(self):
491    fields = ', '.join([f'{k}={v!r}' for (k, v) in self._serialize()])
492    return f'{type(self).__qualname__}({fields})'
493
494  _tf_extension_type_is_packed = False
495
496  def _tf_extension_type_with_packed(self, value):
497    """Returns a copy of this `TypeSpec` with `packed=value`.
498
499    Args:
500      value: A boolean value.
501
502    Returns:
503      A copy of `self` with `_tf_extension_type_is_packed=value`.
504    """
505    copy = _create_object_from_type_and_dict(type(self), self.__dict__)
506    copy.__dict__['_tf_extension_type_is_packed'] = value
507    return copy
508
509
510@tf_export('experimental.ExtensionTypeBatchEncoder')
511class ExtensionTypeBatchEncoder(type_spec.TypeSpecBatchEncoder):
512  """Class used to encode and decode extension type values for batching.
513
514  In order to be batched and unbatched by APIs such as `tf.data.Dataset`,
515  `tf.keras`, and `tf.map_fn`, extension type values must be encoded as a list
516  of `tf.Tensor`s, where stacking, unstacking, or concatenating these encoded
517  tensors and then decoding the result must be equivalent to stacking,
518  unstacking, or concatenating the original values. `ExtensionTypeBatchEncoder`s
519  are responsible for implementing this encoding.
520
521  The default `ExtensionTypeBatchEncoder` that is used by
522  `BatchableExtensionType` assumes that extension type values can be stacked,
523  unstacked, or concatenated by simply stacking, unstacking, or concatenating
524  every nested `Tensor`, `ExtensionType`, `CompositeTensor`, and `TensorShape`
525  field.
526
527  Extension types where this is not the case will need to override
528  `__batch_encoder__` with a custom encoder that overrides the `batch`,
529  `unbatch`, `encode`, and `decode` methods. E.g.:
530
531  >>> class CustomBatchEncoder(ExtensionTypeBatchEncoder):
532  ...   pass # Override batch(), unbatch(), encode(), and decode().
533
534  >>> class CustomType(BatchableExtensionType):
535  ...   x: tf.Tensor
536  ...   y: tf.Tensor
537  ...   shape: tf.TensorShape
538  ...   __batch_encoder__ = CustomBatchEncoder()
539
540  For example, `tf.RaggedTensor` and `tf.SparseTensor` both use custom batch
541  encodings which define ops to "box" and "unbox" individual values into
542  `tf.variant` tensors.
543  """
544
545  def batch(self, spec, batch_size):
546    """Returns the TypeSpec representing a batch of values described by `spec`.
547
548    The default definition returns a `TypeSpec` that is equal to `spec`, except
549    that an outer axis with size `batch_size` is added to every nested
550    `TypeSpec` and `TensorShape` field.  Subclasses may override this default
551    definition, when necessary.
552
553    Args:
554      spec: The `TypeSpec` for an individual value.
555      batch_size: An `int` indicating the number of values that are batched
556        together, or `None` if the batch size is not known.
557
558    Returns:
559      A `TypeSpec` for a batch of values.
560    """
561
562    def batch_field(f):
563      if isinstance(f, type_spec.BatchableTypeSpec):
564        return f.__batch_encoder__.batch(f, batch_size)
565      elif isinstance(f, tensor_shape.TensorShape):
566        return [batch_size] + f
567      else:
568        return f
569
570    fields = tuple(spec.__dict__.items())
571    batched_fields = nest.map_structure(batch_field, fields)
572    return _create_object_from_type_and_dict(type(spec), batched_fields)
573
574  def unbatch(self, spec):
575    """Returns the TypeSpec for a single unbatched element in `spec`.
576
577    The default definition returns a `TypeSpec` that is equal to `spec`, except
578    that the outermost axis is removed from every nested `TypeSpec`, and
579    `TensorShape` field.  Subclasses may override this default definition, when
580    necessary.
581
582    Args:
583      spec: The `TypeSpec` for a batch of values.
584
585    Returns:
586      A `TypeSpec` for an individual value.
587    """
588
589    def unbatch_field(f):
590      if isinstance(f, type_spec.BatchableTypeSpec):
591        return f.__batch_encoder__.unbatch(f)
592      elif isinstance(f, tensor_shape.TensorShape):
593        return f[1:]
594      else:
595        return f
596
597    fields = tuple(spec.__dict__.items())
598    unbatched_fields = nest.map_structure(unbatch_field, fields)
599    return _create_object_from_type_and_dict(type(spec), unbatched_fields)
600
601  def encode(self, spec, value, minimum_rank=0):
602    """Encodes `value` as a nest of batchable Tensors or CompositeTensors.
603
604    The default definition returns a flat tuple of all the `Tensor`s,
605    `CompositeTensor`s, and `ExtensionType`s from a depth-first traversal of
606    `value`'s fields. Subclasses may override this default definition, when
607    necessary.
608
609    Args:
610      spec: The TypeSpec of the value to encode.
611      value: A value compatible with `spec`.
612      minimum_rank: The minimum rank for the returned Tensors, CompositeTensors,
613        and ExtensionType values.  This can be used to ensure that the encoded
614        values can be unbatched this number of times.   If `minimum_rank>0`,
615        then `t.shape[:minimum_rank]` must be compatible for all values `t`
616        returned by `encode`.
617
618    Returns:
619      A nest (as defined by `tf.nest`) of `tf.Tensor`s, batchable
620      `tf.CompositeTensor`s, or `tf.ExtensionType`s.  Stacking, unstacking, or
621      concatenating these encoded values and then decoding the result must be
622      equivalent to stacking, unstacking, or concatenating the original values.
623    """
624    return spec._to_components(value)  # pylint: disable=protected-access
625
626  def decode(self, spec, encoded_value):
627    """Decodes `value` from a batchable tensor encoding.
628
629    See `encode` for a description of the default encoding.  Subclasses may
630    override this default definition, when necessary.
631
632    Args:
633      spec: The TypeSpec for the result value.  If encoded values with spec `s`
634        were batched, then `spec` should be `s.batch(batch_size)`; or if encoded
635        values with spec `s` were unbatched, then `spec` should be
636        `s.unbatch()`.
637      encoded_value: A nest of values returned by `encode`; or a nest of
638        values that was formed by stacking, unstacking, or concatenating the
639        corresponding elements of values returned by `encode`.
640
641    Returns:
642      A value compatible with `type_spec`.
643    """
644    return spec._from_components(encoded_value)  # pylint: disable=protected-access
645
646  def encoding_specs(self, spec):
647    """Returns a list of `TensorSpec`(s) describing the encoding for `spec`.
648
649    See `encode` for a description of the default encoding.  Subclasses may
650    override this default definition, when necessary.
651
652    Args:
653      spec: The TypeSpec whose encoding should be described.
654
655    Returns:
656      A nest (as defined by `tf.nest) of `tf.TypeSpec`, describing the values
657      that are returned by `self.encode(spec, ...)`.  All TypeSpecs in this
658      nest must be batchable.
659    """
660    return spec._component_specs  # pylint: disable=protected-access
661
662
663class BatchableExtensionTypeSpec(ExtensionTypeSpec,
664                                 type_spec.BatchableTypeSpec):
665  """Base class for TypeSpecs for BatchableExtensionTypes."""
666
667  __batch_encoder__ = ExtensionTypeBatchEncoder()
668
669  def _batch(self, batch_size):
670    return self.__batch_encoder__.batch(self, batch_size)
671
672  def _unbatch(self):
673    return self.__batch_encoder__.unbatch(self)
674
675  def _to_tensor_list(self, value):
676    return type_spec.batchable_to_tensor_list(self, value)
677
678  def _to_batched_tensor_list(self, value):
679    return type_spec.batchable_to_tensor_list(self, value, minimum_rank=1)
680
681  def _from_compatible_tensor_list(self, tensor_list):
682    return type_spec.batchable_from_tensor_list(self, tensor_list)
683
684  @property
685  def _flat_tensor_specs(self):
686    return type_spec.get_batchable_flat_tensor_specs(self)
687
688
689@tf_export('experimental.BatchableExtensionType')
690class BatchableExtensionType(ExtensionType):
691  """An ExtensionType that can be batched and unbatched.
692
693  `BatchableExtensionType`s can be used with APIs that require batching or
694  unbatching, including `Keras`, `tf.data.Dataset`, and `tf.map_fn`.  E.g.:
695
696  >>> class Vehicle(tf.experimental.BatchableExtensionType):
697  ...   top_speed: tf.Tensor
698  ...   mpg: tf.Tensor
699  >>> batch = Vehicle([120, 150, 80], [30, 40, 12])
700  >>> tf.map_fn(lambda vehicle: vehicle.top_speed * vehicle.mpg, batch,
701  ...           fn_output_signature=tf.int32).numpy()
702  array([3600, 6000,  960], dtype=int32)
703
704  An `ExtensionTypeBatchEncoder` is used by these APIs to encode `ExtensionType`
705  values. The default encoder assumes that values can be stacked, unstacked, or
706  concatenated by simply stacking, unstacking, or concatenating every nested
707  `Tensor`, `ExtensionType`, `CompositeTensor`, or `TensorShape` field.
708  Extension types where this is not the case will need to override
709  `__batch_encoder__` with a custom `ExtensionTypeBatchEncoder`.  See
710  `tf.experimental.ExtensionTypeBatchEncoder` for more details.
711  """
712  # Let the metaclass know that it should *not* transform this class (since
713  # this class is part of the ExtensionType framework, and not a user class).
714  _tf_extension_type_do_not_transform_this_class = True
715
716
717# For Pickle __reduce__ protocol:
718def _deserialize_for_reduce(value_type, serialization):
719  return value_type.Spec._deserialize(serialization)  # pylint: disable=protected-access
720
721
722def _replace_tensor_with_spec(value):
723  if isinstance(value, ops.Tensor):
724    # Note: we intentionally exclude `value.name` from the `TensorSpec`.
725    return tensor_spec.TensorSpec(value.shape, value.dtype)
726  if hasattr(value, '_type_spec'):
727    return value._type_spec  # pylint: disable=protected-access
728  return value
729
730
731def _change_nested_mappings_to(value, new_type):
732  """Recursively replace mappings with `new_type`."""
733  if isinstance(value, (dict, immutable_dict.ImmutableDict)):
734    return new_type([(k, _change_nested_mappings_to(v, new_type))
735                     for (k, v) in value.items()])
736  elif isinstance(value, tuple):
737    return tuple(_change_nested_mappings_to(elt, new_type) for elt in value)
738  else:
739    return value
740
741
742# ==============================================================================
743# Helper methods for tf.ExtensionTypeMetaclass
744# ==============================================================================
745
746
747def _check_field_annotations(cls):
748  """Validates the field annotations for tf.ExtensionType subclass `cls`."""
749  annotations = getattr(cls, '__annotations__', {})
750
751  # Check that no fields use reserved names.
752  for name, value in cls.__dict__.items():
753    if name == 'Spec':
754      if not isinstance(value, type):
755        raise ValueError(f'{cls.__qualname__}.Spec must be a nested class; '
756                         f'got {value}.')
757      if (value.__bases__ != (type_spec.TypeSpec,) and value.__bases__ !=
758          (object,)):
759        raise ValueError(f'{cls.__qualname__}.Spec must be directly subclassed '
760                         'from tf.TypeSpec.')
761    elif extension_type_field.ExtensionTypeField.is_reserved_name(name):
762      raise ValueError(f'The field annotations for {cls.__name__} are '
763                       f"invalid. Field '{name}' is reserved.")
764  for name in annotations:
765    if extension_type_field.ExtensionTypeField.is_reserved_name(name):
766      raise ValueError(f'The field annotations for {cls.__name__} are '
767                       f"invalid. Field '{name}' is reserved.")
768
769  # Check that all fields have type annotaitons.
770  for (key, value) in cls.__dict__.items():
771    if not (key in annotations or callable(value) or key.startswith('_abc_') or
772            key == '_tf_extension_type_fields' or
773            key.startswith('__') and key.endswith('__') or
774            isinstance(value, (property, classmethod, staticmethod))):
775      raise ValueError(f'The field annotations for {cls.__name__} are '
776                       f'invalid. Field {key} is missing a type annotation.')
777
778
779def _add_extension_type_constructor(cls):
780  """Creates a constructor for a ExtensionType or ExtensionTypeSpec subclass."""
781  if '__init__' in cls.__dict__:
782    _wrap_user_constructor(cls)
783  else:
784    _build_extension_type_constructor(cls)
785
786
787def _wrap_user_constructor(cls):
788  """Wraps a user-defined constructor for tf.ExtensionType subclass `cls`."""
789  user_constructor = cls.__init__
790
791  def wrapped_init(self, *args, **kwargs):
792    self.__dict__[_IN_CONSTRUCTOR] = True
793    user_constructor(self, *args, **kwargs)
794    del self.__dict__[_IN_CONSTRUCTOR]
795
796    self._tf_extension_type_convert_fields()  # pylint: disable=protected-access
797    self.__validate__()
798
799  cls.__init__ = tf_decorator.make_decorator(user_constructor, wrapped_init)
800
801
802_NO_DEFAULT = extension_type_field.ExtensionTypeField.NO_DEFAULT
803
804
805# TODO(b/184565242) Consider using the templating system from autograph here.
806def _build_extension_type_constructor(cls):
807  """Builds a constructor for tf.ExtensionType subclass `cls`."""
808  fields = cls._tf_extension_type_fields()  # pylint: disable=protected-access
809
810  # Mark any no-default fields that follow default fields as keyword_only.
811  got_default = False
812  keyword_only_start = len(fields)
813  for i in range(len(fields)):
814    if got_default:
815      if fields[i].default is _NO_DEFAULT:
816        keyword_only_start = i
817        break
818    elif fields[i].default is not _NO_DEFAULT:
819      got_default = True
820
821  params = []
822  for i, field in enumerate(fields):
823    if i < keyword_only_start:
824      kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD
825    else:
826      kind = tf_inspect.Parameter.KEYWORD_ONLY
827    if field.default is _NO_DEFAULT:
828      default = tf_inspect.Parameter.empty
829    else:
830      default = field.default
831    params.append(
832        tf_inspect.Parameter(
833            field.name, kind, default=default, annotation=field.value_type))
834
835  signature = tf_inspect.Signature(params, return_annotation=cls.__name__)
836
837  def __init__(self, *args, **kwargs):  # pylint: disable=invalid-name
838    bound_args = signature.bind(*args, **kwargs)
839    bound_args.apply_defaults()
840    self.__dict__.update(bound_args.arguments)
841    self._tf_extension_type_convert_fields()  # pylint: disable=protected-access
842    self.__validate__()
843
844  # __signature__ is supported by some inspection/documentation tools
845  # (but note: typing.get_type_hints does not respect __signature__).
846  __init__.__signature__ = tf_inspect.Signature(
847      [
848          tf_inspect.Parameter('self',
849                               tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
850      ] + params,
851      return_annotation=cls)
852
853  cls.__init__ = __init__
854
855
856def _build_spec_constructor(cls):
857  """Builds a constructor for ExtensionTypeSpec subclass `cls`."""
858  params = []
859  kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD
860  for field in cls._tf_extension_type_fields():  # pylint: disable=protected-access
861    params.append(tf_inspect.Parameter(field.name, kind))
862
863  signature = tf_inspect.Signature(params, return_annotation=cls.__name__)
864
865  def __init__(self, *args, **kwargs):  # pylint: disable=invalid-name
866    bound_args = signature.bind(*args, **kwargs)
867    bound_args.apply_defaults()
868    self.__dict__.update(bound_args.arguments)
869    self._tf_extension_type_convert_fields()  # pylint: disable=protected-access
870    self.__validate__()
871
872  # __signature__ is supported by some inspection/documentation tools.
873  __init__.__signature__ = tf_inspect.Signature(
874      [
875          tf_inspect.Parameter('self',
876                               tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
877      ] + params,
878      return_annotation=cls)
879
880  cls.__init__ = __init__
881
882
883def _add_type_spec(cls):
884  """Creates a nested TypeSpec class for tf.ExtensionType subclass `cls`."""
885  spec_name = cls.__name__ + '.Spec'
886  spec_qualname = cls.__qualname__ + '.Spec'
887
888  # Set __module__ explicitly as a dynamic created class has module='abc'
889  # by default.
890  spec_dict = {'value_type': cls, '__module__': cls.__module__}
891
892  # Copy user-supplied customizations into the TypeSpec.
893  user_spec = cls.__dict__.get('Spec', None)
894  if user_spec is not None:
895    for (name, value) in user_spec.__dict__.items():
896      if extension_type_field.ExtensionTypeField.is_reserved_name(name):
897        raise ValueError(f'TypeSpec {spec_qualname} uses reserved '
898                         f"name '{name}'.")
899      if cls._tf_extension_type_has_field(name):  # pylint: disable=protected-access
900        raise ValueError(f"TypeSpec {spec_qualname} defines a variable '{name}'"
901                         f' which shadows a field in {cls.__qualname__}')
902      if name in ('__module__', '__dict__', '__weakref__'):
903        continue
904
905      spec_dict[name] = value
906
907  if issubclass(cls, BatchableExtensionType):
908    type_spec_base = BatchableExtensionTypeSpec
909    if hasattr(cls,
910               '__batch_encoder__') and '__batch_encoder__' not in spec_dict:
911      spec_dict['__batch_encoder__'] = cls.__batch_encoder__
912  else:
913    type_spec_base = ExtensionTypeSpec
914    if hasattr(cls, '__batch_encoder__') or '__batch_encoder__' in spec_dict:
915      raise ValueError('__batch_encoder__ should only be defined for '
916                       'BatchableExtensionType classes.')
917
918  # Build the TypeSpec and store it as a nested class inside `cls`.
919  spec = type(spec_name, (type_spec_base,), spec_dict)
920  spec.__qualname__ = spec_qualname
921  setattr(cls, 'Spec', spec)
922
923  # Build a constructor for the TypeSpec class.
924  if '__init__' in spec.__dict__:
925    _wrap_user_constructor(spec)
926  else:
927    _build_spec_constructor(spec)
928
929  cls.__abstractmethods__ -= {'_type_spec'}
930
931  # If the user included an explicit `__name__` attribute, then use that to
932  # register the TypeSpec (so it can be used in SavedModel signatures).
933  if '__name__' in cls.__dict__:
934    type_spec.register(cls.__dict__['__name__'] + '.Spec')(spec)
935
936
937# ==============================================================================
938# Anonymous ExtensionType
939# ==============================================================================
940class AnonymousExtensionType(ExtensionType):
941  """Fallback used to decode `tf.ExtensionType` when the original type is unavailable.
942
943  When a SavedModel is serialized, the signatures of any functions in the
944  SavedModel can include `tf.ExtensionType` subclasses.  These subclasses are
945  usually
946  registered, so they can be restored when the SavedModel is loaded.  However,
947  if a SavedModel is loaded without first registering the ExtensionType types in
948  its
949  signature, then the SavedModel will fall back to using the
950  `AnonymousExtensionType`
951  type instead.
952
953  If necessary, `AnonymousExtensionType` objects can be converted to a concrete
954  `tf.ExtensionType` subclass (and vice versa) using `reinterpret`.
955  """
956
957  # Let the metaclass know that it should *not* transform this class (since
958  # this class is part of the ExtensionType framework, and not a user class).
959  _tf_extension_type_do_not_transform_this_class = True
960
961  def __init__(self, **fields):
962    for name in fields:
963      if (extension_type_field.ExtensionTypeField.is_reserved_name(name) or
964          (name.startswith('__') and name.endswith('__'))):
965        raise ValueError(
966            f'Reserved field name {name} was encountered '
967            f'when trying to instantiate an AnonymousExtensionType.')
968    fields = [(k, _convert_anonymous_fields(v)) for (k, v) in fields.items()]
969    self.__dict__.update(fields)
970    self._tf_extension_type_convert_fields()
971    super().__init__()
972
973  @classmethod
974  def _tf_extension_type_fields(cls):
975    return [
976        extension_type_field.ExtensionTypeField(name, None)
977        for name in cls.__dict__
978        if not extension_type_field.ExtensionTypeField.is_reserved_name(name)
979    ]
980
981  def __setattr__(self, name, value):
982    raise AttributeError(f'Cannot set attribute `{name}`. '
983                         f'AnonymousExtensionType instances are immutable.')
984
985  def __delattr__(self, name):
986    raise AttributeError(f'Cannot delete attribute `{name}`. '
987                         f'AnonymousExtensionType instances are immutable.')
988
989  def _tf_extension_type_convert_fields(self):
990    fields = [(k, _convert_anonymous_fields(v))
991              for (k, v) in self.__dict__.items()
992              if not extension_type_field.ExtensionTypeField.is_reserved_name(k)
993             ]
994    self.__dict__.update(fields)
995
996  def __repr__(self):
997    fields = [
998        f'{k}={v!r}' for (k, v) in self.__dict__.items()
999        if not extension_type_field.ExtensionTypeField.is_reserved_name(k)
1000    ]
1001    return f'AnonymousExtensionType({", ".join(fields)})'
1002
1003  _tf_extension_type_cached_type_spec = None
1004
1005  @property
1006  def _type_spec(self):  # CompositeTensor API.
1007    # Note: the TypeSpec contains all static (non-tensor) data from `self`.
1008    if self._tf_extension_type_cached_type_spec is None:
1009      spec = AnonymousExtensionTypeSpec.from_value(self)
1010      self.__dict__['_tf_extension_type_cached_type_spec'] = spec
1011    return self._tf_extension_type_cached_type_spec
1012
1013
1014@type_spec.register('tf.AnonymousExtensionType.Spec')
1015class AnonymousExtensionTypeSpec(ExtensionTypeSpec):
1016  """TypeSpec for AnonymousExtensionType."""
1017
1018  def __init__(self, **fields):
1019    for name in fields:
1020      if (extension_type_field.ExtensionTypeField.is_reserved_name(name) or
1021          (name.startswith('__') and name.endswith('__'))):
1022        raise ValueError(
1023            f'Reserved field name {name} was encountered '
1024            f'when trying to instantiate an AnonymousExtensionTypeSpec.')
1025    fields = [(k, _convert_anonymous_fields(v, for_spec=True))
1026              for (k, v) in fields.items()]
1027    self.__dict__.update(fields)
1028    super().__init__()
1029
1030  value_type = AnonymousExtensionType  # TypeSpec API.
1031
1032  def _serialize(self):  # TypeSpec API.
1033    return tuple(
1034        (name, _change_nested_mappings_to(value, dict))
1035        for (name, value) in self.__dict__.items()
1036        if not extension_type_field.ExtensionTypeField.is_reserved_name(name))
1037
1038  def __setattr__(self, name, value):
1039    raise AttributeError(f'Cannot set attribute `{name}`. '
1040                         f'AnonymousExtensionTypeSpec instances are immutable.')
1041
1042  def __delattr__(self, name):
1043    raise AttributeError(f'Cannot delete attribute `{name}`. '
1044                         f'AnonymousExtensionTypeSpec instances are immutable.')
1045
1046
1047def _convert_anonymous_fields(value, for_spec=False):
1048  """Type-checks and converts `value` for inclusion in an AnonymousExtensionType."""
1049  if isinstance(value, (int, float, bool, str, bytes, type(None), dtypes.DType,
1050                        tensor_shape.TensorShape)):
1051    return value
1052
1053  if isinstance(value, tuple):
1054    return tuple(_convert_anonymous_fields(v, for_spec) for v in value)
1055
1056  if isinstance(value, typing.Mapping):
1057    return immutable_dict.ImmutableDict([
1058        (_convert_anonymous_fields(k, for_spec),
1059         _convert_anonymous_fields(v, for_spec)) for (k, v) in value.items()
1060    ])
1061
1062  if (isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)) and
1063      not for_spec):
1064    return value
1065
1066  if isinstance(value, type_spec.TypeSpec) and for_spec:
1067    return value
1068
1069  raise ValueError(f'Cannot convert anonymous fields from '
1070                   f'an unsupported `value` argument: {value!r}.')
1071
1072
1073# ==============================================================================
1074# reinterpret
1075# ==============================================================================
1076def reinterpret(value, new_type):
1077  """Converts a given `ExtensionType` to a new type with compatible fields.
1078
1079  In particular, this can be used to convert a concrete subclass of
1080  `ExtensionType` to an `AnonymousExtensionType`, or vice versa.  When
1081  converting to a non-anonymous ExtensionType, field values are type-checked to
1082  ensure they are consistent with `new_type`'s type annotations, and validated
1083  with `new_type.__validate__`.
1084
1085  Args:
1086    value: An instance of a subclass of `tf.ExtensionType`
1087    new_type: A subclass of `tf.ExtensionType`
1088
1089  Returns:
1090    An instance of `new_type`, whose fields are copied from `value`.
1091  """
1092  if not isinstance(value, ExtensionType):
1093    raise ValueError(
1094        f'reinterpret expects `value` to be a tf.ExtensionType instance; '
1095        f'got {value!r}')
1096  if not (isinstance(new_type, type) and issubclass(new_type, ExtensionType)):
1097    raise ValueError(
1098        f'reinterpret expects `new_type` to be a subclass of tf.ExtensionType; '
1099        f'got {new_type!r}')
1100
1101  fields = [
1102      item for item in value.__dict__.items()
1103      if not extension_type_field.ExtensionTypeField.is_reserved_name(item[0])
1104  ]
1105  new_value = _create_object_from_type_and_dict(new_type, fields)
1106  new_value._tf_extension_type_convert_fields()  # pylint: disable=protected-access
1107  new_value.__validate__()
1108  return new_value
1109