1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""API for defining graph functions with some additional eager semantics. 17 18def_function.function wraps the function concept in function.py ("defun") to 19allow initializing `tf.Variable`s with subgraphs of the function. For example: 20 21```python 22class M(tf.Module): 23 def __init__(self): 24 self.v_opinit = None 25 self.v_arginit = None 26 27 @tf.function 28 def __call__(self, x): 29 # Variables are only created on the first call to the function. This is a 30 # common pattern in layer libraries. 31 if self.v_opinit is None: 32 # self.v_opinit will outlive the function call, but `tf.ones` is traced as 33 # part of the function body before the `tf.Variable` object is 34 # created. This subgraph is easy to lift out of the function. 35 self.v_opinit = tf.Variable(tf.ones([])) 36 37 # If arguments feed into variable initialization, it can be very tricky to 38 # disentangle from the rest of the function. We don't attempt it. 39 self.v_arginit = tf.Variable(tf.ones(tf.shape(x)) * tf.constant(2.)) 40 return self.v_opinit + self.v_arginit + x 41``` 42 43These patterns with "defun" throw an error asking the user to put the variable's 44initializer in a lambda. With tf.function they work with eager semantics either 45by lifting the subgraph out of the function and using it to initialize the 46variable, or by initializing variables on the first call to the function (if 47they weren't already initialized by something else, e.g. a checkpoint API). The 48latter requires tf.conds, and is not well supported by TF-XLA, so we only do it 49when necessary. 50 51Since these patterns are relatively common in layer libraries, we expose the 52wrapper in this file as `tf.function`. The function concept in function.py is an 53internal implementation detail. 54 55In order to support these variable initialization patterns, tf.function defines 56a variable subtype (UnliftedInitializerVariable) which collects the input 57subgraph. This type of variable replaces the regular variable type on the first 58tf.function trace. To exclude initializers from the function body (the `tf.ones` 59ops above and associated assignment operations), tf.function traces a second 60time if it sees variables on the first call. 61""" 62 63import functools 64import os 65import threading 66import types as types_lib 67import weakref 68 69from google.protobuf import text_format as _text_format 70from google.protobuf.message import DecodeError 71from tensorflow.core.framework import attr_value_pb2 72from tensorflow.python.distribute.parallel_device import parallel_device 73from tensorflow.python.eager import context 74from tensorflow.python.eager import function as function_lib 75from tensorflow.python.eager import function_spec as function_spec_lib 76from tensorflow.python.eager import lift_to_graph 77from tensorflow.python.eager import monitoring 78from tensorflow.python.framework import composite_tensor 79from tensorflow.python.framework import errors 80from tensorflow.python.framework import func_graph as func_graph_module 81from tensorflow.python.framework import ops 82from tensorflow.python.ops import array_ops 83from tensorflow.python.ops import control_flow_ops 84from tensorflow.python.ops import control_flow_util 85from tensorflow.python.ops import math_ops 86from tensorflow.python.ops import random_ops 87from tensorflow.python.ops import resource_variable_ops 88from tensorflow.python.platform import tf_logging as logging 89from tensorflow.python.profiler import trace 90from tensorflow.python.trackable import base as trackable 91from tensorflow.python.types import core 92from tensorflow.python.util import deprecation 93from tensorflow.python.util import nest 94from tensorflow.python.util import object_identity 95from tensorflow.python.util import tf_decorator 96from tensorflow.python.util import traceback_utils 97from tensorflow.python.util.tf_export import tf_export 98 99FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10 100FREQUENT_TRACING_WARNING_THRESHOLD = 5 101FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR = 2 102ALLOW_DYNAMIC_VARIABLE_CREATION = False 103 104_tf_function_counter = monitoring.Counter( 105 "/tensorflow/core/tf_function_counter", 106 "Counter for the number of tf.functions created when Eager execution is " 107 "enabled.", 108 # jit_compile is "0" or "1". 109 "jit_compile") 110 111 112class _FrequentTracingDetector(object): 113 """Class keeping track of how many recent calls triggered tracing.""" 114 115 __slots__ = ["_calls_per_tracings", "_call_count", "_total_warning_count"] 116 117 def __init__(self): 118 self._calls_per_tracings = [] 119 self._total_warning_count = 0 120 self._call_count = 0 121 122 def called_with_tracing(self, function_name, omit_warning): 123 """Updates the list of most recent calls' tracing information. 124 125 Warns the user when recent calls caused retracing too often. 126 127 Args: 128 function_name: the python function being traced. 129 omit_warning: If 'True', this call will not warn the user even if 130 retracing happens too often. 131 """ 132 self._call_count += 1 133 self._calls_per_tracings.append(1) 134 135 while self._calls_per_tracings: 136 if (self._call_count - self._calls_per_tracings[0] > 137 FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY): 138 self._call_count -= self._calls_per_tracings.pop(0) 139 else: 140 break 141 142 if (omit_warning or self._total_warning_count >= 143 FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR): 144 return 145 if len(self._calls_per_tracings) >= FREQUENT_TRACING_WARNING_THRESHOLD: 146 self._total_warning_count += 1 147 logging.warning( 148 "{} out of the last {} calls to {} triggered tf.function " 149 "retracing. Tracing is expensive and the excessive number of " 150 "tracings could be due to (1) creating @tf.function repeatedly in " 151 "a loop, (2) passing tensors with different shapes, (3) passing " 152 "Python objects instead of tensors. For (1), please define your " 153 "@tf.function outside of the loop. For (2), @tf.function has " 154 "reduce_retracing=True option that can avoid unnecessary " 155 "retracing. For (3), please refer to " 156 "https://www.tensorflow.org/guide/function#controlling_retracing" 157 " and https://www.tensorflow.org/api_docs/python/tf/function for " 158 " more details.".format( 159 len(self._calls_per_tracings), self._call_count, function_name)) 160 161 def called_without_tracing(self): 162 # We don't count tracing when users load a concrete function directly or 163 # call get_concrete_function, so the first call can be not a tracing call. 164 if not self._calls_per_tracings: 165 self._calls_per_tracings = [0] 166 self._calls_per_tracings[-1] += 1 167 self._call_count += 1 168 169 170class _FrequentTracingDetectorManager(object): 171 """Class for the management of all _FrequentTracingDetector objects.""" 172 173 __slots__ = ["_detectors", "_lock"] 174 175 def __init__(self): 176 self._detectors = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock) 177 self._lock = threading.Lock() 178 179 def _get_detector(self, key): 180 if key not in self._detectors: 181 self._detectors[key] = _FrequentTracingDetector() 182 return self._detectors[key] 183 184 def called_without_tracing(self, key): 185 with self._lock: 186 detector = self._get_detector(key) 187 detector.called_without_tracing() 188 189 def called_with_tracing(self, key, function_name, omit_warning): 190 with self._lock: 191 detector = self._get_detector(key) 192 detector.called_with_tracing(function_name, omit_warning) 193 194 195_frequent_tracing_detector_manager = _FrequentTracingDetectorManager() 196 197 198class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable): 199 """Variable which does not lift its initializer out of function context. 200 201 Instances of this variable, when created, build a graph which runs their 202 initializer inside a tf.cond(is_initialized) block. 203 204 This can only be created inside a defun called from (eventually) eager 205 mode. That is, non-function-building graphs are not supported. 206 """ 207 208 def __init__(self, 209 initial_value=None, 210 trainable=None, 211 caching_device=None, 212 name=None, 213 dtype=None, 214 constraint=None, 215 add_initializers_to=None, 216 lifted_initializer_graph=None, 217 synchronization=None, 218 aggregation=None, 219 shape=None, 220 **unused_kwargs): 221 """Creates a variable. 222 223 Args: 224 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 225 which is the initial value for the Variable. The initial value must have 226 a shape specified unless `validate_shape` is set to False. Can also be a 227 callable with no argument that returns the initial value when called. 228 (Note that initializer functions from init_ops.py must first be bound 229 to a shape before being used here.) 230 trainable: If `True`, GradientTapes automatically watch uses of this 231 Variable. 232 caching_device: Optional device string or function describing where the 233 Variable should be cached for reading. Defaults to the Variable's 234 device. If not `None`, caches on another device. Typical use is to 235 cache on the device where the Ops using the Variable reside, to 236 deduplicate copying through `Switch` and other conditional statements. 237 name: Optional name for the variable. Defaults to `'Variable'` and gets 238 uniquified automatically. 239 dtype: If set, initial_value will be converted to the given type. 240 If None, either the datatype will be kept (if initial_value is 241 a Tensor) or float32 will be used (if it is a Python object convertible 242 to a Tensor). 243 constraint: An optional projection function to be applied to the variable 244 after being updated by an `Optimizer` (e.g. used to implement norm 245 constraints or value constraints for layer weights). The function must 246 take as input the unprojected Tensor representing the value of the 247 variable and return the Tensor for the projected value 248 (which must have the same shape). Constraints are not safe to 249 use when doing asynchronous distributed training. 250 add_initializers_to: if not None and not in legacy graph mode, the 251 initializer tensor will be added to this map in addition to adding the 252 assignment to the function. 253 lifted_initializer_graph: FuncGraph to try to lift initializers to. 254 synchronization: Indicates when a distributed variable will be 255 aggregated. Accepted values are constants defined in the class 256 `tf.VariableSynchronization`. By default the synchronization is set to 257 `AUTO` and the current `DistributionStrategy` chooses 258 when to synchronize. 259 aggregation: Indicates how a distributed variable will be aggregated. 260 Accepted values are constants defined in the class 261 `tf.VariableAggregation`. 262 shape: (optional) The shape of this variable. If None, the shape of 263 `initial_value` will be used. When setting this argument to 264 `tf.TensorShape(None)` (representing an unspecified shape), the variable 265 can be assigned with values of different shapes. 266 267 Raises: 268 ValueError: If the initial value is not specified, or does not have a 269 shape and `validate_shape` is `True`. 270 RuntimeError: If called outside of a function definition. 271 """ 272 with ops.init_scope(): 273 self._in_graph_mode = not context.executing_eagerly() 274 if not ops.inside_function(): 275 # If we've been init_scope()d out of the function definition nothing to do 276 # here; we can't really do the capturing or conditional logic. 277 resource_variable_ops.ResourceVariable.__init__( 278 self, initial_value=initial_value, trainable=trainable, 279 caching_device=caching_device, name=name, dtype=dtype, 280 constraint=constraint) 281 return 282 if initial_value is None: 283 raise ValueError("`initial_value` must be a Tensor or a Python " 284 "object convertible to a Tensor. Got None.") 285 init_from_fn = callable(initial_value) 286 287 if constraint is not None and not callable(constraint): 288 raise ValueError(f"`constraint` with type {type(constraint)} must be a " 289 "callable.") 290 291 with ops.name_scope(name, "Variable", [] 292 if init_from_fn else [initial_value]) as scope_name: 293 with ops.name_scope("Initializer"): 294 if init_from_fn: 295 initial_value = initial_value() 296 if isinstance(initial_value, trackable.CheckpointInitialValue): 297 self._maybe_initialize_trackable() 298 self._update_uid = initial_value.checkpoint_position.restore_uid 299 initial_value = initial_value.wrapped_value 300 301 initial_value = ops.convert_to_tensor(initial_value, 302 name="initial_value", dtype=dtype) 303 assert initial_value is not None 304 305 # Don't use `shape or initial_value.shape` since TensorShape has 306 # overridden `__bool__`. 307 if shape is None: 308 shape = initial_value.shape 309 310 # Use the constructor for UninitializedVariable to start. Outside the name 311 # scope so we don't double up the prefix. 312 super().__init__( 313 trainable=trainable, 314 caching_device=caching_device, 315 name=name, 316 shape=shape, 317 dtype=initial_value.dtype, 318 constraint=constraint, 319 synchronization=synchronization, 320 aggregation=aggregation, 321 extra_handle_data=initial_value, 322 **unused_kwargs) 323 324 with ops.name_scope(scope_name): 325 if self._in_graph_mode: 326 with ops.init_scope(): 327 outer_graph = ops.get_default_graph() 328 func_graph = ops.get_default_graph() 329 function_placeholders = ( 330 func_graph.inputs + func_graph.internal_captures) 331 placeholder_ops = set( 332 [tensor.op for tensor in function_placeholders]) 333 lifted_initializer = lift_to_graph.lift_to_graph( 334 [initial_value], outer_graph, 335 disallowed_placeholders=placeholder_ops)[initial_value] 336 with ops.init_scope(): 337 self._initial_value = lifted_initializer 338 with ops.name_scope("IsInitialized"): 339 self._is_initialized_op = ( 340 resource_variable_ops.var_is_initialized_op(self._handle)) 341 if initial_value is not None: 342 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 343 self._initializer_op = resource_variable_ops.assign_variable_op( 344 self._handle, lifted_initializer, name=n) 345 elif context.executing_eagerly(): 346 # In this case, both current scope and init scope are eager. 347 # Assign_variable_op will be executed immediately. So we don't need to 348 # add it to "add_initializers_to" to lift it out. 349 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 350 resource_variable_ops.assign_variable_op( 351 self._handle, initial_value, name=n) 352 else: 353 # Init scope is eager but current scope is graph. We will lift out this 354 # variable by addint it into "add_initializers_to". 355 if add_initializers_to is not None: 356 add_initializers_to.append((self, initial_value)) 357 358 def assign_fn(): 359 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 360 resource_variable_ops.assign_variable_op( 361 self._handle, 362 initial_value, 363 name=n) 364 # Returning values to keep tf.cond happy. 365 return ops.convert_to_tensor(1) 366 def not_assign_fn(): 367 return ops.convert_to_tensor(0) 368 # Note: this cond is always guaranteed to run because we're inside a 369 # defun which will insert automatic control dependencies. It will only 370 # execute assign_fn if lifting failed. 371 graph = ops.get_default_graph() 372 373 # Capture the handle ahead of time in order to avoid querying the shape 374 # of the handle which helps async execution performance 375 graph.capture(self._handle, shape=()) 376 control_flow_ops.cond( 377 resource_variable_ops.var_is_initialized_op(self._handle), 378 not_assign_fn, assign_fn) 379 380 381JIT_COMPILE_FUNCTIONS = ( 382 os.getenv("TF_FUNCTION_JIT_COMPILE_DEFAULT", "false").lower() 383 in ("true", "1")) 384 385RUN_FUNCTIONS_EAGERLY = False 386 387 388@deprecation.deprecated( 389 None, 390 "Use `tf.config.run_functions_eagerly` instead of the experimental " 391 "version.") 392@tf_export("config.experimental_run_functions_eagerly") 393def experimental_run_functions_eagerly(run_eagerly): 394 """Enables / disables eager execution of `tf.function`s. 395 396 Calling `tf.config.experimental_run_functions_eagerly(True)` will make all 397 invocations of `tf.function` run eagerly instead of running as a traced graph 398 function. 399 400 See `tf.config.run_functions_eagerly` for an example. 401 402 Note: This flag has no effect on functions passed into tf.data transformations 403 as arguments. tf.data functions are never executed eagerly and are always 404 executed as a compiled Tensorflow Graph. 405 406 Args: 407 run_eagerly: Boolean. Whether to run functions eagerly. 408 """ 409 return run_functions_eagerly(run_eagerly) 410 411 412@tf_export("config.run_functions_eagerly") 413def run_functions_eagerly(run_eagerly): 414 """Enables / disables eager execution of `tf.function`s. 415 416 Calling `tf.config.run_functions_eagerly(True)` will make all 417 invocations of `tf.function` run eagerly instead of running as a traced graph 418 function. 419 420 This can be useful for debugging. 421 422 >>> def my_func(a): 423 ... print("Python side effect") 424 ... return a + a 425 >>> a_fn = tf.function(my_func) 426 427 >>> # A side effect the first time the function is traced 428 >>> a_fn(tf.constant(1)) 429 Python side effect 430 <tf.Tensor: shape=(), dtype=int32, numpy=2> 431 432 >>> # No further side effect, as the traced function is called 433 >>> a_fn(tf.constant(2)) 434 <tf.Tensor: shape=(), dtype=int32, numpy=4> 435 436 >>> # Now, switch to eager running 437 >>> tf.config.run_functions_eagerly(True) 438 >>> # Side effect, as the function is called directly 439 >>> a_fn(tf.constant(2)) 440 Python side effect 441 <tf.Tensor: shape=(), dtype=int32, numpy=4> 442 443 >>> # Turn this back off 444 >>> tf.config.run_functions_eagerly(False) 445 446 Note: This flag has no effect on functions passed into tf.data transformations 447 as arguments. tf.data functions are never executed eagerly and are always 448 executed as a compiled Tensorflow Graph. 449 450 Args: 451 run_eagerly: Boolean. Whether to run functions eagerly. 452 """ 453 global RUN_FUNCTIONS_EAGERLY 454 RUN_FUNCTIONS_EAGERLY = bool(run_eagerly) 455 456 457@deprecation.deprecated( 458 None, 459 "Use tf.config.functions_run_eagerly instead of the experimental version.") 460@tf_export("config.experimental_functions_run_eagerly") 461def experimental_functions_run_eagerly(): 462 """Returns the value of the `experimental_run_functions_eagerly` setting.""" 463 return functions_run_eagerly() 464 465 466@tf_export("config.functions_run_eagerly") 467def functions_run_eagerly(): 468 """Returns the value of the `run_functions_eagerly` setting.""" 469 return RUN_FUNCTIONS_EAGERLY 470 471 472def _evaluate_var_is_initialized(variables): 473 """Compute booleans indicating whether each variable is initialized.""" 474 with ops.init_scope(): 475 var_is_initialized = [] 476 for v in variables: 477 var_is_initialized.append( 478 resource_variable_ops.var_is_initialized_op(v.handle)) 479 try: 480 # Stack all the var_is_initialized values into one tensor and interpret 481 # the numpy value. This will reduce the number of RPCs between client and 482 # worker in the remote case. 483 return array_ops.stack(var_is_initialized).numpy() 484 except errors.UnimplementedError: 485 # Some devices do not support implicit copy-off to host. Fall back to 486 # variable-by-variable processing. 487 for index, v in enumerate(variables): 488 try: 489 numpy_value = var_is_initialized[index].numpy() 490 except errors.UnimplementedError: 491 # This is a variable on a parallel device; we'll extract its value on 492 # each replica and assert that they're identical. 493 components = parallel_device.unpack(var_is_initialized[index]) 494 with ops.device(None): 495 components = array_ops.stack(components) 496 all_initialized = math_ops.reduce_all(components).numpy() 497 any_initialized = math_ops.reduce_any(components).numpy() 498 if all_initialized != any_initialized: 499 raise NotImplementedError( 500 f"Some but not all components of a parallel variable {v!r} " 501 "were initialized between their creation in a tf.function and " 502 "the function's trace having completed. This is not " 503 "supported; consider initializing either all or none of the " 504 "components, or moving initialization out of the function.") 505 numpy_value = all_initialized 506 var_is_initialized[index] = numpy_value 507 return var_is_initialized 508 509 510class FunctionDeleter: 511 """An object responsible for cleaning up the function graph.""" 512 513 __slots__ = ["func_graph"] 514 515 def __init__(self, func_graph): 516 self.func_graph = func_graph 517 518 def __del__(self): 519 try: 520 func_graph_module.dismantle_func_graph(self.func_graph) 521 except: # pylint: disable=bare-except 522 # Note: bare except here because this can be noisy at shutdown time. 523 pass 524 525 526class OptionalXlaContext: 527 """Wrapper for XLA context optionally applied under a context manager.""" 528 529 def __init__(self, is_compiled): 530 wrap = is_compiled and not control_flow_util.GraphOrParentsInXlaContext( \ 531 ops.get_default_graph()) 532 self.xla_context = control_flow_ops.XLAControlFlowContext() \ 533 if wrap else None 534 535 def __enter__(self): 536 if self.xla_context: 537 self.xla_context.Enter() 538 539 def __exit__(self, t, value, traceback): 540 if self.xla_context: 541 self.xla_context.Exit() 542 543 544# TODO(mdan): Consider expose this type for instance type checking. 545@tf_export("__internal__.function.Function", v1=[]) 546class Function(core.GenericFunction, trackable.Trackable): 547 """A `tf.types.experimental.GenericFunction` created by `tf.function`. 548 549 Currently, individual methods/attributes under this class are not guaranteed 550 by the TF API contract, and are subject to future changes. 551 """ 552 553 def __init__(self, 554 python_function, 555 name, 556 input_signature=None, 557 autograph=True, 558 jit_compile=None, 559 reduce_retracing=False, 560 experimental_implements=None, 561 experimental_autograph_options=None, 562 experimental_follow_type_hints=None): 563 """Initializes a `Function`. 564 565 Args: 566 python_function: the function to be wrapped. 567 name: the name given to it. 568 input_signature: See the documentation for `tf.function`. 569 autograph: See the documentation for `tf.function`. 570 jit_compile: See the documentation for `tf.function`. 571 reduce_retracing: See the documentation for `tf.function`. 572 experimental_implements: See the documentation for `tf.function`. 573 experimental_autograph_options: See the documentation for `tf.function`. 574 experimental_follow_type_hints: See the documentation for `tf.function`. 575 576 Raises: 577 ValueError: if `input_signature` is not None and the `python_function`'s 578 argspec has keyword arguments. 579 """ 580 self._lock = threading.RLock() 581 self._python_function = python_function 582 self._function_spec = function_spec_lib.FunctionSpec.from_function_and_signature( 583 python_function, 584 input_signature, 585 jit_compile=jit_compile, 586 experimental_follow_type_hints=experimental_follow_type_hints, 587 ) 588 self._implements = experimental_implements 589 # If `True`, the function uses the rendezvous of the parent. This is only 590 # needed to support code where raw send/recv operations are inserted and 591 # when functions are run in graph mode where they may not be inlined. 592 self._shared_rendezvous = None 593 self._autograph = autograph 594 self._experimental_autograph_options = experimental_autograph_options 595 self._reduce_retracing = reduce_retracing 596 self._jit_compile = jit_compile 597 if experimental_follow_type_hints is None: 598 experimental_follow_type_hints = False 599 self._experimental_follow_type_hints = experimental_follow_type_hints 600 self._created_variables = None # GUARDED_BY(self._lock) 601 self._stateful_fn = None # GUARDED_BY(self._lock) 602 self._stateless_fn = None # GUARDED_BY(self._lock) 603 self._descriptor_cache = weakref.WeakKeyDictionary() 604 self._name = name 605 self._key_for_call_stats = self._get_key_for_call_stats() 606 self._omit_frequent_tracing_warning = False 607 ops._tf_function_api_guage.get_cell().set(True) # pylint: disable=protected-access 608 609 @property 610 def name(self): 611 return self._name 612 613 def __getstate__(self): 614 """Custom pickling, to omit unpickleable objects.""" 615 result = self.__dict__.copy() 616 del result["_lock"] 617 del result["_descriptor_cache"] 618 del result["_key_for_call_stats"] 619 return result 620 621 def __setstate__(self, state): 622 """Restore from pickled state.""" 623 self.__dict__ = state 624 self._lock = threading.RLock() 625 self._descriptor_cache = weakref.WeakKeyDictionary() 626 self._key_for_call_stats = self._get_key_for_call_stats() 627 628 def _get_key_for_call_stats(self): 629 """Returns key instance to track call stats and retracings. 630 631 The key instance a best-effort to preserve global consistency. 632 """ 633 target_function = self._python_function 634 # `__wrapped__` is a conventional Python attribute that a higher-order 635 # function keeps its original function's instance. We also directly use 636 # this attribute for dealing with a class method. See 637 # `bound_method_wrapper` in `function.py`. If we don't use `__wrapped__`, 638 # all class methods will return the same `bound_method_wrapper` instance 639 # from this function. 640 while hasattr(target_function, "__wrapped__"): 641 target_function = target_function.__wrapped__ 642 643 if hasattr(target_function, "__func__"): 644 target_function = target_function.__func__ 645 646 if hasattr(target_function, "__code__"): 647 return target_function.__code__ 648 649 return self._python_function 650 651 def _defun_with_scope(self, scope): 652 """Creates a defun wrapped inside a variable creator scope.""" 653 654 weak_wrapped_fn = None 655 compile_with_xla = self._jit_compile 656 657 def wrapped_fn(*args, **kwds): 658 """Wraps `self._python_function` in a variable creator scope.""" 659 # We register a variable creator with reduced priority. If an outer 660 # variable creator is just modifying keyword arguments to the variable 661 # constructor, this will work harmoniously. Since the `scope` registered 662 # here actually creates the variable, it taking priority would otherwise 663 # ignore the outer creator. 664 # 665 # If an outer variable creator calls the variable constructor manually, 666 # for example creating a MirroredVariable, then they won't call our 667 # creator. This means we won't be able to trace the initialization graph, 668 # and so variable initializers can't depend on function arguments. This is 669 # better than the alternative, tracing the initialization graph but giving 670 # the user a variable type they didn't want. 671 default_graph = ops.get_default_graph() 672 with default_graph._variable_creator_scope(scope, priority=50): # pylint: disable=protected-access 673 # __wrapped__ allows AutoGraph to swap in a converted function. We give 674 # the function a weak reference to itself to avoid a reference cycle. 675 with OptionalXlaContext(compile_with_xla): 676 out = weak_wrapped_fn().__wrapped__(*args, **kwds) 677 return out 678 679 weak_wrapped_fn = weakref.ref(wrapped_fn) 680 681 return self._defun(tf_decorator.make_decorator( 682 self._python_function, 683 wrapped_fn)) 684 685 def _create_implements_attribute(self): 686 """Creates the attribute value corresponding to IMPLEMENTS_ATTRIBUTE_NAME.""" 687 attributes = {} 688 if isinstance(self._implements, str): 689 # First check if the IMPLEMENTS_ATTRIBUTE_NAME is specified as a 690 # NameAttrList. This is used when apart from the function name being 691 # implemented, a list of attributes is also being specified. 692 # The attributes are specified as key-value pairs in the NameAttrList 693 # of the corresponding AttrValue. The function name will be in the 694 # 'name' field of the NameAttrList. Else, it is just a string 695 # corresponding to the function name. 696 try: 697 attr_value = attr_value_pb2.AttrValue() 698 nameattrlist = attr_value_pb2.NameAttrList() 699 _text_format.Merge(self._implements, nameattrlist) 700 attr_value.func.CopyFrom(nameattrlist) 701 attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = attr_value 702 except (_text_format.ParseError, DecodeError): 703 attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements 704 return attributes 705 706 def _defun(self, fn): 707 """Returns a defun generated from the input function.""" 708 attributes = {} 709 710 if self._implements is not None: 711 attributes = self._create_implements_attribute() 712 713 share = self._shared_rendezvous 714 if share is not None: 715 attributes[function_lib.SHARED_RENDEZVOUS_ATTRIBUTE_NAME] = share 716 717 if self._jit_compile is not None: 718 attributes.update(_XlaMustCompile=bool(self._jit_compile)) 719 if self._jit_compile: 720 attributes.update(_noinline=True) 721 if not attributes: 722 attributes = None 723 return function_lib.defun_with_attributes( 724 fn, 725 input_signature=self.input_signature, 726 attributes=attributes, 727 autograph=self._autograph, 728 jit_compile=self._jit_compile, 729 reduce_retracing=self._reduce_retracing, 730 experimental_autograph_options=self._experimental_autograph_options, 731 experimental_follow_type_hints=self._experimental_follow_type_hints) 732 733 def _initialize(self, args, kwds, add_initializers_to=None): 734 """Initializes, on the first call. 735 736 Creates two `Function`s, one that will allow creation of variables 737 and one that won't. 738 739 Additionally runs a trace for the `Function` that allows creation 740 of variables. 741 742 Args: 743 args: Arguments to the underlying python callable. 744 kwds: Keyword arguments to the python callable. 745 add_initializers_to: Where to collect variable initializers, if not None. 746 """ 747 self.function_spec.validate_input_signature_with_argspec() 748 749 created_variables = [] 750 lifted_initializer_graph = func_graph_module.FuncGraph("initializer") 751 752 def variable_capturing_scope(unused_next_creator, **kwds): 753 """Creates UnliftedInitializerVariables and saves references to them.""" 754 v = UnliftedInitializerVariable( 755 add_initializers_to=add_initializers_to, 756 lifted_initializer_graph=lifted_initializer_graph, **kwds) 757 created_variables.append(weakref.ref(v)) 758 return v 759 760 self._created_variables = created_variables 761 self._stateful_fn = self._defun_with_scope(variable_capturing_scope) 762 self._stateful_fn._name = self._name # pylint: disable=protected-access 763 # Force the definition of the function for these arguments 764 self._lifted_initializer_graph = lifted_initializer_graph 765 self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph) 766 self._concrete_stateful_fn = ( 767 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access 768 *args, **kwds)) 769 770 def invalid_creator_scope(*unused_args, **unused_kwds): 771 """Disables variable creation.""" 772 raise ValueError( 773 "tf.function only supports singleton tf.Variables created on the " 774 "first call. Make sure the tf.Variable is only created once or " 775 "created outside tf.function. See " 776 "https://www.tensorflow.org/guide/function#creating_tfvariables " 777 "for more information.") 778 779 self._stateless_fn = self._defun_with_scope(invalid_creator_scope) 780 self._stateless_fn._name = self._name # pylint: disable=protected-access 781 782 def _clone(self, python_function): 783 """Clone the function with different python function.""" 784 f = Function( 785 python_function=(self._python_function 786 if python_function is None else python_function), 787 name=self._name, 788 input_signature=self.input_signature, 789 autograph=self._autograph, 790 jit_compile=self._jit_compile, 791 reduce_retracing=self._reduce_retracing, 792 experimental_implements=self._implements, 793 experimental_autograph_options=self._experimental_autograph_options, 794 experimental_follow_type_hints=self._experimental_follow_type_hints) 795 796 if self._shared_rendezvous: 797 f._shared_rendezvous = self._shared_rendezvous # pylint: disable=protected-access 798 799 return f 800 801 def _decorate(self, decorator): 802 """Allows the captured Python function to be decorated in place. 803 804 This method is only safe to call when the Function has not been called by a 805 user. It makes sense to use this method to push a decorator into the 806 function rather than wrapping the function in the decorator. 807 808 We use this in tf.Module to allow user annotated `tf.functions` to remain as 809 `Function` objects but still automatically enter the Module name_scope 810 when they are evaluated like all other methods. 811 812 Args: 813 decorator: A callable accepting a single argument which is the function 814 to decorate and returning a callable result. 815 816 Raises: 817 ValueError: If the function has been called a ValueError is raised. 818 """ 819 if self._stateful_fn is not None or self._stateless_fn is not None: 820 raise ValueError( 821 "Functions cannot be decorated after they have been traced.") 822 823 self._python_function = decorator(self._python_function) 824 self._function_spec = function_spec_lib.FunctionSpec.from_function_and_signature( 825 self._python_function, self.input_signature) 826 827 # TODO: Remove this private method after updating all its uses 828 # A good moment to do this could be when the experimental label is removed 829 def _get_tracing_count(self): 830 return self.experimental_get_tracing_count() 831 832 def experimental_get_tracing_count(self): 833 """Returns the number of times the function has been traced. 834 835 For more information on when a function is traced and when it is 836 traced multiple times see https://www.tensorflow.org/guide/function. 837 Example: 838 839 >>> @tf.function 840 ... def double(a): 841 ... return a + a 842 >>> double(tf.constant(1)) 843 >>> double(tf.constant(2)) 844 >>> double.experimental_get_tracing_count() 845 1 846 >>> double(tf.constant("a")) 847 >>> double.experimental_get_tracing_count() 848 2 849 850 851 The first time experimental_get_tracing_count is called 852 it returns 1, as the function is traced the first 853 time it is called, and the second time the same graph is used 854 since we're calling it with a parameter of the same type. 855 856 The second time experimental_get_tracing_count is called 857 it returns 2, as we called double with a 858 different argument type, and so it was traced again. 859 860 """ 861 result = self._stateless_fn.tracing_count if self._stateless_fn else 0 862 result += self._stateful_fn.tracing_count if self._stateful_fn else 0 863 return result 864 865 @property 866 def _run_functions_eagerly(self): 867 return RUN_FUNCTIONS_EAGERLY 868 869 @traceback_utils.filter_traceback 870 def __call__(self, *args, **kwds): 871 # Implements GenericFunction.__call__. 872 if self._run_functions_eagerly: 873 with trace.Trace(self._name, tf_function_call="eager"): 874 return self._python_function(*args, **kwds) 875 876 # Only count the statistics the first time, before initialization took 877 # place. 878 if self._created_variables is None: 879 compiled = bool(self._jit_compile and 880 not control_flow_util.GraphOrParentsInXlaContext( 881 ops.get_default_graph())) 882 # For nested functions, increment the counter only when a function with 883 # jit_compile=True is called within a function with jit_compile=False. We 884 # count this special case to correctly record that both jit_compile=True 885 # and jit_compile=False is being used for parts of the outer function. 886 if ops.executing_eagerly_outside_functions() and ( 887 context.executing_eagerly() or compiled): 888 # Labels must be strings in Python, so we convert 'compiled' to a string 889 _tf_function_counter.get_cell(str(int(compiled))).increase_by(1) 890 891 tracing_count = self.experimental_get_tracing_count() 892 with trace.Trace(self._name) as tm: 893 # TODO(cheshire): Do not duplicate the XLAControlFlowContext annotation. 894 compiler = "xla" if self._jit_compile else "nonXla" 895 896 with OptionalXlaContext(self._jit_compile): 897 result = self._call(*args, **kwds) 898 899 new_tracing_count = self.experimental_get_tracing_count() 900 without_tracing = (tracing_count == new_tracing_count) 901 execution_mode = "notTraced" if without_tracing else "traced" 902 tm.set_metadata(tf_function_call=execution_mode + "-" + compiler, 903 tracing_count=new_tracing_count) 904 905 if context.executing_eagerly(): 906 if without_tracing: 907 _frequent_tracing_detector_manager.called_without_tracing( 908 self._key_for_call_stats) 909 else: 910 _frequent_tracing_detector_manager.called_with_tracing( 911 self._key_for_call_stats, self._python_function, 912 self._omit_frequent_tracing_warning) 913 914 return result 915 916 def _call(self, *args, **kwds): 917 """Calls the graph function.""" 918 self._lock.acquire() 919 if ALLOW_DYNAMIC_VARIABLE_CREATION: 920 condition = self._created_variables and self._stateful_fn is None 921 else: 922 condition = self._created_variables 923 if condition: 924 # Release the lock early so that multiple threads can perform the call 925 # in parallel. 926 self._lock.release() 927 # In this case we have created variables on the first call, so we run the 928 # defunned version which is guaranteed to never create variables. 929 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable 930 elif self._stateful_fn is not None: 931 # Release the lock early so that multiple threads can perform the call 932 # in parallel. 933 self._lock.release() 934 # In this case we have not created variables on the first call. So we can 935 # run the first trace but we should fail if variables are created. 936 results = self._stateful_fn(*args, **kwds) 937 if self._created_variables and not ALLOW_DYNAMIC_VARIABLE_CREATION: 938 raise ValueError("Creating variables on a non-first call to a function" 939 " decorated with tf.function.") 940 return results 941 942 try: 943 # This is the first call of __call__, so we have to initialize. 944 initializers = [] 945 self._initialize(args, kwds, add_initializers_to=initializers) 946 finally: 947 # At this point we know that the initialization is complete (or less 948 # interestingly an exception was raised) so we no longer need a lock. 949 self._lock.release() 950 951 if self._created_variables: 952 try: 953 # Attempt to initialize variables eagerly and without conds by lifting 954 # out initialization graphs. This is the only initialization strategy 955 # compatible with XLA at the moment. 956 self._initialize_uninitialized_variables(initializers) 957 except lift_to_graph.UnliftableError: 958 pass # Fall through to cond-based initialization. 959 else: 960 # Lifting succeeded, so variables are initialized and we can run the 961 # stateless function. 962 return self._stateless_fn(*args, **kwds) 963 else: 964 _, _, filtered_flat_args = ( 965 self._stateful_fn._function_spec.canonicalize_function_inputs( # pylint: disable=protected-access 966 args, kwds)) 967 # If we did not create any variables the trace we have is good enough. 968 return self._concrete_stateful_fn._call_flat( 969 filtered_flat_args, self._concrete_stateful_fn.captured_inputs) # pylint: disable=protected-access 970 971 def fn_with_cond(inner_args, inner_kwds, inner_filtered_flat_args): 972 """Conditionally runs initialization if it's needed.""" 973 condition = True 974 for v, _ in initializers: 975 condition = math_ops.logical_and( 976 condition, resource_variable_ops.var_is_initialized_op( 977 v.handle)) 978 # We want to call stateless_fn if possible because it avoids recomputing 979 # potentially expensive initializers. 980 return control_flow_ops.cond( 981 condition, 982 lambda: self._stateless_fn(*inner_args, **inner_kwds), 983 functools.partial( 984 self._concrete_stateful_fn._call_flat, # pylint: disable=protected-access 985 inner_filtered_flat_args, 986 captured_inputs=self._concrete_stateful_fn.captured_inputs)) 987 988 # We've created variables and are unable to lift the initialization graphs, 989 # so we fall back to initializing with conds while running the function. 990 canon_args, canon_kwds, filtered_flat_args = ( 991 self._stateful_fn._function_spec.canonicalize_function_inputs( # pylint: disable=protected-access 992 args, kwds)) 993 return function_lib.defun(fn_with_cond)(canon_args, canon_kwds, 994 filtered_flat_args) 995 996 def experimental_get_compiler_ir(self, *args, **kwargs): 997 # Implements GenericFunction.experimental_get_compiler_ir 998 context.ensure_initialized() 999 if not self._jit_compile: 1000 raise ValueError("Compiler IR can only be returned for functions marked " 1001 "with 'jit_compile=True'") 1002 1003 concrete_fn = self.get_concrete_function(*args, **kwargs) 1004 fn_name = concrete_fn.name 1005 1006 # pylint: disable=protected-access 1007 _, _, filtered_flat_args = ( 1008 concrete_fn._function_spec.canonicalize_function_inputs(args, kwargs)) 1009 1010 def compiler_ir_generator(stage="hlo", device_name=None): 1011 # TODO(cheshire): This is a hack to get the current "preferred" device, 1012 # there is no current API to get it otherwise. 1013 if device_name is None: 1014 device_name = random_ops.random_normal([]).device 1015 res_bytes = context.context().get_compiler_ir( 1016 device_name=device_name, 1017 stage=stage, 1018 function_name=fn_name, 1019 args=list(filtered_flat_args) + concrete_fn.captured_inputs) 1020 if stage in ("hlo_serialized", "optimized_hlo_serialized", 1021 "optimized_hlo_proto_serialized"): 1022 return res_bytes 1023 else: 1024 return res_bytes.decode("utf-8") 1025 1026 return compiler_ir_generator 1027 1028 @property 1029 def python_function(self): 1030 """The python function wrapped in this tf.function.""" 1031 return self._python_function 1032 1033 @property 1034 def input_signature(self): 1035 return self._function_spec.input_signature 1036 1037 @property 1038 def function_spec(self): 1039 return self._function_spec 1040 1041 def pretty_printed_concrete_signatures(self, verbose=True): 1042 joiner = "\n\n" if verbose else "\n" 1043 return joiner.join([ 1044 c.pretty_printed_signature(verbose=verbose) 1045 for c in self._list_all_concrete_functions() 1046 ]) 1047 1048 def _initialize_uninitialized_variables(self, initializers): 1049 """Make and call a `ConcreteFunction` which initializes variables.""" 1050 1051 if not initializers: 1052 return 1053 1054 var_is_initialized = _evaluate_var_is_initialized( 1055 [v for v, _ in initializers]) 1056 1057 # Note: using defun here avoids an infinite recursion. 1058 # Most of the code in this function runs eagerly with init_scope, where 1059 # autograph is not necessary. 1060 @function_lib.defun(autograph=False) 1061 def initialize_variables(): 1062 op_map = object_identity.ObjectIdentityDictionary() 1063 1064 inits = [] 1065 for (v, init), is_initialized in zip(initializers, var_is_initialized): 1066 with ops.init_scope(): 1067 if is_initialized: 1068 continue 1069 inits.append(init) 1070 1071 if inits: 1072 op_map = lift_to_graph.lift_to_graph( 1073 inits, ops.get_default_graph(), op_map=op_map) 1074 for (v, init), is_initialized in zip(initializers, var_is_initialized): 1075 with ops.init_scope(): 1076 if is_initialized: 1077 continue 1078 v.assign(op_map[init], read_value=False) 1079 1080 with ops.init_scope(): 1081 return initialize_variables.get_concrete_function()() 1082 1083 def get_initialization_function(self, *args, **kwargs): 1084 """Returns a `ConcreteFunction` which initializes this function's variables. 1085 1086 Requires that this function hasn't been accessed yet through either calling 1087 it or calling get_concrete_function. Fails if we cannot build an initializer 1088 function which does not depend on the concrete values of the inputs to this 1089 function. 1090 1091 Note that running this function will overwrite any values currently assigned 1092 to variables, for example restores from a checkpoint. 1093 1094 Args: 1095 *args: arguments to the underlying python callable. 1096 **kwargs: keyword arguments to the python callable. 1097 1098 Returns: 1099 A `ConcreteFunction` object which initializes the variables of this 1100 function. 1101 1102 Raises: 1103 RuntimeError: if called after the variables have been initialized. 1104 """ 1105 with self._lock: 1106 if self._stateful_fn is not None: 1107 raise RuntimeError( 1108 "get_initialization_function cannot be called after the function " 1109 "has been used") 1110 # Here we trace the function, collect the initializers, and attempt to 1111 # extract them and run them eagerly. Fail only if we cannot do so. 1112 initializers = [] 1113 self._initialize(args, kwargs, add_initializers_to=initializers) 1114 1115 # Note: using defun here avoids an infinite recursion. 1116 @function_lib.defun 1117 def initialize_variables(): 1118 for v, init in initializers: 1119 v.assign( 1120 lift_to_graph.lift_to_graph([init], ops.get_default_graph())[init], 1121 read_value=False) 1122 1123 return initialize_variables.get_concrete_function() 1124 1125 def _list_all_concrete_functions(self): 1126 """Returns all concrete functions.""" 1127 if self.input_signature is not None: 1128 self.get_concrete_function() 1129 concrete_functions = [] 1130 # pylint: disable=protected-access 1131 if self._stateful_fn: 1132 concrete_functions.extend( 1133 self._stateful_fn._list_all_concrete_functions()) 1134 if self._stateless_fn: 1135 concrete_functions.extend( 1136 self._stateless_fn._list_all_concrete_functions()) 1137 # pylint: enable=protected-access 1138 return concrete_functions 1139 1140 def _list_all_concrete_functions_for_serialization(self): 1141 """Returns all concrete functions for serialization. 1142 1143 Returns: 1144 A list of instances of `ConcreteFunction`. 1145 """ 1146 concrete_functions = self._list_all_concrete_functions() 1147 seen_signatures = [] 1148 for concrete_function in concrete_functions: 1149 signature = concrete_function.structured_input_signature 1150 flattened = nest.flatten(signature) 1151 if any( 1152 isinstance(arg, func_graph_module.UnknownArgument) 1153 for arg in flattened): 1154 logging.info("Unsupported signature for serialization: %s.", signature) 1155 continue 1156 equal_to_signature = functools.partial( 1157 function_spec_lib.is_same_structure, signature, check_values=True) 1158 if not any(equal_to_signature(s) for s in seen_signatures): 1159 seen_signatures.append(signature) 1160 1161 # Re-create concrete functions for these signatures. Re-creating ensures 1162 # that if the cache key has changed, the function will be traced again. 1163 concrete_functions = [] 1164 for args, kwargs in seen_signatures: 1165 concrete_functions.append(self.get_concrete_function(*args, **kwargs)) 1166 return concrete_functions 1167 1168 def _trackable_children(self, save_type="checkpoint", **kwargs): 1169 """For implementing `Trackable`.""" 1170 if save_type == "checkpoint": 1171 return {} 1172 return {f"trace_{n}": fn for n, fn in 1173 enumerate(self._list_all_concrete_functions_for_serialization())} 1174 1175 def _deserialization_dependencies(self, children): 1176 """Returns concrete functions which must be loaded before this object.""" 1177 return children 1178 1179 def _get_concrete_function_garbage_collected(self, *args, **kwargs): 1180 """Returns a `ConcreteFunction` specialized to inputs and execution context. 1181 1182 Unlike `get_concrete_function(...)`, the graph will be deleted when the 1183 returned function is deleted. It's useful to avoid creating a reference 1184 cycle when you know for sure that the graph will be no longer used without 1185 the returned function. 1186 1187 Args: 1188 *args: inputs to specialize on. 1189 **kwargs: inputs to specialize on. 1190 1191 Returns: 1192 A TensorFlow function which takes exactly one `tf.Tensor` per argument. 1193 1194 Raises: 1195 ValueError: if this object has not yet been called on concrete values. 1196 """ 1197 with self._lock: 1198 if self._stateful_fn is None: 1199 initializers = [] 1200 self._initialize(args, kwargs, add_initializers_to=initializers) 1201 self._initialize_uninitialized_variables(initializers) 1202 1203 if self._created_variables: 1204 # In this case we have created variables on the first call, so we run the 1205 # defunned version which is guaranteed to never create variables. 1206 return self._stateless_fn._get_concrete_function_garbage_collected( # pylint: disable=protected-access 1207 *args, **kwargs) 1208 elif self._stateful_fn is not None: 1209 # In this case we have not created variables on the first call. So we can 1210 # run the first trace but we should fail if variables are created. 1211 concrete = self._stateful_fn._get_concrete_function_garbage_collected( # pylint: disable=protected-access 1212 *args, **kwargs) 1213 if self._created_variables: 1214 raise ValueError("Creating variables on a non-first call to a function" 1215 " decorated with tf.function.") 1216 return concrete 1217 1218 def get_concrete_function(self, *args, **kwargs): 1219 # Implements GenericFunction.get_concrete_function. 1220 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs) 1221 concrete._garbage_collector.release() # pylint: disable=protected-access 1222 return concrete 1223 1224 def __get__(self, instance, owner): 1225 """Makes it possible to defun instance methods.""" 1226 del owner 1227 # `instance` here is the instance that this `Function` was accessed through 1228 # e.g., for 1229 # 1230 # class Foo: 1231 # 1232 # @function.defun 1233 # def bar(self): 1234 # ... 1235 # 1236 # foo = Foo() 1237 # foo.bar() # `foo.bar` is a `Function` instance 1238 # 1239 # then `instance` will be `foo` (and `owner` will be `Foo`). For composite 1240 # tensors, we can just treat `instance` as a normal parameter. But for 1241 # other types, we create a new instance of `Function` here to allow 1242 # different instances each to create variables once, thereby allowing 1243 # methods to be decorated with tf.function. Keeps a cache to avoid retracing 1244 # the function every time the descriptor is accessed. 1245 # TODO(mdan): Identify types which can just be parameters more generically. 1246 # 1247 # The check for instance._type_spec=None is used because certain classes 1248 # (including subclasses of tf.linalg.LinearOperator) are subclasses of 1249 # CompositeTensor but do not actually implement the required APIs. 1250 # TODO(b/199278478): Fix those classes, then remove the check for 1251 # `instance._type_spec is not None`. 1252 if (isinstance(instance, composite_tensor.CompositeTensor) and 1253 instance._type_spec is not None): # pylint: disable=protected-access 1254 return types_lib.MethodType(self, instance) 1255 if instance not in self._descriptor_cache: 1256 if instance is None: 1257 return self 1258 # TODO(mdan): If the CompositeTensor path works, do the same here. 1259 # It's unclear whether we need the tf-decorator, or could just call 1260 # MethodType(self.clone(), instance) 1261 self._descriptor_cache[instance] = ( 1262 function_lib.class_method_to_instance_method(self, instance)) 1263 return self._descriptor_cache[instance] 1264 1265 1266@tf_export("function") 1267@deprecation.deprecated_args(None, 1268 "experimental_compile is deprecated, use " 1269 "jit_compile instead", "experimental_compile") 1270@deprecation.deprecated_args(None, 1271 "experimental_relax_shapes is deprecated, use " 1272 "reduce_retracing instead", 1273 "experimental_relax_shapes") 1274def function(func=None, 1275 input_signature=None, 1276 autograph=True, 1277 jit_compile=None, 1278 reduce_retracing=False, 1279 experimental_implements=None, 1280 experimental_autograph_options=None, 1281 experimental_relax_shapes=None, 1282 experimental_compile=None, 1283 experimental_follow_type_hints=None) -> core.GenericFunction: 1284 """Compiles a function into a callable TensorFlow graph. 1285 1286 `tf.function` constructs a `tf.types.experimental.GenericFunction` that 1287 executes a TensorFlow graph (`tf.Graph`) created by trace-compiling the 1288 TensorFlow operations in `func`. More information on the topic can be found 1289 in [Introduction to Graphs and tf.function] 1290 (https://www.tensorflow.org/guide/intro_to_graphs). 1291 1292 See [Better Performance with tf.function] 1293 (https://www.tensorflow.org/guide/function) for tips on performance and 1294 known limitations. 1295 1296 Example usage: 1297 1298 >>> @tf.function 1299 ... def f(x, y): 1300 ... return x ** 2 + y 1301 >>> x = tf.constant([2, 3]) 1302 >>> y = tf.constant([3, -2]) 1303 >>> f(x, y) 1304 <tf.Tensor: ... numpy=array([7, 7], ...)> 1305 1306 The trace-compilation allows non-TensorFlow operations to execute, but under 1307 special conditions. In general, only TensorFlow operations are guaranteed to 1308 run and create fresh results whenever the `GenericFunction` is called. 1309 1310 ## Features 1311 1312 `func` may use data-dependent Python control flow statements, including `if`, 1313 `for`, `while` `break`, `continue` and `return`: 1314 1315 >>> @tf.function 1316 ... def f(x): 1317 ... if tf.reduce_sum(x) > 0: 1318 ... return x * x 1319 ... else: 1320 ... return -x // 2 1321 >>> f(tf.constant(-2)) 1322 <tf.Tensor: ... numpy=1> 1323 1324 `func`'s closure may include `tf.Tensor` and `tf.Variable` objects: 1325 1326 >>> @tf.function 1327 ... def f(): 1328 ... return x ** 2 + y 1329 >>> x = tf.constant([-2, -3]) 1330 >>> y = tf.Variable([3, -2]) 1331 >>> f() 1332 <tf.Tensor: ... numpy=array([7, 7], ...)> 1333 1334 `func` may also use ops with side effects, such as `tf.print`, `tf.Variable` 1335 and others: 1336 1337 >>> v = tf.Variable(1) 1338 >>> @tf.function 1339 ... def f(x): 1340 ... for i in tf.range(x): 1341 ... v.assign_add(i) 1342 >>> f(3) 1343 >>> v 1344 <tf.Variable ... numpy=4> 1345 1346 Important: Any Python side-effects (appending to a list, printing with 1347 `print`, etc) will only happen once, when `func` is traced. To have 1348 side-effects executed into your `tf.function` they need to be written 1349 as TF ops: 1350 1351 >>> l = [] 1352 >>> @tf.function 1353 ... def f(x): 1354 ... for i in x: 1355 ... l.append(i + 1) # Caution! Will only happen once when tracing 1356 >>> f(tf.constant([1, 2, 3])) 1357 >>> l 1358 [<tf.Tensor ...>] 1359 1360 Instead, use TensorFlow collections like `tf.TensorArray`: 1361 1362 >>> @tf.function 1363 ... def f(x): 1364 ... ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True) 1365 ... for i in range(len(x)): 1366 ... ta = ta.write(i, x[i] + 1) 1367 ... return ta.stack() 1368 >>> f(tf.constant([1, 2, 3])) 1369 <tf.Tensor: ..., numpy=array([2, 3, 4], ...)> 1370 1371 ## `tf.function` creates polymorphic callables 1372 1373 Internally, `tf.types.experimental.GenericFunction` may contain multiple 1374 `tf.types.experimental.ConcreteFunction`s, each specialized to arguments with 1375 different data types or shapes, since TensorFlow can perform more 1376 optimizations on graphs of specific shapes, dtypes and values of constant 1377 arguments. `tf.function` treats any pure Python values as opaque objects (best 1378 thought of as compile-time constants), and builds a separate `tf.Graph` for 1379 each set of Python arguments that it encounters. 1380 For more information, see the 1381 [tf.function guide](https://www.tensorflow.org/guide/function#rules_of_tracing) 1382 1383 Executing a `GenericFunction` will select and execute the appropriate 1384 `ConcreteFunction` based on the argument types and values. 1385 1386 To obtain an individual `ConcreteFunction`, use the 1387 `GenericFunction.get_concrete_function` method. It can be called with the 1388 same arguments as `func` and returns a 1389 `tf.types.experimental.ConcreteFunction`. `ConcreteFunction`s are backed by a 1390 single `tf.Graph`: 1391 1392 >>> @tf.function 1393 ... def f(x): 1394 ... return x + 1 1395 >>> isinstance(f.get_concrete_function(1).graph, tf.Graph) 1396 True 1397 1398 `ConcreteFunction`s can be executed just like `GenericFunction`s, but their 1399 input is resticted to the types to which they're specialized. 1400 1401 ## Retracing 1402 1403 `ConcreteFunctions` are built (traced) on the fly, as the `GenericFunction` is 1404 called with new TensorFlow types or shapes, or with new Python values as 1405 arguments. When `GenericFunction` builds a new trace, it is said that `func` 1406 is retraced. Retracing is a frequent performance concern for `tf.function` as 1407 it can be considerably slower than executing a graph that's already been 1408 traced. It is ideal to minimize the amount of retracing in your code. 1409 1410 Caution: Passing python scalars or lists as arguments to `tf.function` will 1411 usually retrace. To avoid this, pass numeric arguments as Tensors whenever 1412 possible: 1413 1414 >>> @tf.function 1415 ... def f(x): 1416 ... return tf.abs(x) 1417 >>> f1 = f.get_concrete_function(1) 1418 >>> f2 = f.get_concrete_function(2) # Slow - compiles new graph 1419 >>> f1 is f2 1420 False 1421 >>> f1 = f.get_concrete_function(tf.constant(1)) 1422 >>> f2 = f.get_concrete_function(tf.constant(2)) # Fast - reuses f1 1423 >>> f1 is f2 1424 True 1425 1426 Python numerical arguments should only be used when they take few distinct 1427 values, such as hyperparameters like the number of layers in a neural network. 1428 1429 ## Input signatures 1430 1431 For Tensor arguments, `GenericFunction`creates a new `ConcreteFunction` for 1432 every unique set of input shapes and datatypes. The example below creates two 1433 separate `ConcreteFunction`s, each specialized to a different shape: 1434 1435 >>> @tf.function 1436 ... def f(x): 1437 ... return x + 1 1438 >>> vector = tf.constant([1.0, 1.0]) 1439 >>> matrix = tf.constant([[3.0]]) 1440 >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix) 1441 False 1442 1443 An "input signature" can be optionally provided to `tf.function` to control 1444 this process. The input signature specifies the shape and type of each 1445 Tensor argument to the function using a `tf.TensorSpec` object. More general 1446 shapes can be used. This ensures only one `ConcreteFunction` is created, and 1447 restricts the `GenericFunction` to the specified shapes and types. It is 1448 an effective way to limit retracing when Tensors have dynamic shapes. 1449 1450 >>> @tf.function( 1451 ... input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) 1452 ... def f(x): 1453 ... return x + 1 1454 >>> vector = tf.constant([1.0, 1.0]) 1455 >>> matrix = tf.constant([[3.0]]) 1456 >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix) 1457 True 1458 1459 ## Variables may only be created once 1460 1461 `tf.function` only allows creating new `tf.Variable` objects when it is called 1462 for the first time: 1463 1464 >>> class MyModule(tf.Module): 1465 ... def __init__(self): 1466 ... self.v = None 1467 ... 1468 ... @tf.function 1469 ... def __call__(self, x): 1470 ... if self.v is None: 1471 ... self.v = tf.Variable(tf.ones_like(x)) 1472 ... return self.v * x 1473 1474 In general, it is recommended to create `tf.Variable`s outside of 1475 `tf.function`. 1476 In simple cases, persisting state across `tf.function` boundaries may be 1477 implemented using a pure functional style in which state is represented by 1478 `tf.Tensor`s passed as arguments and returned as return values. 1479 1480 Contrast the two styles below: 1481 1482 >>> state = tf.Variable(1) 1483 >>> @tf.function 1484 ... def f(x): 1485 ... state.assign_add(x) 1486 >>> f(tf.constant(2)) # Non-pure functional style 1487 >>> state 1488 <tf.Variable ... numpy=3> 1489 1490 >>> state = tf.constant(1) 1491 >>> @tf.function 1492 ... def f(state, x): 1493 ... state += x 1494 ... return state 1495 >>> state = f(state, tf.constant(2)) # Pure functional style 1496 >>> state 1497 <tf.Tensor: ... numpy=3> 1498 1499 ## Python operations execute only once per trace 1500 1501 `func` may contain TensorFlow operations mixed with pure Python operations. 1502 However, when the function is executed, only the TensorFlow operations will 1503 run. The Python operations run only once, at trace time. If TensorFlow 1504 operations depend on results from Python operations, those results will be 1505 frozen into the graph. 1506 1507 >>> @tf.function 1508 ... def f(a, b): 1509 ... print('this runs at trace time; a is', a, 'and b is', b) 1510 ... return b 1511 >>> f(1, tf.constant(1)) 1512 this runs at trace time; a is 1 and b is Tensor("...", shape=(), dtype=int32) 1513 <tf.Tensor: shape=(), dtype=int32, numpy=1> 1514 1515 >>> f(1, tf.constant(2)) 1516 <tf.Tensor: shape=(), dtype=int32, numpy=2> 1517 1518 >>> f(2, tf.constant(1)) 1519 this runs at trace time; a is 2 and b is Tensor("...", shape=(), dtype=int32) 1520 <tf.Tensor: shape=(), dtype=int32, numpy=1> 1521 1522 >>> f(2, tf.constant(2)) 1523 <tf.Tensor: shape=(), dtype=int32, numpy=2> 1524 1525 ## Using type annotations to improve performance 1526 1527 `experimental_follow_type_hints` can be used along with type annotations to 1528 reduce retracing by automatically casting any Python values to `tf.Tensor` 1529 (something that is not done by default, unless you use input signatures). 1530 1531 >>> @tf.function(experimental_follow_type_hints=True) 1532 ... def f_with_hints(x: tf.Tensor): 1533 ... print('Tracing') 1534 ... return x 1535 >>> @tf.function(experimental_follow_type_hints=False) 1536 ... def f_no_hints(x: tf.Tensor): 1537 ... print('Tracing') 1538 ... return x 1539 >>> f_no_hints(1) 1540 Tracing 1541 <tf.Tensor: shape=(), dtype=int32, numpy=1> 1542 >>> f_no_hints(2) 1543 Tracing 1544 <tf.Tensor: shape=(), dtype=int32, numpy=2> 1545 >>> f_with_hints(1) 1546 Tracing 1547 <tf.Tensor: shape=(), dtype=int32, numpy=1> 1548 >>> f_with_hints(2) 1549 <tf.Tensor: shape=(), dtype=int32, numpy=2> 1550 1551 Args: 1552 func: The function to be compiled. If `func` is None, `tf.function` returns 1553 a decorator that can be invoked with a single argument - `func`. In other 1554 words, `tf.function(input_signature=...)(func)` is equivalent to 1555 `tf.function(func, input_signature=...)`. The former can be used as 1556 decorator. 1557 input_signature: A possibly nested sequence of `tf.TensorSpec` objects 1558 specifying the shapes and dtypes of the Tensors that will be supplied to 1559 this function. If `None`, a separate function is instantiated for each 1560 inferred input signature. If input_signature is specified, every input to 1561 `func` must be a `Tensor`, and `func` cannot accept `**kwargs`. 1562 autograph: Whether autograph should be applied on `func` before tracing a 1563 graph. Data-dependent Python control flow statements require 1564 `autograph=True`. For more information, see the 1565 [tf.function and AutoGraph guide]( 1566 https://www.tensorflow.org/guide/function#autograph_transformations). 1567 jit_compile: If `True`, compiles the function using 1568 [XLA](https://tensorflow.org/xla). XLA performs compiler optimizations, 1569 such as fusion, and attempts to emit more efficient code. This may 1570 drastically improve the performance. If set to `True`, 1571 the whole function needs to be compilable by XLA, or an 1572 `errors.InvalidArgumentError` is thrown. 1573 If `None` (default), compiles the function with XLA when running on TPU 1574 and goes through the regular function execution path when running on 1575 other devices. 1576 If `False`, executes the function without XLA compilation. Set this value 1577 to `False` when directly running a multi-device function on TPUs (e.g. two 1578 TPU cores, one TPU core and its host CPU). 1579 Not all functions are compilable, see a list of 1580 [sharp corners](https://tensorflow.org/xla/known_issues). 1581 reduce_retracing: When True, `tf.function` attempts to reduce the 1582 amount of retracing, for example by using more generic shapes. This 1583 can be controlled for user objects by customizing their associated 1584 `tf.types.experimental.TraceType`. 1585 experimental_implements: If provided, contains a name of a "known" function 1586 this implements. For example "mycompany.my_recurrent_cell". 1587 This is stored as an attribute in inference function, 1588 which can then be detected when processing serialized function. 1589 See [standardizing composite ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md) # pylint: disable=line-too-long 1590 for details. For an example of utilizing this attribute see this 1591 [example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc) 1592 The code above automatically detects and substitutes function that 1593 implements "embedded_matmul" and allows TFLite to substitute its own 1594 implementations. For instance, a tensorflow user can use this 1595 attribute to mark that their function also implements 1596 `embedded_matmul` (perhaps more efficiently!) 1597 by specifying it using this parameter: 1598 `@tf.function(experimental_implements="embedded_matmul")` 1599 This can either be specified as just the string name of the function or 1600 a NameAttrList corresponding to a list of key-value attributes associated 1601 with the function name. The name of the function will be in the 'name' 1602 field of the NameAttrList. To define a formal TF op for this function 1603 implements, try the experimental [composite TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr) 1604 project. 1605 experimental_autograph_options: Optional tuple of 1606 `tf.autograph.experimental.Feature` values. 1607 experimental_relax_shapes: Deprecated. Use `reduce_retracing` 1608 instead. 1609 experimental_compile: Deprecated alias to 'jit_compile'. 1610 experimental_follow_type_hints: When True, the function may use type 1611 annotations from `func` to optimize the tracing performance. For example, 1612 arguments annotated with `tf.Tensor` will automatically be converted 1613 to a Tensor. 1614 1615 Returns: 1616 If `func` is not None, returns a `tf.types.experimental.GenericFunction`. 1617 If `func` is None, returns a decorator that, when invoked with a single 1618 `func` argument, returns a `tf.types.experimental.GenericFunction`. 1619 1620 Raises: 1621 `ValueError` when attempting to use `jit_compile=True`, but XLA support is 1622 not available. 1623 """ 1624 if experimental_follow_type_hints is None: 1625 experimental_follow_type_hints = False 1626 1627 if jit_compile is None and JIT_COMPILE_FUNCTIONS: 1628 jit_compile = True 1629 1630 # TODO(b/224808187): Remove after renaming usages. 1631 if experimental_relax_shapes: 1632 reduce_retracing = True 1633 1634 def decorated(inner_function): 1635 try: 1636 name = inner_function.__name__ 1637 except AttributeError: 1638 name = "function" 1639 return tf_decorator.make_decorator( 1640 inner_function, 1641 decorator_name="tf.function", 1642 decorator_func=Function( 1643 inner_function, 1644 name, 1645 input_signature=input_signature, 1646 autograph=autograph, 1647 experimental_autograph_options=experimental_autograph_options, 1648 reduce_retracing=reduce_retracing, 1649 1650 # TODO(b/171825496): Update once `experimental_compile` is removed 1651 # entirely in favor of 'jit_compile'. 1652 jit_compile=deprecation.deprecated_argument_lookup( 1653 "jit_compile", 1654 jit_compile, 1655 "experimental_compile", 1656 experimental_compile), 1657 experimental_implements=experimental_implements, 1658 experimental_follow_type_hints=experimental_follow_type_hints)) 1659 1660 # This code path is for the `foo = tf.function(foo, ...)` use case 1661 if func is not None: 1662 return decorated(func) 1663 1664 # This code path is for the 1665 # 1666 # @tf.function(...) 1667 # def foo(...): 1668 # ... 1669 # 1670 # use case, which is equivalent to `foo = tf.function(...)(foo)` 1671 return decorated 1672