1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A TensorSpec class.""" 16 17from typing import Type 18 19import numpy as np 20 21from tensorflow.core.function import trace_type 22from tensorflow.core.protobuf import struct_pb2 23from tensorflow.python.framework import common_shapes 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import tensor_util 28from tensorflow.python.framework import type_spec 29from tensorflow.python.util import _pywrap_utils 30from tensorflow.python.util.tf_export import tf_export 31 32 33class DenseSpec(type_spec.TypeSpec): 34 """Describes a dense object with shape, dtype, and name.""" 35 36 __slots__ = ["_shape", "_dtype", "_name"] 37 38 _component_specs = property(lambda self: self) 39 40 def __init__(self, shape, dtype=dtypes.float32, name=None): 41 """Creates a TensorSpec. 42 43 Args: 44 shape: Value convertible to `tf.TensorShape`. The shape of the tensor. 45 dtype: Value convertible to `tf.DType`. The type of the tensor values. 46 name: Optional name for the Tensor. 47 48 Raises: 49 TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is 50 not convertible to a `tf.DType`. 51 """ 52 self._shape = tensor_shape.TensorShape(shape) 53 self._dtype = dtypes.as_dtype(dtype) 54 self._name = name 55 56 @property 57 def shape(self): 58 """Returns the `TensorShape` that represents the shape of the tensor.""" 59 return self._shape 60 61 @property 62 def dtype(self): 63 """Returns the `dtype` of elements in the tensor.""" 64 return self._dtype 65 66 @property 67 def name(self): 68 """Returns the (optionally provided) name of the described tensor.""" 69 return self._name 70 71 def is_compatible_with(self, spec_or_value): 72 return (isinstance(spec_or_value, (DenseSpec, self.value_type)) and 73 self._dtype.is_compatible_with(spec_or_value.dtype) and 74 self._shape.is_compatible_with(spec_or_value.shape)) 75 76 def __repr__(self): 77 return "{}(shape={}, dtype={}, name={})".format( 78 type(self).__name__, self.shape, repr(self.dtype), repr(self.name)) 79 80 def __hash__(self): 81 return hash((self._shape, self.dtype)) 82 83 def __eq__(self, other): 84 # pylint: disable=protected-access 85 return (type(self) is type(other) and self._shape == other._shape and 86 self._dtype == other._dtype and self._name == other._name) 87 88 def __ne__(self, other): 89 return not self == other 90 91 def _serialize(self): 92 return (self._shape, self._dtype, self._name) 93 94 def _to_legacy_output_types(self): 95 return self._dtype 96 97 def _to_legacy_output_shapes(self): 98 return self._shape 99 100 def _to_legacy_output_classes(self): 101 return self.value_type 102 103 104@tf_export("TensorSpec") 105@type_spec.register("tf.TensorSpec") 106class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec, 107 trace_type.Serializable): 108 """Describes a tf.Tensor. 109 110 Metadata for describing the `tf.Tensor` objects accepted or returned 111 by some TensorFlow APIs. 112 """ 113 114 __slots__ = [] 115 116 @classmethod 117 def experimental_type_proto(cls) -> Type[struct_pb2.TensorSpecProto]: 118 """Returns the type of proto associated with TensorSpec serialization.""" 119 return struct_pb2.TensorSpecProto 120 121 @classmethod 122 def experimental_from_proto( 123 cls, proto: struct_pb2.TensorSpecProto) -> "TensorSpec": 124 """Returns a TensorSpec instance based on the serialized proto.""" 125 return TensorSpec( 126 shape=tensor_shape.TensorShape.experimental_from_proto(proto.shape), 127 dtype=proto.dtype, 128 name=proto.name if proto.name else None) 129 130 def experimental_as_proto(self) -> struct_pb2.TensorSpecProto: 131 """Returns a proto representation of the TensorSpec instance.""" 132 return struct_pb2.TensorSpecProto( 133 shape=self.shape.experimental_as_proto(), 134 dtype=self.dtype.experimental_as_proto().datatype, 135 name=self.name) 136 137 def is_compatible_with(self, spec_or_tensor): # pylint:disable=useless-super-delegation 138 """Returns True if spec_or_tensor is compatible with this TensorSpec. 139 140 Two tensors are considered compatible if they have the same dtype 141 and their shapes are compatible (see `tf.TensorShape.is_compatible_with`). 142 143 Args: 144 spec_or_tensor: A tf.TensorSpec or a tf.Tensor 145 146 Returns: 147 True if spec_or_tensor is compatible with self. 148 """ 149 return super(TensorSpec, self).is_compatible_with(spec_or_tensor) 150 151 @classmethod 152 def from_spec(cls, spec, name=None): 153 """Returns a `TensorSpec` with the same shape and dtype as `spec`. 154 155 >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="OriginalName") 156 >>> tf.TensorSpec.from_spec(spec, "NewName") 157 TensorSpec(shape=(8, 3), dtype=tf.int32, name='NewName') 158 159 Args: 160 spec: The `TypeSpec` used to create the new `TensorSpec`. 161 name: The name for the new `TensorSpec`. Defaults to `spec.name`. 162 """ 163 return cls(spec.shape, spec.dtype, name or spec.name) 164 165 @classmethod 166 def from_tensor(cls, tensor, name=None): 167 """Returns a `TensorSpec` that describes `tensor`. 168 169 >>> tf.TensorSpec.from_tensor(tf.constant([1, 2, 3])) 170 TensorSpec(shape=(3,), dtype=tf.int32, name=None) 171 172 Args: 173 tensor: The `tf.Tensor` that should be described. 174 name: A name for the `TensorSpec`. Defaults to `tensor.op.name`. 175 176 Returns: 177 A `TensorSpec` that describes `tensor`. 178 """ 179 if isinstance(tensor, ops.EagerTensor): 180 return TensorSpec(tensor.shape, tensor.dtype, name) 181 elif isinstance(tensor, ops.Tensor): 182 return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name) 183 else: 184 raise ValueError( 185 f"`tensor` should be a tf.Tensor, but got type {type(tensor)}.") 186 187 @property 188 def value_type(self): 189 """The Python type for values that are compatible with this TypeSpec.""" 190 return ops.Tensor 191 192 def _to_components(self, value): 193 try: 194 value = ops.convert_to_tensor(value, self._dtype) 195 except (TypeError, ValueError): 196 raise ValueError(f"Value {value} is not convertible to a tensor with " 197 f"dtype {self._dtype} and shape {self._shape}.") 198 if not value.shape.is_compatible_with(self._shape): 199 raise ValueError(f"Value {value} is not convertible to a tensor with " 200 f"dtype {self._dtype} and shape {self._shape}.") 201 return value 202 203 def _from_components(self, components): 204 return components 205 206 def _from_compatible_tensor_list(self, tensor_list): 207 # TODO(b/112266545): It would be cleaner to create a new `ensure_shape()` 208 # op here and return that, instead of mutating the input's shape using 209 # `Tensor.set_shape()`. However, that would add extra ops, which could 210 # impact performance. When this bug is resolved, we should be able to add 211 # the `ensure_shape()` ops and optimize them away using contextual shape 212 # information. 213 assert len(tensor_list) == 1 214 tensor_list[0].set_shape(self._shape) 215 return tensor_list[0] 216 217 def _to_batchable_tensor_list(self, value, batched=False): 218 if batched and self._shape.merge_with(value.shape).ndims == 0: 219 raise ValueError("Unbatching a tensor is only supported for rank >= 1") 220 return self._to_components(value) 221 222 def _batch(self, batch_size): 223 return TensorSpec( 224 tensor_shape.TensorShape([batch_size]).concatenate(self._shape), 225 self._dtype) 226 227 def _unbatch(self): 228 if self._shape.ndims == 0: 229 raise ValueError("Unbatching a tensor is only supported for rank >= 1") 230 return TensorSpec(self._shape[1:], self._dtype) 231 232 @property 233 def _flat_tensor_specs(self): 234 return [self] 235 236 def _to_tensor_list(self, value): 237 return [self._to_components(value)] 238 239 def _to_batched_tensor_list(self, value): 240 return self._to_tensor_list(value) 241 242 # TODO(b/206014848): Helper function to support logic that does not consider 243 # Tensor name. Will be removed once load-bearing usages of Tensor name are 244 # fixed. 245 def _without_tensor_names(self) -> "TensorSpec": 246 """Returns a version of `TensorSpec` with the name removed.""" 247 if self.name is None: 248 return self 249 else: 250 return TensorSpec(self.shape, self.dtype) 251 252trace_type.register_serializable(TensorSpec) 253 254 255# TODO(b/133606651): Should is_compatible_with should check min/max bounds? 256@type_spec.register("tf.BoundedTensorSpec") 257class BoundedTensorSpec(TensorSpec, trace_type.Serializable): 258 """A `TensorSpec` that specifies minimum and maximum values. 259 260 Example usage: 261 ```python 262 spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5)) 263 tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype) 264 tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype) 265 ``` 266 267 Bounds are meant to be inclusive. This is especially important for 268 integer types. The following spec will be satisfied by tensors 269 with values in the set {0, 1, 2}: 270 ```python 271 spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2) 272 ``` 273 """ 274 275 __slots__ = ("_minimum", "_maximum") 276 277 def __init__(self, shape, dtype, minimum, maximum, name=None): 278 """Initializes a new `BoundedTensorSpec`. 279 280 Args: 281 shape: Value convertible to `tf.TensorShape`. The shape of the tensor. 282 dtype: Value convertible to `tf.DType`. The type of the tensor values. 283 minimum: Number or sequence specifying the minimum element bounds 284 (inclusive). Must be broadcastable to `shape`. 285 maximum: Number or sequence specifying the maximum element bounds 286 (inclusive). Must be broadcastable to `shape`. 287 name: Optional string containing a semantic name for the corresponding 288 array. Defaults to `None`. 289 290 Raises: 291 ValueError: If `minimum` or `maximum` are not provided or not 292 broadcastable to `shape`. 293 TypeError: If the shape is not an iterable or if the `dtype` is an invalid 294 numpy dtype. 295 """ 296 super(BoundedTensorSpec, self).__init__(shape, dtype, name) 297 298 if minimum is None: 299 raise ValueError("`minimum` can not be None.") 300 if maximum is None: 301 raise ValueError("`maximum` can not be None.") 302 303 try: 304 minimum_shape = np.shape(minimum) 305 common_shapes.broadcast_shape( 306 tensor_shape.TensorShape(minimum_shape), self.shape) 307 except ValueError as exception: 308 raise ValueError(f"`minimum` {minimum} is not compatible with shape " 309 f"{self.shape}. Original error: {exception!r}.") 310 311 try: 312 maximum_shape = np.shape(maximum) 313 common_shapes.broadcast_shape( 314 tensor_shape.TensorShape(maximum_shape), self.shape) 315 except ValueError as exception: 316 raise ValueError(f"`maximum` {maximum} is not compatible with shape " 317 f"{self.shape}. Original error: {exception!r}.") 318 319 self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype) 320 self._minimum.setflags(write=False) 321 322 self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype) 323 self._maximum.setflags(write=False) 324 325 @classmethod 326 def experimental_type_proto(cls) -> Type[struct_pb2.BoundedTensorSpecProto]: 327 """Returns the type of proto associated with BoundedTensorSpec serialization.""" 328 return struct_pb2.BoundedTensorSpecProto 329 330 @classmethod 331 def experimental_from_proto( 332 cls, proto: struct_pb2.BoundedTensorSpecProto) -> "BoundedTensorSpec": 333 """Returns a BoundedTensorSpec instance based on the serialized proto.""" 334 return BoundedTensorSpec( 335 shape=tensor_shape.TensorShape.experimental_from_proto(proto.shape), 336 dtype=proto.dtype, 337 minimum=tensor_util.MakeNdarray(proto.minimum), 338 maximum=tensor_util.MakeNdarray(proto.maximum), 339 name=proto.name if proto.name else None) 340 341 def experimental_as_proto(self) -> struct_pb2.BoundedTensorSpecProto: 342 """Returns a proto representation of the BoundedTensorSpec instance.""" 343 return struct_pb2.BoundedTensorSpecProto( 344 shape=self.shape.experimental_as_proto(), 345 dtype=self.dtype.experimental_as_proto().datatype, 346 minimum=tensor_util.make_tensor_proto(self._minimum), 347 maximum=tensor_util.make_tensor_proto(self._maximum), 348 name=self.name) 349 350 @classmethod 351 def from_spec(cls, spec): 352 """Returns a `TensorSpec` with the same shape and dtype as `spec`. 353 354 If `spec` is a `BoundedTensorSpec`, then the new spec's bounds are set to 355 `spec.minimum` and `spec.maximum`; otherwise, the bounds are set to 356 `spec.dtype.min` and `spec.dtype.max`. 357 358 >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="x") 359 >>> BoundedTensorSpec.from_spec(spec) 360 BoundedTensorSpec(shape=(8, 3), dtype=tf.int32, name='x', 361 minimum=array(-2147483648, dtype=int32), 362 maximum=array(2147483647, dtype=int32)) 363 364 Args: 365 spec: The `TypeSpec` used to create the new `BoundedTensorSpec`. 366 """ 367 dtype = dtypes.as_dtype(spec.dtype) 368 minimum = getattr(spec, "minimum", dtype.min) 369 maximum = getattr(spec, "maximum", dtype.max) 370 return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name) 371 372 @property 373 def minimum(self): 374 """Returns a NumPy array specifying the minimum bounds (inclusive).""" 375 return self._minimum 376 377 @property 378 def maximum(self): 379 """Returns a NumPy array specifying the maximum bounds (inclusive).""" 380 return self._maximum 381 382 def __repr__(self): 383 s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})" 384 return s.format(self.shape, repr(self.dtype), repr(self.name), 385 repr(self.minimum), repr(self.maximum)) 386 387 def __eq__(self, other): 388 tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other) 389 return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and 390 np.allclose(self.maximum, other.maximum)) 391 392 def __hash__(self): 393 return hash((self._shape, self.dtype)) 394 395 def __reduce__(self): 396 return BoundedTensorSpec, (self._shape, self._dtype, self._minimum, 397 self._maximum, self._name) 398 399 def _serialize(self): 400 return (self._shape, self._dtype, self._minimum, self._maximum, self._name) 401 402trace_type.register_serializable(BoundedTensorSpec) 403_pywrap_utils.RegisterType("TensorSpec", TensorSpec) 404 405# Note: we do not include Tensor names when constructing TypeSpecs. 406type_spec.register_type_spec_from_value_converter( 407 ops.Tensor, lambda tensor: TensorSpec(tensor.shape, tensor.dtype)) 408 409type_spec.register_type_spec_from_value_converter( 410 np.ndarray, lambda array: TensorSpec(array.shape, array.dtype)) 411