xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/def_function.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""API for defining graph functions with some additional eager semantics.
17
18def_function.function wraps the function concept in function.py ("defun") to
19allow initializing `tf.Variable`s with subgraphs of the function. For example:
20
21```python
22class M(tf.Module):
23  def __init__(self):
24    self.v_opinit = None
25    self.v_arginit = None
26
27  @tf.function
28  def __call__(self, x):
29    # Variables are only created on the first call to the function. This is a
30    # common pattern in layer libraries.
31    if self.v_opinit is None:
32      # self.v_opinit will outlive the function call, but `tf.ones` is traced as
33      # part of the function body before the `tf.Variable` object is
34      # created. This subgraph is easy to lift out of the function.
35      self.v_opinit = tf.Variable(tf.ones([]))
36
37      # If arguments feed into variable initialization, it can be very tricky to
38      # disentangle from the rest of the function. We don't attempt it.
39      self.v_arginit = tf.Variable(tf.ones(tf.shape(x)) * tf.constant(2.))
40    return self.v_opinit + self.v_arginit + x
41```
42
43These patterns with "defun" throw an error asking the user to put the variable's
44initializer in a lambda. With tf.function they work with eager semantics either
45by lifting the subgraph out of the function and using it to initialize the
46variable, or by initializing variables on the first call to the function (if
47they weren't already initialized by something else, e.g. a checkpoint API). The
48latter requires tf.conds, and is not well supported by TF-XLA, so we only do it
49when necessary.
50
51Since these patterns are relatively common in layer libraries, we expose the
52wrapper in this file as `tf.function`. The function concept in function.py is an
53internal implementation detail.
54
55In order to support these variable initialization patterns, tf.function defines
56a variable subtype (UnliftedInitializerVariable) which collects the input
57subgraph. This type of variable replaces the regular variable type on the first
58tf.function trace. To exclude initializers from the function body (the `tf.ones`
59ops above and associated assignment operations), tf.function traces a second
60time if it sees variables on the first call.
61"""
62
63import functools
64import os
65import threading
66import types as types_lib
67import weakref
68
69from google.protobuf import text_format as _text_format
70from google.protobuf.message import DecodeError
71from tensorflow.core.framework import attr_value_pb2
72from tensorflow.python.distribute.parallel_device import parallel_device
73from tensorflow.python.eager import context
74from tensorflow.python.eager import function as function_lib
75from tensorflow.python.eager import function_spec as function_spec_lib
76from tensorflow.python.eager import lift_to_graph
77from tensorflow.python.eager import monitoring
78from tensorflow.python.framework import composite_tensor
79from tensorflow.python.framework import errors
80from tensorflow.python.framework import func_graph as func_graph_module
81from tensorflow.python.framework import ops
82from tensorflow.python.ops import array_ops
83from tensorflow.python.ops import control_flow_ops
84from tensorflow.python.ops import control_flow_util
85from tensorflow.python.ops import math_ops
86from tensorflow.python.ops import random_ops
87from tensorflow.python.ops import resource_variable_ops
88from tensorflow.python.platform import tf_logging as logging
89from tensorflow.python.profiler import trace
90from tensorflow.python.trackable import base as trackable
91from tensorflow.python.types import core
92from tensorflow.python.util import deprecation
93from tensorflow.python.util import nest
94from tensorflow.python.util import object_identity
95from tensorflow.python.util import tf_decorator
96from tensorflow.python.util import traceback_utils
97from tensorflow.python.util.tf_export import tf_export
98
99FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
100FREQUENT_TRACING_WARNING_THRESHOLD = 5
101FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR = 2
102ALLOW_DYNAMIC_VARIABLE_CREATION = False
103
104_tf_function_counter = monitoring.Counter(
105    "/tensorflow/core/tf_function_counter",
106    "Counter for the number of tf.functions created when Eager execution is "
107    "enabled.",
108    # jit_compile is "0" or "1".
109    "jit_compile")
110
111
112class _FrequentTracingDetector(object):
113  """Class keeping track of how many recent calls triggered tracing."""
114
115  __slots__ = ["_calls_per_tracings", "_call_count", "_total_warning_count"]
116
117  def __init__(self):
118    self._calls_per_tracings = []
119    self._total_warning_count = 0
120    self._call_count = 0
121
122  def called_with_tracing(self, function_name, omit_warning):
123    """Updates the list of most recent calls' tracing information.
124
125    Warns the user when recent calls caused retracing too often.
126
127    Args:
128      function_name: the python function being traced.
129      omit_warning: If 'True', this call will not warn the user even if
130        retracing happens too often.
131    """
132    self._call_count += 1
133    self._calls_per_tracings.append(1)
134
135    while self._calls_per_tracings:
136      if (self._call_count - self._calls_per_tracings[0] >
137          FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY):
138        self._call_count -= self._calls_per_tracings.pop(0)
139      else:
140        break
141
142    if (omit_warning or self._total_warning_count >=
143        FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR):
144      return
145    if len(self._calls_per_tracings) >= FREQUENT_TRACING_WARNING_THRESHOLD:
146      self._total_warning_count += 1
147      logging.warning(
148          "{} out of the last {} calls to {} triggered tf.function "
149          "retracing. Tracing is expensive and the excessive number of "
150          "tracings could be due to (1) creating @tf.function repeatedly in "
151          "a loop, (2) passing tensors with different shapes, (3) passing "
152          "Python objects instead of tensors. For (1), please define your "
153          "@tf.function outside of the loop. For (2), @tf.function has "
154          "reduce_retracing=True option that can avoid unnecessary "
155          "retracing. For (3), please refer to "
156          "https://www.tensorflow.org/guide/function#controlling_retracing"
157          " and https://www.tensorflow.org/api_docs/python/tf/function for "
158          " more details.".format(
159              len(self._calls_per_tracings), self._call_count, function_name))
160
161  def called_without_tracing(self):
162    # We don't count tracing when users load a concrete function directly or
163    # call get_concrete_function, so the first call can be not a tracing call.
164    if not self._calls_per_tracings:
165      self._calls_per_tracings = [0]
166    self._calls_per_tracings[-1] += 1
167    self._call_count += 1
168
169
170class _FrequentTracingDetectorManager(object):
171  """Class for the management of all _FrequentTracingDetector objects."""
172
173  __slots__ = ["_detectors", "_lock"]
174
175  def __init__(self):
176    self._detectors = weakref.WeakKeyDictionary()  # GUARDED_BY(self._lock)
177    self._lock = threading.Lock()
178
179  def _get_detector(self, key):
180    if key not in self._detectors:
181      self._detectors[key] = _FrequentTracingDetector()
182    return self._detectors[key]
183
184  def called_without_tracing(self, key):
185    with self._lock:
186      detector = self._get_detector(key)
187      detector.called_without_tracing()
188
189  def called_with_tracing(self, key, function_name, omit_warning):
190    with self._lock:
191      detector = self._get_detector(key)
192      detector.called_with_tracing(function_name, omit_warning)
193
194
195_frequent_tracing_detector_manager = _FrequentTracingDetectorManager()
196
197
198class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
199  """Variable which does not lift its initializer out of function context.
200
201  Instances of this variable, when created, build a graph which runs their
202  initializer inside a tf.cond(is_initialized) block.
203
204  This can only be created inside a defun called from (eventually) eager
205  mode. That is, non-function-building graphs are not supported.
206  """
207
208  def __init__(self,
209               initial_value=None,
210               trainable=None,
211               caching_device=None,
212               name=None,
213               dtype=None,
214               constraint=None,
215               add_initializers_to=None,
216               lifted_initializer_graph=None,
217               synchronization=None,
218               aggregation=None,
219               shape=None,
220               **unused_kwargs):
221    """Creates a variable.
222
223    Args:
224      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
225        which is the initial value for the Variable. The initial value must have
226        a shape specified unless `validate_shape` is set to False. Can also be a
227        callable with no argument that returns the initial value when called.
228        (Note that initializer functions from init_ops.py must first be bound
229         to a shape before being used here.)
230      trainable: If `True`, GradientTapes automatically watch uses of this
231        Variable.
232      caching_device: Optional device string or function describing where the
233        Variable should be cached for reading.  Defaults to the Variable's
234        device.  If not `None`, caches on another device.  Typical use is to
235        cache on the device where the Ops using the Variable reside, to
236        deduplicate copying through `Switch` and other conditional statements.
237      name: Optional name for the variable. Defaults to `'Variable'` and gets
238        uniquified automatically.
239      dtype: If set, initial_value will be converted to the given type.
240        If None, either the datatype will be kept (if initial_value is
241       a Tensor) or float32 will be used (if it is a Python object convertible
242       to a Tensor).
243      constraint: An optional projection function to be applied to the variable
244        after being updated by an `Optimizer` (e.g. used to implement norm
245        constraints or value constraints for layer weights). The function must
246        take as input the unprojected Tensor representing the value of the
247        variable and return the Tensor for the projected value
248        (which must have the same shape). Constraints are not safe to
249        use when doing asynchronous distributed training.
250      add_initializers_to: if not None and not in legacy graph mode, the
251        initializer tensor will be added to this map in addition to adding the
252        assignment to the function.
253      lifted_initializer_graph: FuncGraph to try to lift initializers to.
254      synchronization: Indicates when a distributed variable will be
255        aggregated. Accepted values are constants defined in the class
256        `tf.VariableSynchronization`. By default the synchronization is set to
257        `AUTO` and the current `DistributionStrategy` chooses
258        when to synchronize.
259      aggregation: Indicates how a distributed variable will be aggregated.
260        Accepted values are constants defined in the class
261        `tf.VariableAggregation`.
262      shape: (optional) The shape of this variable. If None, the shape of
263        `initial_value` will be used. When setting this argument to
264        `tf.TensorShape(None)` (representing an unspecified shape), the variable
265        can be assigned with values of different shapes.
266
267    Raises:
268      ValueError: If the initial value is not specified, or does not have a
269        shape and `validate_shape` is `True`.
270      RuntimeError: If called outside of a function definition.
271    """
272    with ops.init_scope():
273      self._in_graph_mode = not context.executing_eagerly()
274    if not ops.inside_function():
275      # If we've been init_scope()d out of the function definition nothing to do
276      # here; we can't really do the capturing or conditional logic.
277      resource_variable_ops.ResourceVariable.__init__(
278          self, initial_value=initial_value, trainable=trainable,
279          caching_device=caching_device, name=name, dtype=dtype,
280          constraint=constraint)
281      return
282    if initial_value is None:
283      raise ValueError("`initial_value` must be a Tensor or a Python "
284                       "object convertible to a Tensor. Got None.")
285    init_from_fn = callable(initial_value)
286
287    if constraint is not None and not callable(constraint):
288      raise ValueError(f"`constraint` with type {type(constraint)} must be a "
289                       "callable.")
290
291    with ops.name_scope(name, "Variable", []
292                        if init_from_fn else [initial_value]) as scope_name:
293      with ops.name_scope("Initializer"):
294        if init_from_fn:
295          initial_value = initial_value()
296        if isinstance(initial_value, trackable.CheckpointInitialValue):
297          self._maybe_initialize_trackable()
298          self._update_uid = initial_value.checkpoint_position.restore_uid
299          initial_value = initial_value.wrapped_value
300
301        initial_value = ops.convert_to_tensor(initial_value,
302                                              name="initial_value", dtype=dtype)
303      assert initial_value is not None
304
305      # Don't use `shape or initial_value.shape` since TensorShape has
306      # overridden `__bool__`.
307      if shape is None:
308        shape = initial_value.shape
309
310    # Use the constructor for UninitializedVariable to start. Outside the name
311    # scope so we don't double up the prefix.
312    super().__init__(
313        trainable=trainable,
314        caching_device=caching_device,
315        name=name,
316        shape=shape,
317        dtype=initial_value.dtype,
318        constraint=constraint,
319        synchronization=synchronization,
320        aggregation=aggregation,
321        extra_handle_data=initial_value,
322        **unused_kwargs)
323
324    with ops.name_scope(scope_name):
325      if self._in_graph_mode:
326        with ops.init_scope():
327          outer_graph = ops.get_default_graph()
328        func_graph = ops.get_default_graph()
329        function_placeholders = (
330            func_graph.inputs + func_graph.internal_captures)
331        placeholder_ops = set(
332            [tensor.op for tensor in function_placeholders])
333        lifted_initializer = lift_to_graph.lift_to_graph(
334            [initial_value], outer_graph,
335            disallowed_placeholders=placeholder_ops)[initial_value]
336        with ops.init_scope():
337          self._initial_value = lifted_initializer
338          with ops.name_scope("IsInitialized"):
339            self._is_initialized_op = (
340                resource_variable_ops.var_is_initialized_op(self._handle))
341          if initial_value is not None:
342            with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
343              self._initializer_op = resource_variable_ops.assign_variable_op(
344                  self._handle, lifted_initializer, name=n)
345      elif context.executing_eagerly():
346        # In this case, both current scope and init scope are eager.
347        # Assign_variable_op will be executed immediately. So we don't need to
348        # add it to "add_initializers_to" to lift it out.
349        with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
350          resource_variable_ops.assign_variable_op(
351              self._handle, initial_value, name=n)
352      else:
353        # Init scope is eager but current scope is graph. We will lift out this
354        # variable by addint it into "add_initializers_to".
355        if add_initializers_to is not None:
356          add_initializers_to.append((self, initial_value))
357
358        def assign_fn():
359          with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
360            resource_variable_ops.assign_variable_op(
361                self._handle,
362                initial_value,
363                name=n)
364            # Returning values to keep tf.cond happy.
365          return ops.convert_to_tensor(1)
366        def not_assign_fn():
367          return ops.convert_to_tensor(0)
368        # Note: this cond is always guaranteed to run because we're inside a
369        # defun which will insert automatic control dependencies. It will only
370        # execute assign_fn if lifting failed.
371        graph = ops.get_default_graph()
372
373        # Capture the handle ahead of time in order to avoid querying the shape
374        # of the handle which helps async execution performance
375        graph.capture(self._handle, shape=())
376        control_flow_ops.cond(
377            resource_variable_ops.var_is_initialized_op(self._handle),
378            not_assign_fn, assign_fn)
379
380
381JIT_COMPILE_FUNCTIONS = (
382    os.getenv("TF_FUNCTION_JIT_COMPILE_DEFAULT", "false").lower()
383    in ("true", "1"))
384
385RUN_FUNCTIONS_EAGERLY = False
386
387
388@deprecation.deprecated(
389    None,
390    "Use `tf.config.run_functions_eagerly` instead of the experimental "
391    "version.")
392@tf_export("config.experimental_run_functions_eagerly")
393def experimental_run_functions_eagerly(run_eagerly):
394  """Enables / disables eager execution of `tf.function`s.
395
396  Calling `tf.config.experimental_run_functions_eagerly(True)` will make all
397  invocations of `tf.function` run eagerly instead of running as a traced graph
398  function.
399
400  See `tf.config.run_functions_eagerly` for an example.
401
402  Note: This flag has no effect on functions passed into tf.data transformations
403  as arguments. tf.data functions are never executed eagerly and are always
404  executed as a compiled Tensorflow Graph.
405
406  Args:
407    run_eagerly: Boolean. Whether to run functions eagerly.
408  """
409  return run_functions_eagerly(run_eagerly)
410
411
412@tf_export("config.run_functions_eagerly")
413def run_functions_eagerly(run_eagerly):
414  """Enables / disables eager execution of `tf.function`s.
415
416  Calling `tf.config.run_functions_eagerly(True)` will make all
417  invocations of `tf.function` run eagerly instead of running as a traced graph
418  function.
419
420  This can be useful for debugging.
421
422  >>> def my_func(a):
423  ...  print("Python side effect")
424  ...  return a + a
425  >>> a_fn = tf.function(my_func)
426
427  >>> # A side effect the first time the function is traced
428  >>> a_fn(tf.constant(1))
429  Python side effect
430  <tf.Tensor: shape=(), dtype=int32, numpy=2>
431
432  >>> # No further side effect, as the traced function is called
433  >>> a_fn(tf.constant(2))
434  <tf.Tensor: shape=(), dtype=int32, numpy=4>
435
436  >>> # Now, switch to eager running
437  >>> tf.config.run_functions_eagerly(True)
438  >>> # Side effect, as the function is called directly
439  >>> a_fn(tf.constant(2))
440  Python side effect
441  <tf.Tensor: shape=(), dtype=int32, numpy=4>
442
443  >>> # Turn this back off
444  >>> tf.config.run_functions_eagerly(False)
445
446  Note: This flag has no effect on functions passed into tf.data transformations
447  as arguments. tf.data functions are never executed eagerly and are always
448  executed as a compiled Tensorflow Graph.
449
450  Args:
451    run_eagerly: Boolean. Whether to run functions eagerly.
452  """
453  global RUN_FUNCTIONS_EAGERLY
454  RUN_FUNCTIONS_EAGERLY = bool(run_eagerly)
455
456
457@deprecation.deprecated(
458    None,
459    "Use tf.config.functions_run_eagerly instead of the experimental version.")
460@tf_export("config.experimental_functions_run_eagerly")
461def experimental_functions_run_eagerly():
462  """Returns the value of the `experimental_run_functions_eagerly` setting."""
463  return functions_run_eagerly()
464
465
466@tf_export("config.functions_run_eagerly")
467def functions_run_eagerly():
468  """Returns the value of the `run_functions_eagerly` setting."""
469  return RUN_FUNCTIONS_EAGERLY
470
471
472def _evaluate_var_is_initialized(variables):
473  """Compute booleans indicating whether each variable is initialized."""
474  with ops.init_scope():
475    var_is_initialized = []
476    for v in variables:
477      var_is_initialized.append(
478          resource_variable_ops.var_is_initialized_op(v.handle))
479    try:
480      # Stack all the var_is_initialized values into one tensor and interpret
481      # the numpy value. This will reduce the number of RPCs between client and
482      # worker in the remote case.
483      return array_ops.stack(var_is_initialized).numpy()
484    except errors.UnimplementedError:
485      # Some devices do not support implicit copy-off to host. Fall back to
486      # variable-by-variable processing.
487      for index, v in enumerate(variables):
488        try:
489          numpy_value = var_is_initialized[index].numpy()
490        except errors.UnimplementedError:
491          # This is a variable on a parallel device; we'll extract its value on
492          # each replica and assert that they're identical.
493          components = parallel_device.unpack(var_is_initialized[index])
494          with ops.device(None):
495            components = array_ops.stack(components)
496            all_initialized = math_ops.reduce_all(components).numpy()
497            any_initialized = math_ops.reduce_any(components).numpy()
498          if all_initialized != any_initialized:
499            raise NotImplementedError(
500                f"Some but not all components of a parallel variable {v!r} "
501                "were initialized between their creation in a tf.function and "
502                "the function's trace having completed. This is not "
503                "supported; consider initializing either all or none of the "
504                "components, or moving initialization out of the function.")
505          numpy_value = all_initialized
506        var_is_initialized[index] = numpy_value
507  return var_is_initialized
508
509
510class FunctionDeleter:
511  """An object responsible for cleaning up the function graph."""
512
513  __slots__ = ["func_graph"]
514
515  def __init__(self, func_graph):
516    self.func_graph = func_graph
517
518  def __del__(self):
519    try:
520      func_graph_module.dismantle_func_graph(self.func_graph)
521    except:  # pylint: disable=bare-except
522      # Note: bare except here because this can be noisy at shutdown time.
523      pass
524
525
526class OptionalXlaContext:
527  """Wrapper for XLA context optionally applied under a context manager."""
528
529  def __init__(self, is_compiled):
530    wrap = is_compiled and not control_flow_util.GraphOrParentsInXlaContext( \
531              ops.get_default_graph())
532    self.xla_context = control_flow_ops.XLAControlFlowContext() \
533        if wrap else None
534
535  def __enter__(self):
536    if self.xla_context:
537      self.xla_context.Enter()
538
539  def __exit__(self, t, value, traceback):
540    if self.xla_context:
541      self.xla_context.Exit()
542
543
544# TODO(mdan): Consider expose this type for instance type checking.
545@tf_export("__internal__.function.Function", v1=[])
546class Function(core.GenericFunction, trackable.Trackable):
547  """A `tf.types.experimental.GenericFunction` created by `tf.function`.
548
549  Currently, individual methods/attributes under this class are not guaranteed
550  by the TF API contract, and are subject to future changes.
551  """
552
553  def __init__(self,
554               python_function,
555               name,
556               input_signature=None,
557               autograph=True,
558               jit_compile=None,
559               reduce_retracing=False,
560               experimental_implements=None,
561               experimental_autograph_options=None,
562               experimental_follow_type_hints=None):
563    """Initializes a `Function`.
564
565    Args:
566      python_function: the function to be wrapped.
567      name: the name given to it.
568      input_signature: See the documentation for `tf.function`.
569      autograph: See the documentation for `tf.function`.
570      jit_compile: See the documentation for `tf.function`.
571      reduce_retracing: See the documentation for `tf.function`.
572      experimental_implements: See the documentation for `tf.function`.
573      experimental_autograph_options: See the documentation for `tf.function`.
574      experimental_follow_type_hints: See the documentation for `tf.function`.
575
576    Raises:
577      ValueError: if `input_signature` is not None and the `python_function`'s
578        argspec has keyword arguments.
579    """
580    self._lock = threading.RLock()
581    self._python_function = python_function
582    self._function_spec = function_spec_lib.FunctionSpec.from_function_and_signature(
583        python_function,
584        input_signature,
585        jit_compile=jit_compile,
586        experimental_follow_type_hints=experimental_follow_type_hints,
587    )
588    self._implements = experimental_implements
589    # If `True`, the function uses the rendezvous of the parent. This is only
590    # needed to support code where raw send/recv operations are inserted and
591    # when functions are run in graph mode where they may not be inlined.
592    self._shared_rendezvous = None
593    self._autograph = autograph
594    self._experimental_autograph_options = experimental_autograph_options
595    self._reduce_retracing = reduce_retracing
596    self._jit_compile = jit_compile
597    if experimental_follow_type_hints is None:
598      experimental_follow_type_hints = False
599    self._experimental_follow_type_hints = experimental_follow_type_hints
600    self._created_variables = None  # GUARDED_BY(self._lock)
601    self._stateful_fn = None  # GUARDED_BY(self._lock)
602    self._stateless_fn = None  # GUARDED_BY(self._lock)
603    self._descriptor_cache = weakref.WeakKeyDictionary()
604    self._name = name
605    self._key_for_call_stats = self._get_key_for_call_stats()
606    self._omit_frequent_tracing_warning = False
607    ops._tf_function_api_guage.get_cell().set(True)  # pylint: disable=protected-access
608
609  @property
610  def name(self):
611    return self._name
612
613  def __getstate__(self):
614    """Custom pickling, to omit unpickleable objects."""
615    result = self.__dict__.copy()
616    del result["_lock"]
617    del result["_descriptor_cache"]
618    del result["_key_for_call_stats"]
619    return result
620
621  def __setstate__(self, state):
622    """Restore from pickled state."""
623    self.__dict__ = state
624    self._lock = threading.RLock()
625    self._descriptor_cache = weakref.WeakKeyDictionary()
626    self._key_for_call_stats = self._get_key_for_call_stats()
627
628  def _get_key_for_call_stats(self):
629    """Returns key instance to track call stats and retracings.
630
631    The key instance a best-effort to preserve global consistency.
632    """
633    target_function = self._python_function
634    # `__wrapped__` is a conventional Python attribute that a higher-order
635    # function keeps its original function's instance.  We also directly use
636    # this attribute for dealing with a class method.  See
637    # `bound_method_wrapper` in `function.py`.  If we don't use `__wrapped__`,
638    # all class methods will return the same `bound_method_wrapper` instance
639    # from this function.
640    while hasattr(target_function, "__wrapped__"):
641      target_function = target_function.__wrapped__
642
643    if hasattr(target_function, "__func__"):
644      target_function = target_function.__func__
645
646    if hasattr(target_function, "__code__"):
647      return target_function.__code__
648
649    return self._python_function
650
651  def _defun_with_scope(self, scope):
652    """Creates a defun wrapped inside a variable creator scope."""
653
654    weak_wrapped_fn = None
655    compile_with_xla = self._jit_compile
656
657    def wrapped_fn(*args, **kwds):
658      """Wraps `self._python_function` in a variable creator scope."""
659      # We register a variable creator with reduced priority. If an outer
660      # variable creator is just modifying keyword arguments to the variable
661      # constructor, this will work harmoniously. Since the `scope` registered
662      # here actually creates the variable, it taking priority would otherwise
663      # ignore the outer creator.
664      #
665      # If an outer variable creator calls the variable constructor manually,
666      # for example creating a MirroredVariable, then they won't call our
667      # creator. This means we won't be able to trace the initialization graph,
668      # and so variable initializers can't depend on function arguments. This is
669      # better than the alternative, tracing the initialization graph but giving
670      # the user a variable type they didn't want.
671      default_graph = ops.get_default_graph()
672      with default_graph._variable_creator_scope(scope, priority=50):  # pylint: disable=protected-access
673        # __wrapped__ allows AutoGraph to swap in a converted function. We give
674        # the function a weak reference to itself to avoid a reference cycle.
675        with OptionalXlaContext(compile_with_xla):
676          out = weak_wrapped_fn().__wrapped__(*args, **kwds)
677        return out
678
679    weak_wrapped_fn = weakref.ref(wrapped_fn)
680
681    return self._defun(tf_decorator.make_decorator(
682        self._python_function,
683        wrapped_fn))
684
685  def _create_implements_attribute(self):
686    """Creates the attribute value corresponding to IMPLEMENTS_ATTRIBUTE_NAME."""
687    attributes = {}
688    if isinstance(self._implements, str):
689      # First check if the IMPLEMENTS_ATTRIBUTE_NAME is specified as a
690      # NameAttrList. This is used when apart from the function name being
691      # implemented, a list of attributes is also being specified.
692      # The attributes are specified as key-value pairs in the NameAttrList
693      # of the corresponding AttrValue. The function name will be in the
694      # 'name' field of the NameAttrList. Else, it is just a string
695      # corresponding to the function name.
696      try:
697        attr_value = attr_value_pb2.AttrValue()
698        nameattrlist = attr_value_pb2.NameAttrList()
699        _text_format.Merge(self._implements, nameattrlist)
700        attr_value.func.CopyFrom(nameattrlist)
701        attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = attr_value
702      except (_text_format.ParseError, DecodeError):
703        attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements
704    return attributes
705
706  def _defun(self, fn):
707    """Returns a defun generated from the input function."""
708    attributes = {}
709
710    if self._implements is not None:
711      attributes = self._create_implements_attribute()
712
713    share = self._shared_rendezvous
714    if share is not None:
715      attributes[function_lib.SHARED_RENDEZVOUS_ATTRIBUTE_NAME] = share
716
717    if self._jit_compile is not None:
718      attributes.update(_XlaMustCompile=bool(self._jit_compile))
719      if self._jit_compile:
720        attributes.update(_noinline=True)
721    if not attributes:
722      attributes = None
723    return function_lib.defun_with_attributes(
724        fn,
725        input_signature=self.input_signature,
726        attributes=attributes,
727        autograph=self._autograph,
728        jit_compile=self._jit_compile,
729        reduce_retracing=self._reduce_retracing,
730        experimental_autograph_options=self._experimental_autograph_options,
731        experimental_follow_type_hints=self._experimental_follow_type_hints)
732
733  def _initialize(self, args, kwds, add_initializers_to=None):
734    """Initializes, on the first call.
735
736    Creates two `Function`s, one that will allow creation of variables
737    and one that won't.
738
739    Additionally runs a trace for the `Function` that allows creation
740    of variables.
741
742    Args:
743      args: Arguments to the underlying python callable.
744      kwds: Keyword arguments to the python callable.
745      add_initializers_to: Where to collect variable initializers, if not None.
746    """
747    self.function_spec.validate_input_signature_with_argspec()
748
749    created_variables = []
750    lifted_initializer_graph = func_graph_module.FuncGraph("initializer")
751
752    def variable_capturing_scope(unused_next_creator, **kwds):
753      """Creates UnliftedInitializerVariables and saves references to them."""
754      v = UnliftedInitializerVariable(
755          add_initializers_to=add_initializers_to,
756          lifted_initializer_graph=lifted_initializer_graph, **kwds)
757      created_variables.append(weakref.ref(v))
758      return v
759
760    self._created_variables = created_variables
761    self._stateful_fn = self._defun_with_scope(variable_capturing_scope)
762    self._stateful_fn._name = self._name  # pylint: disable=protected-access
763    # Force the definition of the function for these arguments
764    self._lifted_initializer_graph = lifted_initializer_graph
765    self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
766    self._concrete_stateful_fn = (
767        self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
768            *args, **kwds))
769
770    def invalid_creator_scope(*unused_args, **unused_kwds):
771      """Disables variable creation."""
772      raise ValueError(
773          "tf.function only supports singleton tf.Variables created on the "
774          "first call. Make sure the tf.Variable is only created once or "
775          "created outside tf.function. See "
776          "https://www.tensorflow.org/guide/function#creating_tfvariables "
777          "for more information.")
778
779    self._stateless_fn = self._defun_with_scope(invalid_creator_scope)
780    self._stateless_fn._name = self._name  # pylint: disable=protected-access
781
782  def _clone(self, python_function):
783    """Clone the function with different python function."""
784    f = Function(
785        python_function=(self._python_function
786                         if python_function is None else python_function),
787        name=self._name,
788        input_signature=self.input_signature,
789        autograph=self._autograph,
790        jit_compile=self._jit_compile,
791        reduce_retracing=self._reduce_retracing,
792        experimental_implements=self._implements,
793        experimental_autograph_options=self._experimental_autograph_options,
794        experimental_follow_type_hints=self._experimental_follow_type_hints)
795
796    if self._shared_rendezvous:
797      f._shared_rendezvous = self._shared_rendezvous  # pylint: disable=protected-access
798
799    return f
800
801  def _decorate(self, decorator):
802    """Allows the captured Python function to be decorated in place.
803
804    This method is only safe to call when the Function has not been called by a
805    user. It makes sense to use this method to push a decorator into the
806    function rather than wrapping the function in the decorator.
807
808    We use this in tf.Module to allow user annotated `tf.functions` to remain as
809    `Function` objects but still automatically enter the Module name_scope
810    when they are evaluated like all other methods.
811
812    Args:
813      decorator: A callable accepting a single argument which is the function
814        to decorate and returning a callable result.
815
816    Raises:
817      ValueError: If the function has been called a ValueError is raised.
818    """
819    if self._stateful_fn is not None or self._stateless_fn is not None:
820      raise ValueError(
821          "Functions cannot be decorated after they have been traced.")
822
823    self._python_function = decorator(self._python_function)
824    self._function_spec = function_spec_lib.FunctionSpec.from_function_and_signature(
825        self._python_function, self.input_signature)
826
827  # TODO: Remove this private method after updating all its uses
828  # A good moment to do this could be when the experimental label is removed
829  def _get_tracing_count(self):
830    return self.experimental_get_tracing_count()
831
832  def experimental_get_tracing_count(self):
833    """Returns the number of times the function has been traced.
834
835    For more information on when a function is traced and when it is
836    traced multiple times see https://www.tensorflow.org/guide/function.
837    Example:
838
839    >>> @tf.function
840    ... def double(a):
841    ...   return a + a
842    >>> double(tf.constant(1))
843    >>> double(tf.constant(2))
844    >>> double.experimental_get_tracing_count()
845    1
846    >>> double(tf.constant("a"))
847    >>> double.experimental_get_tracing_count()
848    2
849
850
851    The first time experimental_get_tracing_count is called
852    it returns 1, as the function is traced the first
853    time it is called, and the second time the same graph is used
854    since we're calling it with a parameter of the same type.
855
856    The second time experimental_get_tracing_count is called
857    it returns 2, as we called double with a
858    different argument type, and so it was traced again.
859
860    """
861    result = self._stateless_fn.tracing_count if self._stateless_fn else 0
862    result += self._stateful_fn.tracing_count if self._stateful_fn else 0
863    return result
864
865  @property
866  def _run_functions_eagerly(self):
867    return RUN_FUNCTIONS_EAGERLY
868
869  @traceback_utils.filter_traceback
870  def __call__(self, *args, **kwds):
871    # Implements GenericFunction.__call__.
872    if self._run_functions_eagerly:
873      with trace.Trace(self._name, tf_function_call="eager"):
874        return self._python_function(*args, **kwds)
875
876    # Only count the statistics the first time, before initialization took
877    # place.
878    if self._created_variables is None:
879      compiled = bool(self._jit_compile and
880                      not control_flow_util.GraphOrParentsInXlaContext(
881                          ops.get_default_graph()))
882      # For nested functions, increment the counter only when a function with
883      # jit_compile=True is called within a function with jit_compile=False. We
884      # count this special case to correctly record that both jit_compile=True
885      # and jit_compile=False is being used for parts of the outer function.
886      if ops.executing_eagerly_outside_functions() and (
887          context.executing_eagerly() or compiled):
888        # Labels must be strings in Python, so we convert 'compiled' to a string
889        _tf_function_counter.get_cell(str(int(compiled))).increase_by(1)
890
891    tracing_count = self.experimental_get_tracing_count()
892    with trace.Trace(self._name) as tm:
893      # TODO(cheshire): Do not duplicate the XLAControlFlowContext annotation.
894      compiler = "xla" if self._jit_compile else "nonXla"
895
896      with OptionalXlaContext(self._jit_compile):
897        result = self._call(*args, **kwds)
898
899      new_tracing_count = self.experimental_get_tracing_count()
900      without_tracing = (tracing_count == new_tracing_count)
901      execution_mode = "notTraced" if without_tracing else "traced"
902      tm.set_metadata(tf_function_call=execution_mode + "-" + compiler,
903                      tracing_count=new_tracing_count)
904
905    if context.executing_eagerly():
906      if without_tracing:
907        _frequent_tracing_detector_manager.called_without_tracing(
908            self._key_for_call_stats)
909      else:
910        _frequent_tracing_detector_manager.called_with_tracing(
911            self._key_for_call_stats, self._python_function,
912            self._omit_frequent_tracing_warning)
913
914    return result
915
916  def _call(self, *args, **kwds):
917    """Calls the graph function."""
918    self._lock.acquire()
919    if ALLOW_DYNAMIC_VARIABLE_CREATION:
920      condition = self._created_variables and self._stateful_fn is None
921    else:
922      condition = self._created_variables
923    if condition:
924      # Release the lock early so that multiple threads can perform the call
925      # in parallel.
926      self._lock.release()
927      # In this case we have created variables on the first call, so we run the
928      # defunned version which is guaranteed to never create variables.
929      return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
930    elif self._stateful_fn is not None:
931      # Release the lock early so that multiple threads can perform the call
932      # in parallel.
933      self._lock.release()
934      # In this case we have not created variables on the first call. So we can
935      # run the first trace but we should fail if variables are created.
936      results = self._stateful_fn(*args, **kwds)
937      if self._created_variables and not ALLOW_DYNAMIC_VARIABLE_CREATION:
938        raise ValueError("Creating variables on a non-first call to a function"
939                         " decorated with tf.function.")
940      return results
941
942    try:
943      # This is the first call of __call__, so we have to initialize.
944      initializers = []
945      self._initialize(args, kwds, add_initializers_to=initializers)
946    finally:
947      # At this point we know that the initialization is complete (or less
948      # interestingly an exception was raised) so we no longer need a lock.
949      self._lock.release()
950
951    if self._created_variables:
952      try:
953        # Attempt to initialize variables eagerly and without conds by lifting
954        # out initialization graphs. This is the only initialization strategy
955        # compatible with XLA at the moment.
956        self._initialize_uninitialized_variables(initializers)
957      except lift_to_graph.UnliftableError:
958        pass  # Fall through to cond-based initialization.
959      else:
960        # Lifting succeeded, so variables are initialized and we can run the
961        # stateless function.
962        return self._stateless_fn(*args, **kwds)
963    else:
964      _, _, filtered_flat_args = (
965          self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
966              args, kwds))
967      # If we did not create any variables the trace we have is good enough.
968      return self._concrete_stateful_fn._call_flat(
969          filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
970
971    def fn_with_cond(inner_args, inner_kwds, inner_filtered_flat_args):
972      """Conditionally runs initialization if it's needed."""
973      condition = True
974      for v, _ in initializers:
975        condition = math_ops.logical_and(
976            condition, resource_variable_ops.var_is_initialized_op(
977                v.handle))
978      # We want to call stateless_fn if possible because it avoids recomputing
979      # potentially expensive initializers.
980      return control_flow_ops.cond(
981          condition,
982          lambda: self._stateless_fn(*inner_args, **inner_kwds),
983          functools.partial(
984              self._concrete_stateful_fn._call_flat,  # pylint: disable=protected-access
985              inner_filtered_flat_args,
986              captured_inputs=self._concrete_stateful_fn.captured_inputs))
987
988    # We've created variables and are unable to lift the initialization graphs,
989    # so we fall back to initializing with conds while running the function.
990    canon_args, canon_kwds, filtered_flat_args = (
991        self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
992            args, kwds))
993    return function_lib.defun(fn_with_cond)(canon_args, canon_kwds,
994                                            filtered_flat_args)
995
996  def experimental_get_compiler_ir(self, *args, **kwargs):
997    # Implements GenericFunction.experimental_get_compiler_ir
998    context.ensure_initialized()
999    if not self._jit_compile:
1000      raise ValueError("Compiler IR can only be returned for functions marked "
1001                       "with 'jit_compile=True'")
1002
1003    concrete_fn = self.get_concrete_function(*args, **kwargs)
1004    fn_name = concrete_fn.name
1005
1006    # pylint: disable=protected-access
1007    _, _, filtered_flat_args = (
1008        concrete_fn._function_spec.canonicalize_function_inputs(args, kwargs))
1009
1010    def compiler_ir_generator(stage="hlo", device_name=None):
1011      # TODO(cheshire): This is a hack to get the current "preferred" device,
1012      # there is no current API to get it otherwise.
1013      if device_name is None:
1014        device_name = random_ops.random_normal([]).device
1015      res_bytes = context.context().get_compiler_ir(
1016          device_name=device_name,
1017          stage=stage,
1018          function_name=fn_name,
1019          args=list(filtered_flat_args) + concrete_fn.captured_inputs)
1020      if stage in ("hlo_serialized", "optimized_hlo_serialized",
1021                   "optimized_hlo_proto_serialized"):
1022        return res_bytes
1023      else:
1024        return res_bytes.decode("utf-8")
1025
1026    return compiler_ir_generator
1027
1028  @property
1029  def python_function(self):
1030    """The python function wrapped in this tf.function."""
1031    return self._python_function
1032
1033  @property
1034  def input_signature(self):
1035    return self._function_spec.input_signature
1036
1037  @property
1038  def function_spec(self):
1039    return self._function_spec
1040
1041  def pretty_printed_concrete_signatures(self, verbose=True):
1042    joiner = "\n\n" if verbose else "\n"
1043    return joiner.join([
1044        c.pretty_printed_signature(verbose=verbose)
1045        for c in self._list_all_concrete_functions()
1046    ])
1047
1048  def _initialize_uninitialized_variables(self, initializers):
1049    """Make and call a `ConcreteFunction` which initializes variables."""
1050
1051    if not initializers:
1052      return
1053
1054    var_is_initialized = _evaluate_var_is_initialized(
1055        [v for v, _ in initializers])
1056
1057    # Note: using defun here avoids an infinite recursion.
1058    # Most of the code in this function runs eagerly with init_scope, where
1059    # autograph is not necessary.
1060    @function_lib.defun(autograph=False)
1061    def initialize_variables():
1062      op_map = object_identity.ObjectIdentityDictionary()
1063
1064      inits = []
1065      for (v, init), is_initialized in zip(initializers, var_is_initialized):
1066        with ops.init_scope():
1067          if is_initialized:
1068            continue
1069        inits.append(init)
1070
1071      if inits:
1072        op_map = lift_to_graph.lift_to_graph(
1073            inits, ops.get_default_graph(), op_map=op_map)
1074      for (v, init), is_initialized in zip(initializers, var_is_initialized):
1075        with ops.init_scope():
1076          if is_initialized:
1077            continue
1078        v.assign(op_map[init], read_value=False)
1079
1080    with ops.init_scope():
1081      return initialize_variables.get_concrete_function()()
1082
1083  def get_initialization_function(self, *args, **kwargs):
1084    """Returns a `ConcreteFunction` which initializes this function's variables.
1085
1086    Requires that this function hasn't been accessed yet through either calling
1087    it or calling get_concrete_function. Fails if we cannot build an initializer
1088    function which does not depend on the concrete values of the inputs to this
1089    function.
1090
1091    Note that running this function will overwrite any values currently assigned
1092    to variables, for example restores from a checkpoint.
1093
1094    Args:
1095      *args: arguments to the underlying python callable.
1096      **kwargs: keyword arguments to the python callable.
1097
1098    Returns:
1099      A `ConcreteFunction` object which initializes the variables of this
1100      function.
1101
1102    Raises:
1103      RuntimeError: if called after the variables have been initialized.
1104    """
1105    with self._lock:
1106      if self._stateful_fn is not None:
1107        raise RuntimeError(
1108            "get_initialization_function cannot be called after the function "
1109            "has been used")
1110      # Here we trace the function, collect the initializers, and attempt to
1111      # extract them and run them eagerly. Fail only if we cannot do so.
1112      initializers = []
1113      self._initialize(args, kwargs, add_initializers_to=initializers)
1114
1115    # Note: using defun here avoids an infinite recursion.
1116    @function_lib.defun
1117    def initialize_variables():
1118      for v, init in initializers:
1119        v.assign(
1120            lift_to_graph.lift_to_graph([init], ops.get_default_graph())[init],
1121            read_value=False)
1122
1123    return initialize_variables.get_concrete_function()
1124
1125  def _list_all_concrete_functions(self):
1126    """Returns all concrete functions."""
1127    if self.input_signature is not None:
1128      self.get_concrete_function()
1129    concrete_functions = []
1130    # pylint: disable=protected-access
1131    if self._stateful_fn:
1132      concrete_functions.extend(
1133          self._stateful_fn._list_all_concrete_functions())
1134    if self._stateless_fn:
1135      concrete_functions.extend(
1136          self._stateless_fn._list_all_concrete_functions())
1137    # pylint: enable=protected-access
1138    return concrete_functions
1139
1140  def _list_all_concrete_functions_for_serialization(self):
1141    """Returns all concrete functions for serialization.
1142
1143    Returns:
1144      A list of instances of `ConcreteFunction`.
1145    """
1146    concrete_functions = self._list_all_concrete_functions()
1147    seen_signatures = []
1148    for concrete_function in concrete_functions:
1149      signature = concrete_function.structured_input_signature
1150      flattened = nest.flatten(signature)
1151      if any(
1152          isinstance(arg, func_graph_module.UnknownArgument)
1153          for arg in flattened):
1154        logging.info("Unsupported signature for serialization: %s.", signature)
1155        continue
1156      equal_to_signature = functools.partial(
1157          function_spec_lib.is_same_structure, signature, check_values=True)
1158      if not any(equal_to_signature(s) for s in seen_signatures):
1159        seen_signatures.append(signature)
1160
1161    # Re-create concrete functions for these signatures. Re-creating ensures
1162    # that if the cache key has changed, the function will be traced again.
1163    concrete_functions = []
1164    for args, kwargs in seen_signatures:
1165      concrete_functions.append(self.get_concrete_function(*args, **kwargs))
1166    return concrete_functions
1167
1168  def _trackable_children(self, save_type="checkpoint", **kwargs):
1169    """For implementing `Trackable`."""
1170    if save_type == "checkpoint":
1171      return {}
1172    return {f"trace_{n}": fn for n, fn in
1173            enumerate(self._list_all_concrete_functions_for_serialization())}
1174
1175  def _deserialization_dependencies(self, children):
1176    """Returns concrete functions which must be loaded before this object."""
1177    return children
1178
1179  def _get_concrete_function_garbage_collected(self, *args, **kwargs):
1180    """Returns a `ConcreteFunction` specialized to inputs and execution context.
1181
1182    Unlike `get_concrete_function(...)`, the graph will be deleted when the
1183    returned function is deleted.  It's useful to avoid creating a reference
1184    cycle when you know for sure that the graph will be no longer used without
1185    the returned function.
1186
1187    Args:
1188      *args: inputs to specialize on.
1189      **kwargs: inputs to specialize on.
1190
1191    Returns:
1192      A TensorFlow function which takes exactly one `tf.Tensor` per argument.
1193
1194    Raises:
1195      ValueError: if this object has not yet been called on concrete values.
1196    """
1197    with self._lock:
1198      if self._stateful_fn is None:
1199        initializers = []
1200        self._initialize(args, kwargs, add_initializers_to=initializers)
1201        self._initialize_uninitialized_variables(initializers)
1202
1203    if self._created_variables:
1204      # In this case we have created variables on the first call, so we run the
1205      # defunned version which is guaranteed to never create variables.
1206      return self._stateless_fn._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
1207          *args, **kwargs)
1208    elif self._stateful_fn is not None:
1209      # In this case we have not created variables on the first call. So we can
1210      # run the first trace but we should fail if variables are created.
1211      concrete = self._stateful_fn._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
1212          *args, **kwargs)
1213      if self._created_variables:
1214        raise ValueError("Creating variables on a non-first call to a function"
1215                         " decorated with tf.function.")
1216      return concrete
1217
1218  def get_concrete_function(self, *args, **kwargs):
1219    # Implements GenericFunction.get_concrete_function.
1220    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
1221    concrete._garbage_collector.release()  # pylint: disable=protected-access
1222    return concrete
1223
1224  def __get__(self, instance, owner):
1225    """Makes it possible to defun instance methods."""
1226    del owner
1227    # `instance` here is the instance that this `Function` was accessed through
1228    # e.g., for
1229    #
1230    #   class Foo:
1231    #
1232    #     @function.defun
1233    #     def bar(self):
1234    #       ...
1235    #
1236    #   foo = Foo()
1237    #   foo.bar()  # `foo.bar` is a `Function` instance
1238    #
1239    # then `instance` will be `foo` (and `owner` will be `Foo`).  For composite
1240    # tensors, we can just treat `instance` as a normal parameter.  But for
1241    # other types, we create a new instance of `Function` here to allow
1242    # different instances each to create variables once, thereby allowing
1243    # methods to be decorated with tf.function. Keeps a cache to avoid retracing
1244    # the function every time the descriptor is accessed.
1245    # TODO(mdan): Identify types which can just be parameters more generically.
1246    #
1247    # The check for instance._type_spec=None is used because certain classes
1248    # (including subclasses of tf.linalg.LinearOperator) are subclasses of
1249    # CompositeTensor but do not actually implement the required APIs.
1250    # TODO(b/199278478): Fix those classes, then remove the check for
1251    # `instance._type_spec is not None`.
1252    if (isinstance(instance, composite_tensor.CompositeTensor) and
1253        instance._type_spec is not None):  # pylint: disable=protected-access
1254      return types_lib.MethodType(self, instance)
1255    if instance not in self._descriptor_cache:
1256      if instance is None:
1257        return self
1258      # TODO(mdan): If the CompositeTensor path works, do the same here.
1259      # It's unclear whether we need the tf-decorator, or could just call
1260      # MethodType(self.clone(), instance)
1261      self._descriptor_cache[instance] = (
1262          function_lib.class_method_to_instance_method(self, instance))
1263    return self._descriptor_cache[instance]
1264
1265
1266@tf_export("function")
1267@deprecation.deprecated_args(None,
1268                             "experimental_compile is deprecated, use "
1269                             "jit_compile instead", "experimental_compile")
1270@deprecation.deprecated_args(None,
1271                             "experimental_relax_shapes is deprecated, use "
1272                             "reduce_retracing instead",
1273                             "experimental_relax_shapes")
1274def function(func=None,
1275             input_signature=None,
1276             autograph=True,
1277             jit_compile=None,
1278             reduce_retracing=False,
1279             experimental_implements=None,
1280             experimental_autograph_options=None,
1281             experimental_relax_shapes=None,
1282             experimental_compile=None,
1283             experimental_follow_type_hints=None) -> core.GenericFunction:
1284  """Compiles a function into a callable TensorFlow graph.
1285
1286  `tf.function` constructs a `tf.types.experimental.GenericFunction` that
1287  executes a TensorFlow graph (`tf.Graph`) created by trace-compiling the
1288  TensorFlow operations in `func`. More information on the topic can be found
1289  in [Introduction to Graphs and tf.function]
1290  (https://www.tensorflow.org/guide/intro_to_graphs).
1291
1292  See [Better Performance with tf.function]
1293  (https://www.tensorflow.org/guide/function) for tips on performance and
1294  known limitations.
1295
1296  Example usage:
1297
1298  >>> @tf.function
1299  ... def f(x, y):
1300  ...   return x ** 2 + y
1301  >>> x = tf.constant([2, 3])
1302  >>> y = tf.constant([3, -2])
1303  >>> f(x, y)
1304  <tf.Tensor: ... numpy=array([7, 7], ...)>
1305
1306  The trace-compilation allows non-TensorFlow operations to execute, but under
1307  special conditions. In general, only TensorFlow operations are guaranteed to
1308  run and create fresh results whenever the `GenericFunction` is called.
1309
1310  ## Features
1311
1312  `func` may use data-dependent Python control flow statements, including `if`,
1313  `for`, `while` `break`, `continue` and `return`:
1314
1315  >>> @tf.function
1316  ... def f(x):
1317  ...   if tf.reduce_sum(x) > 0:
1318  ...     return x * x
1319  ...   else:
1320  ...     return -x // 2
1321  >>> f(tf.constant(-2))
1322  <tf.Tensor: ... numpy=1>
1323
1324  `func`'s closure may include `tf.Tensor` and `tf.Variable` objects:
1325
1326  >>> @tf.function
1327  ... def f():
1328  ...   return x ** 2 + y
1329  >>> x = tf.constant([-2, -3])
1330  >>> y = tf.Variable([3, -2])
1331  >>> f()
1332  <tf.Tensor: ... numpy=array([7, 7], ...)>
1333
1334  `func` may also use ops with side effects, such as `tf.print`, `tf.Variable`
1335  and others:
1336
1337  >>> v = tf.Variable(1)
1338  >>> @tf.function
1339  ... def f(x):
1340  ...   for i in tf.range(x):
1341  ...     v.assign_add(i)
1342  >>> f(3)
1343  >>> v
1344  <tf.Variable ... numpy=4>
1345
1346  Important: Any Python side-effects (appending to a list, printing with
1347  `print`, etc) will only happen once, when `func` is traced. To have
1348  side-effects executed into your `tf.function` they need to be written
1349  as TF ops:
1350
1351  >>> l = []
1352  >>> @tf.function
1353  ... def f(x):
1354  ...   for i in x:
1355  ...     l.append(i + 1)    # Caution! Will only happen once when tracing
1356  >>> f(tf.constant([1, 2, 3]))
1357  >>> l
1358  [<tf.Tensor ...>]
1359
1360  Instead, use TensorFlow collections like `tf.TensorArray`:
1361
1362  >>> @tf.function
1363  ... def f(x):
1364  ...   ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
1365  ...   for i in range(len(x)):
1366  ...     ta = ta.write(i, x[i] + 1)
1367  ...   return ta.stack()
1368  >>> f(tf.constant([1, 2, 3]))
1369  <tf.Tensor: ..., numpy=array([2, 3, 4], ...)>
1370
1371  ## `tf.function` creates polymorphic callables
1372
1373  Internally, `tf.types.experimental.GenericFunction` may contain multiple
1374  `tf.types.experimental.ConcreteFunction`s, each specialized to arguments with
1375  different data types or shapes, since TensorFlow can perform more
1376  optimizations on graphs of specific shapes, dtypes and values of constant
1377  arguments. `tf.function` treats any pure Python values as opaque objects (best
1378  thought of as compile-time constants), and builds a separate `tf.Graph` for
1379  each set of Python arguments that it encounters.
1380  For more information, see the
1381  [tf.function guide](https://www.tensorflow.org/guide/function#rules_of_tracing)
1382
1383  Executing a `GenericFunction` will select and execute the appropriate
1384  `ConcreteFunction` based on the argument types and values.
1385
1386  To obtain an individual `ConcreteFunction`, use the
1387  `GenericFunction.get_concrete_function` method. It can be called with the
1388  same arguments as `func` and returns a
1389  `tf.types.experimental.ConcreteFunction`. `ConcreteFunction`s are backed by a
1390  single `tf.Graph`:
1391
1392  >>> @tf.function
1393  ... def f(x):
1394  ...   return x + 1
1395  >>> isinstance(f.get_concrete_function(1).graph, tf.Graph)
1396  True
1397
1398  `ConcreteFunction`s can be executed just like `GenericFunction`s, but their
1399  input is resticted to the types to which they're specialized.
1400
1401  ## Retracing
1402
1403  `ConcreteFunctions` are built (traced) on the fly, as the `GenericFunction` is
1404  called with new TensorFlow types or shapes, or with new Python values as
1405  arguments. When `GenericFunction` builds a new trace, it is said that `func`
1406  is retraced. Retracing is a frequent performance concern for `tf.function` as
1407  it can be considerably slower than executing a graph that's already been
1408  traced. It is ideal to minimize the amount of retracing in your code.
1409
1410  Caution: Passing python scalars or lists as arguments to `tf.function` will
1411  usually retrace. To avoid this, pass numeric arguments as Tensors whenever
1412  possible:
1413
1414  >>> @tf.function
1415  ... def f(x):
1416  ...   return tf.abs(x)
1417  >>> f1 = f.get_concrete_function(1)
1418  >>> f2 = f.get_concrete_function(2)  # Slow - compiles new graph
1419  >>> f1 is f2
1420  False
1421  >>> f1 = f.get_concrete_function(tf.constant(1))
1422  >>> f2 = f.get_concrete_function(tf.constant(2))  # Fast - reuses f1
1423  >>> f1 is f2
1424  True
1425
1426  Python numerical arguments should only be used when they take few distinct
1427  values, such as hyperparameters like the number of layers in a neural network.
1428
1429  ## Input signatures
1430
1431  For Tensor arguments, `GenericFunction`creates a new `ConcreteFunction` for
1432  every unique set of input shapes and datatypes. The example below creates two
1433  separate `ConcreteFunction`s, each specialized to a different shape:
1434
1435  >>> @tf.function
1436  ... def f(x):
1437  ...   return x + 1
1438  >>> vector = tf.constant([1.0, 1.0])
1439  >>> matrix = tf.constant([[3.0]])
1440  >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
1441  False
1442
1443  An "input signature" can be optionally provided to `tf.function` to control
1444  this process. The input signature specifies the shape and type of each
1445  Tensor argument to the function using a `tf.TensorSpec` object. More general
1446  shapes can be used. This ensures only one `ConcreteFunction` is created, and
1447  restricts the `GenericFunction` to the specified shapes and types. It is
1448  an effective way to limit retracing when Tensors have dynamic shapes.
1449
1450  >>> @tf.function(
1451  ...     input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
1452  ... def f(x):
1453  ...   return x + 1
1454  >>> vector = tf.constant([1.0, 1.0])
1455  >>> matrix = tf.constant([[3.0]])
1456  >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
1457  True
1458
1459  ## Variables may only be created once
1460
1461  `tf.function` only allows creating new `tf.Variable` objects when it is called
1462  for the first time:
1463
1464  >>> class MyModule(tf.Module):
1465  ...   def __init__(self):
1466  ...     self.v = None
1467  ...
1468  ...   @tf.function
1469  ...   def __call__(self, x):
1470  ...     if self.v is None:
1471  ...       self.v = tf.Variable(tf.ones_like(x))
1472  ...     return self.v * x
1473
1474  In general, it is recommended to create `tf.Variable`s outside of
1475  `tf.function`.
1476  In simple cases, persisting state across `tf.function` boundaries may be
1477  implemented using a pure functional style in which state is represented by
1478  `tf.Tensor`s passed as arguments and returned as return values.
1479
1480  Contrast the two styles below:
1481
1482  >>> state = tf.Variable(1)
1483  >>> @tf.function
1484  ... def f(x):
1485  ...   state.assign_add(x)
1486  >>> f(tf.constant(2))  # Non-pure functional style
1487  >>> state
1488  <tf.Variable ... numpy=3>
1489
1490  >>> state = tf.constant(1)
1491  >>> @tf.function
1492  ... def f(state, x):
1493  ...   state += x
1494  ...   return state
1495  >>> state = f(state, tf.constant(2))  # Pure functional style
1496  >>> state
1497  <tf.Tensor: ... numpy=3>
1498
1499  ## Python operations execute only once per trace
1500
1501  `func` may contain TensorFlow operations mixed with pure Python operations.
1502  However, when the function is executed, only the TensorFlow operations will
1503  run. The Python operations run only once, at trace time. If TensorFlow
1504  operations depend on results from Python operations, those results will be
1505  frozen into the graph.
1506
1507  >>> @tf.function
1508  ... def f(a, b):
1509  ...   print('this runs at trace time; a is', a, 'and b is', b)
1510  ...   return b
1511  >>> f(1, tf.constant(1))
1512  this runs at trace time; a is 1 and b is Tensor("...", shape=(), dtype=int32)
1513  <tf.Tensor: shape=(), dtype=int32, numpy=1>
1514
1515  >>> f(1, tf.constant(2))
1516  <tf.Tensor: shape=(), dtype=int32, numpy=2>
1517
1518  >>> f(2, tf.constant(1))
1519  this runs at trace time; a is 2 and b is Tensor("...", shape=(), dtype=int32)
1520  <tf.Tensor: shape=(), dtype=int32, numpy=1>
1521
1522  >>> f(2, tf.constant(2))
1523  <tf.Tensor: shape=(), dtype=int32, numpy=2>
1524
1525  ## Using type annotations to improve performance
1526
1527  `experimental_follow_type_hints` can be used along with type annotations to
1528  reduce retracing by automatically casting any Python values to `tf.Tensor`
1529  (something that is not done by default, unless you use input signatures).
1530
1531  >>> @tf.function(experimental_follow_type_hints=True)
1532  ... def f_with_hints(x: tf.Tensor):
1533  ...   print('Tracing')
1534  ...   return x
1535  >>> @tf.function(experimental_follow_type_hints=False)
1536  ... def f_no_hints(x: tf.Tensor):
1537  ...   print('Tracing')
1538  ...   return x
1539  >>> f_no_hints(1)
1540  Tracing
1541  <tf.Tensor: shape=(), dtype=int32, numpy=1>
1542  >>> f_no_hints(2)
1543  Tracing
1544  <tf.Tensor: shape=(), dtype=int32, numpy=2>
1545  >>> f_with_hints(1)
1546  Tracing
1547  <tf.Tensor: shape=(), dtype=int32, numpy=1>
1548  >>> f_with_hints(2)
1549  <tf.Tensor: shape=(), dtype=int32, numpy=2>
1550
1551  Args:
1552    func: The function to be compiled. If `func` is None, `tf.function` returns
1553      a decorator that can be invoked with a single argument - `func`. In other
1554      words, `tf.function(input_signature=...)(func)` is equivalent to
1555      `tf.function(func, input_signature=...)`. The former can be used as
1556      decorator.
1557    input_signature: A possibly nested sequence of `tf.TensorSpec` objects
1558      specifying the shapes and dtypes of the Tensors that will be supplied to
1559      this function. If `None`, a separate function is instantiated for each
1560      inferred input signature.  If input_signature is specified, every input to
1561      `func` must be a `Tensor`, and `func` cannot accept `**kwargs`.
1562    autograph: Whether autograph should be applied on `func` before tracing a
1563      graph. Data-dependent Python control flow statements require
1564      `autograph=True`. For more information, see the
1565      [tf.function and AutoGraph guide](
1566      https://www.tensorflow.org/guide/function#autograph_transformations).
1567    jit_compile: If `True`, compiles the function using
1568      [XLA](https://tensorflow.org/xla). XLA performs compiler optimizations,
1569      such as fusion, and attempts to emit more efficient code. This may
1570      drastically improve the performance. If set to `True`,
1571      the whole function needs to be compilable by XLA, or an
1572      `errors.InvalidArgumentError` is thrown.
1573      If `None` (default), compiles the function with XLA when running on TPU
1574      and goes through the regular function execution path when running on
1575      other devices.
1576      If `False`, executes the function without XLA compilation.  Set this value
1577      to `False` when directly running a multi-device function on TPUs (e.g. two
1578      TPU cores, one TPU core and its host CPU).
1579      Not all functions are compilable, see a list of
1580      [sharp corners](https://tensorflow.org/xla/known_issues).
1581    reduce_retracing: When True, `tf.function` attempts to reduce the
1582      amount of retracing, for example by using more generic shapes. This
1583      can be controlled for user objects by customizing their associated
1584      `tf.types.experimental.TraceType`.
1585    experimental_implements: If provided, contains a name of a "known" function
1586      this implements. For example "mycompany.my_recurrent_cell".
1587      This is stored as an attribute in inference function,
1588      which can then be detected when processing serialized function.
1589      See [standardizing composite ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md)  # pylint: disable=line-too-long
1590      for details.  For an example of utilizing this attribute see this
1591      [example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc)
1592      The code above automatically detects and substitutes function that
1593      implements "embedded_matmul" and allows TFLite to substitute its own
1594      implementations. For instance, a tensorflow user can use this
1595       attribute to mark that their function also implements
1596      `embedded_matmul` (perhaps more efficiently!)
1597      by specifying it using this parameter:
1598      `@tf.function(experimental_implements="embedded_matmul")`
1599      This can either be specified as just the string name of the function or
1600      a NameAttrList corresponding to a list of key-value attributes associated
1601      with the function name. The name of the function will be in the 'name'
1602      field of the NameAttrList. To define a formal TF op for this function
1603      implements, try the experimental [composite TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr)
1604      project.
1605    experimental_autograph_options: Optional tuple of
1606      `tf.autograph.experimental.Feature` values.
1607    experimental_relax_shapes: Deprecated. Use `reduce_retracing`
1608      instead.
1609    experimental_compile: Deprecated alias to 'jit_compile'.
1610    experimental_follow_type_hints: When True, the function may use type
1611      annotations from `func` to optimize the tracing performance. For example,
1612      arguments annotated with `tf.Tensor` will automatically be converted
1613      to a Tensor.
1614
1615  Returns:
1616     If `func` is not None, returns a `tf.types.experimental.GenericFunction`.
1617     If `func` is None, returns a decorator that, when invoked with a single
1618     `func` argument, returns a `tf.types.experimental.GenericFunction`.
1619
1620  Raises:
1621     `ValueError` when attempting to use `jit_compile=True`, but XLA support is
1622     not available.
1623  """
1624  if experimental_follow_type_hints is None:
1625    experimental_follow_type_hints = False
1626
1627  if jit_compile is None and JIT_COMPILE_FUNCTIONS:
1628    jit_compile = True
1629
1630  # TODO(b/224808187): Remove after renaming usages.
1631  if experimental_relax_shapes:
1632    reduce_retracing = True
1633
1634  def decorated(inner_function):
1635    try:
1636      name = inner_function.__name__
1637    except AttributeError:
1638      name = "function"
1639    return tf_decorator.make_decorator(
1640        inner_function,
1641        decorator_name="tf.function",
1642        decorator_func=Function(
1643            inner_function,
1644            name,
1645            input_signature=input_signature,
1646            autograph=autograph,
1647            experimental_autograph_options=experimental_autograph_options,
1648            reduce_retracing=reduce_retracing,
1649
1650            # TODO(b/171825496): Update once `experimental_compile` is removed
1651            # entirely in favor of 'jit_compile'.
1652            jit_compile=deprecation.deprecated_argument_lookup(
1653                "jit_compile",
1654                jit_compile,
1655                "experimental_compile",
1656                experimental_compile),
1657            experimental_implements=experimental_implements,
1658            experimental_follow_type_hints=experimental_follow_type_hints))
1659
1660  # This code path is for the `foo = tf.function(foo, ...)` use case
1661  if func is not None:
1662    return decorated(func)
1663
1664  # This code path is for the
1665  #
1666  # @tf.function(...)
1667  # def foo(...):
1668  #    ...
1669  #
1670  # use case, which is equivalent to `foo = tf.function(...)(foo)`
1671  return decorated
1672