xref: /aosp_15_r20/external/tensorflow/tensorflow/python/types/core.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Core TensorFlow types."""
16
17import sys
18import textwrap
19
20from typing import Union
21
22import numpy as np
23
24from tensorflow.python.types import doc_typealias
25from tensorflow.python.util.tf_export import tf_export
26
27# pylint:disable=g-import-not-at-top
28if sys.version_info >= (3, 8):
29  from typing import Protocol
30  from typing import runtime_checkable
31else:
32  from typing_extensions import Protocol
33  from typing_extensions import runtime_checkable
34# pylint:enable=g-import-not-at-top
35
36# TODO(mdan): Consider adding ABC once the dependence on isinstance is reduced.
37# TODO(mdan): Add type annotations.
38
39
40# TODO(b/178822082): Revisit this API when tf.types gets more resource.
41@tf_export("__internal__.types.Tensor", v1=[])
42class Tensor(object):
43  """The base class of all dense Tensor objects.
44
45  A dense tensor has a static data type (dtype), and may have a static rank and
46  shape. Tensor objects are immutable. Mutable objects may be backed by a Tensor
47  which holds the unique handle that identifies the mutable object.
48  """
49
50  @property
51  def dtype(self):
52    pass
53
54  @property
55  def shape(self):
56    pass
57
58
59class Symbol(Tensor):
60  """Symbolic "graph" Tensor.
61
62  These objects represent the output of an op definition and do not carry a
63  value.
64  """
65  pass
66
67
68class Value(Tensor):
69  """Tensor that can be associated with a value (aka "eager tensor").
70
71  These objects represent the (usually future) output of executing an op
72  immediately.
73  """
74
75  def numpy(self):
76    pass
77
78
79@tf_export("types.experimental.Callable", v1=[])
80class Callable:
81  """Base class for TF callables like those created by tf.function.
82
83  Note: Callables are conceptually very similar to `tf.Operation`: a
84  `tf.Operation` is a kind of callable.
85  """
86
87  def __call__(self, *args, **kwargs):
88    """Executes this callable.
89
90    This behaves like a regular op - in eager mode, it immediately starts
91    execution, returning results. In graph mode, it creates ops which return
92    symbolic TensorFlow values (like `tf.Tensor`, `tf.data.Dataset`,
93    etc.). For example, `tf.function` callables typically generate a
94    `tf.raw_ops.PartitionedCall` op, but not always - the
95    exact operations being generated are an internal implementation detail.
96
97    Args:
98      *args: positional argument for this call
99      **kwargs: keyword arguments for this call
100    Returns:
101      The execution results.
102    """
103
104
105@tf_export("types.experimental.ConcreteFunction", v1=[])
106class ConcreteFunction(Callable):
107  """Base class for graph functions.
108
109  A `ConcreteFunction` encapsulates a single graph function definition and
110  is differentiable under `tf.GradientTape` contexts.
111  """
112
113
114# TODO(mdan): Name just `types.Function`, for historic continuity?
115@tf_export("types.experimental.GenericFunction", v1=[])
116class GenericFunction(Callable):
117  """Base class for polymorphic graph functions.
118
119  Graph functions are Python callable objects that dispatch calls to a
120  TensorFlow graph. Polymorphic graph functions can be backed by multiple TF
121  graphs, and automatically select the appropriate specialization based on the
122  type of input they were called with. They may also create specializations on
123  the fly if necessary, for example by tracing.
124
125  Also see `tf.function`.
126  """
127
128  def get_concrete_function(self, *args, **kwargs) -> ConcreteFunction:
129    """Returns a `ConcreteFunction` specialized to input types.
130
131    The arguments specified by `args` and `kwargs` follow normal function call
132    rules. The returned `ConcreteFunction` has the same set of positional and
133    keyword arguments as `self`, but their types are compatible to the types
134    specified by `args` and `kwargs` (though not neccessarily equal).
135
136    >>> @tf.function
137    ... def f(x):
138    ...   return x
139    >>> f_concrete = f.get_concrete_function(tf.constant(1.0))
140    >>> f_concrete = f.get_concrete_function(x=tf.constant(1.0))
141
142    Unlike normal calls, `get_concrete_function` allow type specifiers instead
143    of TensorFlow objects, so for example `tf.Tensor`s may be replaced with
144    `tf.TensorSpec`s.
145
146    >>> @tf.function
147    ... def f(x):
148    ...   return x
149    >>> f_concrete = f.get_concrete_function(tf.TensorSpec([], tf.float64))
150
151    If the function definition allows only one specialization, `args` and
152    `kwargs` may be omitted altogether.
153
154    >>> @tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])
155    ... def f(x):
156    ...   return x
157    >>> f_concrete = f.get_concrete_function()
158
159    The returned `ConcreteFunction` can be called normally:
160
161    >>> f_concrete(tf.constant(1.0))
162    <tf.Tensor: shape=(), dtype=float32, numpy=1.0>
163    >>> f_concrete(x=tf.constant(1.0))
164    <tf.Tensor: shape=(), dtype=float32, numpy=1.0>
165
166    Args:
167      *args: inputs to specialize on.
168      **kwargs: inputs to specialize on.
169
170    Returns:
171      A `ConcreteFunction`.
172    """
173    pass
174
175  def experimental_get_compiler_ir(self, *args, **kwargs):
176    """Returns compiler IR for the compiled function.
177
178    This API is intended *only* for debugging as there are no guarantees on
179    backwards compatibility of returned IR or the allowed values of `stage`.
180
181    Args:
182      *args: Arguments used for compilation; same arguments as used for calling
183        the function. Need to be eager tensors.
184      **kwargs: Keyword arguments used for compilation.
185
186    Returns:
187      Function callable with the following kwargs:
188        - `stage` at which the compiler IR should be serialized. Allowed values
189          are:
190           - `hlo`: HLO output after conversion from TF
191            (https://www.tensorflow.org/xla/operation_semantics).
192           - `hlo_serialized`: Like stage=`hlo`, but the output is a serialized
193             HLO module proto (a bytes object).
194           - `optimized_hlo`: HLO after compiler optimizations.
195           - `optimized_hlo_serialized`: Like stage=`optimized_hlo`, but the
196             output is a serialized HLO module proto (a bytes object).
197           - `optimized_hlo_dot`: optimized HLO in DOT format suitable for
198             Graphviz.
199        - `device_name` can be either None, in which case the preferred device
200          is used for compilation, or a device name. It can be a full device
201          name, or a partial one, e.g., `/device:CPU:0`.
202
203      For example, for
204
205      ```python
206      @tf.function(jit_compile=True)
207      def f(x):
208        return x + 1
209
210      f.experimental_get_compiler_ir(tf.random.normal([10, 10])(stage='hlo')
211      ```
212
213      the output is:
214
215      ```
216      HloModule a_inference_f_13__.9
217
218      ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] {
219        %arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false}
220        %reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1)
221        %constant.3 = f32[] constant(1)
222        %broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3)
223        %add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2,
224                                     f32[10,10]{1,0} %broadcast.4)
225        %reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5)
226        %tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6)
227        ROOT %get-tuple-element.8 = f32[10,10]{1,0}
228          get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0
229      }
230      ```
231
232    Raises:
233      ValueError: If an invalid `stage` is selected or if applied to a function
234        which is not compiled (`jit_compile=True` is not set).
235      TypeError: When called with input in graph mode.
236    """
237    pass
238
239
240@runtime_checkable
241class TensorProtocol(Protocol):
242  """Protocol type for objects that can be converted to Tensor."""
243
244  def __tf_tensor__(self, dtype=None, name=None):
245    """Converts this object to a Tensor.
246
247    Args:
248      dtype: data type for the returned Tensor
249      name: a name for the operations which create the Tensor
250    Returns:
251      A Tensor.
252    """
253    pass
254
255
256# TODO(rahulkamat): Add missing types that are convertible to Tensor.
257TensorLike = Union[Tensor, TensorProtocol, int, float, bool, str, bytes,
258                   complex, tuple, list, np.ndarray, np.generic]
259doc_typealias.document(
260    obj=TensorLike,
261    doc=textwrap.dedent("""\
262      Union of all types that can be converted to a `tf.Tensor` by `tf.convert_to_tensor`.
263
264      This definition may be used in user code. Additional types may be added
265      in the future as more input types are supported.
266
267      Example:
268
269      ```
270      def foo(x: TensorLike):
271        pass
272      ```
273
274      This definition passes static type verification for:
275
276      ```
277      foo(tf.constant([1, 2, 3]))
278      foo([1, 2, 3])
279      foo(np.array([1, 2, 3]))
280      ```
281      """),
282)
283tf_export("types.experimental.TensorLike").export_constant(
284    __name__, "TensorLike")
285