xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/extension_type_field.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Meatadata about fields for user-defined ExtensionType classes."""
16
17import collections
18import collections.abc
19import enum
20import typing
21
22from tensorflow.python.framework import composite_tensor
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import immutable_dict
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import type_spec
28from tensorflow.python.util import type_annotations
29
30# These names may not be used as the name for a ExtensionType field (to prevent
31# name clashes).  All names beginning with `'_tf_extension_type'` are also
32# reserved.
33RESERVED_FIELD_NAMES = [
34    'self',
35    # Name of the nested TypeSpec class.
36    'Spec',
37    # Names defined by the CompositeTensor base class.
38    '_type_spec',
39    '_shape_invariant_to_type_spec',
40    '_consumers',
41    # Names defined by the TypeSpec base class.
42    'value_type',
43    'is_compatible_with',
44    'most_specific_compatible_type',
45    '_with_tensor_ranks_only',
46    '_to_components',
47    '_from_components',
48    '_component_specs',
49    '_to_tensor_list',
50    '_from_tensor_list',
51    '_from_compatible_tensor_list',
52    '_flat_tensor_specs',
53    '_serialize',
54    '_deserialize',
55    '_to_legacy_output_types',
56    '_to_legacy_output_shapes',
57    '_to_legacy_output_classes',
58    # Used by Keras
59    '_keras_mask'
60]
61
62
63class Sentinel(object):
64  """Sentinel value that's not equal (w/ `is`) to any user value."""
65
66  def __init__(self, name):
67    self._name = name
68
69  def __repr__(self):
70    return self._name
71
72
73_NoneType = type(None)
74
75
76# ==============================================================================
77# ExtensionTypeField
78# ==============================================================================
79class ExtensionTypeField(
80    collections.namedtuple('ExtensionTypeField',
81                           ['name', 'value_type', 'default'])):
82  """Metadata about a single field in a `tf.ExtensionType` object."""
83
84  NO_DEFAULT = Sentinel('ExtensionTypeField.NO_DEFAULT')
85
86  def __new__(cls, name, value_type, default=NO_DEFAULT):
87    """Constructs a new ExtensionTypeField containing metadata for a single field.
88
89    Args:
90      name: The name of the new field (`str`).  May not be a reserved name.
91      value_type: A python type expression constraining what values this field
92        can take.
93      default: The default value for the new field, or `NO_DEFAULT` if this
94        field has no default value.
95
96    Returns:
97      A new `ExtensionTypeField`.
98
99    Raises:
100      TypeError: If the type described by `value_type` is not currently
101          supported by `tf.ExtensionType`.
102      TypeError: If `default` is specified and its type does not match
103        `value_type`.
104    """
105    try:
106      validate_field_value_type(value_type, allow_forward_references=True)
107    except TypeError as e:
108      raise TypeError(f'In field {name!r}: {e}') from e
109
110    if default is not cls.NO_DEFAULT:
111      default = _convert_value(default, value_type,
112                               (f'default value for {name}',),
113                               _ConversionContext.DEFAULT)
114    return super(ExtensionTypeField, cls).__new__(cls, name, value_type,
115                                                  default)
116
117  @staticmethod
118  def is_reserved_name(name):
119    """Returns true if `name` is a reserved name."""
120    return name in RESERVED_FIELD_NAMES or name.lower().startswith(
121        '_tf_extension_type')
122
123
124def validate_field_value_type(value_type,
125                              in_mapping_key=False,
126                              allow_forward_references=False):
127  """Checks that `value_type` contains only supported type annotations.
128
129  Args:
130    value_type: The type annotation to check.
131    in_mapping_key: True if `value_type` is nested in the key of a mapping.
132    allow_forward_references: If false, then raise an exception if a
133      `value_type` contains a forward reference (i.e., a string literal).
134
135  Raises:
136    TypeError: If `value_type` contains an unsupported type annotation.
137  """
138  if isinstance(value_type, str) or type_annotations.is_forward_ref(value_type):
139    if allow_forward_references:
140      return
141    else:
142      raise TypeError(f'Unresolved forward reference {value_type!r}')
143
144  if value_type in (int, float, str, bytes, bool, None, _NoneType,
145                    dtypes.DType):
146    return
147  elif (value_type in (ops.Tensor, tensor_shape.TensorShape) or
148        (isinstance(value_type, type) and
149         issubclass(value_type, composite_tensor.CompositeTensor))):
150    if in_mapping_key:
151      raise TypeError(f"Mapping had a key '{value_type.__name__}' with type "
152                      f"'{type(value_type).__name__}'")
153  elif (type_annotations.is_generic_tuple(value_type) or
154        type_annotations.is_generic_union(value_type)):
155    type_args = type_annotations.get_generic_type_args(value_type)
156    if (len(type_args) == 2 and type_args[1] is Ellipsis and
157        type_annotations.is_generic_tuple(value_type)):  # `Tuple[X, ...]`
158      validate_field_value_type(type_args[0], in_mapping_key,
159                                allow_forward_references)
160    else:
161      for arg in type_annotations.get_generic_type_args(value_type):
162        validate_field_value_type(arg, in_mapping_key, allow_forward_references)
163  elif type_annotations.is_generic_mapping(value_type):
164    key_type, value_type = type_annotations.get_generic_type_args(value_type)
165    validate_field_value_type(key_type, True, allow_forward_references)
166    validate_field_value_type(value_type, in_mapping_key,
167                              allow_forward_references)
168  elif isinstance(value_type, type):
169    raise TypeError(f'Unsupported type annotation `{value_type.__name__}`')
170  else:
171    raise TypeError(f'Unsupported type annotation {value_type!r}')
172
173
174# ==============================================================================
175# Type-checking & conversion for ExtensionTypeField values
176# ==============================================================================
177
178
179class _ConversionContext(enum.Enum):
180  """Enum to indicate what kind of value is being converted.
181
182  Used by `_convert_fields` and `_convert_value` and their helper methods.
183  """
184  VALUE = 1  # Converting an ExtensionType field
185  SPEC = 2  # Converting an ExtensionType.Spec field
186  DEFAULT = 3  # Converting a default value for __init__
187
188
189def convert_fields(fields, field_values):
190  """Type-checks and converts each field in `field_values` (in place).
191
192  Args:
193    fields: A list of `ExtensionTypeField` objects.
194    field_values: A `dict` mapping field names to values.  Must contain an entry
195      for each field.  I.e., `set(field_values.keys())` must be equal to
196      `set([f.name for f in fields])`.
197
198  Raises:
199    ValueError: If the keys of `field_values` do not match the names of
200      the fields in `fields`.
201    TypeError: If any value in `field_values` does not have the type indicated
202      by the corresponding `ExtensionTypeField` object.
203  """
204  _convert_fields(fields, field_values, context=_ConversionContext.VALUE)
205
206
207def convert_fields_for_spec(fields, field_values):
208  """Type-checks and converts field values for a TypeSpec (in place).
209
210  This is similar to `convert_fields`, except that we expect a `TypeSpec` for
211  tensor-like types.  In particular, if the `value_type` of a field is
212  `tf.Tensor` or a `CompositeTensor` subclass, then the corresponding value in
213  `fields` is expected to contain a `TypeSpec` (rather than a value described by
214  that `TypeSpec`).
215
216  Args:
217    fields: A list of `ExtensionTypeField` objects.
218    field_values: A `dict` mapping field names to values.  Must contain an entry
219      for each field.  I.e., `set(field_values.keys())` must be equal to
220      `set([f.name for f in fields])`.
221
222  Raises:
223    ValueError: If the keys of `field_values` do not match the names of
224      the fields in `fields`.
225    TypeError: If any value in `field_values` does not have the type indicated
226      by the corresponding `ExtensionTypeField` object.
227  """
228  _convert_fields(fields, field_values, context=_ConversionContext.SPEC)
229
230
231def _convert_fields(fields, field_values, context):
232  """Type-checks and converts each field in `field_values` (in place).
233
234  Args:
235    fields: A list of `ExtensionTypeField` objects.
236    field_values: A `dict` mapping field names to values.  Must contain an entry
237      for each field.  I.e., `set(field_values.keys())` must be equal to
238      `set([f.name for f in fields])`.
239    context: _ConversionContext, indicates what kind of value we are converting.
240
241  Raises:
242    ValueError: If the keys of `field_values` do not match the names of
243      the fields in `fields`.
244    TypeError: If any value in `field_values` does not have the type indicated
245      by the corresponding `ExtensionTypeField` object.
246  """
247  converted = {}
248  if len(fields) != len(field_values):
249    _report_field_mismatches(fields, field_values)
250  for field in fields:
251    if field.name not in field_values:
252      _report_field_mismatches(fields, field_values)
253    field_value = field_values[field.name]
254    converted[field.name] = _convert_value(field_value, field.value_type,
255                                           (field.name,), context)
256  field_values.update(converted)
257
258
259def _convert_value(value, expected_type, path,
260                   context=_ConversionContext.VALUE):
261  """Type-checks and converts a value.
262
263  Args:
264    value: The value to type-check.
265    expected_type: The expected type for the value.
266    path: Tuple of `str` naming the value (used for exception messages).
267    context: _ConversionContext, indicates what kind of value we are converting.
268
269  Returns:
270    A copy of `value`, converted to the expected type.
271
272  Raises:
273    TypeError: If `value` can not be converted to the expected type.
274  """
275  assert isinstance(path, tuple)
276
277  if expected_type is None:
278    expected_type = _NoneType
279
280  if expected_type is ops.Tensor:
281    return _convert_tensor(value, path, context)
282  elif (isinstance(expected_type, type) and
283        issubclass(expected_type, composite_tensor.CompositeTensor)):
284    return _convert_composite_tensor(value, expected_type, path, context)
285  elif expected_type is tensor_shape.TensorShape:
286    try:
287      return tensor_shape.as_shape(value)
288    except TypeError as e:
289      raise TypeError(
290          f'{"".join(path)}: expected tf.TensorShape, got {value!r}') from e
291  elif expected_type is dtypes.DType:
292    try:
293      return dtypes.as_dtype(value)
294    except TypeError as e:
295      raise TypeError(
296          f'{"".join(path)}: expected tf.DType, got {value!r}') from e
297  elif expected_type in (int, float, bool, str, bytes, _NoneType):
298    if not isinstance(value, expected_type):
299      raise TypeError(f'{"".join(path)}: expected '
300                      f'{expected_type.__name__}, got {value!r}')
301    return value
302  elif type_annotations.is_generic_tuple(expected_type):
303    return _convert_tuple(value, expected_type, path, context)
304  elif type_annotations.is_generic_mapping(expected_type):
305    return _convert_mapping(value, expected_type, path, context)
306  elif type_annotations.is_generic_union(expected_type):
307    return _convert_union(value, expected_type, path, context)
308  else:
309    raise TypeError(f'{"".join(path)}: Unsupported type annotation '
310                    f'{expected_type!r}')
311
312
313def _convert_tensor(value, path, context):
314  """Converts `value` to a `Tensor`."""
315  if context == _ConversionContext.SPEC:
316    if not (isinstance(value, type_spec.TypeSpec) and
317            value.value_type is ops.Tensor):
318      raise TypeError(f'{"".join(path)}: expected a TensorSpec, got {value!r}')
319    return value
320
321  if not isinstance(value, ops.Tensor):
322    if context == _ConversionContext.DEFAULT:
323      # TODO(edloper): Convert the value to a numpy array?  (Note: we can't just
324      # use `np.array(value)`, since the default dtypes for TF and numpy are
325      # different -- e.g., int->np.int64 but int->tf.int32.
326      return value
327    try:
328      value = ops.convert_to_tensor(value)
329    except (ValueError, TypeError) as e:
330      raise TypeError(f'{"".join(path)}: expected a Tensor, '
331                      f'got {value!r}') from e
332  return value
333
334
335def _convert_composite_tensor(value, expected_type, path, context):
336  """Converts `value` to a value of type `expected_type`."""
337  if context == _ConversionContext.SPEC:
338    if not (isinstance(value, type_spec.TypeSpec) and
339            issubclass(value.value_type, expected_type)):
340      raise TypeError(f'{"".join(path)}: expected a TypeSpec for '
341                      f'{expected_type.__name__}, got {value!r}')
342    return value
343
344  if not isinstance(value, expected_type):
345    raise TypeError(f'{"".join(path)}: expected {expected_type.__name__}, '
346                    f'got {value!r}')
347  return value
348
349
350def _convert_tuple(value, expected_type, path, context):
351  """Converts `value` to a tuple with type `expected_type`."""
352  if not isinstance(value, typing.Sequence):
353    raise TypeError(f'{"".join(path)}: expected tuple, got {value!r}')
354  element_types = type_annotations.get_generic_type_args(expected_type)
355  if len(element_types) == 2 and element_types[1] is Ellipsis:
356    return tuple([
357        _convert_value(v, element_types[0], path + (f'[{i}]',), context)
358        for (i, v) in enumerate(value)
359    ])
360  else:
361    if len(value) != len(element_types):
362      raise TypeError(f'{"".join(path)}: expected tuple with length '
363                      f'{len(element_types)}, got {value!r})')
364    return tuple([
365        _convert_value(v, t, path + (f'[{i}]',), context)
366        for (i, (v, t)) in enumerate(zip(value, element_types))
367    ])
368
369
370def _convert_mapping(value, expected_type, path, context):
371  """Converts `value` to a mapping with type `expected_type`."""
372  if not isinstance(value, typing.Mapping):
373    raise TypeError(f'{"".join(path)}: expected mapping, got {value!r}')
374  key_type, value_type = type_annotations.get_generic_type_args(expected_type)
375  return immutable_dict.ImmutableDict([
376      (_convert_value(k, key_type, path + ('[<key>]',), context),
377       _convert_value(v, value_type, path + (f'[{k!r}]',), context))
378      for (k, v) in value.items()
379  ])
380
381
382def _convert_union(value, expected_type, path, context):
383  """Converts `value` to a value with any of the types in `expected_type`."""
384  for type_option in type_annotations.get_generic_type_args(expected_type):
385    try:
386      return _convert_value(value, type_option, path, context)
387    except TypeError:
388      pass
389  raise TypeError(f'{"".join(path)}: expected {expected_type}, got {value!r}')
390
391
392def _report_field_mismatches(fields, field_values):
393  """Raises an exception with mismatches between fields and field_values."""
394  expected = set(f.name for f in fields)
395  actual = set(field_values)
396  extra = actual - expected
397  if extra:
398    raise ValueError(f'Got unexpected fields: {extra}')
399  missing = expected - actual
400  if missing:
401    raise ValueError(f'Missing required fields: {missing}')
402