xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/tensor_spec.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A TensorSpec class."""
16
17from typing import Type
18
19import numpy as np
20
21from tensorflow.core.function import trace_type
22from tensorflow.core.protobuf import struct_pb2
23from tensorflow.python.framework import common_shapes
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.framework import type_spec
29from tensorflow.python.util import _pywrap_utils
30from tensorflow.python.util.tf_export import tf_export
31
32
33class DenseSpec(type_spec.TypeSpec):
34  """Describes a dense object with shape, dtype, and name."""
35
36  __slots__ = ["_shape", "_dtype", "_name"]
37
38  _component_specs = property(lambda self: self)
39
40  def __init__(self, shape, dtype=dtypes.float32, name=None):
41    """Creates a TensorSpec.
42
43    Args:
44      shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
45      dtype: Value convertible to `tf.DType`. The type of the tensor values.
46      name: Optional name for the Tensor.
47
48    Raises:
49      TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is
50        not convertible to a `tf.DType`.
51    """
52    self._shape = tensor_shape.TensorShape(shape)
53    self._dtype = dtypes.as_dtype(dtype)
54    self._name = name
55
56  @property
57  def shape(self):
58    """Returns the `TensorShape` that represents the shape of the tensor."""
59    return self._shape
60
61  @property
62  def dtype(self):
63    """Returns the `dtype` of elements in the tensor."""
64    return self._dtype
65
66  @property
67  def name(self):
68    """Returns the (optionally provided) name of the described tensor."""
69    return self._name
70
71  def is_compatible_with(self, spec_or_value):
72    return (isinstance(spec_or_value, (DenseSpec, self.value_type)) and
73            self._dtype.is_compatible_with(spec_or_value.dtype) and
74            self._shape.is_compatible_with(spec_or_value.shape))
75
76  def __repr__(self):
77    return "{}(shape={}, dtype={}, name={})".format(
78        type(self).__name__, self.shape, repr(self.dtype), repr(self.name))
79
80  def __hash__(self):
81    return hash((self._shape, self.dtype))
82
83  def __eq__(self, other):
84    # pylint: disable=protected-access
85    return (type(self) is type(other) and self._shape == other._shape and
86            self._dtype == other._dtype and self._name == other._name)
87
88  def __ne__(self, other):
89    return not self == other
90
91  def _serialize(self):
92    return (self._shape, self._dtype, self._name)
93
94  def _to_legacy_output_types(self):
95    return self._dtype
96
97  def _to_legacy_output_shapes(self):
98    return self._shape
99
100  def _to_legacy_output_classes(self):
101    return self.value_type
102
103
104@tf_export("TensorSpec")
105@type_spec.register("tf.TensorSpec")
106class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec,
107                 trace_type.Serializable):
108  """Describes a tf.Tensor.
109
110  Metadata for describing the `tf.Tensor` objects accepted or returned
111  by some TensorFlow APIs.
112  """
113
114  __slots__ = []
115
116  @classmethod
117  def experimental_type_proto(cls) -> Type[struct_pb2.TensorSpecProto]:
118    """Returns the type of proto associated with TensorSpec serialization."""
119    return struct_pb2.TensorSpecProto
120
121  @classmethod
122  def experimental_from_proto(
123      cls, proto: struct_pb2.TensorSpecProto) -> "TensorSpec":
124    """Returns a TensorSpec instance based on the serialized proto."""
125    return TensorSpec(
126        shape=tensor_shape.TensorShape.experimental_from_proto(proto.shape),
127        dtype=proto.dtype,
128        name=proto.name if proto.name else None)
129
130  def experimental_as_proto(self) -> struct_pb2.TensorSpecProto:
131    """Returns a proto representation of the TensorSpec instance."""
132    return struct_pb2.TensorSpecProto(
133        shape=self.shape.experimental_as_proto(),
134        dtype=self.dtype.experimental_as_proto().datatype,
135        name=self.name)
136
137  def is_compatible_with(self, spec_or_tensor):  # pylint:disable=useless-super-delegation
138    """Returns True if spec_or_tensor is compatible with this TensorSpec.
139
140    Two tensors are considered compatible if they have the same dtype
141    and their shapes are compatible (see `tf.TensorShape.is_compatible_with`).
142
143    Args:
144      spec_or_tensor: A tf.TensorSpec or a tf.Tensor
145
146    Returns:
147      True if spec_or_tensor is compatible with self.
148    """
149    return super(TensorSpec, self).is_compatible_with(spec_or_tensor)
150
151  @classmethod
152  def from_spec(cls, spec, name=None):
153    """Returns a `TensorSpec` with the same shape and dtype as `spec`.
154
155    >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="OriginalName")
156    >>> tf.TensorSpec.from_spec(spec, "NewName")
157    TensorSpec(shape=(8, 3), dtype=tf.int32, name='NewName')
158
159    Args:
160      spec: The `TypeSpec` used to create the new `TensorSpec`.
161      name: The name for the new `TensorSpec`.  Defaults to `spec.name`.
162    """
163    return cls(spec.shape, spec.dtype, name or spec.name)
164
165  @classmethod
166  def from_tensor(cls, tensor, name=None):
167    """Returns a `TensorSpec` that describes `tensor`.
168
169    >>> tf.TensorSpec.from_tensor(tf.constant([1, 2, 3]))
170    TensorSpec(shape=(3,), dtype=tf.int32, name=None)
171
172    Args:
173      tensor: The `tf.Tensor` that should be described.
174      name: A name for the `TensorSpec`.  Defaults to `tensor.op.name`.
175
176    Returns:
177      A `TensorSpec` that describes `tensor`.
178    """
179    if isinstance(tensor, ops.EagerTensor):
180      return TensorSpec(tensor.shape, tensor.dtype, name)
181    elif isinstance(tensor, ops.Tensor):
182      return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name)
183    else:
184      raise ValueError(
185          f"`tensor` should be a tf.Tensor, but got type {type(tensor)}.")
186
187  @property
188  def value_type(self):
189    """The Python type for values that are compatible with this TypeSpec."""
190    return ops.Tensor
191
192  def _to_components(self, value):
193    try:
194      value = ops.convert_to_tensor(value, self._dtype)
195    except (TypeError, ValueError):
196      raise ValueError(f"Value {value} is not convertible to a tensor with "
197                       f"dtype {self._dtype} and shape {self._shape}.")
198    if not value.shape.is_compatible_with(self._shape):
199      raise ValueError(f"Value {value} is not convertible to a tensor with "
200                       f"dtype {self._dtype} and shape {self._shape}.")
201    return value
202
203  def _from_components(self, components):
204    return components
205
206  def _from_compatible_tensor_list(self, tensor_list):
207    # TODO(b/112266545): It would be cleaner to create a new `ensure_shape()`
208    # op here and return that, instead of mutating the input's shape using
209    # `Tensor.set_shape()`. However, that would add extra ops, which could
210    # impact performance. When this bug is resolved, we should be able to add
211    # the `ensure_shape()` ops and optimize them away using contextual shape
212    # information.
213    assert len(tensor_list) == 1
214    tensor_list[0].set_shape(self._shape)
215    return tensor_list[0]
216
217  def _to_batchable_tensor_list(self, value, batched=False):
218    if batched and self._shape.merge_with(value.shape).ndims == 0:
219      raise ValueError("Unbatching a tensor is only supported for rank >= 1")
220    return self._to_components(value)
221
222  def _batch(self, batch_size):
223    return TensorSpec(
224        tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
225        self._dtype)
226
227  def _unbatch(self):
228    if self._shape.ndims == 0:
229      raise ValueError("Unbatching a tensor is only supported for rank >= 1")
230    return TensorSpec(self._shape[1:], self._dtype)
231
232  @property
233  def _flat_tensor_specs(self):
234    return [self]
235
236  def _to_tensor_list(self, value):
237    return [self._to_components(value)]
238
239  def _to_batched_tensor_list(self, value):
240    return self._to_tensor_list(value)
241
242  # TODO(b/206014848): Helper function to support logic that does not consider
243  # Tensor name. Will be removed once load-bearing usages of Tensor name are
244  # fixed.
245  def _without_tensor_names(self) -> "TensorSpec":
246    """Returns a version of `TensorSpec` with the name removed."""
247    if self.name is None:
248      return self
249    else:
250      return TensorSpec(self.shape, self.dtype)
251
252trace_type.register_serializable(TensorSpec)
253
254
255# TODO(b/133606651): Should is_compatible_with should check min/max bounds?
256@type_spec.register("tf.BoundedTensorSpec")
257class BoundedTensorSpec(TensorSpec, trace_type.Serializable):
258  """A `TensorSpec` that specifies minimum and maximum values.
259
260  Example usage:
261  ```python
262  spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5))
263  tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype)
264  tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype)
265  ```
266
267  Bounds are meant to be inclusive. This is especially important for
268  integer types. The following spec will be satisfied by tensors
269  with values in the set {0, 1, 2}:
270  ```python
271  spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2)
272  ```
273  """
274
275  __slots__ = ("_minimum", "_maximum")
276
277  def __init__(self, shape, dtype, minimum, maximum, name=None):
278    """Initializes a new `BoundedTensorSpec`.
279
280    Args:
281      shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
282      dtype: Value convertible to `tf.DType`. The type of the tensor values.
283      minimum: Number or sequence specifying the minimum element bounds
284        (inclusive). Must be broadcastable to `shape`.
285      maximum: Number or sequence specifying the maximum element bounds
286        (inclusive). Must be broadcastable to `shape`.
287      name: Optional string containing a semantic name for the corresponding
288        array. Defaults to `None`.
289
290    Raises:
291      ValueError: If `minimum` or `maximum` are not provided or not
292        broadcastable to `shape`.
293      TypeError: If the shape is not an iterable or if the `dtype` is an invalid
294        numpy dtype.
295    """
296    super(BoundedTensorSpec, self).__init__(shape, dtype, name)
297
298    if minimum is None:
299      raise ValueError("`minimum` can not be None.")
300    if maximum is None:
301      raise ValueError("`maximum` can not be None.")
302
303    try:
304      minimum_shape = np.shape(minimum)
305      common_shapes.broadcast_shape(
306          tensor_shape.TensorShape(minimum_shape), self.shape)
307    except ValueError as exception:
308      raise ValueError(f"`minimum` {minimum} is not compatible with shape "
309                       f"{self.shape}. Original error: {exception!r}.")
310
311    try:
312      maximum_shape = np.shape(maximum)
313      common_shapes.broadcast_shape(
314          tensor_shape.TensorShape(maximum_shape), self.shape)
315    except ValueError as exception:
316      raise ValueError(f"`maximum` {maximum} is not compatible with shape "
317                       f"{self.shape}. Original error: {exception!r}.")
318
319    self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype)
320    self._minimum.setflags(write=False)
321
322    self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype)
323    self._maximum.setflags(write=False)
324
325  @classmethod
326  def experimental_type_proto(cls) -> Type[struct_pb2.BoundedTensorSpecProto]:
327    """Returns the type of proto associated with BoundedTensorSpec serialization."""
328    return struct_pb2.BoundedTensorSpecProto
329
330  @classmethod
331  def experimental_from_proto(
332      cls, proto: struct_pb2.BoundedTensorSpecProto) -> "BoundedTensorSpec":
333    """Returns a BoundedTensorSpec instance based on the serialized proto."""
334    return BoundedTensorSpec(
335        shape=tensor_shape.TensorShape.experimental_from_proto(proto.shape),
336        dtype=proto.dtype,
337        minimum=tensor_util.MakeNdarray(proto.minimum),
338        maximum=tensor_util.MakeNdarray(proto.maximum),
339        name=proto.name if proto.name else None)
340
341  def experimental_as_proto(self) -> struct_pb2.BoundedTensorSpecProto:
342    """Returns a proto representation of the BoundedTensorSpec instance."""
343    return struct_pb2.BoundedTensorSpecProto(
344        shape=self.shape.experimental_as_proto(),
345        dtype=self.dtype.experimental_as_proto().datatype,
346        minimum=tensor_util.make_tensor_proto(self._minimum),
347        maximum=tensor_util.make_tensor_proto(self._maximum),
348        name=self.name)
349
350  @classmethod
351  def from_spec(cls, spec):
352    """Returns a `TensorSpec` with the same shape and dtype as `spec`.
353
354    If `spec` is a `BoundedTensorSpec`, then the new spec's bounds are set to
355    `spec.minimum` and `spec.maximum`; otherwise, the bounds are set to
356    `spec.dtype.min` and `spec.dtype.max`.
357
358    >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="x")
359    >>> BoundedTensorSpec.from_spec(spec)
360    BoundedTensorSpec(shape=(8, 3), dtype=tf.int32, name='x',
361        minimum=array(-2147483648, dtype=int32),
362        maximum=array(2147483647, dtype=int32))
363
364    Args:
365      spec: The `TypeSpec` used to create the new `BoundedTensorSpec`.
366    """
367    dtype = dtypes.as_dtype(spec.dtype)
368    minimum = getattr(spec, "minimum", dtype.min)
369    maximum = getattr(spec, "maximum", dtype.max)
370    return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name)
371
372  @property
373  def minimum(self):
374    """Returns a NumPy array specifying the minimum bounds (inclusive)."""
375    return self._minimum
376
377  @property
378  def maximum(self):
379    """Returns a NumPy array specifying the maximum bounds (inclusive)."""
380    return self._maximum
381
382  def __repr__(self):
383    s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})"
384    return s.format(self.shape, repr(self.dtype), repr(self.name),
385                    repr(self.minimum), repr(self.maximum))
386
387  def __eq__(self, other):
388    tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other)
389    return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and
390            np.allclose(self.maximum, other.maximum))
391
392  def __hash__(self):
393    return hash((self._shape, self.dtype))
394
395  def __reduce__(self):
396    return BoundedTensorSpec, (self._shape, self._dtype, self._minimum,
397                               self._maximum, self._name)
398
399  def _serialize(self):
400    return (self._shape, self._dtype, self._minimum, self._maximum, self._name)
401
402trace_type.register_serializable(BoundedTensorSpec)
403_pywrap_utils.RegisterType("TensorSpec", TensorSpec)
404
405# Note: we do not include Tensor names when constructing TypeSpecs.
406type_spec.register_type_spec_from_value_converter(
407    ops.Tensor, lambda tensor: TensorSpec(tensor.shape, tensor.dtype))
408
409type_spec.register_type_spec_from_value_converter(
410    np.ndarray, lambda array: TensorSpec(array.shape, array.dtype))
411