xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/func_graph.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""FuncGraph and related functionality."""
16
17import collections as py_collections
18import traceback
19from typing import Any, Hashable, Callable, Mapping
20import weakref
21
22import numpy as np
23
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.python.eager import context
26from tensorflow.python.eager import execute
27from tensorflow.python.eager import tape
28from tensorflow.python.eager.graph_only_ops import graph_placeholder
29from tensorflow.python.framework import auto_control_deps
30from tensorflow.python.framework import composite_tensor
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import errors
34from tensorflow.python.framework import indexed_slices
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import tensor_spec
37from tensorflow.python.framework import tensor_util
38from tensorflow.python.framework import type_spec
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import handle_data_util
41from tensorflow.python.ops import resource_variable_ops
42from tensorflow.python.ops import tensor_array_ops
43from tensorflow.python.ops import variable_scope
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.saved_model import save_context
46from tensorflow.python.types import core
47from tensorflow.python.util import compat
48from tensorflow.python.util import memory
49from tensorflow.python.util import nest
50from tensorflow.python.util import object_identity
51from tensorflow.python.util import tf_contextlib
52from tensorflow.python.util import tf_decorator
53from tensorflow.python.util import tf_inspect
54from tensorflow.python.util.tf_export import tf_export
55
56
57ALLOWLIST_COLLECTIONS = [
58    ops.GraphKeys.GLOBAL_VARIABLES,
59    ops.GraphKeys.LOCAL_VARIABLES,
60    ops.GraphKeys.TRAINABLE_VARIABLES,
61    variable_scope._VARSTORE_KEY,  # pylint: disable=protected-access
62    variable_scope._VARSCOPESTORE_KEY  # pylint: disable=protected-access
63]
64
65_EAGER_CONST_THRESHOLD = 128
66
67
68class UnknownArgument(object):
69  """Signifies an argument which is not currently handled."""
70  pass
71
72
73def convert_structure_to_signature(structure, arg_names=None):
74  """Convert a potentially nested structure to a signature.
75
76  Args:
77    structure: Structure to convert, where top level collection is a list or a
78      tuple.
79    arg_names: Optional list of arguments that has equal number of elements as
80      `structure` and is used for naming corresponding TensorSpecs.
81
82  Returns:
83    Identical structure that has TensorSpec objects instead of Tensors and
84    UnknownArgument instead of any unsupported types.
85  """
86
87  def encode_arg(arg, path):
88    """A representation for this argument, for converting into signatures."""
89    if isinstance(arg, ops.Tensor):
90      user_specified_name = None
91      try:
92        user_specified_name = compat.as_str(
93            arg.op.get_attr("_user_specified_name"))
94      except ValueError:
95        pass
96
97      if path and user_specified_name and user_specified_name != path[0]:
98        # The user has explicitly named the argument differently than the name
99        # of the function argument.
100        name = user_specified_name
101      else:
102        name = "/".join(str(p) for p in path)
103      return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
104    if isinstance(arg, composite_tensor.CompositeTensor):
105      # TODO(b/133606651) Do we need to inject arg_name?
106      return arg._type_spec  # pylint: disable=protected-access
107    if isinstance(arg, resource_variable_ops.BaseResourceVariable):
108      return resource_variable_ops.VariableSpec.from_value(arg)
109    if isinstance(arg, (
110        int,
111        float,
112        bool,
113        str,
114        type(None),
115        dtypes.DType,
116        tensor_spec.TensorSpec,
117        type_spec.TypeSpec,
118    )):
119      return arg
120    return UnknownArgument()
121
122  # We are using the flattened paths to name the TensorSpecs. We need an
123  # explicit name for them downstream.
124  flattened = nest.flatten_with_tuple_paths(structure)
125  if arg_names:
126    if len(arg_names) != len(structure):
127      raise ValueError(
128          "Passed in arg_names don't match actual signature (%s)." % arg_names)
129    # Replace all top-level names with their actual arg_names. If a path before
130    # was "(2,'a',1)", it will become "(arg_names[2],'a',1)".
131    flattened = [
132        ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened
133    ]
134
135  mapped = [encode_arg(arg, path) for path, arg in flattened]
136  return nest.pack_sequence_as(structure, mapped)
137
138
139class CapturesContainer(object):
140  """A container class to store captures with a dict."""
141
142  def __init__(self):
143    # A dict that maps capture identifier -> function
144    self._captures = py_collections.OrderedDict()
145
146  def add_capture(self, identifier: Hashable,
147                  func: Callable[[], Any]):
148    self._captures[identifier] = func
149
150  def update(self, container: "CapturesContainer"):
151    # Add captures to self from other Container if not exist
152    assert isinstance(container, CapturesContainer)
153    for key, func in container.captures.items():
154      if key not in self._captures:
155        self._captures[key] = func
156
157  def get_snapshot(self) -> Mapping[Hashable, Any]:
158    snapshot = {}
159    for key, func in self.captures.items():
160      snapshot[key] = func()
161    return snapshot
162
163  @property
164  def captures(self) -> Mapping[Hashable, Any]:
165    return self._captures
166
167  def __len__(self):
168    return len(self._captures)
169
170
171@tf_export("__internal__.FuncGraph", v1=[])
172class FuncGraph(ops.Graph):
173  """Graph representing a function body.
174
175  Attributes:
176    name: The name of the function.
177    inputs: Placeholder tensors representing the inputs to this function. The
178      tensors are in this FuncGraph. This represents "regular" inputs as well as
179      captured inputs (i.e. the values of self.captures), with the regular
180      inputs coming first.
181    outputs: Tensors that will be returned by this function. The tensors are in
182      this FuncGraph.
183    control_outputs: Operations that must be executed before the function
184      represented by this graph can be said to have been executed.
185    structured_input_signature: A tuple of (args, kwargs), which are both
186      possibly-nested python objects that were received by this function. Note
187      that these structures might contain Python `None`s.
188    structured_outputs: A possibly-nested python object which will be returned
189      by this function. The Tensors in this structure are the same as those of
190      self.outputs. Note that this structure might contain Python `None`s.
191    variables: Variables that should be watched during function execution.
192    outer_graph: The graph this function is defined in. May be another FuncGraph
193      or the global default Graph.
194    captures: Maps external tensor -> internal tensor (i.e. input placeholder).
195      The entries are in the order they were captured.
196    control_captures: Set of external ops on which this graph has a control
197      dependency.
198    seed: The graph-level random seed.
199    capture_by_value: If True, the func graph will capture Variables by value
200      instead of reference.
201  """
202
203  def __init__(self,
204               name,
205               collections=None,
206               capture_by_value=None,
207               structured_input_signature=None,
208               structured_outputs=None):
209    """Construct a new FuncGraph.
210
211    The graph will inherit its graph key, collections, seed, and distribution
212    strategy stack from the current context or graph.
213
214    Args:
215      name: the name of the function.
216      collections: a dictionary of collections this FuncGraph should start with.
217        If not specified (None), the FuncGraph will read (but not write to) the
218        outer graph's collections that are not allowlisted, and both read and
219        write to the outer graph's collections that are allowlisted. The current
220        allowlisted collections are the global variables, the local variables,
221        and the trainable variables. Defaults to None.
222      capture_by_value: An optional boolean. If True, the func graph will
223        capture Variables by value instead of reference. By default inherit from
224        outer graphs, and failing that will default to False.
225      structured_input_signature: Optional. The structured input signature to
226        use for initializing the FuncGraph. See the docstring for FuncGraph for
227        more information.
228      structured_outputs: Optional. The structured outputs to use for
229        initializing the FuncGraph. See the docstring for FuncGraph for more
230        information.
231    """
232    super(FuncGraph, self).__init__()
233    self.name = name
234    self.inputs = []
235    self.outputs = []
236    self.control_outputs = []
237    self.control_captures = object_identity.ObjectIdentitySet()
238    self.structured_input_signature = structured_input_signature
239    self.structured_outputs = structured_outputs
240    self._weak_variables = []
241    self._watched_variables = object_identity.ObjectIdentityWeakSet()
242    self.is_control_flow_graph = False
243
244    outer_graph = ops.get_default_graph()
245    self._weak_outer_graph = weakref.ref(outer_graph)
246    while outer_graph.building_function:
247      outer_graph = outer_graph.outer_graph
248    # If self._weak_outer_graph is deleted, we revert to the outermost Graph
249    # active when the FuncGraph was traced. This will not be a FuncGraph.
250    self._fallback_outer_graph = outer_graph
251    self._captures = py_collections.OrderedDict()
252    # Maps capture identifier -> lambda function that returns capture values
253    # Used to get runtime value to determine if retracing is needed.
254    self._capture_func_lib = CapturesContainer()
255    # Maps capture identifier -> a container with the same structure as
256    # the original side input, except tensors are replaced with placeholders.
257    # Used to fetch existing placeholders and prevent repeated creatation.
258    self._capture_placeholder_lib = py_collections.OrderedDict()
259    # If not None, records the names of output args of this function. Used to
260    # preserve the output names in the signature of a serialized+deserialized
261    # function. Private at the moment mostly because it's often out of date.
262    self._output_names = None
263    # Maps arbitrary key -> (closure, nest of placeholders), where at function
264    # call time the value of closure() will be used to feed the nest of
265    # placeholders.
266    self._deferred_captures = py_collections.OrderedDict()
267    # Inherit capture-by-value from outer graph.
268    if capture_by_value is not None:
269      self.capture_by_value = capture_by_value
270    elif self.outer_graph is not None and isinstance(self.outer_graph,
271                                                     FuncGraph):
272      self.capture_by_value = self.outer_graph.capture_by_value
273    else:
274      self.capture_by_value = False
275
276    self._building_function = True
277    # Map from resource tensor name to last op (in program order) which uses
278    # this tensor. Used to enforce that execution order matches program order
279    # for resource tensors.
280    self._last_op_using_resource_tensor = {}
281
282    graph = self.outer_graph
283
284    if context.executing_eagerly():
285      self.seed = context.global_seed()
286      # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of
287      # any None op_seed for random_op in the function, in which case we end up
288      # using function seed, which could be unintended behavior for the op.
289      self._seed_used = False
290    else:
291      self.seed = graph.seed
292      self._seed_used = False
293      # TODO(allenl): Figure out if we can remove colocation stack
294      # specialization (currently used in cond_v2), here and in the cache key.
295      self._colocation_stack = graph._colocation_stack.copy()  # pylint: disable=protected-access
296
297    if collections is None:
298      for collection_name in graph.get_all_collection_keys():
299        if collection_name not in ALLOWLIST_COLLECTIONS:
300          self._collections[collection_name] = graph.get_collection(
301              collection_name)
302      for collection_name in ALLOWLIST_COLLECTIONS:
303        self._collections[collection_name] = graph.get_collection_ref(
304            collection_name)
305    else:
306      self._collections = collections
307
308    # Keep track of whether this FuncGraph is exportable to SavedModel. Use
309    # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any
310    # dependent functions as unsaveable.
311    self._saveable = True
312    self._saving_errors = set()
313
314    # Keep track of callbacks to run when this graph exits default scope
315    self._scope_exit_callbacks = None
316
317  def __str__(self):
318    return "FuncGraph(name=%s, id=%s)" % (self.name, id(self))
319
320  def watch_variable(self, v):
321    """Marks the variable v as accessed while building this graph."""
322    while self is not None and isinstance(self, FuncGraph):
323      self._watched_variables.add(v)
324      self = self.outer_graph
325
326  def capture_call_time_value(self,
327                              closure,
328                              spec,
329                              key=None,
330                              default_value=None,
331                              placeholder=None):
332    """Returns a placeholder which at call time has the value closure().
333
334    The `tf.function` supports the notion of captures, that is, it allows Python
335    functions to have closure variables, which bind over some value outside the
336    function. However, this name binding is "early binding" performed before the
337    program is run, i.e.,
338    ```
339    @tf.function
340    def f():
341      return x
342
343    x = tf.constant(1)
344    f()  # returns 1
345
346    x = tf.constant(2)
347    f()  # still returns 1!
348    ```
349    while in Python, name binding is performed as the program is running.
350    ```
351    def f():
352      return x
353
354    x = 1
355    f()  # returns 1
356
357    x = 2
358    f()  # returns 2
359    ```
360    `capture_call_time_value` allows tf.function to mimic late binding as a
361    Python function does, by passing in a `closure` callable argument to be
362    executed when the tf.function is invoked eagerly.  E.g.
363    ```
364    @tf.function
365    def f():
366      return ops.get_default_graph.capture_call_time_value(lambda: x)
367
368    x = tf.constant(1)
369    f()  # returns 1
370
371    x = tf.constant(2)
372    f()  # returns 2
373    ```
374    Note that a `capture_call_time_value` function itself does not work well in
375    the saving process (since the tf.function in which it's called is not
376    invoked eagerly) unless passed a `default_value` argument. At saving time,
377    the `default_value` argument is returned instead.
378
379    Args:
380      closure: function which takes no arguments, to be evaluated at function
381        call time, returning a nest of tensors compatible with `spec`.
382      spec: nest of TypeSpec for the value to capture.
383      key: optional. If not None, multiple calls to lazy_capture with the same
384        key in the same graph will return the same placeholder, and the first
385        closure will be used at function call time.
386      default_value: optional value to return in environments that cannot safely
387        evaluate closure.
388      placeholder: optional. If not None, the graph will take the passed-in
389        `placeholder` as the internal capture instead of creating a new one.
390        This is useful when loading from a SavedModel.
391
392    Returns:
393      Nest of placeholders which, at function call time, will be fed with the
394      result of calling closure().
395
396    Raises:
397      ValueError: at function call time, if the return value of closure() is
398       not compatible with `spec`.
399    """
400    if key is None:
401      key = object()
402    if key not in self._deferred_captures:
403
404      if placeholder is None:
405
406        def convert_to_placeholder(s):
407          if not isinstance(s, tensor_spec.DenseSpec):
408            raise TypeError(
409                "Expected a nest of `TypeSpec` objects, found %s of type %s." %
410                (s, type(s)))
411          return array_ops.placeholder(dtype=s.dtype, shape=s.shape)
412
413        placeholder = nest.map_structure(
414            convert_to_placeholder, spec, expand_composites=True)
415
416      def wrapped_closure():
417
418        # One major case requiring returning a `default_value` is when passing a
419        # concrete function to `save`, i.e.
420        # serving_fn = serve_fn.get_concrete_function(...)
421        # model.save(save_dir, signatures={"serving_default": serving_fn})
422        # `serving_fn` has deferred captures added through
423        # `capture_call_time_value`. It can't be saved correctly since
424        # `wrapped_closure` will end up executing under a default Graph instead
425        # of FuncGraph. The user of `capture_call_time_value` also cannot
426        # conditionally avoid this call since presence of `save_context` when
427        # executing `wrapped_closure` is not known at tracing time of
428        # `serving_fn`.
429        if save_context.in_save_context() and default_value is not None:
430          return default_value
431        # TODO(wxinyi): raise an error if in save context but no default value.
432
433        if not context.executing_eagerly():
434          graph = ops.get_default_graph()
435
436          # In the case of control flow, we need to capture the
437          # external_captures (deferred or not) of the body_graph (i.e.
438          # `WhileBodyFuncGraph) in `cond_graph` (i.e. WhileCondFuncGraph) and
439          # create the corresponding placeholders in `cond_graph` so that it
440          # expects to receive these as arguments. However, doing so requires
441          # having evaluated the call_time_value already (and maybe repeatedly),
442          # so we skip adding deferred_captures to the control flow graph but
443          # add it to its outer graph.
444          while graph.is_control_flow_graph:
445            graph = graph.outer_graph
446
447          with graph.as_default():
448            ret_nest = graph.capture_call_time_value(
449                closure, spec, key=key, default_value=default_value)
450        else:
451          ret_nest = closure()
452
453        nest.assert_same_structure(spec, ret_nest, expand_composites=True)
454        # This uses the tensor dtype defined in `spec` when converting values
455        # in `ret_nest` to tensors.
456        # pylint: disable=protected-access
457        y = nest.map_structure(
458            lambda s, r: s._to_components(r),
459            spec,
460            ret_nest,
461            expand_composites=False)
462        # pylint: enable=protected-access
463        return nest.flatten(y, expand_composites=True)
464
465      wrapped_closure.output_spec = spec
466      self._deferred_captures[key] = (wrapped_closure, placeholder)
467    return self._deferred_captures[key][1]
468
469  def control_dependencies(self, control_inputs):
470    """Handles control dependencies.
471
472    FuncGraph wraps Graph's control_dependencies logic by first filtering out
473    any external tensors / operations and storing them in the graph's
474    control_captures member. Any consumers of this function graph must then
475    decide how to handle the control captures.
476
477    Args:
478      control_inputs: A list of `Operation` or `Tensor` objects which must be
479        executed or computed before running the operations defined in the
480        context.  Can also be `None` to clear the control dependencies.
481
482    Returns:
483     A context manager that specifies control dependencies for all
484     operations constructed within the context.
485
486    Raises:
487      TypeError: If `control_inputs` is not a list of `Operation` or
488        `Tensor` objects.
489    """
490    if control_inputs is None:
491      return super(FuncGraph, self).control_dependencies(control_inputs)
492
493    filtered_control_inputs = []
494    for c in control_inputs:
495      # Check for _UnreadVariable
496      if (isinstance(c, indexed_slices.IndexedSlices) or
497          (hasattr(c, "_handle") and hasattr(c, "op"))):
498        c = c.op
499      graph_element = ops._as_graph_element(c)  # pylint: disable=protected-access
500      if graph_element is None:
501        graph_element = c
502      if graph_element is not None and getattr(graph_element, "graph",
503                                               None) is not self:
504        self.control_captures.add(graph_element)
505      else:
506        filtered_control_inputs.append(graph_element)
507    return super(FuncGraph, self).control_dependencies(filtered_control_inputs)
508
509  def as_default(self):
510    outer_cm = super(FuncGraph, self).as_default()
511
512    @tf_contextlib.contextmanager
513    def inner_cm():
514      """Context manager for copying distribute.Strategy scope information."""
515      # pylint: disable=protected-access
516      # TODO(b/112906995, nareshmodi): distribution strategy depends on
517      # inheriting this stack from the default graph even in eager mode. Maybe
518      # it should be part of the eager context? This would also allow us to
519      # remove a get_default_graph() call from the function cache lookup.
520      graph = ops.get_default_graph()
521      old_strategy_stack = self._distribution_strategy_stack
522      self._distribution_strategy_stack = list(
523          graph._distribution_strategy_stack)
524
525      # We ignore device placements from any outer scopes while tracing the
526      # function when possible, to avoid hard-coding them in the function
527      # graph. "Default" placements come from the PartitionedCallOp's placement,
528      # so that the same trace of the Python function may be placed on several
529      # different devices and saved functions may be placed on new devices when
530      # restored.
531      # However, we need to preserve the outer device stack in the following
532      # cases in non eager context:
533      # 1. device stack is callable
534      # 2. When using distribution strategy with legacy graph mode.
535      old_device_stack = self._device_function_stack
536      if (not context.executing_eagerly() and
537          (device_stack_has_callable(graph._device_function_stack) or
538           (self._distribution_strategy_stack and
539            not ops.executing_eagerly_outside_functions()))):
540        # Hard-code devices from device functions in the function body
541        self._device_function_stack = graph._device_function_stack.copy()
542
543      old_creator_stack = self._variable_creator_stack
544      self._variable_creator_stack = graph._variable_creator_stack
545      # Inherit the graph key, since this is used for matching variables in
546      # optimizers.
547      old_graph_key = self._graph_key
548      self._graph_key = graph._graph_key
549      # pylint: enable=protected-access
550
551      old_scope_exit_callbacks = self._scope_exit_callbacks
552      self._scope_exit_callbacks = []
553
554      with outer_cm as g:
555        try:
556          yield g
557        finally:
558          try:
559            for fn in self._scope_exit_callbacks:
560              fn()
561          finally:
562            self._scope_exit_callbacks = old_scope_exit_callbacks
563            self._distribution_strategy_stack = old_strategy_stack
564            self._device_function_stack = old_device_stack
565            self._variable_creator_stack = old_creator_stack
566            self._graph_key = old_graph_key
567
568    return inner_cm()
569
570  @property
571  def outer_graph(self):
572    """The Graph this FuncGraph is nested in.
573
574    Functions may capture Tensors from graphs they are nested in (transitive).
575
576    Returns:
577      A Graph object. Initially set to the current default graph when the
578      FuncGraph was created. If the previous `outer_graph` was deleted because
579      the function that owns it was deleted, `outer_graph` is reset to the
580      outermost default graph active when the FuncGraph was created. This
581      FuncGraph won't have captured anything from the new `outer_graph` (and
582      likely not from the previous setting, since that would have created a
583      strong reference), but it is returned so that FuncGraphs always have a
584      parent.
585    """
586    current = self._weak_outer_graph()
587    if current is None:
588      return self._fallback_outer_graph
589    return current
590
591  @outer_graph.setter
592  def outer_graph(self, new_outer_graph):
593    """Sets `outer_graph` to `new_outer_graph`."""
594    self._weak_outer_graph = weakref.ref(new_outer_graph)
595
596  @property
597  def output_types(self):
598    return [t.dtype for t in self.outputs]
599
600  @property
601  def output_shapes(self):
602    return [t.shape for t in self.outputs]
603
604  @property
605  def trainable_variables(self):
606    """A sequence of trainable variables accessed by this FuncGraph.
607
608    Note that functions keep only weak references to variables. Calling the
609    function after a variable it accesses has been deleted is an error.
610
611    Returns:
612      Sequence of trainable variables for this func graph.
613    """
614    return tuple(v for v in self.variables if v.trainable)
615
616  @property
617  def variables(self):
618    """A sequence of variables accessed by this FuncGraph.
619
620    Note that functions keep only weak references to variables. Calling the
621    function after a variable it accesses has been deleted is an error.
622
623    Returns:
624      Sequence of variables for this func graph.
625    """
626
627    def deref(weak_v):
628      v = weak_v()
629      if v is None:
630        raise AssertionError(
631            "Called a function referencing variables which have been deleted. "
632            "This likely means that function-local variables were created and "
633            "not referenced elsewhere in the program. This is generally a "
634            "mistake; consider storing variables in an object attribute on "
635            "first call.")
636      return v
637
638    return tuple(deref(v) for v in self._weak_variables)
639
640  @variables.setter
641  def variables(self, var_list):
642    self._weak_variables = [weakref.ref(v) for v in var_list]
643
644  def _capture_by_value(
645      self,
646      op_type,
647      inputs,
648      dtypes,  # pylint: disable=redefined-outer-name
649      input_types=None,
650      name=None,
651      attrs=None,
652      op_def=None,
653      compute_device=True):
654    # When capturing by value, do the read outside
655    reverse_captures = dict((id(v), k) for k, v in self.captures)
656    uncaptured_inputs = [reverse_captures.get(id(t), t) for t in inputs]
657    with ops.init_scope():
658      if context.executing_eagerly():
659        attr_list = ("dtype", int(attrs["dtype"].type))
660        value, = execute.execute(
661            compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list,
662            context.context())
663      else:
664        op = ops.get_default_graph()._create_op_internal(  # pylint: disable=protected-access
665            op_type, uncaptured_inputs, dtypes, input_types, name, attrs,
666            op_def, compute_device)
667        value = op.outputs[0]
668    captured_value = self.capture(value)
669    return captured_value.op
670
671  def _create_op_internal(
672      self,
673      op_type,
674      inputs,
675      dtypes=None,  # pylint: disable=redefined-outer-name
676      input_types=None,
677      name=None,
678      attrs=None,
679      op_def=None,
680      compute_device=True):
681    """Like Graph.create_op, except handles external input tensors.
682
683    This overload adds functionality to create_op to "capture" any external
684    input tensors, i.e. tensors from the eager context or outer function graphs
685    if this is a nested function. See `capture` for more information.
686
687    Args:
688      op_type: The `Operation` type to create. This corresponds to the
689        `OpDef.name` field for the proto that defines the operation.
690      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
691      dtypes: (Optional) A list of `DType` objects that will be the types of the
692        tensors that the operation produces.
693      input_types: (Optional.) A list of `DType`s that will be the types of the
694        tensors that the operation consumes. By default, uses the base `DType`
695        of each input in `inputs`. Operations that expect reference-typed inputs
696        must specify `input_types` explicitly.
697      name: (Optional.) A string name for the operation. If not specified, a
698        name is generated based on `op_type`.
699      attrs: (Optional.) A dictionary where the key is the attribute name (a
700        string) and the value is the respective `attr` attribute of the
701        `NodeDef` proto that will represent the operation (an `AttrValue`
702        proto).
703      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
704        the operation will have.
705      compute_device: (Optional.) If True, device functions will be executed to
706        compute the device property of the Operation.
707
708    Returns:
709      An `Operation` object.
710    """
711    if self.capture_by_value and op_type in [
712        "ReadVariableOp", "ResourceGather"
713    ]:
714      return self._capture_by_value(op_type, inputs, dtypes, input_types, name,
715                                    attrs, op_def, compute_device)
716
717    # This capturing logic interacts poorly with control flow contexts which
718    # want to replace inputs of ops far too late in the process. This can lead
719    # the context to get confused and try to create an Enter for an Enter. We
720    # can detect this here and skip the additional Enter which can confuse loop
721    # validation logic.
722    if op_type == "Enter" and inputs[0].op.type == "Enter":
723      if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
724        return inputs[0].op
725    # Calling AddValue on the control flow contexts to force creation of the
726    # backward accumulators in the original graph before we create placeholders
727    # to capture the inputs.
728    ctxt = ops.get_default_graph()._control_flow_context  # pylint: disable=protected-access
729    # Use a different list to avoid modifying the original inputs list.
730    captured_inputs = []
731    for inp in inputs:
732      # TPU Estimator defines a control flow context with no AddValue method.
733      if ctxt is not None and hasattr(ctxt, "AddValue"):
734        inp = ctxt.AddValue(inp)
735      inp = self.capture(inp)
736      captured_inputs.append(inp)
737    return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
738        op_type, captured_inputs, dtypes, input_types, name, attrs, op_def,
739        compute_device)
740
741  def capture(self, tensor, name=None, shape=None):
742    """Captures `tensor` if it's external to this graph.
743
744    If `tensor` is from a different graph, returns a placeholder for it.
745    `tensor` and the placeholder will appear in self.captures, and the
746    placeholder will appear in self.inputs.  Multiple calls to this method with
747    the same `tensor` argument will return the same placeholder. If `tensor` is
748    from this graph, returns `tensor`.
749
750    Args:
751      tensor: Tensor. May be from this FuncGraph or a different graph.
752      name: Optional name if a placeholder is created.
753      shape: Optional shape if a placeholder is created.
754
755    Returns:
756      Tensor from this FuncGraph.
757
758    Raises:
759      InaccessibleTensorError: if any tensors are accessed in a manner that
760      bypasses the mechanisms required for the data dependencies to be correctly
761      wired.
762    """
763    if isinstance(tensor, ops.EagerTensor):
764      if name is None:
765        name = str(ops.uid())
766
767      # Small EagerTensors are captured with Const ops
768      if (tensor.dtype in dtypes.TF_VALUE_DTYPES and
769          np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD):
770        return self.capture_eager_tensor(tensor, name)
771
772      # Large EagerTensors and resources are captured with Placeholder ops
773      return self._capture_helper(tensor, name, shape)
774    if tensor.graph is not self:
775      if name is None:
776        name = tensor.op.name
777      inner_graph = tensor.graph
778      while inner_graph is not None and isinstance(inner_graph, FuncGraph):
779        if inner_graph is self:
780          try:
781            tb = tensor.op.traceback
782          except AttributeError:
783            tensor_traceback = "<unknown>"
784          else:
785            tensor_traceback_list = []
786            for frame in traceback.format_list(tb.get_user_frames()):
787              tensor_traceback_list.extend(
788                  [f"  {line}" for line in frame.split("\n") if line.strip()])
789            tensor_traceback = "\n".join(tensor_traceback_list)
790          # Keep in sync with tfe_wrapper.cc.
791          # TODO(b/200991648): Unify those two paths.
792          raise errors.InaccessibleTensorError(
793              f"{tensor!r} is out of scope and cannot be used here. Use return "
794              "values, explicit Python locals or TensorFlow collections to "
795              "access it.\n"
796              "Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values "
797              "for more information.\n\n"
798              f"{tensor!r} was defined here:\n{tensor_traceback}\n\n"
799              f"The tensor {tensor!r} cannot be accessed from {self}, because "
800              f"it was defined in {tensor.graph}, which is out of scope.")
801        inner_graph = inner_graph.outer_graph
802      return self._capture_helper(tensor, name)
803    return tensor
804
805  def _capture_helper(self, tensor, name, shape=None):
806    capture = self._captures.get(id(tensor))
807    if capture is None:
808      placeholder = _create_substitute_placeholder(
809          tensor, name=name, dtype=tensor.dtype, shape=shape)
810      # Record the composite device as an attribute to the placeholder.
811      # This attribute would be propogated into the arg_attr of the FunctionDef.
812      # Currently, a packed eager tensor is always placed on a CompositeDevice.
813      if isinstance(tensor, ops.EagerTensor) and tensor.is_packed:
814        placeholder.op._set_attr(  # pylint: disable=protected-access
815            "_composite_device",
816            attr_value_pb2.AttrValue(s=compat.as_bytes(tensor.device)))
817      self.add_capture(tensor, placeholder)
818    else:
819      placeholder = capture[1]
820    tape.record_operation(
821        "captured_value", [placeholder], [tensor],
822        backward_function=lambda x: [x],
823        forward_function=lambda x: [x])
824    return placeholder
825
826  def _experimental_capture_side_input_by_ref(self, identifier: Hashable,
827                                              func: Callable[[], Any]) ->...:
828    """Implement capturing side input by reference for tf.function.
829
830    Args:
831      identifier: A hashable object as the key for the capture.
832      func: A Python function that takes no arguments and returns the value of
833        side input. The function is evaluated at function call time.
834
835    Returns:
836      A nested structure with the same structure as the side input. Tensors
837        are replaced with placehoders, and non-tensors remain the same.
838
839    """
840    # Support manual capture for inner nested tf.function is not possible at the
841    # moment. Inner here means any tf.function wrapped by another tf.function.
842    # Usage inside the outer most tf.function only is fine.
843    # The infeasibility is due to it's impossible to determine the
844    # definition scope of the captured side input. This info is needed when
845    # propagating inner tf.function captures to outer tf.function.
846    if isinstance(self.outer_graph, FuncGraph):
847      raise NotImplementedError(
848          ("Manual side input usage for inner nested tf.function is not "
849           f"supported. Got side input: {identifier}."))
850
851    # Prevent repeated captures
852    if identifier in self._capture_placeholder_lib:
853      return self._capture_placeholder_lib[identifier]
854
855    nested_placeholder = self._maybe_create_capture_placeholder(func)
856    self._capture_func_lib.add_capture(identifier, func)
857    self._capture_placeholder_lib[identifier] = nested_placeholder
858    return nested_placeholder
859
860  def _maybe_create_capture_placeholder(self, func: Callable[[], Any]) -> ...:
861    """Create placeholder if the input is tensor."""
862    values_nest = func()
863
864    if context.executing_eagerly():
865      return values_nest
866
867    values_flat = nest.flatten(values_nest)
868    # Return values in flat format. It consists of placeholders and non-tensor
869    # values.
870    return_flat = []
871    tensor_spec_flat = []
872    # Create return_flat and replace tensors with None. Later, each None is
873    # replaced again by corresponding placeholders
874    for value in values_flat:
875      if isinstance(value, core.Tensor):
876        return_flat.append(None)
877        tensor_spec_flat.append(type_spec.type_spec_from_value(value))
878      elif isinstance(value, set) or isinstance(value, frozenset):
879        raise NotImplementedError(
880            (f"Side input returned by '{tf_inspect.getsource(func).strip()}' "
881             f"has element of {type(value)} type, which is currently not "
882             "supported by tf.function."))
883      else:
884        return_flat.append(value)
885    if tensor_spec_flat:
886
887      def tensor_func():
888        values = nest.flatten(func())
889        return [value for value in values if isinstance(value, core.Tensor)]
890
891      placeholder_flat = self.capture_call_time_value(
892          tensor_func, tensor_spec_flat)
893      # replace None that represents tensors with placehoders
894      flat_ptr = 0
895      for idx, item in enumerate(return_flat):
896        if item is None:
897          return_flat[idx] = placeholder_flat[flat_ptr]
898          flat_ptr += 1
899    return_nest = nest.pack_sequence_as(values_nest, return_flat)
900    return return_nest
901
902  @property
903  def captures(self):
904    """Order list of tuples containing external and internal captures."""
905    return self._captures.values()
906
907  def add_capture(self, tensor, placeholder):
908    """Capture a specific tensor and utilize the provided placeholder.
909
910    Args:
911      tensor: Tensor to captures.
912      placeholder: Provided placeholder for the tensor.
913    """
914    self._captures[id(tensor)] = (tensor, placeholder)
915    self.inputs.append(placeholder)
916
917  def replace_capture(self, tensor, placeholder):
918    """Replace already existing capture."""
919    self._captures[id(tensor)] = (tensor, placeholder)
920
921  def replace_capture_with_deferred_capture(self,
922                                            tensor,
923                                            closure,
924                                            spec,
925                                            placeholder,
926                                            default_value=None):
927    """Replaces existing capture `tensor` with a deferred capture `closure`.
928
929    Caution: It is the caller's responsibility to make sure that, after calling
930    this function, the TypeSpec of the `inputs` (i.e. internal placeholders) and
931    the `_captured_inputs` (i.e. external captures) of a concrete function that
932    wraps this function graph are still compatible. Thus user should pairing
933    usage of this function with `ConcreteFunction.set_external_captures` to make
934    sure the order still matches. For example,
935    ```
936    # concrete_fn._captured_inputs == [tensor1, tensor2, tensor3]
937    # concrete_fn.inputs == [placeholder1, placeholder2, placeholder3]
938    # replace external capture `tensor2` with a deferred_capture, i.e., a
939    # closure, `closure2`
940    concrete_fn.graph.replace_capture_with_deferred_capture(tensor2,
941                                                            closure2,
942                                                            placeholder2,
943                                                            some_spec,
944                                                            some_default)
945    concrete_fn.set_external_captures([tensor1, closure2, tensor3])
946    ```
947
948    Args:
949      tensor: Tensor already captured.
950      closure: function which takes no arguments, to be evaluated at function
951        call time, returning a nest of tensors compatible with `spec`.
952      spec: nest of TypeSpec for the value to capture.
953      placeholder: the internal placeholder corresponding to the captured
954        `tensor`.
955      default_value: optional value to use in environments that cannot safely
956        evaluate closure.
957    """
958    if id(tensor) in self._captures:
959      self.pop_capture(tensor)
960    self.capture_call_time_value(
961        closure,
962        spec,
963        key=id(tensor),
964        default_value=default_value,
965        placeholder=placeholder)
966
967  def reset_captures(self, capture_list):
968    """Set the captures with the provided list of captures & placeholder."""
969    self._captures = py_collections.OrderedDict()
970    for tensor, placeholder in capture_list:
971      self._captures[id(tensor)] = (tensor, placeholder)
972
973  def pop_capture(self, tensor):
974    """Remove the capture and return the generated placeholder."""
975    capture = self._captures.pop(id(tensor), None)
976    if capture is None:
977      return None
978
979    return capture[1]
980
981  def clear_captures(self):
982    # TODO(b/115366440): Delete this method when a custom OrderedDict is added.
983    # Clearing captures using clear() leaves some cycles around.
984    while self._captures:
985      self._captures.popitem()
986    memory.dismantle_ordered_dict(self._captures)
987    while self._deferred_captures:
988      self._deferred_captures.popitem()
989    memory.dismantle_ordered_dict(self._deferred_captures)
990
991  def capture_distributed_variable(self, variable, placeholder):
992    """Add given distributed variable to captures with given placeholder."""
993    self._captures[id(variable)] = (variable, placeholder)
994    tape.record_operation(
995        "captured_value", [placeholder], [variable],
996        backward_function=lambda x: [x],
997        forward_function=lambda x: [x])
998
999  def capture_eager_tensor(self, tensor, name):
1000    capture = self._captures.get(id(tensor))
1001    if capture is None:
1002      with ops.control_dependencies(None):
1003        constant_value = tensor_util.constant_value(tensor)
1004        if constant_value is None:
1005          # Some eager tensors, e.g. parallel tensors, are not convertible to a
1006          # single constant. We'll use a placeholder for this case.
1007          return self._capture_helper(tensor, name)
1008        graph_const = constant_op.constant(
1009            constant_value, dtype=tensor.dtype, shape=tensor.shape, name=name)
1010      self.add_capture(tensor, graph_const)
1011    else:
1012      graph_const = capture[1]
1013    tape.record_operation(
1014        "captured_value", [graph_const], [tensor],
1015        backward_function=lambda x: [x],
1016        forward_function=lambda x: [x])
1017    return graph_const
1018
1019  def captured(self, tensor):
1020    """Check if the specified tensor has been captured."""
1021    return id(tensor) in self._captures
1022
1023  @property
1024  def external_captures(self):
1025    """External tensors captured by this function."""
1026    return [c[0] for c in self._captures.values()]
1027
1028  @property
1029  def internal_captures(self):
1030    """Placeholders in this function corresponding captured tensors."""
1031    return [c[1] for c in self._captures.values()]
1032
1033  @property
1034  def deferred_external_captures(self):
1035    """Ordered nest of tensors whose placeholders will be fed at call time."""
1036    return [c[0] for c in self._deferred_captures.values()]
1037
1038  @property
1039  def deferred_internal_captures(self):
1040    """List of nest of placeholders which at call time will be fed."""
1041    return [c[1] for c in self._deferred_captures.values()]
1042
1043  @property
1044  def variable_captures(self):
1045    """Map of python object ids of variables to variables which are captured."""
1046    return {
1047        id(self._captures[id(v)][1]): v
1048        for v in self.variables
1049        if id(v) in self._captures
1050    }
1051
1052  def mark_as_unsaveable(self, error_message):
1053    """Marks this FuncGraph as unsaveable.
1054
1055    Any attempts to export this FuncGraph will raise an error with the specified
1056    message.
1057
1058    Args:
1059      error_message: List or string containing the error message to be raised
1060        when saving this FuncGraph to SavedModel.
1061    """
1062    self._saveable = False
1063    if isinstance(error_message, str):
1064      error_message = [error_message]
1065    self._saving_errors.update(error_message)
1066
1067  @property
1068  def saveable(self):
1069    """Returns whether this FuncGraph is saveable."""
1070    return self._saveable
1071
1072  @property
1073  def saving_errors(self):
1074    """Returns set of errors preventing this FuncGraph from being saved."""
1075    return self._saving_errors
1076
1077  def _add_scope_exit_callback(self, fn):
1078    """Add a function to call when this graph exits the default scope."""
1079    if not callable(fn):
1080      raise TypeError("fn is not callable: {}".format(fn))
1081    if self._scope_exit_callbacks is None:
1082      raise RuntimeError(
1083          "Attempting to add a scope exit callback, but the default graph is "
1084          "not the context scope graph.  Did you forget to call "
1085          "'with graph.as_default(): ...'?")
1086    self._scope_exit_callbacks.append(fn)
1087
1088
1089# TODO(mdan): Too many threaded arguments. Accept an ACD ctx manager instead.
1090def func_graph_from_py_func(name,
1091                            python_func,
1092                            args,
1093                            kwargs,
1094                            signature=None,
1095                            func_graph=None,
1096                            autograph=False,
1097                            autograph_options=None,
1098                            add_control_dependencies=True,
1099                            arg_names=None,
1100                            op_return_value=None,
1101                            collections=None,
1102                            capture_by_value=None,
1103                            acd_record_initial_resource_uses=False):
1104  """Returns a `FuncGraph` generated from `python_func`.
1105
1106  Args:
1107    name: an identifier for the function.
1108    python_func: the Python function to trace.
1109    args: the positional args with which the Python function should be called;
1110      ignored if a signature is provided.
1111    kwargs: the keyword args with which the Python function should be called;
1112      ignored if a signature is provided.
1113    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
1114      and dtypes of the arguments. When a signature is provided, `args` and
1115      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
1116      to `signature`. If `None`, the shapes and dtypes are inferred from the
1117      inputs.
1118    func_graph: Optional. An instance of FuncGraph. If provided, we will use
1119      this graph else a new one is built and returned.
1120    autograph: whether to use autograph to compile `python_func`.
1121      See https://www.tensorflow.org/guide/autograph for more information.
1122    autograph_options: additional knobs to control when `autograph=True`.
1123      See https://www.tensorflow.org/guide/autograph for more information.
1124    add_control_dependencies: If True, automatically adds control dependencies
1125      to ensure program order matches execution order and stateful ops always
1126      execute.
1127    arg_names: Optional list of argument names, used to give input placeholders
1128      recognizable names.
1129    op_return_value: Optional. A Tensor. If set and `python_func` returns
1130      Operations, those return values will be replaced with this value. If not
1131      set, returning an Operation triggers an error.
1132    collections: a dictionary of collections this FuncGraph should start with.
1133      If not specified (None), the FuncGraph will read (but not write to) the
1134      outer graph's collections that are not allowlisted, and both read and
1135      write to the outer graph's collections that are allowlisted. The current
1136      allowlisted collections are the global variables, the local variables, and
1137      the trainable variables. Defaults to None.
1138    capture_by_value: An optional boolean. If True, the func graph will capture
1139      Variables by value instead of reference. By default inherit from outer
1140      graphs, and failing that will default to False.
1141    acd_record_initial_resource_uses: If `True` and `add_control_dependencies`
1142      is enabled, the results (those marked with
1143      AutomaticControlDependencies.mark_result) will be annotated with a private
1144      attribute, "_res_first_used_by", which points to the first nodes which
1145      used the any of the resources that the result op is using.
1146
1147  Returns:
1148    A FuncGraph.
1149
1150  Raises:
1151    TypeError: If any of `python_func`'s return values is neither `None`, a
1152      `Tensor` or a `tf.experimental.ExtensionType`.
1153  """
1154  if op_return_value is not None:
1155    assert isinstance(op_return_value, ops.Tensor), op_return_value
1156  if func_graph is None:
1157    func_graph = FuncGraph(
1158        name, collections=collections, capture_by_value=capture_by_value)
1159  assert isinstance(func_graph, FuncGraph)
1160  if add_control_dependencies:
1161    deps_control_manager = auto_control_deps.AutomaticControlDependencies(
1162        record_initial_resource_uses=acd_record_initial_resource_uses)
1163  else:
1164    deps_control_manager = ops.NullContextmanager()
1165
1166  with func_graph.as_default(), deps_control_manager as deps_ctx:
1167    current_scope = variable_scope.get_variable_scope()
1168    default_use_resource = current_scope.use_resource
1169    current_scope.set_use_resource(True)
1170
1171    if signature is not None:
1172      args = signature
1173      kwargs = {}
1174    func_args = _get_defun_inputs_from_args(args, arg_names)
1175    func_kwargs = _get_defun_inputs_from_kwargs(kwargs)
1176
1177    # Convert all Tensors into TensorSpecs before saving the structured inputs.
1178    # If storing pure concrete functions that are not called through polymorphic
1179    # functions, we don't have access to FunctionSpec, so we need to call the
1180    # TensorSpecs by their `arg_names` for later binding.
1181    func_graph.structured_input_signature = (convert_structure_to_signature(
1182        func_args, arg_names), convert_structure_to_signature(func_kwargs))
1183
1184    flat_func_args = nest.flatten(func_args, expand_composites=True)
1185    flat_func_kwargs = nest.flatten(func_kwargs, expand_composites=True)
1186    # Temporarily set inputs to allow graph building code to inspect
1187    # them. Reassigned below.
1188    func_graph.inputs = [
1189        arg for arg in flat_func_args + flat_func_kwargs
1190        if isinstance(arg, ops.Tensor)
1191    ]
1192
1193    # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
1194    # Variables to help check whether mutation happens in calling the function
1195    # Copy the recursive list, tuple and map structure, but not base objects
1196    func_args_before = nest.pack_sequence_as(
1197        func_args, flat_func_args, expand_composites=True)
1198    func_kwargs_before = nest.pack_sequence_as(
1199        func_kwargs, flat_func_kwargs, expand_composites=True)
1200
1201    def convert(x):
1202      """Converts a function output to a Tensor."""
1203      if x is None:
1204        return None
1205      if op_return_value is not None and isinstance(x, ops.Operation):
1206        # TODO(b/79881896): we currently can't capture external control deps, so
1207        # this won't work if x needs to be captured (i.e. if python_func returns
1208        # captured Operations).
1209        with ops.control_dependencies([x]):
1210          x = array_ops.identity(op_return_value)
1211      elif not isinstance(x, tensor_array_ops.TensorArray):
1212        try:
1213          x = ops.convert_to_tensor_or_composite(x)
1214        except (ValueError, TypeError):
1215          raise TypeError(
1216              "To be compatible with tf.function, Python functions "
1217              "must return zero or more Tensors or ExtensionTypes or None "
1218              f"values; in compilation of {str(python_func)}, found return "
1219              f"value of type {type(x).__name__}, which is not a Tensor or "
1220              "ExtensionType.")
1221      if add_control_dependencies:
1222        x = deps_ctx.mark_as_return(x)
1223      return x
1224
1225    try:
1226      if autograph:
1227        from tensorflow.python import autograph  # pylint: disable=g-import-not-at-top
1228        _, original_func = tf_decorator.unwrap(python_func)
1229
1230        def autograph_handler(*args, **kwargs):
1231          """Calls a converted version of original_func."""
1232          # TODO(mdan): Push this block higher in tf.function's call stack.
1233          try:
1234            return autograph.converted_call(
1235                original_func,
1236                args,
1237                kwargs,
1238                options=autograph.ConversionOptions(
1239                    recursive=True,
1240                    optional_features=autograph_options,
1241                    user_requested=True,
1242                ))
1243          except Exception as e:  # pylint:disable=broad-except
1244            if hasattr(e, "ag_error_metadata"):
1245              raise e.ag_error_metadata.to_exception(e)
1246            else:
1247              raise
1248
1249        # Wrapping around a decorator allows checks like tf_inspect.getargspec
1250        # to be accurate.
1251        converted_func = tf_decorator.make_decorator(original_func,
1252                                                     autograph_handler)
1253        python_func = tf_decorator.rewrap(python_func, original_func,
1254                                          converted_func)
1255
1256      else:
1257        _, original_func = tf_decorator.unwrap(python_func)
1258
1259      func_outputs = python_func(*func_args, **func_kwargs)
1260
1261      # invariant: `func_outputs` contains only Tensors, CompositeTensors,
1262      # TensorArrays and `None`s.
1263      func_outputs = nest.map_structure(
1264          convert, func_outputs, expand_composites=True)
1265
1266      check_func_mutation(func_args_before, func_kwargs_before, func_args,
1267                          func_kwargs, original_func)
1268    finally:
1269      current_scope.set_use_resource(default_use_resource)
1270
1271    # Variables in `func_args`, `func_kwargs` should be explicit inputs
1272    # to the function, not captured inputs.
1273    graph_variables = list(func_graph._watched_variables)  # pylint: disable=protected-access
1274    arg_variables = object_identity.ObjectIdentitySet()
1275    inputs = []
1276    for arg in (nest.flatten(func_args, expand_composites=True) +
1277                nest.flatten(func_kwargs, expand_composites=True)):
1278      if isinstance(arg, resource_variable_ops.BaseResourceVariable):
1279        # Even if an argument variable was not used in the function, we've
1280        # already manually captured the resource Tensor when creating argument
1281        # placeholders.
1282        resource_placeholder = func_graph.pop_capture(arg.handle)
1283        if resource_placeholder is None:
1284          continue
1285        arg_variables.add(arg)
1286        inputs.append(resource_placeholder)
1287      elif isinstance(arg, ops.Tensor):
1288        inputs.append(arg)
1289    variables = [v for v in graph_variables if v not in arg_variables]
1290    func_graph.inputs = (
1291        inputs + func_graph.internal_captures + nest.flatten(
1292            func_graph.deferred_internal_captures, expand_composites=True))
1293    func_graph.structured_outputs = func_outputs
1294    # Returning a closed-over tensor does not trigger convert_to_tensor.
1295    func_graph.outputs.extend(
1296        func_graph.capture(x)
1297        for x in flatten(func_graph.structured_outputs)
1298        if x is not None)
1299
1300    func_graph.variables = variables
1301
1302  if add_control_dependencies:
1303    func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run)
1304    func_graph.collective_manager_ids_used = (
1305        deps_control_manager.collective_manager_ids_used)
1306
1307  return func_graph
1308
1309
1310def maybe_captured(tensor):
1311  """If t is a captured value placeholder, returns the original captured value.
1312
1313  Args:
1314    tensor: Tensor.
1315
1316  Returns:
1317    A tensor, potentially from a different Graph/FuncGraph.
1318  """
1319  if (not isinstance(tensor, ops.EagerTensor) and
1320      tensor.op.graph.building_function and tensor.op.type == "Placeholder"):
1321    for input_t, placeholder_t in tensor.op.graph.captures:
1322      if tensor == placeholder_t:
1323        return maybe_captured(input_t)
1324  # pylint: enable=protected-access
1325  return tensor
1326
1327
1328def device_stack_has_callable(device_stack):
1329  """Checks whether a device stack contains a callable."""
1330  return any(
1331      callable(spec._device_name_or_function)  # pylint: disable=protected-access
1332      for spec in device_stack.peek_objs())
1333
1334
1335def has_mutation(n1, n2):
1336  """Returns true if n1 and n2 are different (using `is` to compare leaves)."""
1337  try:
1338    nest.assert_same_structure(n1, n2, expand_composites=True)
1339  except ValueError:
1340    return True
1341
1342  for arg1, arg2 in zip(
1343      nest.flatten(n1, expand_composites=True),
1344      nest.flatten(n2, expand_composites=True)):
1345    if arg1 is not arg2:
1346      return True
1347
1348  return False
1349
1350
1351def check_func_mutation(old_args, old_kwargs, new_args, new_kwargs, func):
1352  """Checks that the arguments to a function are not modified."""
1353  if not has_mutation((old_args, old_kwargs), (new_args, new_kwargs)):
1354    return
1355
1356  # Mutation detected; construct a useful error message.
1357  func_name = getattr(func, "__qualname__", getattr(func, "__name__", func))
1358  signature = tf_inspect.signature(func)
1359  try:
1360    old_bound = signature.bind(*old_args, **old_kwargs).arguments
1361    new_bound = signature.bind(*new_args, **new_kwargs).arguments
1362  except TypeError as e:
1363    # This occurs when the function is called with the (deprecated)
1364    # "flat signature".  See ConcreteFunction._call_with_flat_signature.  In
1365    # this case, we can't report which arguments were modified.
1366    raise ValueError(
1367        f"{func_name}{signature} should not modify its Python input "
1368        f"arguments. Check if it modifies any lists or dicts passed as "
1369        f"arguments. Modifying a copy is allowed.") from e
1370
1371  assert set(old_bound) == set(new_bound)
1372  modified_args = [
1373      arg_name for arg_name in new_bound
1374      if has_mutation(old_bound[arg_name], new_bound[arg_name])
1375  ]
1376  changes = ", ".join(modified_args)
1377  raise ValueError(f"{func_name}{signature} should not modify its Python "
1378                   f"input arguments. Modifying a copy is allowed. The "
1379                   f"following parameter(s) were modified: {changes}")
1380
1381
1382# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this.
1383def flatten(sequence):
1384  """Like nest.flatten w/ expand_composites, but returns flow for TensorArrays.
1385
1386  Args:
1387    sequence: A nested structure of Tensors, CompositeTensors, and TensorArrays.
1388
1389  Returns:
1390    A list of tensors.
1391  """
1392  flat_sequence = nest.flatten(sequence, expand_composites=True)
1393  return [
1394      item.flow if isinstance(item, tensor_array_ops.TensorArray) else item
1395      for item in flat_sequence
1396  ]
1397
1398
1399# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this.
1400def pack_sequence_as(structure, flat_sequence):
1401  """Like `nest.pack_sequence_as` but also builds TensorArrays from flows.
1402
1403  Args:
1404    structure: The structure to pack into. May contain Tensors,
1405      CompositeTensors, or TensorArrays.
1406    flat_sequence: An iterable containing tensors.
1407
1408  Returns:
1409    A nested structure.
1410
1411  Raises:
1412    AssertionError if `structure` and `flat_sequence` are not compatible.
1413  """
1414  flat_sequence = list(flat_sequence)
1415  flattened_structure = nest.flatten(structure, expand_composites=True)
1416  if len(flattened_structure) != len(flat_sequence):
1417    raise ValueError("Mismatch in element count")
1418  for i in range(len(flat_sequence)):
1419    if isinstance(flattened_structure[i], tensor_array_ops.TensorArray):
1420      flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow(
1421          old_ta=flattened_structure[i], flow=flat_sequence[i])
1422  return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True)
1423
1424
1425def _create_substitute_placeholder(value, name=None, dtype=None, shape=None):
1426  """Creates a placeholder for `value` and propagates shape info to it."""
1427  # Note: setting ops.control_dependencies(None) ensures we always put
1428  # capturing placeholders outside of any control flow context.
1429  if shape is None:
1430    shape = value.shape
1431  with ops.control_dependencies(None):
1432    placeholder = graph_placeholder(
1433        dtype=dtype or value.dtype, shape=shape, name=name)
1434  handle_data_util.copy_handle_data(value, placeholder)
1435  return placeholder
1436
1437
1438def _get_defun_inputs_from_args(args, names):
1439  """Maps Python function positional args to graph-construction inputs."""
1440  return _get_defun_inputs(args, names, structured_args=args)
1441
1442
1443def _get_defun_inputs_from_kwargs(kwargs):
1444  """Maps Python function keyword args to graph-construction inputs."""
1445  if kwargs:
1446    names, args = zip(*sorted(kwargs.items()))
1447  else:
1448    names = []
1449    args = []
1450  return _get_defun_inputs(args, names, structured_args=kwargs)
1451
1452
1453def _get_composite_tensor_spec(x):
1454  """Returns the TypeSpec for x if it's a composite tensor, or x otherwise."""
1455  return (x._type_spec  # pylint: disable=protected-access
1456          if isinstance(x, composite_tensor.CompositeTensor) else x)
1457
1458
1459def _get_defun_inputs(args, names, structured_args):
1460  """Maps python function args to graph-construction inputs.
1461
1462  Args:
1463    args: A list of user-specified arguments. If `structured_args` is a list,
1464      `args` is the same with `structured_args`. If `structured_args` is a dict,
1465      `args` is the values of the dict.
1466    names: A list of strings with user-specified argument names, same length as
1467      `args`. May be `None`, in which case a generic name is used.
1468    structured_args: The original argument list or dictionary.
1469
1470  Returns:
1471    Placeholders with the same structure as `structured_args`.
1472  """
1473  func_graph = ops.get_default_graph()
1474  function_inputs = []
1475  if names is None:
1476    names = [None] * len(args)
1477
1478  for arg_value, name in zip(args, names):
1479    # Replace any composite tensors with their TypeSpecs.  This is important
1480    # for ensuring that shape information that's not preserved by the TypeSpec
1481    # (such as the number of values in a SparseTensor) gets properly masked.
1482    arg_value = nest.map_structure(_get_composite_tensor_spec, arg_value)
1483    flat_args = nest.flatten(arg_value, expand_composites=True)
1484
1485    for arg in flat_args:
1486      if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)):
1487        arg_is_spec = isinstance(arg, tensor_spec.TensorSpec)
1488        if arg_is_spec and arg.name:
1489          requested_name = arg.name
1490        else:
1491          requested_name = name
1492        try:
1493          placeholder = graph_placeholder(
1494              arg.dtype, arg.shape, name=requested_name)
1495        except ValueError as e:
1496          # Sometimes parameter names are not valid op names, so fall back to
1497          # unnamed placeholders.
1498          logging.warning(e)
1499          placeholder = graph_placeholder(arg.dtype, arg.shape)
1500        if not arg_is_spec:
1501          handle_data_util.copy_handle_data(arg, placeholder)
1502        if name is not None:
1503          # Record the requested/user-specified name in case it's different than
1504          # the uniquified name, for validation when exporting signatures.
1505          placeholder.op._set_attr(  # pylint: disable=protected-access
1506              "_user_specified_name",
1507              attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name)))
1508        function_inputs.append(placeholder)
1509      elif isinstance(arg, (resource_variable_ops.BaseResourceVariable,
1510                            resource_variable_ops.VariableSpec)):
1511        if isinstance(arg, resource_variable_ops.VariableSpec):
1512          name = arg.name or name
1513          with func_graph.outer_graph.as_default():
1514            placeholder = graph_placeholder(
1515                dtypes.resource, arg.shape, name=name)
1516
1517            arg = resource_variable_ops.BaseResourceVariable(
1518                name=name,
1519                shape=arg.shape,
1520                dtype=arg.dtype,
1521                handle=placeholder,
1522                handle_name=name,
1523                trainable=arg.trainable)
1524        # Capture arg variables to create placeholders for them. These will be
1525        # removed as captures after the function is traced (since otherwise we'd
1526        # just add it back with a new placeholder when the variable was
1527        # referenced).
1528        placeholder = func_graph.capture(arg.handle, name=name)
1529        placeholder.op._set_attr(  # pylint: disable=protected-access
1530            "_user_specified_name",
1531            attr_value_pb2.AttrValue(s=compat.as_bytes(name)))
1532        function_inputs.append(arg)
1533      else:
1534        function_inputs.append(arg)
1535  return nest.pack_sequence_as(
1536      structured_args, function_inputs, expand_composites=True)
1537
1538
1539def dismantle_func_graph(func_graph):
1540  """Removes reference cycles in `func_graph` FuncGraph.
1541
1542  Helpful for making sure the garbage collector doesn't need to run when
1543  the FuncGraph goes out of scope, e.g. in tests using defun with
1544  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
1545
1546  Args:
1547    func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable after
1548      this function.
1549  """
1550  func_graph.clear_captures()
1551  ops.dismantle_graph(func_graph)
1552
1553
1554def override_func_graph_name_scope(func_graph, name_scope):
1555  func_graph._name_stack = name_scope  # pylint: disable=protected-access
1556