xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/composite_tensor_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Operations for ExtensionTypes (aka Composite Tensors)."""
16
17from tensorflow.core.protobuf import composite_tensor_variant_pb2
18from tensorflow.python.framework import composite_tensor
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import gen_composite_tensor_ops
22from tensorflow.python.saved_model import nested_structure_coder
23from tensorflow.python.util import nest
24
25
26def composite_tensor_to_variants(value, type_spec=None, name=None):
27  """Encodes `value` as a scalar variant tensor.
28
29  Args:
30    value: The `ExtensionType` value to encode.
31    type_spec: Information about the value's type that should be included in the
32      encoding.
33    name: Optional name for the operation.
34
35  Returns:
36    A Tensor with shape=`()` and dtype=`tf.variant`.
37
38  Raises:
39    ValueError: If `type_spec` is not compatible with `value`.
40  """
41  if not isinstance(value, composite_tensor.CompositeTensor):
42    raise TypeError("Expected `value` to be a CompositeTensor. "
43                    f"Received {type(value)}.")
44
45  if type_spec is None:
46    type_spec = value._type_spec  # pylint: disable=protected-access
47  if not type_spec.is_compatible_with(value):
48    raise ValueError(f"`type_spec` {type_spec} is not compatible with `value` "
49                     f"{value!r}.")
50  metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata()
51  metadata.type_spec_proto.CopyFrom(
52      nested_structure_coder.encode_structure(type_spec).type_spec_value)
53
54  return gen_composite_tensor_ops.CompositeTensorVariantFromComponents(
55      components=nest.flatten(value, expand_composites=True),
56      metadata=metadata.SerializeToString(),
57      name=name)
58
59
60def composite_tensor_from_variant(encoded, type_spec, name=None):
61  """Returns the `ExtensionType` value encoded by a variant scalar tensor.
62
63  Args:
64    encoded: A Tensor returned by `composite_tensor_to_variants`.
65    type_spec: The `TypeSpec` of the original value.  This is used to determine
66      the number and types of the component tensors that comprise the decoded
67      value.  Must be compatible with the `TypeSpec` serilized in `encoded`.
68    name: Optional name for the operation.
69
70  Returns:
71    An `ExtensionType` value that is compatible with `TypeSpec`.
72
73  Raises:
74    TypeError: If `encoded` is not a Tensor with dtype=variant.
75    InvalidArgumentError: If `encoded` is not compatible with `type_spec`.
76  """
77  if not isinstance(encoded, ops.Tensor):
78    raise TypeError(f"Expected `encoded` to be a Tensor, got {encoded!r}.")
79  if encoded.dtype != dtypes.variant:
80    raise TypeError("Expected `encoded` to have dtype=variant, got "
81                    f"{encoded!r}.")
82  encoded.shape.assert_is_compatible_with(())
83
84  metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata()
85  metadata.type_spec_proto.CopyFrom(
86      nested_structure_coder.encode_structure(type_spec).type_spec_value)
87
88  component_dtypes = [
89      t.dtype for t in nest.flatten(type_spec, expand_composites=True)
90  ]
91
92  components = gen_composite_tensor_ops.CompositeTensorVariantToComponents(
93      encoded=encoded,
94      metadata=metadata.SerializeToString(),
95      Tcomponents=component_dtypes,
96      name=name)
97  return nest.pack_sequence_as(type_spec, components, expand_composites=True)
98
99
100@ops.RegisterGradient("CompositeTensorVariantFromComponents")
101def _composite_tensor_to_variants_grad(op, grad):
102  return gen_composite_tensor_ops.CompositeTensorVariantToComponents(
103      encoded=grad,
104      metadata=op.get_attr("metadata"),
105      Tcomponents=op.get_attr("Tcomponents"))
106
107
108@ops.RegisterGradient("CompositeTensorVariantToComponents")
109def _composite_tensor_from_variant_grad(op, *grad):
110  assert len(grad) == len(op.outputs)
111  # `components` is `op.outputs`, but with any tensors for which we're
112  # taking the gradient replaced by the corresponding value from `grad`.
113  components = [
114      op.outputs[i] if grad[i] is None else grad[i] for i in range(len(grad))
115  ]
116  return gen_composite_tensor_ops.CompositeTensorVariantFromComponents(
117      components=components, metadata=op.get_attr("metadata"))
118