1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Operations for ExtensionTypes (aka Composite Tensors).""" 16 17from tensorflow.core.protobuf import composite_tensor_variant_pb2 18from tensorflow.python.framework import composite_tensor 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.ops import gen_composite_tensor_ops 22from tensorflow.python.saved_model import nested_structure_coder 23from tensorflow.python.util import nest 24 25 26def composite_tensor_to_variants(value, type_spec=None, name=None): 27 """Encodes `value` as a scalar variant tensor. 28 29 Args: 30 value: The `ExtensionType` value to encode. 31 type_spec: Information about the value's type that should be included in the 32 encoding. 33 name: Optional name for the operation. 34 35 Returns: 36 A Tensor with shape=`()` and dtype=`tf.variant`. 37 38 Raises: 39 ValueError: If `type_spec` is not compatible with `value`. 40 """ 41 if not isinstance(value, composite_tensor.CompositeTensor): 42 raise TypeError("Expected `value` to be a CompositeTensor. " 43 f"Received {type(value)}.") 44 45 if type_spec is None: 46 type_spec = value._type_spec # pylint: disable=protected-access 47 if not type_spec.is_compatible_with(value): 48 raise ValueError(f"`type_spec` {type_spec} is not compatible with `value` " 49 f"{value!r}.") 50 metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata() 51 metadata.type_spec_proto.CopyFrom( 52 nested_structure_coder.encode_structure(type_spec).type_spec_value) 53 54 return gen_composite_tensor_ops.CompositeTensorVariantFromComponents( 55 components=nest.flatten(value, expand_composites=True), 56 metadata=metadata.SerializeToString(), 57 name=name) 58 59 60def composite_tensor_from_variant(encoded, type_spec, name=None): 61 """Returns the `ExtensionType` value encoded by a variant scalar tensor. 62 63 Args: 64 encoded: A Tensor returned by `composite_tensor_to_variants`. 65 type_spec: The `TypeSpec` of the original value. This is used to determine 66 the number and types of the component tensors that comprise the decoded 67 value. Must be compatible with the `TypeSpec` serilized in `encoded`. 68 name: Optional name for the operation. 69 70 Returns: 71 An `ExtensionType` value that is compatible with `TypeSpec`. 72 73 Raises: 74 TypeError: If `encoded` is not a Tensor with dtype=variant. 75 InvalidArgumentError: If `encoded` is not compatible with `type_spec`. 76 """ 77 if not isinstance(encoded, ops.Tensor): 78 raise TypeError(f"Expected `encoded` to be a Tensor, got {encoded!r}.") 79 if encoded.dtype != dtypes.variant: 80 raise TypeError("Expected `encoded` to have dtype=variant, got " 81 f"{encoded!r}.") 82 encoded.shape.assert_is_compatible_with(()) 83 84 metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata() 85 metadata.type_spec_proto.CopyFrom( 86 nested_structure_coder.encode_structure(type_spec).type_spec_value) 87 88 component_dtypes = [ 89 t.dtype for t in nest.flatten(type_spec, expand_composites=True) 90 ] 91 92 components = gen_composite_tensor_ops.CompositeTensorVariantToComponents( 93 encoded=encoded, 94 metadata=metadata.SerializeToString(), 95 Tcomponents=component_dtypes, 96 name=name) 97 return nest.pack_sequence_as(type_spec, components, expand_composites=True) 98 99 100@ops.RegisterGradient("CompositeTensorVariantFromComponents") 101def _composite_tensor_to_variants_grad(op, grad): 102 return gen_composite_tensor_ops.CompositeTensorVariantToComponents( 103 encoded=grad, 104 metadata=op.get_attr("metadata"), 105 Tcomponents=op.get_attr("Tcomponents")) 106 107 108@ops.RegisterGradient("CompositeTensorVariantToComponents") 109def _composite_tensor_from_variant_grad(op, *grad): 110 assert len(grad) == len(op.outputs) 111 # `components` is `op.outputs`, but with any tensors for which we're 112 # taking the gradient replaced by the corresponding value from `grad`. 113 components = [ 114 op.outputs[i] if grad[i] is None else grad[i] for i in range(len(grad)) 115 ] 116 return gen_composite_tensor_ops.CompositeTensorVariantFromComponents( 117 components=components, metadata=op.get_attr("metadata")) 118