1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""FuncGraph and related functionality.""" 16 17import collections as py_collections 18import traceback 19from typing import Any, Hashable, Callable, Mapping 20import weakref 21 22import numpy as np 23 24from tensorflow.core.framework import attr_value_pb2 25from tensorflow.python.eager import context 26from tensorflow.python.eager import execute 27from tensorflow.python.eager import tape 28from tensorflow.python.eager.graph_only_ops import graph_placeholder 29from tensorflow.python.framework import auto_control_deps 30from tensorflow.python.framework import composite_tensor 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import errors 34from tensorflow.python.framework import indexed_slices 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import tensor_spec 37from tensorflow.python.framework import tensor_util 38from tensorflow.python.framework import type_spec 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import handle_data_util 41from tensorflow.python.ops import resource_variable_ops 42from tensorflow.python.ops import tensor_array_ops 43from tensorflow.python.ops import variable_scope 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.saved_model import save_context 46from tensorflow.python.types import core 47from tensorflow.python.util import compat 48from tensorflow.python.util import memory 49from tensorflow.python.util import nest 50from tensorflow.python.util import object_identity 51from tensorflow.python.util import tf_contextlib 52from tensorflow.python.util import tf_decorator 53from tensorflow.python.util import tf_inspect 54from tensorflow.python.util.tf_export import tf_export 55 56 57ALLOWLIST_COLLECTIONS = [ 58 ops.GraphKeys.GLOBAL_VARIABLES, 59 ops.GraphKeys.LOCAL_VARIABLES, 60 ops.GraphKeys.TRAINABLE_VARIABLES, 61 variable_scope._VARSTORE_KEY, # pylint: disable=protected-access 62 variable_scope._VARSCOPESTORE_KEY # pylint: disable=protected-access 63] 64 65_EAGER_CONST_THRESHOLD = 128 66 67 68class UnknownArgument(object): 69 """Signifies an argument which is not currently handled.""" 70 pass 71 72 73def convert_structure_to_signature(structure, arg_names=None): 74 """Convert a potentially nested structure to a signature. 75 76 Args: 77 structure: Structure to convert, where top level collection is a list or a 78 tuple. 79 arg_names: Optional list of arguments that has equal number of elements as 80 `structure` and is used for naming corresponding TensorSpecs. 81 82 Returns: 83 Identical structure that has TensorSpec objects instead of Tensors and 84 UnknownArgument instead of any unsupported types. 85 """ 86 87 def encode_arg(arg, path): 88 """A representation for this argument, for converting into signatures.""" 89 if isinstance(arg, ops.Tensor): 90 user_specified_name = None 91 try: 92 user_specified_name = compat.as_str( 93 arg.op.get_attr("_user_specified_name")) 94 except ValueError: 95 pass 96 97 if path and user_specified_name and user_specified_name != path[0]: 98 # The user has explicitly named the argument differently than the name 99 # of the function argument. 100 name = user_specified_name 101 else: 102 name = "/".join(str(p) for p in path) 103 return tensor_spec.TensorSpec(arg.shape, arg.dtype, name) 104 if isinstance(arg, composite_tensor.CompositeTensor): 105 # TODO(b/133606651) Do we need to inject arg_name? 106 return arg._type_spec # pylint: disable=protected-access 107 if isinstance(arg, resource_variable_ops.BaseResourceVariable): 108 return resource_variable_ops.VariableSpec.from_value(arg) 109 if isinstance(arg, ( 110 int, 111 float, 112 bool, 113 str, 114 type(None), 115 dtypes.DType, 116 tensor_spec.TensorSpec, 117 type_spec.TypeSpec, 118 )): 119 return arg 120 return UnknownArgument() 121 122 # We are using the flattened paths to name the TensorSpecs. We need an 123 # explicit name for them downstream. 124 flattened = nest.flatten_with_tuple_paths(structure) 125 if arg_names: 126 if len(arg_names) != len(structure): 127 raise ValueError( 128 "Passed in arg_names don't match actual signature (%s)." % arg_names) 129 # Replace all top-level names with their actual arg_names. If a path before 130 # was "(2,'a',1)", it will become "(arg_names[2],'a',1)". 131 flattened = [ 132 ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened 133 ] 134 135 mapped = [encode_arg(arg, path) for path, arg in flattened] 136 return nest.pack_sequence_as(structure, mapped) 137 138 139class CapturesContainer(object): 140 """A container class to store captures with a dict.""" 141 142 def __init__(self): 143 # A dict that maps capture identifier -> function 144 self._captures = py_collections.OrderedDict() 145 146 def add_capture(self, identifier: Hashable, 147 func: Callable[[], Any]): 148 self._captures[identifier] = func 149 150 def update(self, container: "CapturesContainer"): 151 # Add captures to self from other Container if not exist 152 assert isinstance(container, CapturesContainer) 153 for key, func in container.captures.items(): 154 if key not in self._captures: 155 self._captures[key] = func 156 157 def get_snapshot(self) -> Mapping[Hashable, Any]: 158 snapshot = {} 159 for key, func in self.captures.items(): 160 snapshot[key] = func() 161 return snapshot 162 163 @property 164 def captures(self) -> Mapping[Hashable, Any]: 165 return self._captures 166 167 def __len__(self): 168 return len(self._captures) 169 170 171@tf_export("__internal__.FuncGraph", v1=[]) 172class FuncGraph(ops.Graph): 173 """Graph representing a function body. 174 175 Attributes: 176 name: The name of the function. 177 inputs: Placeholder tensors representing the inputs to this function. The 178 tensors are in this FuncGraph. This represents "regular" inputs as well as 179 captured inputs (i.e. the values of self.captures), with the regular 180 inputs coming first. 181 outputs: Tensors that will be returned by this function. The tensors are in 182 this FuncGraph. 183 control_outputs: Operations that must be executed before the function 184 represented by this graph can be said to have been executed. 185 structured_input_signature: A tuple of (args, kwargs), which are both 186 possibly-nested python objects that were received by this function. Note 187 that these structures might contain Python `None`s. 188 structured_outputs: A possibly-nested python object which will be returned 189 by this function. The Tensors in this structure are the same as those of 190 self.outputs. Note that this structure might contain Python `None`s. 191 variables: Variables that should be watched during function execution. 192 outer_graph: The graph this function is defined in. May be another FuncGraph 193 or the global default Graph. 194 captures: Maps external tensor -> internal tensor (i.e. input placeholder). 195 The entries are in the order they were captured. 196 control_captures: Set of external ops on which this graph has a control 197 dependency. 198 seed: The graph-level random seed. 199 capture_by_value: If True, the func graph will capture Variables by value 200 instead of reference. 201 """ 202 203 def __init__(self, 204 name, 205 collections=None, 206 capture_by_value=None, 207 structured_input_signature=None, 208 structured_outputs=None): 209 """Construct a new FuncGraph. 210 211 The graph will inherit its graph key, collections, seed, and distribution 212 strategy stack from the current context or graph. 213 214 Args: 215 name: the name of the function. 216 collections: a dictionary of collections this FuncGraph should start with. 217 If not specified (None), the FuncGraph will read (but not write to) the 218 outer graph's collections that are not allowlisted, and both read and 219 write to the outer graph's collections that are allowlisted. The current 220 allowlisted collections are the global variables, the local variables, 221 and the trainable variables. Defaults to None. 222 capture_by_value: An optional boolean. If True, the func graph will 223 capture Variables by value instead of reference. By default inherit from 224 outer graphs, and failing that will default to False. 225 structured_input_signature: Optional. The structured input signature to 226 use for initializing the FuncGraph. See the docstring for FuncGraph for 227 more information. 228 structured_outputs: Optional. The structured outputs to use for 229 initializing the FuncGraph. See the docstring for FuncGraph for more 230 information. 231 """ 232 super(FuncGraph, self).__init__() 233 self.name = name 234 self.inputs = [] 235 self.outputs = [] 236 self.control_outputs = [] 237 self.control_captures = object_identity.ObjectIdentitySet() 238 self.structured_input_signature = structured_input_signature 239 self.structured_outputs = structured_outputs 240 self._weak_variables = [] 241 self._watched_variables = object_identity.ObjectIdentityWeakSet() 242 self.is_control_flow_graph = False 243 244 outer_graph = ops.get_default_graph() 245 self._weak_outer_graph = weakref.ref(outer_graph) 246 while outer_graph.building_function: 247 outer_graph = outer_graph.outer_graph 248 # If self._weak_outer_graph is deleted, we revert to the outermost Graph 249 # active when the FuncGraph was traced. This will not be a FuncGraph. 250 self._fallback_outer_graph = outer_graph 251 self._captures = py_collections.OrderedDict() 252 # Maps capture identifier -> lambda function that returns capture values 253 # Used to get runtime value to determine if retracing is needed. 254 self._capture_func_lib = CapturesContainer() 255 # Maps capture identifier -> a container with the same structure as 256 # the original side input, except tensors are replaced with placeholders. 257 # Used to fetch existing placeholders and prevent repeated creatation. 258 self._capture_placeholder_lib = py_collections.OrderedDict() 259 # If not None, records the names of output args of this function. Used to 260 # preserve the output names in the signature of a serialized+deserialized 261 # function. Private at the moment mostly because it's often out of date. 262 self._output_names = None 263 # Maps arbitrary key -> (closure, nest of placeholders), where at function 264 # call time the value of closure() will be used to feed the nest of 265 # placeholders. 266 self._deferred_captures = py_collections.OrderedDict() 267 # Inherit capture-by-value from outer graph. 268 if capture_by_value is not None: 269 self.capture_by_value = capture_by_value 270 elif self.outer_graph is not None and isinstance(self.outer_graph, 271 FuncGraph): 272 self.capture_by_value = self.outer_graph.capture_by_value 273 else: 274 self.capture_by_value = False 275 276 self._building_function = True 277 # Map from resource tensor name to last op (in program order) which uses 278 # this tensor. Used to enforce that execution order matches program order 279 # for resource tensors. 280 self._last_op_using_resource_tensor = {} 281 282 graph = self.outer_graph 283 284 if context.executing_eagerly(): 285 self.seed = context.global_seed() 286 # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of 287 # any None op_seed for random_op in the function, in which case we end up 288 # using function seed, which could be unintended behavior for the op. 289 self._seed_used = False 290 else: 291 self.seed = graph.seed 292 self._seed_used = False 293 # TODO(allenl): Figure out if we can remove colocation stack 294 # specialization (currently used in cond_v2), here and in the cache key. 295 self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access 296 297 if collections is None: 298 for collection_name in graph.get_all_collection_keys(): 299 if collection_name not in ALLOWLIST_COLLECTIONS: 300 self._collections[collection_name] = graph.get_collection( 301 collection_name) 302 for collection_name in ALLOWLIST_COLLECTIONS: 303 self._collections[collection_name] = graph.get_collection_ref( 304 collection_name) 305 else: 306 self._collections = collections 307 308 # Keep track of whether this FuncGraph is exportable to SavedModel. Use 309 # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any 310 # dependent functions as unsaveable. 311 self._saveable = True 312 self._saving_errors = set() 313 314 # Keep track of callbacks to run when this graph exits default scope 315 self._scope_exit_callbacks = None 316 317 def __str__(self): 318 return "FuncGraph(name=%s, id=%s)" % (self.name, id(self)) 319 320 def watch_variable(self, v): 321 """Marks the variable v as accessed while building this graph.""" 322 while self is not None and isinstance(self, FuncGraph): 323 self._watched_variables.add(v) 324 self = self.outer_graph 325 326 def capture_call_time_value(self, 327 closure, 328 spec, 329 key=None, 330 default_value=None, 331 placeholder=None): 332 """Returns a placeholder which at call time has the value closure(). 333 334 The `tf.function` supports the notion of captures, that is, it allows Python 335 functions to have closure variables, which bind over some value outside the 336 function. However, this name binding is "early binding" performed before the 337 program is run, i.e., 338 ``` 339 @tf.function 340 def f(): 341 return x 342 343 x = tf.constant(1) 344 f() # returns 1 345 346 x = tf.constant(2) 347 f() # still returns 1! 348 ``` 349 while in Python, name binding is performed as the program is running. 350 ``` 351 def f(): 352 return x 353 354 x = 1 355 f() # returns 1 356 357 x = 2 358 f() # returns 2 359 ``` 360 `capture_call_time_value` allows tf.function to mimic late binding as a 361 Python function does, by passing in a `closure` callable argument to be 362 executed when the tf.function is invoked eagerly. E.g. 363 ``` 364 @tf.function 365 def f(): 366 return ops.get_default_graph.capture_call_time_value(lambda: x) 367 368 x = tf.constant(1) 369 f() # returns 1 370 371 x = tf.constant(2) 372 f() # returns 2 373 ``` 374 Note that a `capture_call_time_value` function itself does not work well in 375 the saving process (since the tf.function in which it's called is not 376 invoked eagerly) unless passed a `default_value` argument. At saving time, 377 the `default_value` argument is returned instead. 378 379 Args: 380 closure: function which takes no arguments, to be evaluated at function 381 call time, returning a nest of tensors compatible with `spec`. 382 spec: nest of TypeSpec for the value to capture. 383 key: optional. If not None, multiple calls to lazy_capture with the same 384 key in the same graph will return the same placeholder, and the first 385 closure will be used at function call time. 386 default_value: optional value to return in environments that cannot safely 387 evaluate closure. 388 placeholder: optional. If not None, the graph will take the passed-in 389 `placeholder` as the internal capture instead of creating a new one. 390 This is useful when loading from a SavedModel. 391 392 Returns: 393 Nest of placeholders which, at function call time, will be fed with the 394 result of calling closure(). 395 396 Raises: 397 ValueError: at function call time, if the return value of closure() is 398 not compatible with `spec`. 399 """ 400 if key is None: 401 key = object() 402 if key not in self._deferred_captures: 403 404 if placeholder is None: 405 406 def convert_to_placeholder(s): 407 if not isinstance(s, tensor_spec.DenseSpec): 408 raise TypeError( 409 "Expected a nest of `TypeSpec` objects, found %s of type %s." % 410 (s, type(s))) 411 return array_ops.placeholder(dtype=s.dtype, shape=s.shape) 412 413 placeholder = nest.map_structure( 414 convert_to_placeholder, spec, expand_composites=True) 415 416 def wrapped_closure(): 417 418 # One major case requiring returning a `default_value` is when passing a 419 # concrete function to `save`, i.e. 420 # serving_fn = serve_fn.get_concrete_function(...) 421 # model.save(save_dir, signatures={"serving_default": serving_fn}) 422 # `serving_fn` has deferred captures added through 423 # `capture_call_time_value`. It can't be saved correctly since 424 # `wrapped_closure` will end up executing under a default Graph instead 425 # of FuncGraph. The user of `capture_call_time_value` also cannot 426 # conditionally avoid this call since presence of `save_context` when 427 # executing `wrapped_closure` is not known at tracing time of 428 # `serving_fn`. 429 if save_context.in_save_context() and default_value is not None: 430 return default_value 431 # TODO(wxinyi): raise an error if in save context but no default value. 432 433 if not context.executing_eagerly(): 434 graph = ops.get_default_graph() 435 436 # In the case of control flow, we need to capture the 437 # external_captures (deferred or not) of the body_graph (i.e. 438 # `WhileBodyFuncGraph) in `cond_graph` (i.e. WhileCondFuncGraph) and 439 # create the corresponding placeholders in `cond_graph` so that it 440 # expects to receive these as arguments. However, doing so requires 441 # having evaluated the call_time_value already (and maybe repeatedly), 442 # so we skip adding deferred_captures to the control flow graph but 443 # add it to its outer graph. 444 while graph.is_control_flow_graph: 445 graph = graph.outer_graph 446 447 with graph.as_default(): 448 ret_nest = graph.capture_call_time_value( 449 closure, spec, key=key, default_value=default_value) 450 else: 451 ret_nest = closure() 452 453 nest.assert_same_structure(spec, ret_nest, expand_composites=True) 454 # This uses the tensor dtype defined in `spec` when converting values 455 # in `ret_nest` to tensors. 456 # pylint: disable=protected-access 457 y = nest.map_structure( 458 lambda s, r: s._to_components(r), 459 spec, 460 ret_nest, 461 expand_composites=False) 462 # pylint: enable=protected-access 463 return nest.flatten(y, expand_composites=True) 464 465 wrapped_closure.output_spec = spec 466 self._deferred_captures[key] = (wrapped_closure, placeholder) 467 return self._deferred_captures[key][1] 468 469 def control_dependencies(self, control_inputs): 470 """Handles control dependencies. 471 472 FuncGraph wraps Graph's control_dependencies logic by first filtering out 473 any external tensors / operations and storing them in the graph's 474 control_captures member. Any consumers of this function graph must then 475 decide how to handle the control captures. 476 477 Args: 478 control_inputs: A list of `Operation` or `Tensor` objects which must be 479 executed or computed before running the operations defined in the 480 context. Can also be `None` to clear the control dependencies. 481 482 Returns: 483 A context manager that specifies control dependencies for all 484 operations constructed within the context. 485 486 Raises: 487 TypeError: If `control_inputs` is not a list of `Operation` or 488 `Tensor` objects. 489 """ 490 if control_inputs is None: 491 return super(FuncGraph, self).control_dependencies(control_inputs) 492 493 filtered_control_inputs = [] 494 for c in control_inputs: 495 # Check for _UnreadVariable 496 if (isinstance(c, indexed_slices.IndexedSlices) or 497 (hasattr(c, "_handle") and hasattr(c, "op"))): 498 c = c.op 499 graph_element = ops._as_graph_element(c) # pylint: disable=protected-access 500 if graph_element is None: 501 graph_element = c 502 if graph_element is not None and getattr(graph_element, "graph", 503 None) is not self: 504 self.control_captures.add(graph_element) 505 else: 506 filtered_control_inputs.append(graph_element) 507 return super(FuncGraph, self).control_dependencies(filtered_control_inputs) 508 509 def as_default(self): 510 outer_cm = super(FuncGraph, self).as_default() 511 512 @tf_contextlib.contextmanager 513 def inner_cm(): 514 """Context manager for copying distribute.Strategy scope information.""" 515 # pylint: disable=protected-access 516 # TODO(b/112906995, nareshmodi): distribution strategy depends on 517 # inheriting this stack from the default graph even in eager mode. Maybe 518 # it should be part of the eager context? This would also allow us to 519 # remove a get_default_graph() call from the function cache lookup. 520 graph = ops.get_default_graph() 521 old_strategy_stack = self._distribution_strategy_stack 522 self._distribution_strategy_stack = list( 523 graph._distribution_strategy_stack) 524 525 # We ignore device placements from any outer scopes while tracing the 526 # function when possible, to avoid hard-coding them in the function 527 # graph. "Default" placements come from the PartitionedCallOp's placement, 528 # so that the same trace of the Python function may be placed on several 529 # different devices and saved functions may be placed on new devices when 530 # restored. 531 # However, we need to preserve the outer device stack in the following 532 # cases in non eager context: 533 # 1. device stack is callable 534 # 2. When using distribution strategy with legacy graph mode. 535 old_device_stack = self._device_function_stack 536 if (not context.executing_eagerly() and 537 (device_stack_has_callable(graph._device_function_stack) or 538 (self._distribution_strategy_stack and 539 not ops.executing_eagerly_outside_functions()))): 540 # Hard-code devices from device functions in the function body 541 self._device_function_stack = graph._device_function_stack.copy() 542 543 old_creator_stack = self._variable_creator_stack 544 self._variable_creator_stack = graph._variable_creator_stack 545 # Inherit the graph key, since this is used for matching variables in 546 # optimizers. 547 old_graph_key = self._graph_key 548 self._graph_key = graph._graph_key 549 # pylint: enable=protected-access 550 551 old_scope_exit_callbacks = self._scope_exit_callbacks 552 self._scope_exit_callbacks = [] 553 554 with outer_cm as g: 555 try: 556 yield g 557 finally: 558 try: 559 for fn in self._scope_exit_callbacks: 560 fn() 561 finally: 562 self._scope_exit_callbacks = old_scope_exit_callbacks 563 self._distribution_strategy_stack = old_strategy_stack 564 self._device_function_stack = old_device_stack 565 self._variable_creator_stack = old_creator_stack 566 self._graph_key = old_graph_key 567 568 return inner_cm() 569 570 @property 571 def outer_graph(self): 572 """The Graph this FuncGraph is nested in. 573 574 Functions may capture Tensors from graphs they are nested in (transitive). 575 576 Returns: 577 A Graph object. Initially set to the current default graph when the 578 FuncGraph was created. If the previous `outer_graph` was deleted because 579 the function that owns it was deleted, `outer_graph` is reset to the 580 outermost default graph active when the FuncGraph was created. This 581 FuncGraph won't have captured anything from the new `outer_graph` (and 582 likely not from the previous setting, since that would have created a 583 strong reference), but it is returned so that FuncGraphs always have a 584 parent. 585 """ 586 current = self._weak_outer_graph() 587 if current is None: 588 return self._fallback_outer_graph 589 return current 590 591 @outer_graph.setter 592 def outer_graph(self, new_outer_graph): 593 """Sets `outer_graph` to `new_outer_graph`.""" 594 self._weak_outer_graph = weakref.ref(new_outer_graph) 595 596 @property 597 def output_types(self): 598 return [t.dtype for t in self.outputs] 599 600 @property 601 def output_shapes(self): 602 return [t.shape for t in self.outputs] 603 604 @property 605 def trainable_variables(self): 606 """A sequence of trainable variables accessed by this FuncGraph. 607 608 Note that functions keep only weak references to variables. Calling the 609 function after a variable it accesses has been deleted is an error. 610 611 Returns: 612 Sequence of trainable variables for this func graph. 613 """ 614 return tuple(v for v in self.variables if v.trainable) 615 616 @property 617 def variables(self): 618 """A sequence of variables accessed by this FuncGraph. 619 620 Note that functions keep only weak references to variables. Calling the 621 function after a variable it accesses has been deleted is an error. 622 623 Returns: 624 Sequence of variables for this func graph. 625 """ 626 627 def deref(weak_v): 628 v = weak_v() 629 if v is None: 630 raise AssertionError( 631 "Called a function referencing variables which have been deleted. " 632 "This likely means that function-local variables were created and " 633 "not referenced elsewhere in the program. This is generally a " 634 "mistake; consider storing variables in an object attribute on " 635 "first call.") 636 return v 637 638 return tuple(deref(v) for v in self._weak_variables) 639 640 @variables.setter 641 def variables(self, var_list): 642 self._weak_variables = [weakref.ref(v) for v in var_list] 643 644 def _capture_by_value( 645 self, 646 op_type, 647 inputs, 648 dtypes, # pylint: disable=redefined-outer-name 649 input_types=None, 650 name=None, 651 attrs=None, 652 op_def=None, 653 compute_device=True): 654 # When capturing by value, do the read outside 655 reverse_captures = dict((id(v), k) for k, v in self.captures) 656 uncaptured_inputs = [reverse_captures.get(id(t), t) for t in inputs] 657 with ops.init_scope(): 658 if context.executing_eagerly(): 659 attr_list = ("dtype", int(attrs["dtype"].type)) 660 value, = execute.execute( 661 compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list, 662 context.context()) 663 else: 664 op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access 665 op_type, uncaptured_inputs, dtypes, input_types, name, attrs, 666 op_def, compute_device) 667 value = op.outputs[0] 668 captured_value = self.capture(value) 669 return captured_value.op 670 671 def _create_op_internal( 672 self, 673 op_type, 674 inputs, 675 dtypes=None, # pylint: disable=redefined-outer-name 676 input_types=None, 677 name=None, 678 attrs=None, 679 op_def=None, 680 compute_device=True): 681 """Like Graph.create_op, except handles external input tensors. 682 683 This overload adds functionality to create_op to "capture" any external 684 input tensors, i.e. tensors from the eager context or outer function graphs 685 if this is a nested function. See `capture` for more information. 686 687 Args: 688 op_type: The `Operation` type to create. This corresponds to the 689 `OpDef.name` field for the proto that defines the operation. 690 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 691 dtypes: (Optional) A list of `DType` objects that will be the types of the 692 tensors that the operation produces. 693 input_types: (Optional.) A list of `DType`s that will be the types of the 694 tensors that the operation consumes. By default, uses the base `DType` 695 of each input in `inputs`. Operations that expect reference-typed inputs 696 must specify `input_types` explicitly. 697 name: (Optional.) A string name for the operation. If not specified, a 698 name is generated based on `op_type`. 699 attrs: (Optional.) A dictionary where the key is the attribute name (a 700 string) and the value is the respective `attr` attribute of the 701 `NodeDef` proto that will represent the operation (an `AttrValue` 702 proto). 703 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 704 the operation will have. 705 compute_device: (Optional.) If True, device functions will be executed to 706 compute the device property of the Operation. 707 708 Returns: 709 An `Operation` object. 710 """ 711 if self.capture_by_value and op_type in [ 712 "ReadVariableOp", "ResourceGather" 713 ]: 714 return self._capture_by_value(op_type, inputs, dtypes, input_types, name, 715 attrs, op_def, compute_device) 716 717 # This capturing logic interacts poorly with control flow contexts which 718 # want to replace inputs of ops far too late in the process. This can lead 719 # the context to get confused and try to create an Enter for an Enter. We 720 # can detect this here and skip the additional Enter which can confuse loop 721 # validation logic. 722 if op_type == "Enter" and inputs[0].op.type == "Enter": 723 if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: 724 return inputs[0].op 725 # Calling AddValue on the control flow contexts to force creation of the 726 # backward accumulators in the original graph before we create placeholders 727 # to capture the inputs. 728 ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access 729 # Use a different list to avoid modifying the original inputs list. 730 captured_inputs = [] 731 for inp in inputs: 732 # TPU Estimator defines a control flow context with no AddValue method. 733 if ctxt is not None and hasattr(ctxt, "AddValue"): 734 inp = ctxt.AddValue(inp) 735 inp = self.capture(inp) 736 captured_inputs.append(inp) 737 return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access 738 op_type, captured_inputs, dtypes, input_types, name, attrs, op_def, 739 compute_device) 740 741 def capture(self, tensor, name=None, shape=None): 742 """Captures `tensor` if it's external to this graph. 743 744 If `tensor` is from a different graph, returns a placeholder for it. 745 `tensor` and the placeholder will appear in self.captures, and the 746 placeholder will appear in self.inputs. Multiple calls to this method with 747 the same `tensor` argument will return the same placeholder. If `tensor` is 748 from this graph, returns `tensor`. 749 750 Args: 751 tensor: Tensor. May be from this FuncGraph or a different graph. 752 name: Optional name if a placeholder is created. 753 shape: Optional shape if a placeholder is created. 754 755 Returns: 756 Tensor from this FuncGraph. 757 758 Raises: 759 InaccessibleTensorError: if any tensors are accessed in a manner that 760 bypasses the mechanisms required for the data dependencies to be correctly 761 wired. 762 """ 763 if isinstance(tensor, ops.EagerTensor): 764 if name is None: 765 name = str(ops.uid()) 766 767 # Small EagerTensors are captured with Const ops 768 if (tensor.dtype in dtypes.TF_VALUE_DTYPES and 769 np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD): 770 return self.capture_eager_tensor(tensor, name) 771 772 # Large EagerTensors and resources are captured with Placeholder ops 773 return self._capture_helper(tensor, name, shape) 774 if tensor.graph is not self: 775 if name is None: 776 name = tensor.op.name 777 inner_graph = tensor.graph 778 while inner_graph is not None and isinstance(inner_graph, FuncGraph): 779 if inner_graph is self: 780 try: 781 tb = tensor.op.traceback 782 except AttributeError: 783 tensor_traceback = "<unknown>" 784 else: 785 tensor_traceback_list = [] 786 for frame in traceback.format_list(tb.get_user_frames()): 787 tensor_traceback_list.extend( 788 [f" {line}" for line in frame.split("\n") if line.strip()]) 789 tensor_traceback = "\n".join(tensor_traceback_list) 790 # Keep in sync with tfe_wrapper.cc. 791 # TODO(b/200991648): Unify those two paths. 792 raise errors.InaccessibleTensorError( 793 f"{tensor!r} is out of scope and cannot be used here. Use return " 794 "values, explicit Python locals or TensorFlow collections to " 795 "access it.\n" 796 "Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values " 797 "for more information.\n\n" 798 f"{tensor!r} was defined here:\n{tensor_traceback}\n\n" 799 f"The tensor {tensor!r} cannot be accessed from {self}, because " 800 f"it was defined in {tensor.graph}, which is out of scope.") 801 inner_graph = inner_graph.outer_graph 802 return self._capture_helper(tensor, name) 803 return tensor 804 805 def _capture_helper(self, tensor, name, shape=None): 806 capture = self._captures.get(id(tensor)) 807 if capture is None: 808 placeholder = _create_substitute_placeholder( 809 tensor, name=name, dtype=tensor.dtype, shape=shape) 810 # Record the composite device as an attribute to the placeholder. 811 # This attribute would be propogated into the arg_attr of the FunctionDef. 812 # Currently, a packed eager tensor is always placed on a CompositeDevice. 813 if isinstance(tensor, ops.EagerTensor) and tensor.is_packed: 814 placeholder.op._set_attr( # pylint: disable=protected-access 815 "_composite_device", 816 attr_value_pb2.AttrValue(s=compat.as_bytes(tensor.device))) 817 self.add_capture(tensor, placeholder) 818 else: 819 placeholder = capture[1] 820 tape.record_operation( 821 "captured_value", [placeholder], [tensor], 822 backward_function=lambda x: [x], 823 forward_function=lambda x: [x]) 824 return placeholder 825 826 def _experimental_capture_side_input_by_ref(self, identifier: Hashable, 827 func: Callable[[], Any]) ->...: 828 """Implement capturing side input by reference for tf.function. 829 830 Args: 831 identifier: A hashable object as the key for the capture. 832 func: A Python function that takes no arguments and returns the value of 833 side input. The function is evaluated at function call time. 834 835 Returns: 836 A nested structure with the same structure as the side input. Tensors 837 are replaced with placehoders, and non-tensors remain the same. 838 839 """ 840 # Support manual capture for inner nested tf.function is not possible at the 841 # moment. Inner here means any tf.function wrapped by another tf.function. 842 # Usage inside the outer most tf.function only is fine. 843 # The infeasibility is due to it's impossible to determine the 844 # definition scope of the captured side input. This info is needed when 845 # propagating inner tf.function captures to outer tf.function. 846 if isinstance(self.outer_graph, FuncGraph): 847 raise NotImplementedError( 848 ("Manual side input usage for inner nested tf.function is not " 849 f"supported. Got side input: {identifier}.")) 850 851 # Prevent repeated captures 852 if identifier in self._capture_placeholder_lib: 853 return self._capture_placeholder_lib[identifier] 854 855 nested_placeholder = self._maybe_create_capture_placeholder(func) 856 self._capture_func_lib.add_capture(identifier, func) 857 self._capture_placeholder_lib[identifier] = nested_placeholder 858 return nested_placeholder 859 860 def _maybe_create_capture_placeholder(self, func: Callable[[], Any]) -> ...: 861 """Create placeholder if the input is tensor.""" 862 values_nest = func() 863 864 if context.executing_eagerly(): 865 return values_nest 866 867 values_flat = nest.flatten(values_nest) 868 # Return values in flat format. It consists of placeholders and non-tensor 869 # values. 870 return_flat = [] 871 tensor_spec_flat = [] 872 # Create return_flat and replace tensors with None. Later, each None is 873 # replaced again by corresponding placeholders 874 for value in values_flat: 875 if isinstance(value, core.Tensor): 876 return_flat.append(None) 877 tensor_spec_flat.append(type_spec.type_spec_from_value(value)) 878 elif isinstance(value, set) or isinstance(value, frozenset): 879 raise NotImplementedError( 880 (f"Side input returned by '{tf_inspect.getsource(func).strip()}' " 881 f"has element of {type(value)} type, which is currently not " 882 "supported by tf.function.")) 883 else: 884 return_flat.append(value) 885 if tensor_spec_flat: 886 887 def tensor_func(): 888 values = nest.flatten(func()) 889 return [value for value in values if isinstance(value, core.Tensor)] 890 891 placeholder_flat = self.capture_call_time_value( 892 tensor_func, tensor_spec_flat) 893 # replace None that represents tensors with placehoders 894 flat_ptr = 0 895 for idx, item in enumerate(return_flat): 896 if item is None: 897 return_flat[idx] = placeholder_flat[flat_ptr] 898 flat_ptr += 1 899 return_nest = nest.pack_sequence_as(values_nest, return_flat) 900 return return_nest 901 902 @property 903 def captures(self): 904 """Order list of tuples containing external and internal captures.""" 905 return self._captures.values() 906 907 def add_capture(self, tensor, placeholder): 908 """Capture a specific tensor and utilize the provided placeholder. 909 910 Args: 911 tensor: Tensor to captures. 912 placeholder: Provided placeholder for the tensor. 913 """ 914 self._captures[id(tensor)] = (tensor, placeholder) 915 self.inputs.append(placeholder) 916 917 def replace_capture(self, tensor, placeholder): 918 """Replace already existing capture.""" 919 self._captures[id(tensor)] = (tensor, placeholder) 920 921 def replace_capture_with_deferred_capture(self, 922 tensor, 923 closure, 924 spec, 925 placeholder, 926 default_value=None): 927 """Replaces existing capture `tensor` with a deferred capture `closure`. 928 929 Caution: It is the caller's responsibility to make sure that, after calling 930 this function, the TypeSpec of the `inputs` (i.e. internal placeholders) and 931 the `_captured_inputs` (i.e. external captures) of a concrete function that 932 wraps this function graph are still compatible. Thus user should pairing 933 usage of this function with `ConcreteFunction.set_external_captures` to make 934 sure the order still matches. For example, 935 ``` 936 # concrete_fn._captured_inputs == [tensor1, tensor2, tensor3] 937 # concrete_fn.inputs == [placeholder1, placeholder2, placeholder3] 938 # replace external capture `tensor2` with a deferred_capture, i.e., a 939 # closure, `closure2` 940 concrete_fn.graph.replace_capture_with_deferred_capture(tensor2, 941 closure2, 942 placeholder2, 943 some_spec, 944 some_default) 945 concrete_fn.set_external_captures([tensor1, closure2, tensor3]) 946 ``` 947 948 Args: 949 tensor: Tensor already captured. 950 closure: function which takes no arguments, to be evaluated at function 951 call time, returning a nest of tensors compatible with `spec`. 952 spec: nest of TypeSpec for the value to capture. 953 placeholder: the internal placeholder corresponding to the captured 954 `tensor`. 955 default_value: optional value to use in environments that cannot safely 956 evaluate closure. 957 """ 958 if id(tensor) in self._captures: 959 self.pop_capture(tensor) 960 self.capture_call_time_value( 961 closure, 962 spec, 963 key=id(tensor), 964 default_value=default_value, 965 placeholder=placeholder) 966 967 def reset_captures(self, capture_list): 968 """Set the captures with the provided list of captures & placeholder.""" 969 self._captures = py_collections.OrderedDict() 970 for tensor, placeholder in capture_list: 971 self._captures[id(tensor)] = (tensor, placeholder) 972 973 def pop_capture(self, tensor): 974 """Remove the capture and return the generated placeholder.""" 975 capture = self._captures.pop(id(tensor), None) 976 if capture is None: 977 return None 978 979 return capture[1] 980 981 def clear_captures(self): 982 # TODO(b/115366440): Delete this method when a custom OrderedDict is added. 983 # Clearing captures using clear() leaves some cycles around. 984 while self._captures: 985 self._captures.popitem() 986 memory.dismantle_ordered_dict(self._captures) 987 while self._deferred_captures: 988 self._deferred_captures.popitem() 989 memory.dismantle_ordered_dict(self._deferred_captures) 990 991 def capture_distributed_variable(self, variable, placeholder): 992 """Add given distributed variable to captures with given placeholder.""" 993 self._captures[id(variable)] = (variable, placeholder) 994 tape.record_operation( 995 "captured_value", [placeholder], [variable], 996 backward_function=lambda x: [x], 997 forward_function=lambda x: [x]) 998 999 def capture_eager_tensor(self, tensor, name): 1000 capture = self._captures.get(id(tensor)) 1001 if capture is None: 1002 with ops.control_dependencies(None): 1003 constant_value = tensor_util.constant_value(tensor) 1004 if constant_value is None: 1005 # Some eager tensors, e.g. parallel tensors, are not convertible to a 1006 # single constant. We'll use a placeholder for this case. 1007 return self._capture_helper(tensor, name) 1008 graph_const = constant_op.constant( 1009 constant_value, dtype=tensor.dtype, shape=tensor.shape, name=name) 1010 self.add_capture(tensor, graph_const) 1011 else: 1012 graph_const = capture[1] 1013 tape.record_operation( 1014 "captured_value", [graph_const], [tensor], 1015 backward_function=lambda x: [x], 1016 forward_function=lambda x: [x]) 1017 return graph_const 1018 1019 def captured(self, tensor): 1020 """Check if the specified tensor has been captured.""" 1021 return id(tensor) in self._captures 1022 1023 @property 1024 def external_captures(self): 1025 """External tensors captured by this function.""" 1026 return [c[0] for c in self._captures.values()] 1027 1028 @property 1029 def internal_captures(self): 1030 """Placeholders in this function corresponding captured tensors.""" 1031 return [c[1] for c in self._captures.values()] 1032 1033 @property 1034 def deferred_external_captures(self): 1035 """Ordered nest of tensors whose placeholders will be fed at call time.""" 1036 return [c[0] for c in self._deferred_captures.values()] 1037 1038 @property 1039 def deferred_internal_captures(self): 1040 """List of nest of placeholders which at call time will be fed.""" 1041 return [c[1] for c in self._deferred_captures.values()] 1042 1043 @property 1044 def variable_captures(self): 1045 """Map of python object ids of variables to variables which are captured.""" 1046 return { 1047 id(self._captures[id(v)][1]): v 1048 for v in self.variables 1049 if id(v) in self._captures 1050 } 1051 1052 def mark_as_unsaveable(self, error_message): 1053 """Marks this FuncGraph as unsaveable. 1054 1055 Any attempts to export this FuncGraph will raise an error with the specified 1056 message. 1057 1058 Args: 1059 error_message: List or string containing the error message to be raised 1060 when saving this FuncGraph to SavedModel. 1061 """ 1062 self._saveable = False 1063 if isinstance(error_message, str): 1064 error_message = [error_message] 1065 self._saving_errors.update(error_message) 1066 1067 @property 1068 def saveable(self): 1069 """Returns whether this FuncGraph is saveable.""" 1070 return self._saveable 1071 1072 @property 1073 def saving_errors(self): 1074 """Returns set of errors preventing this FuncGraph from being saved.""" 1075 return self._saving_errors 1076 1077 def _add_scope_exit_callback(self, fn): 1078 """Add a function to call when this graph exits the default scope.""" 1079 if not callable(fn): 1080 raise TypeError("fn is not callable: {}".format(fn)) 1081 if self._scope_exit_callbacks is None: 1082 raise RuntimeError( 1083 "Attempting to add a scope exit callback, but the default graph is " 1084 "not the context scope graph. Did you forget to call " 1085 "'with graph.as_default(): ...'?") 1086 self._scope_exit_callbacks.append(fn) 1087 1088 1089# TODO(mdan): Too many threaded arguments. Accept an ACD ctx manager instead. 1090def func_graph_from_py_func(name, 1091 python_func, 1092 args, 1093 kwargs, 1094 signature=None, 1095 func_graph=None, 1096 autograph=False, 1097 autograph_options=None, 1098 add_control_dependencies=True, 1099 arg_names=None, 1100 op_return_value=None, 1101 collections=None, 1102 capture_by_value=None, 1103 acd_record_initial_resource_uses=False): 1104 """Returns a `FuncGraph` generated from `python_func`. 1105 1106 Args: 1107 name: an identifier for the function. 1108 python_func: the Python function to trace. 1109 args: the positional args with which the Python function should be called; 1110 ignored if a signature is provided. 1111 kwargs: the keyword args with which the Python function should be called; 1112 ignored if a signature is provided. 1113 signature: a possibly nested sequence of `TensorSpecs` specifying the shapes 1114 and dtypes of the arguments. When a signature is provided, `args` and 1115 `kwargs` are ignored, and `python_func` is traced with Tensors conforming 1116 to `signature`. If `None`, the shapes and dtypes are inferred from the 1117 inputs. 1118 func_graph: Optional. An instance of FuncGraph. If provided, we will use 1119 this graph else a new one is built and returned. 1120 autograph: whether to use autograph to compile `python_func`. 1121 See https://www.tensorflow.org/guide/autograph for more information. 1122 autograph_options: additional knobs to control when `autograph=True`. 1123 See https://www.tensorflow.org/guide/autograph for more information. 1124 add_control_dependencies: If True, automatically adds control dependencies 1125 to ensure program order matches execution order and stateful ops always 1126 execute. 1127 arg_names: Optional list of argument names, used to give input placeholders 1128 recognizable names. 1129 op_return_value: Optional. A Tensor. If set and `python_func` returns 1130 Operations, those return values will be replaced with this value. If not 1131 set, returning an Operation triggers an error. 1132 collections: a dictionary of collections this FuncGraph should start with. 1133 If not specified (None), the FuncGraph will read (but not write to) the 1134 outer graph's collections that are not allowlisted, and both read and 1135 write to the outer graph's collections that are allowlisted. The current 1136 allowlisted collections are the global variables, the local variables, and 1137 the trainable variables. Defaults to None. 1138 capture_by_value: An optional boolean. If True, the func graph will capture 1139 Variables by value instead of reference. By default inherit from outer 1140 graphs, and failing that will default to False. 1141 acd_record_initial_resource_uses: If `True` and `add_control_dependencies` 1142 is enabled, the results (those marked with 1143 AutomaticControlDependencies.mark_result) will be annotated with a private 1144 attribute, "_res_first_used_by", which points to the first nodes which 1145 used the any of the resources that the result op is using. 1146 1147 Returns: 1148 A FuncGraph. 1149 1150 Raises: 1151 TypeError: If any of `python_func`'s return values is neither `None`, a 1152 `Tensor` or a `tf.experimental.ExtensionType`. 1153 """ 1154 if op_return_value is not None: 1155 assert isinstance(op_return_value, ops.Tensor), op_return_value 1156 if func_graph is None: 1157 func_graph = FuncGraph( 1158 name, collections=collections, capture_by_value=capture_by_value) 1159 assert isinstance(func_graph, FuncGraph) 1160 if add_control_dependencies: 1161 deps_control_manager = auto_control_deps.AutomaticControlDependencies( 1162 record_initial_resource_uses=acd_record_initial_resource_uses) 1163 else: 1164 deps_control_manager = ops.NullContextmanager() 1165 1166 with func_graph.as_default(), deps_control_manager as deps_ctx: 1167 current_scope = variable_scope.get_variable_scope() 1168 default_use_resource = current_scope.use_resource 1169 current_scope.set_use_resource(True) 1170 1171 if signature is not None: 1172 args = signature 1173 kwargs = {} 1174 func_args = _get_defun_inputs_from_args(args, arg_names) 1175 func_kwargs = _get_defun_inputs_from_kwargs(kwargs) 1176 1177 # Convert all Tensors into TensorSpecs before saving the structured inputs. 1178 # If storing pure concrete functions that are not called through polymorphic 1179 # functions, we don't have access to FunctionSpec, so we need to call the 1180 # TensorSpecs by their `arg_names` for later binding. 1181 func_graph.structured_input_signature = (convert_structure_to_signature( 1182 func_args, arg_names), convert_structure_to_signature(func_kwargs)) 1183 1184 flat_func_args = nest.flatten(func_args, expand_composites=True) 1185 flat_func_kwargs = nest.flatten(func_kwargs, expand_composites=True) 1186 # Temporarily set inputs to allow graph building code to inspect 1187 # them. Reassigned below. 1188 func_graph.inputs = [ 1189 arg for arg in flat_func_args + flat_func_kwargs 1190 if isinstance(arg, ops.Tensor) 1191 ] 1192 1193 # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. 1194 # Variables to help check whether mutation happens in calling the function 1195 # Copy the recursive list, tuple and map structure, but not base objects 1196 func_args_before = nest.pack_sequence_as( 1197 func_args, flat_func_args, expand_composites=True) 1198 func_kwargs_before = nest.pack_sequence_as( 1199 func_kwargs, flat_func_kwargs, expand_composites=True) 1200 1201 def convert(x): 1202 """Converts a function output to a Tensor.""" 1203 if x is None: 1204 return None 1205 if op_return_value is not None and isinstance(x, ops.Operation): 1206 # TODO(b/79881896): we currently can't capture external control deps, so 1207 # this won't work if x needs to be captured (i.e. if python_func returns 1208 # captured Operations). 1209 with ops.control_dependencies([x]): 1210 x = array_ops.identity(op_return_value) 1211 elif not isinstance(x, tensor_array_ops.TensorArray): 1212 try: 1213 x = ops.convert_to_tensor_or_composite(x) 1214 except (ValueError, TypeError): 1215 raise TypeError( 1216 "To be compatible with tf.function, Python functions " 1217 "must return zero or more Tensors or ExtensionTypes or None " 1218 f"values; in compilation of {str(python_func)}, found return " 1219 f"value of type {type(x).__name__}, which is not a Tensor or " 1220 "ExtensionType.") 1221 if add_control_dependencies: 1222 x = deps_ctx.mark_as_return(x) 1223 return x 1224 1225 try: 1226 if autograph: 1227 from tensorflow.python import autograph # pylint: disable=g-import-not-at-top 1228 _, original_func = tf_decorator.unwrap(python_func) 1229 1230 def autograph_handler(*args, **kwargs): 1231 """Calls a converted version of original_func.""" 1232 # TODO(mdan): Push this block higher in tf.function's call stack. 1233 try: 1234 return autograph.converted_call( 1235 original_func, 1236 args, 1237 kwargs, 1238 options=autograph.ConversionOptions( 1239 recursive=True, 1240 optional_features=autograph_options, 1241 user_requested=True, 1242 )) 1243 except Exception as e: # pylint:disable=broad-except 1244 if hasattr(e, "ag_error_metadata"): 1245 raise e.ag_error_metadata.to_exception(e) 1246 else: 1247 raise 1248 1249 # Wrapping around a decorator allows checks like tf_inspect.getargspec 1250 # to be accurate. 1251 converted_func = tf_decorator.make_decorator(original_func, 1252 autograph_handler) 1253 python_func = tf_decorator.rewrap(python_func, original_func, 1254 converted_func) 1255 1256 else: 1257 _, original_func = tf_decorator.unwrap(python_func) 1258 1259 func_outputs = python_func(*func_args, **func_kwargs) 1260 1261 # invariant: `func_outputs` contains only Tensors, CompositeTensors, 1262 # TensorArrays and `None`s. 1263 func_outputs = nest.map_structure( 1264 convert, func_outputs, expand_composites=True) 1265 1266 check_func_mutation(func_args_before, func_kwargs_before, func_args, 1267 func_kwargs, original_func) 1268 finally: 1269 current_scope.set_use_resource(default_use_resource) 1270 1271 # Variables in `func_args`, `func_kwargs` should be explicit inputs 1272 # to the function, not captured inputs. 1273 graph_variables = list(func_graph._watched_variables) # pylint: disable=protected-access 1274 arg_variables = object_identity.ObjectIdentitySet() 1275 inputs = [] 1276 for arg in (nest.flatten(func_args, expand_composites=True) + 1277 nest.flatten(func_kwargs, expand_composites=True)): 1278 if isinstance(arg, resource_variable_ops.BaseResourceVariable): 1279 # Even if an argument variable was not used in the function, we've 1280 # already manually captured the resource Tensor when creating argument 1281 # placeholders. 1282 resource_placeholder = func_graph.pop_capture(arg.handle) 1283 if resource_placeholder is None: 1284 continue 1285 arg_variables.add(arg) 1286 inputs.append(resource_placeholder) 1287 elif isinstance(arg, ops.Tensor): 1288 inputs.append(arg) 1289 variables = [v for v in graph_variables if v not in arg_variables] 1290 func_graph.inputs = ( 1291 inputs + func_graph.internal_captures + nest.flatten( 1292 func_graph.deferred_internal_captures, expand_composites=True)) 1293 func_graph.structured_outputs = func_outputs 1294 # Returning a closed-over tensor does not trigger convert_to_tensor. 1295 func_graph.outputs.extend( 1296 func_graph.capture(x) 1297 for x in flatten(func_graph.structured_outputs) 1298 if x is not None) 1299 1300 func_graph.variables = variables 1301 1302 if add_control_dependencies: 1303 func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run) 1304 func_graph.collective_manager_ids_used = ( 1305 deps_control_manager.collective_manager_ids_used) 1306 1307 return func_graph 1308 1309 1310def maybe_captured(tensor): 1311 """If t is a captured value placeholder, returns the original captured value. 1312 1313 Args: 1314 tensor: Tensor. 1315 1316 Returns: 1317 A tensor, potentially from a different Graph/FuncGraph. 1318 """ 1319 if (not isinstance(tensor, ops.EagerTensor) and 1320 tensor.op.graph.building_function and tensor.op.type == "Placeholder"): 1321 for input_t, placeholder_t in tensor.op.graph.captures: 1322 if tensor == placeholder_t: 1323 return maybe_captured(input_t) 1324 # pylint: enable=protected-access 1325 return tensor 1326 1327 1328def device_stack_has_callable(device_stack): 1329 """Checks whether a device stack contains a callable.""" 1330 return any( 1331 callable(spec._device_name_or_function) # pylint: disable=protected-access 1332 for spec in device_stack.peek_objs()) 1333 1334 1335def has_mutation(n1, n2): 1336 """Returns true if n1 and n2 are different (using `is` to compare leaves).""" 1337 try: 1338 nest.assert_same_structure(n1, n2, expand_composites=True) 1339 except ValueError: 1340 return True 1341 1342 for arg1, arg2 in zip( 1343 nest.flatten(n1, expand_composites=True), 1344 nest.flatten(n2, expand_composites=True)): 1345 if arg1 is not arg2: 1346 return True 1347 1348 return False 1349 1350 1351def check_func_mutation(old_args, old_kwargs, new_args, new_kwargs, func): 1352 """Checks that the arguments to a function are not modified.""" 1353 if not has_mutation((old_args, old_kwargs), (new_args, new_kwargs)): 1354 return 1355 1356 # Mutation detected; construct a useful error message. 1357 func_name = getattr(func, "__qualname__", getattr(func, "__name__", func)) 1358 signature = tf_inspect.signature(func) 1359 try: 1360 old_bound = signature.bind(*old_args, **old_kwargs).arguments 1361 new_bound = signature.bind(*new_args, **new_kwargs).arguments 1362 except TypeError as e: 1363 # This occurs when the function is called with the (deprecated) 1364 # "flat signature". See ConcreteFunction._call_with_flat_signature. In 1365 # this case, we can't report which arguments were modified. 1366 raise ValueError( 1367 f"{func_name}{signature} should not modify its Python input " 1368 f"arguments. Check if it modifies any lists or dicts passed as " 1369 f"arguments. Modifying a copy is allowed.") from e 1370 1371 assert set(old_bound) == set(new_bound) 1372 modified_args = [ 1373 arg_name for arg_name in new_bound 1374 if has_mutation(old_bound[arg_name], new_bound[arg_name]) 1375 ] 1376 changes = ", ".join(modified_args) 1377 raise ValueError(f"{func_name}{signature} should not modify its Python " 1378 f"input arguments. Modifying a copy is allowed. The " 1379 f"following parameter(s) were modified: {changes}") 1380 1381 1382# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. 1383def flatten(sequence): 1384 """Like nest.flatten w/ expand_composites, but returns flow for TensorArrays. 1385 1386 Args: 1387 sequence: A nested structure of Tensors, CompositeTensors, and TensorArrays. 1388 1389 Returns: 1390 A list of tensors. 1391 """ 1392 flat_sequence = nest.flatten(sequence, expand_composites=True) 1393 return [ 1394 item.flow if isinstance(item, tensor_array_ops.TensorArray) else item 1395 for item in flat_sequence 1396 ] 1397 1398 1399# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. 1400def pack_sequence_as(structure, flat_sequence): 1401 """Like `nest.pack_sequence_as` but also builds TensorArrays from flows. 1402 1403 Args: 1404 structure: The structure to pack into. May contain Tensors, 1405 CompositeTensors, or TensorArrays. 1406 flat_sequence: An iterable containing tensors. 1407 1408 Returns: 1409 A nested structure. 1410 1411 Raises: 1412 AssertionError if `structure` and `flat_sequence` are not compatible. 1413 """ 1414 flat_sequence = list(flat_sequence) 1415 flattened_structure = nest.flatten(structure, expand_composites=True) 1416 if len(flattened_structure) != len(flat_sequence): 1417 raise ValueError("Mismatch in element count") 1418 for i in range(len(flat_sequence)): 1419 if isinstance(flattened_structure[i], tensor_array_ops.TensorArray): 1420 flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow( 1421 old_ta=flattened_structure[i], flow=flat_sequence[i]) 1422 return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True) 1423 1424 1425def _create_substitute_placeholder(value, name=None, dtype=None, shape=None): 1426 """Creates a placeholder for `value` and propagates shape info to it.""" 1427 # Note: setting ops.control_dependencies(None) ensures we always put 1428 # capturing placeholders outside of any control flow context. 1429 if shape is None: 1430 shape = value.shape 1431 with ops.control_dependencies(None): 1432 placeholder = graph_placeholder( 1433 dtype=dtype or value.dtype, shape=shape, name=name) 1434 handle_data_util.copy_handle_data(value, placeholder) 1435 return placeholder 1436 1437 1438def _get_defun_inputs_from_args(args, names): 1439 """Maps Python function positional args to graph-construction inputs.""" 1440 return _get_defun_inputs(args, names, structured_args=args) 1441 1442 1443def _get_defun_inputs_from_kwargs(kwargs): 1444 """Maps Python function keyword args to graph-construction inputs.""" 1445 if kwargs: 1446 names, args = zip(*sorted(kwargs.items())) 1447 else: 1448 names = [] 1449 args = [] 1450 return _get_defun_inputs(args, names, structured_args=kwargs) 1451 1452 1453def _get_composite_tensor_spec(x): 1454 """Returns the TypeSpec for x if it's a composite tensor, or x otherwise.""" 1455 return (x._type_spec # pylint: disable=protected-access 1456 if isinstance(x, composite_tensor.CompositeTensor) else x) 1457 1458 1459def _get_defun_inputs(args, names, structured_args): 1460 """Maps python function args to graph-construction inputs. 1461 1462 Args: 1463 args: A list of user-specified arguments. If `structured_args` is a list, 1464 `args` is the same with `structured_args`. If `structured_args` is a dict, 1465 `args` is the values of the dict. 1466 names: A list of strings with user-specified argument names, same length as 1467 `args`. May be `None`, in which case a generic name is used. 1468 structured_args: The original argument list or dictionary. 1469 1470 Returns: 1471 Placeholders with the same structure as `structured_args`. 1472 """ 1473 func_graph = ops.get_default_graph() 1474 function_inputs = [] 1475 if names is None: 1476 names = [None] * len(args) 1477 1478 for arg_value, name in zip(args, names): 1479 # Replace any composite tensors with their TypeSpecs. This is important 1480 # for ensuring that shape information that's not preserved by the TypeSpec 1481 # (such as the number of values in a SparseTensor) gets properly masked. 1482 arg_value = nest.map_structure(_get_composite_tensor_spec, arg_value) 1483 flat_args = nest.flatten(arg_value, expand_composites=True) 1484 1485 for arg in flat_args: 1486 if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)): 1487 arg_is_spec = isinstance(arg, tensor_spec.TensorSpec) 1488 if arg_is_spec and arg.name: 1489 requested_name = arg.name 1490 else: 1491 requested_name = name 1492 try: 1493 placeholder = graph_placeholder( 1494 arg.dtype, arg.shape, name=requested_name) 1495 except ValueError as e: 1496 # Sometimes parameter names are not valid op names, so fall back to 1497 # unnamed placeholders. 1498 logging.warning(e) 1499 placeholder = graph_placeholder(arg.dtype, arg.shape) 1500 if not arg_is_spec: 1501 handle_data_util.copy_handle_data(arg, placeholder) 1502 if name is not None: 1503 # Record the requested/user-specified name in case it's different than 1504 # the uniquified name, for validation when exporting signatures. 1505 placeholder.op._set_attr( # pylint: disable=protected-access 1506 "_user_specified_name", 1507 attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name))) 1508 function_inputs.append(placeholder) 1509 elif isinstance(arg, (resource_variable_ops.BaseResourceVariable, 1510 resource_variable_ops.VariableSpec)): 1511 if isinstance(arg, resource_variable_ops.VariableSpec): 1512 name = arg.name or name 1513 with func_graph.outer_graph.as_default(): 1514 placeholder = graph_placeholder( 1515 dtypes.resource, arg.shape, name=name) 1516 1517 arg = resource_variable_ops.BaseResourceVariable( 1518 name=name, 1519 shape=arg.shape, 1520 dtype=arg.dtype, 1521 handle=placeholder, 1522 handle_name=name, 1523 trainable=arg.trainable) 1524 # Capture arg variables to create placeholders for them. These will be 1525 # removed as captures after the function is traced (since otherwise we'd 1526 # just add it back with a new placeholder when the variable was 1527 # referenced). 1528 placeholder = func_graph.capture(arg.handle, name=name) 1529 placeholder.op._set_attr( # pylint: disable=protected-access 1530 "_user_specified_name", 1531 attr_value_pb2.AttrValue(s=compat.as_bytes(name))) 1532 function_inputs.append(arg) 1533 else: 1534 function_inputs.append(arg) 1535 return nest.pack_sequence_as( 1536 structured_args, function_inputs, expand_composites=True) 1537 1538 1539def dismantle_func_graph(func_graph): 1540 """Removes reference cycles in `func_graph` FuncGraph. 1541 1542 Helpful for making sure the garbage collector doesn't need to run when 1543 the FuncGraph goes out of scope, e.g. in tests using defun with 1544 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True). 1545 1546 Args: 1547 func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable after 1548 this function. 1549 """ 1550 func_graph.clear_captures() 1551 ops.dismantle_graph(func_graph) 1552 1553 1554def override_func_graph_name_scope(func_graph, name_scope): 1555 func_graph._name_stack = name_scope # pylint: disable=protected-access 1556