1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Core TensorFlow types.""" 16 17import sys 18import textwrap 19 20from typing import Union 21 22import numpy as np 23 24from tensorflow.python.types import doc_typealias 25from tensorflow.python.util.tf_export import tf_export 26 27# pylint:disable=g-import-not-at-top 28if sys.version_info >= (3, 8): 29 from typing import Protocol 30 from typing import runtime_checkable 31else: 32 from typing_extensions import Protocol 33 from typing_extensions import runtime_checkable 34# pylint:enable=g-import-not-at-top 35 36# TODO(mdan): Consider adding ABC once the dependence on isinstance is reduced. 37# TODO(mdan): Add type annotations. 38 39 40# TODO(b/178822082): Revisit this API when tf.types gets more resource. 41@tf_export("__internal__.types.Tensor", v1=[]) 42class Tensor(object): 43 """The base class of all dense Tensor objects. 44 45 A dense tensor has a static data type (dtype), and may have a static rank and 46 shape. Tensor objects are immutable. Mutable objects may be backed by a Tensor 47 which holds the unique handle that identifies the mutable object. 48 """ 49 50 @property 51 def dtype(self): 52 pass 53 54 @property 55 def shape(self): 56 pass 57 58 59class Symbol(Tensor): 60 """Symbolic "graph" Tensor. 61 62 These objects represent the output of an op definition and do not carry a 63 value. 64 """ 65 pass 66 67 68class Value(Tensor): 69 """Tensor that can be associated with a value (aka "eager tensor"). 70 71 These objects represent the (usually future) output of executing an op 72 immediately. 73 """ 74 75 def numpy(self): 76 pass 77 78 79@tf_export("types.experimental.Callable", v1=[]) 80class Callable: 81 """Base class for TF callables like those created by tf.function. 82 83 Note: Callables are conceptually very similar to `tf.Operation`: a 84 `tf.Operation` is a kind of callable. 85 """ 86 87 def __call__(self, *args, **kwargs): 88 """Executes this callable. 89 90 This behaves like a regular op - in eager mode, it immediately starts 91 execution, returning results. In graph mode, it creates ops which return 92 symbolic TensorFlow values (like `tf.Tensor`, `tf.data.Dataset`, 93 etc.). For example, `tf.function` callables typically generate a 94 `tf.raw_ops.PartitionedCall` op, but not always - the 95 exact operations being generated are an internal implementation detail. 96 97 Args: 98 *args: positional argument for this call 99 **kwargs: keyword arguments for this call 100 Returns: 101 The execution results. 102 """ 103 104 105@tf_export("types.experimental.ConcreteFunction", v1=[]) 106class ConcreteFunction(Callable): 107 """Base class for graph functions. 108 109 A `ConcreteFunction` encapsulates a single graph function definition and 110 is differentiable under `tf.GradientTape` contexts. 111 """ 112 113 114# TODO(mdan): Name just `types.Function`, for historic continuity? 115@tf_export("types.experimental.GenericFunction", v1=[]) 116class GenericFunction(Callable): 117 """Base class for polymorphic graph functions. 118 119 Graph functions are Python callable objects that dispatch calls to a 120 TensorFlow graph. Polymorphic graph functions can be backed by multiple TF 121 graphs, and automatically select the appropriate specialization based on the 122 type of input they were called with. They may also create specializations on 123 the fly if necessary, for example by tracing. 124 125 Also see `tf.function`. 126 """ 127 128 def get_concrete_function(self, *args, **kwargs) -> ConcreteFunction: 129 """Returns a `ConcreteFunction` specialized to input types. 130 131 The arguments specified by `args` and `kwargs` follow normal function call 132 rules. The returned `ConcreteFunction` has the same set of positional and 133 keyword arguments as `self`, but their types are compatible to the types 134 specified by `args` and `kwargs` (though not neccessarily equal). 135 136 >>> @tf.function 137 ... def f(x): 138 ... return x 139 >>> f_concrete = f.get_concrete_function(tf.constant(1.0)) 140 >>> f_concrete = f.get_concrete_function(x=tf.constant(1.0)) 141 142 Unlike normal calls, `get_concrete_function` allow type specifiers instead 143 of TensorFlow objects, so for example `tf.Tensor`s may be replaced with 144 `tf.TensorSpec`s. 145 146 >>> @tf.function 147 ... def f(x): 148 ... return x 149 >>> f_concrete = f.get_concrete_function(tf.TensorSpec([], tf.float64)) 150 151 If the function definition allows only one specialization, `args` and 152 `kwargs` may be omitted altogether. 153 154 >>> @tf.function(input_signature=[tf.TensorSpec(None, tf.float32)]) 155 ... def f(x): 156 ... return x 157 >>> f_concrete = f.get_concrete_function() 158 159 The returned `ConcreteFunction` can be called normally: 160 161 >>> f_concrete(tf.constant(1.0)) 162 <tf.Tensor: shape=(), dtype=float32, numpy=1.0> 163 >>> f_concrete(x=tf.constant(1.0)) 164 <tf.Tensor: shape=(), dtype=float32, numpy=1.0> 165 166 Args: 167 *args: inputs to specialize on. 168 **kwargs: inputs to specialize on. 169 170 Returns: 171 A `ConcreteFunction`. 172 """ 173 pass 174 175 def experimental_get_compiler_ir(self, *args, **kwargs): 176 """Returns compiler IR for the compiled function. 177 178 This API is intended *only* for debugging as there are no guarantees on 179 backwards compatibility of returned IR or the allowed values of `stage`. 180 181 Args: 182 *args: Arguments used for compilation; same arguments as used for calling 183 the function. Need to be eager tensors. 184 **kwargs: Keyword arguments used for compilation. 185 186 Returns: 187 Function callable with the following kwargs: 188 - `stage` at which the compiler IR should be serialized. Allowed values 189 are: 190 - `hlo`: HLO output after conversion from TF 191 (https://www.tensorflow.org/xla/operation_semantics). 192 - `hlo_serialized`: Like stage=`hlo`, but the output is a serialized 193 HLO module proto (a bytes object). 194 - `optimized_hlo`: HLO after compiler optimizations. 195 - `optimized_hlo_serialized`: Like stage=`optimized_hlo`, but the 196 output is a serialized HLO module proto (a bytes object). 197 - `optimized_hlo_dot`: optimized HLO in DOT format suitable for 198 Graphviz. 199 - `device_name` can be either None, in which case the preferred device 200 is used for compilation, or a device name. It can be a full device 201 name, or a partial one, e.g., `/device:CPU:0`. 202 203 For example, for 204 205 ```python 206 @tf.function(jit_compile=True) 207 def f(x): 208 return x + 1 209 210 f.experimental_get_compiler_ir(tf.random.normal([10, 10])(stage='hlo') 211 ``` 212 213 the output is: 214 215 ``` 216 HloModule a_inference_f_13__.9 217 218 ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] { 219 %arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false} 220 %reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1) 221 %constant.3 = f32[] constant(1) 222 %broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3) 223 %add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2, 224 f32[10,10]{1,0} %broadcast.4) 225 %reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5) 226 %tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6) 227 ROOT %get-tuple-element.8 = f32[10,10]{1,0} 228 get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0 229 } 230 ``` 231 232 Raises: 233 ValueError: If an invalid `stage` is selected or if applied to a function 234 which is not compiled (`jit_compile=True` is not set). 235 TypeError: When called with input in graph mode. 236 """ 237 pass 238 239 240@runtime_checkable 241class TensorProtocol(Protocol): 242 """Protocol type for objects that can be converted to Tensor.""" 243 244 def __tf_tensor__(self, dtype=None, name=None): 245 """Converts this object to a Tensor. 246 247 Args: 248 dtype: data type for the returned Tensor 249 name: a name for the operations which create the Tensor 250 Returns: 251 A Tensor. 252 """ 253 pass 254 255 256# TODO(rahulkamat): Add missing types that are convertible to Tensor. 257TensorLike = Union[Tensor, TensorProtocol, int, float, bool, str, bytes, 258 complex, tuple, list, np.ndarray, np.generic] 259doc_typealias.document( 260 obj=TensorLike, 261 doc=textwrap.dedent("""\ 262 Union of all types that can be converted to a `tf.Tensor` by `tf.convert_to_tensor`. 263 264 This definition may be used in user code. Additional types may be added 265 in the future as more input types are supported. 266 267 Example: 268 269 ``` 270 def foo(x: TensorLike): 271 pass 272 ``` 273 274 This definition passes static type verification for: 275 276 ``` 277 foo(tf.constant([1, 2, 3])) 278 foo([1, 2, 3]) 279 foo(np.array([1, 2, 3])) 280 ``` 281 """), 282) 283tf_export("types.experimental.TensorLike").export_constant( 284 __name__, "TensorLike") 285