1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Meatadata about fields for user-defined ExtensionType classes.""" 16 17import collections 18import collections.abc 19import enum 20import typing 21 22from tensorflow.python.framework import composite_tensor 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import immutable_dict 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import type_spec 28from tensorflow.python.util import type_annotations 29 30# These names may not be used as the name for a ExtensionType field (to prevent 31# name clashes). All names beginning with `'_tf_extension_type'` are also 32# reserved. 33RESERVED_FIELD_NAMES = [ 34 'self', 35 # Name of the nested TypeSpec class. 36 'Spec', 37 # Names defined by the CompositeTensor base class. 38 '_type_spec', 39 '_shape_invariant_to_type_spec', 40 '_consumers', 41 # Names defined by the TypeSpec base class. 42 'value_type', 43 'is_compatible_with', 44 'most_specific_compatible_type', 45 '_with_tensor_ranks_only', 46 '_to_components', 47 '_from_components', 48 '_component_specs', 49 '_to_tensor_list', 50 '_from_tensor_list', 51 '_from_compatible_tensor_list', 52 '_flat_tensor_specs', 53 '_serialize', 54 '_deserialize', 55 '_to_legacy_output_types', 56 '_to_legacy_output_shapes', 57 '_to_legacy_output_classes', 58 # Used by Keras 59 '_keras_mask' 60] 61 62 63class Sentinel(object): 64 """Sentinel value that's not equal (w/ `is`) to any user value.""" 65 66 def __init__(self, name): 67 self._name = name 68 69 def __repr__(self): 70 return self._name 71 72 73_NoneType = type(None) 74 75 76# ============================================================================== 77# ExtensionTypeField 78# ============================================================================== 79class ExtensionTypeField( 80 collections.namedtuple('ExtensionTypeField', 81 ['name', 'value_type', 'default'])): 82 """Metadata about a single field in a `tf.ExtensionType` object.""" 83 84 NO_DEFAULT = Sentinel('ExtensionTypeField.NO_DEFAULT') 85 86 def __new__(cls, name, value_type, default=NO_DEFAULT): 87 """Constructs a new ExtensionTypeField containing metadata for a single field. 88 89 Args: 90 name: The name of the new field (`str`). May not be a reserved name. 91 value_type: A python type expression constraining what values this field 92 can take. 93 default: The default value for the new field, or `NO_DEFAULT` if this 94 field has no default value. 95 96 Returns: 97 A new `ExtensionTypeField`. 98 99 Raises: 100 TypeError: If the type described by `value_type` is not currently 101 supported by `tf.ExtensionType`. 102 TypeError: If `default` is specified and its type does not match 103 `value_type`. 104 """ 105 try: 106 validate_field_value_type(value_type, allow_forward_references=True) 107 except TypeError as e: 108 raise TypeError(f'In field {name!r}: {e}') from e 109 110 if default is not cls.NO_DEFAULT: 111 default = _convert_value(default, value_type, 112 (f'default value for {name}',), 113 _ConversionContext.DEFAULT) 114 return super(ExtensionTypeField, cls).__new__(cls, name, value_type, 115 default) 116 117 @staticmethod 118 def is_reserved_name(name): 119 """Returns true if `name` is a reserved name.""" 120 return name in RESERVED_FIELD_NAMES or name.lower().startswith( 121 '_tf_extension_type') 122 123 124def validate_field_value_type(value_type, 125 in_mapping_key=False, 126 allow_forward_references=False): 127 """Checks that `value_type` contains only supported type annotations. 128 129 Args: 130 value_type: The type annotation to check. 131 in_mapping_key: True if `value_type` is nested in the key of a mapping. 132 allow_forward_references: If false, then raise an exception if a 133 `value_type` contains a forward reference (i.e., a string literal). 134 135 Raises: 136 TypeError: If `value_type` contains an unsupported type annotation. 137 """ 138 if isinstance(value_type, str) or type_annotations.is_forward_ref(value_type): 139 if allow_forward_references: 140 return 141 else: 142 raise TypeError(f'Unresolved forward reference {value_type!r}') 143 144 if value_type in (int, float, str, bytes, bool, None, _NoneType, 145 dtypes.DType): 146 return 147 elif (value_type in (ops.Tensor, tensor_shape.TensorShape) or 148 (isinstance(value_type, type) and 149 issubclass(value_type, composite_tensor.CompositeTensor))): 150 if in_mapping_key: 151 raise TypeError(f"Mapping had a key '{value_type.__name__}' with type " 152 f"'{type(value_type).__name__}'") 153 elif (type_annotations.is_generic_tuple(value_type) or 154 type_annotations.is_generic_union(value_type)): 155 type_args = type_annotations.get_generic_type_args(value_type) 156 if (len(type_args) == 2 and type_args[1] is Ellipsis and 157 type_annotations.is_generic_tuple(value_type)): # `Tuple[X, ...]` 158 validate_field_value_type(type_args[0], in_mapping_key, 159 allow_forward_references) 160 else: 161 for arg in type_annotations.get_generic_type_args(value_type): 162 validate_field_value_type(arg, in_mapping_key, allow_forward_references) 163 elif type_annotations.is_generic_mapping(value_type): 164 key_type, value_type = type_annotations.get_generic_type_args(value_type) 165 validate_field_value_type(key_type, True, allow_forward_references) 166 validate_field_value_type(value_type, in_mapping_key, 167 allow_forward_references) 168 elif isinstance(value_type, type): 169 raise TypeError(f'Unsupported type annotation `{value_type.__name__}`') 170 else: 171 raise TypeError(f'Unsupported type annotation {value_type!r}') 172 173 174# ============================================================================== 175# Type-checking & conversion for ExtensionTypeField values 176# ============================================================================== 177 178 179class _ConversionContext(enum.Enum): 180 """Enum to indicate what kind of value is being converted. 181 182 Used by `_convert_fields` and `_convert_value` and their helper methods. 183 """ 184 VALUE = 1 # Converting an ExtensionType field 185 SPEC = 2 # Converting an ExtensionType.Spec field 186 DEFAULT = 3 # Converting a default value for __init__ 187 188 189def convert_fields(fields, field_values): 190 """Type-checks and converts each field in `field_values` (in place). 191 192 Args: 193 fields: A list of `ExtensionTypeField` objects. 194 field_values: A `dict` mapping field names to values. Must contain an entry 195 for each field. I.e., `set(field_values.keys())` must be equal to 196 `set([f.name for f in fields])`. 197 198 Raises: 199 ValueError: If the keys of `field_values` do not match the names of 200 the fields in `fields`. 201 TypeError: If any value in `field_values` does not have the type indicated 202 by the corresponding `ExtensionTypeField` object. 203 """ 204 _convert_fields(fields, field_values, context=_ConversionContext.VALUE) 205 206 207def convert_fields_for_spec(fields, field_values): 208 """Type-checks and converts field values for a TypeSpec (in place). 209 210 This is similar to `convert_fields`, except that we expect a `TypeSpec` for 211 tensor-like types. In particular, if the `value_type` of a field is 212 `tf.Tensor` or a `CompositeTensor` subclass, then the corresponding value in 213 `fields` is expected to contain a `TypeSpec` (rather than a value described by 214 that `TypeSpec`). 215 216 Args: 217 fields: A list of `ExtensionTypeField` objects. 218 field_values: A `dict` mapping field names to values. Must contain an entry 219 for each field. I.e., `set(field_values.keys())` must be equal to 220 `set([f.name for f in fields])`. 221 222 Raises: 223 ValueError: If the keys of `field_values` do not match the names of 224 the fields in `fields`. 225 TypeError: If any value in `field_values` does not have the type indicated 226 by the corresponding `ExtensionTypeField` object. 227 """ 228 _convert_fields(fields, field_values, context=_ConversionContext.SPEC) 229 230 231def _convert_fields(fields, field_values, context): 232 """Type-checks and converts each field in `field_values` (in place). 233 234 Args: 235 fields: A list of `ExtensionTypeField` objects. 236 field_values: A `dict` mapping field names to values. Must contain an entry 237 for each field. I.e., `set(field_values.keys())` must be equal to 238 `set([f.name for f in fields])`. 239 context: _ConversionContext, indicates what kind of value we are converting. 240 241 Raises: 242 ValueError: If the keys of `field_values` do not match the names of 243 the fields in `fields`. 244 TypeError: If any value in `field_values` does not have the type indicated 245 by the corresponding `ExtensionTypeField` object. 246 """ 247 converted = {} 248 if len(fields) != len(field_values): 249 _report_field_mismatches(fields, field_values) 250 for field in fields: 251 if field.name not in field_values: 252 _report_field_mismatches(fields, field_values) 253 field_value = field_values[field.name] 254 converted[field.name] = _convert_value(field_value, field.value_type, 255 (field.name,), context) 256 field_values.update(converted) 257 258 259def _convert_value(value, expected_type, path, 260 context=_ConversionContext.VALUE): 261 """Type-checks and converts a value. 262 263 Args: 264 value: The value to type-check. 265 expected_type: The expected type for the value. 266 path: Tuple of `str` naming the value (used for exception messages). 267 context: _ConversionContext, indicates what kind of value we are converting. 268 269 Returns: 270 A copy of `value`, converted to the expected type. 271 272 Raises: 273 TypeError: If `value` can not be converted to the expected type. 274 """ 275 assert isinstance(path, tuple) 276 277 if expected_type is None: 278 expected_type = _NoneType 279 280 if expected_type is ops.Tensor: 281 return _convert_tensor(value, path, context) 282 elif (isinstance(expected_type, type) and 283 issubclass(expected_type, composite_tensor.CompositeTensor)): 284 return _convert_composite_tensor(value, expected_type, path, context) 285 elif expected_type is tensor_shape.TensorShape: 286 try: 287 return tensor_shape.as_shape(value) 288 except TypeError as e: 289 raise TypeError( 290 f'{"".join(path)}: expected tf.TensorShape, got {value!r}') from e 291 elif expected_type is dtypes.DType: 292 try: 293 return dtypes.as_dtype(value) 294 except TypeError as e: 295 raise TypeError( 296 f'{"".join(path)}: expected tf.DType, got {value!r}') from e 297 elif expected_type in (int, float, bool, str, bytes, _NoneType): 298 if not isinstance(value, expected_type): 299 raise TypeError(f'{"".join(path)}: expected ' 300 f'{expected_type.__name__}, got {value!r}') 301 return value 302 elif type_annotations.is_generic_tuple(expected_type): 303 return _convert_tuple(value, expected_type, path, context) 304 elif type_annotations.is_generic_mapping(expected_type): 305 return _convert_mapping(value, expected_type, path, context) 306 elif type_annotations.is_generic_union(expected_type): 307 return _convert_union(value, expected_type, path, context) 308 else: 309 raise TypeError(f'{"".join(path)}: Unsupported type annotation ' 310 f'{expected_type!r}') 311 312 313def _convert_tensor(value, path, context): 314 """Converts `value` to a `Tensor`.""" 315 if context == _ConversionContext.SPEC: 316 if not (isinstance(value, type_spec.TypeSpec) and 317 value.value_type is ops.Tensor): 318 raise TypeError(f'{"".join(path)}: expected a TensorSpec, got {value!r}') 319 return value 320 321 if not isinstance(value, ops.Tensor): 322 if context == _ConversionContext.DEFAULT: 323 # TODO(edloper): Convert the value to a numpy array? (Note: we can't just 324 # use `np.array(value)`, since the default dtypes for TF and numpy are 325 # different -- e.g., int->np.int64 but int->tf.int32. 326 return value 327 try: 328 value = ops.convert_to_tensor(value) 329 except (ValueError, TypeError) as e: 330 raise TypeError(f'{"".join(path)}: expected a Tensor, ' 331 f'got {value!r}') from e 332 return value 333 334 335def _convert_composite_tensor(value, expected_type, path, context): 336 """Converts `value` to a value of type `expected_type`.""" 337 if context == _ConversionContext.SPEC: 338 if not (isinstance(value, type_spec.TypeSpec) and 339 issubclass(value.value_type, expected_type)): 340 raise TypeError(f'{"".join(path)}: expected a TypeSpec for ' 341 f'{expected_type.__name__}, got {value!r}') 342 return value 343 344 if not isinstance(value, expected_type): 345 raise TypeError(f'{"".join(path)}: expected {expected_type.__name__}, ' 346 f'got {value!r}') 347 return value 348 349 350def _convert_tuple(value, expected_type, path, context): 351 """Converts `value` to a tuple with type `expected_type`.""" 352 if not isinstance(value, typing.Sequence): 353 raise TypeError(f'{"".join(path)}: expected tuple, got {value!r}') 354 element_types = type_annotations.get_generic_type_args(expected_type) 355 if len(element_types) == 2 and element_types[1] is Ellipsis: 356 return tuple([ 357 _convert_value(v, element_types[0], path + (f'[{i}]',), context) 358 for (i, v) in enumerate(value) 359 ]) 360 else: 361 if len(value) != len(element_types): 362 raise TypeError(f'{"".join(path)}: expected tuple with length ' 363 f'{len(element_types)}, got {value!r})') 364 return tuple([ 365 _convert_value(v, t, path + (f'[{i}]',), context) 366 for (i, (v, t)) in enumerate(zip(value, element_types)) 367 ]) 368 369 370def _convert_mapping(value, expected_type, path, context): 371 """Converts `value` to a mapping with type `expected_type`.""" 372 if not isinstance(value, typing.Mapping): 373 raise TypeError(f'{"".join(path)}: expected mapping, got {value!r}') 374 key_type, value_type = type_annotations.get_generic_type_args(expected_type) 375 return immutable_dict.ImmutableDict([ 376 (_convert_value(k, key_type, path + ('[<key>]',), context), 377 _convert_value(v, value_type, path + (f'[{k!r}]',), context)) 378 for (k, v) in value.items() 379 ]) 380 381 382def _convert_union(value, expected_type, path, context): 383 """Converts `value` to a value with any of the types in `expected_type`.""" 384 for type_option in type_annotations.get_generic_type_args(expected_type): 385 try: 386 return _convert_value(value, type_option, path, context) 387 except TypeError: 388 pass 389 raise TypeError(f'{"".join(path)}: expected {expected_type}, got {value!r}') 390 391 392def _report_field_mismatches(fields, field_values): 393 """Raises an exception with mismatches between fields and field_values.""" 394 expected = set(f.name for f in fields) 395 actual = set(field_values) 396 extra = actual - expected 397 if extra: 398 raise ValueError(f'Got unexpected fields: {extra}') 399 missing = expected - actual 400 if missing: 401 raise ValueError(f'Missing required fields: {missing}') 402