1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This module contains the user- and codegen-facing API for AutoGraph.""" 16 17import functools 18import importlib 19import inspect 20import os 21import sys 22import textwrap 23import traceback 24 25from tensorflow.python.autograph import operators 26from tensorflow.python.autograph import utils 27from tensorflow.python.autograph.converters import asserts 28from tensorflow.python.autograph.converters import break_statements 29from tensorflow.python.autograph.converters import call_trees 30from tensorflow.python.autograph.converters import conditional_expressions 31from tensorflow.python.autograph.converters import continue_statements 32from tensorflow.python.autograph.converters import control_flow 33from tensorflow.python.autograph.converters import directives 34from tensorflow.python.autograph.converters import functions 35from tensorflow.python.autograph.converters import lists 36from tensorflow.python.autograph.converters import logical_expressions 37from tensorflow.python.autograph.converters import return_statements 38from tensorflow.python.autograph.converters import slices 39from tensorflow.python.autograph.converters import variables 40from tensorflow.python.autograph.core import ag_ctx 41from tensorflow.python.autograph.core import converter 42from tensorflow.python.autograph.core import function_wrappers 43from tensorflow.python.autograph.core import unsupported_features_checker 44from tensorflow.python.autograph.impl import conversion 45from tensorflow.python.autograph.lang import special_functions 46from tensorflow.python.autograph.operators import py_builtins 47from tensorflow.python.autograph.pyct import anno 48from tensorflow.python.autograph.pyct import cfg 49from tensorflow.python.autograph.pyct import error_utils 50from tensorflow.python.autograph.pyct import errors 51from tensorflow.python.autograph.pyct import inspect_utils 52from tensorflow.python.autograph.pyct import origin_info 53from tensorflow.python.autograph.pyct import qual_names 54from tensorflow.python.autograph.pyct import transpiler 55from tensorflow.python.autograph.pyct.static_analysis import activity 56from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions 57from tensorflow.python.autograph.utils import ag_logging as logging 58from tensorflow.python.eager import function 59from tensorflow.python.framework import errors_impl 60from tensorflow.python.util import tf_decorator 61from tensorflow.python.util import tf_inspect 62from tensorflow.python.util import tf_stack 63from tensorflow.python.util.tf_export import tf_export 64 65 66def is_autograph_strict_conversion_mode(): 67 return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0 68 69 70# 71# Error handling 72# 73 74 75# TODO(mdan): Export this symbol. 76class AutoGraphError(errors.PyCTError): 77 """Base class for all AutoGraph exceptions.""" 78 pass 79 80 81class ConversionError(AutoGraphError): 82 """Raised during the conversion process.""" 83 pass 84 85 86class StagingError(AutoGraphError): 87 """Raised during the staging (i.e. Python execution) of converted code.""" 88 pass 89 90 91class _ErrorMetadata(error_utils.ErrorMetadataBase): 92 """AutoGraph-specific error metadata. See base class.""" 93 94 def create_exception(self, source_error): 95 preferred_type = type(source_error) 96 if issubclass(preferred_type, errors_impl.OpError): 97 # Best-effort unpacking of OpError exceptions. 98 # TODO(mdan): Use a mechanism that is more future-proof. 99 init_argspec = tf_inspect.getfullargspec(preferred_type.__init__) 100 message = self.get_message() 101 init_args = tuple(init_argspec.args) 102 # At the time of this writing, TF errors either take 3 or 4 arguments, 103 # the argument '*args' may or may not be used. 104 if init_args == ('self', 'node_def', 'op', 'message'): 105 return preferred_type(source_error.node_def, source_error.op, message, 106 source_error.experimental_payloads) 107 108 elif preferred_type in (errors.PyCTError, AutoGraphError, ConversionError, 109 StagingError, errors_impl.InaccessibleTensorError, 110 errors_impl.OperatorNotAllowedInGraphError): 111 return preferred_type(self.get_message()) 112 113 exc = super(_ErrorMetadata, self).create_exception(source_error) 114 if exc is not None: 115 return exc 116 117 # Note: While changing an error's message property to change the message it 118 # displays will probably work a lot of times, there is no standard way in 119 # Python to do that. The safest way is therefore to create a new exception. 120 # For user defined exceptions, we could define an interface that allowed 121 # them to work under this mechanism. 122 return StagingError(self.get_message()) 123 124 125def _attach_error_metadata(e, f): 126 """Augments an error with the metadata necessary for rewrite.""" 127 if hasattr(e, 'ag_pass_through'): 128 return 129 130 metadata = getattr(e, 'ag_error_metadata', None) 131 source_map = f.ag_source_map 132 133 if metadata is None: 134 logging.log(1, 'Caught error in user callable %s', f, exc_info=True) 135 message = '{}: {}'.format(e.__class__.__name__, e) 136 else: 137 message = None 138 139 cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:] 140 141 e.ag_error_metadata = _ErrorMetadata(cause_tb, metadata, message, source_map, 142 __file__) 143 144 145class StackTraceMapper(tf_stack.StackTraceMapper): 146 """Remaps generated code to code it originated from.""" 147 148 def __init__(self, converted_fn): 149 super().__init__() 150 self._source_map = converted_fn.ag_source_map 151 # This may be called repeatedly: once on entry, by the superclass, then by 152 # each child context manager. 153 self._cached_map = None 154 155 def get_effective_source_map(self): 156 if self._cached_map is not None: 157 return self._cached_map 158 159 parent_map = self.parent.get_effective_source_map() 160 161 effective_source_map = {} 162 for loc, origin in self._source_map.items(): 163 effective_source_map[(loc.filename, loc.lineno)] = (origin.loc.filename, 164 origin.loc.lineno, 165 origin.function_name) 166 167 for key, value in parent_map.items(): 168 filename, lineno, _ = value 169 value_loc = origin_info.LineLocation(filename=filename, lineno=lineno) 170 if value_loc in self._source_map: 171 origin = self._source_map[value_loc] 172 effective_source_map[key] = (origin.loc.filename, origin.loc.lineno, 173 origin.function_name) 174 else: 175 effective_source_map[key] = value 176 177 self._cached_map = effective_source_map 178 return effective_source_map 179 180 181# 182# Actual source code transformation 183# 184 185 186class PyToTF(transpiler.PyToPy): 187 """The TensorFlow AutoGraph transformer.""" 188 189 def __init__(self): 190 super(PyToTF, self).__init__() 191 self._extra_locals = None 192 193 def get_transformed_name(self, node): 194 return 'tf__' + super(PyToTF, self).get_transformed_name(node) 195 196 def get_extra_locals(self): 197 if self._extra_locals is None: 198 # TODO(mdan): Move into core or replace with an actual importable module. 199 # Craft a module that exposes the external API as well as certain 200 # internal modules. 201 module_spec = importlib.machinery.ModuleSpec('autograph', None) 202 ag_internal = importlib.util.module_from_spec(module_spec) 203 ag_internal.__dict__.update(inspect.getmodule(PyToTF).__dict__) 204 ag_internal.ConversionOptions = converter.ConversionOptions 205 ag_internal.STD = converter.STANDARD_OPTIONS 206 ag_internal.Feature = converter.Feature 207 ag_internal.utils = utils 208 ag_internal.FunctionScope = function_wrappers.FunctionScope 209 ag_internal.with_function_scope = function_wrappers.with_function_scope 210 # TODO(mdan): Add safeguards against name clashes. 211 # We don't want to create a submodule because we want the operators to be 212 # accessible as ag__.<operator> 213 ag_internal.__dict__.update(special_functions.__dict__) 214 ag_internal.__dict__.update(operators.__dict__) 215 216 self._extra_locals = {'ag__': ag_internal} 217 return self._extra_locals 218 219 def get_caching_key(self, ctx): 220 return ctx.options 221 222 def initial_analysis(self, node, ctx): 223 graphs = cfg.build(node) 224 node = qual_names.resolve(node) 225 node = activity.resolve(node, ctx, None) 226 node = reaching_definitions.resolve(node, ctx, graphs) 227 anno.dup( 228 node, 229 { 230 anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, 231 }, 232 ) 233 return node 234 235 def transform_ast(self, node, ctx): 236 unsupported_features_checker.verify(node) 237 node = self.initial_analysis(node, ctx) 238 239 node = functions.transform(node, ctx) 240 node = directives.transform(node, ctx) 241 node = break_statements.transform(node, ctx) 242 if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): 243 node = asserts.transform(node, ctx) 244 # Note: sequencing continue canonicalization before for loop one avoids 245 # dealing with the extra loop increment operation that the for 246 # canonicalization creates. 247 node = continue_statements.transform(node, ctx) 248 node = return_statements.transform(node, ctx) 249 if ctx.user.options.uses(converter.Feature.LISTS): 250 node = lists.transform(node, ctx) 251 node = slices.transform(node, ctx) 252 node = call_trees.transform(node, ctx) 253 node = control_flow.transform(node, ctx) 254 node = conditional_expressions.transform(node, ctx) 255 node = logical_expressions.transform(node, ctx) 256 node = variables.transform(node, ctx) 257 return node 258 259 260def _convert_actual(entity, program_ctx): 261 """Applies AutoGraph to entity.""" 262 263 # TODO(mdan): Put these extra fields inside __autograph_info__. 264 if not hasattr(entity, '__code__'): 265 raise ValueError('Cannot apply autograph to a function that doesn\'t ' 266 'expose a __code__ object. If this is a @tf.function,' 267 ' try passing f.python_function instead.') 268 269 transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx) 270 271 assert not hasattr(transformed, 'ag_module') 272 assert not hasattr(transformed, 'ag_source_map') 273 transformed.ag_module = module 274 transformed.ag_source_map = source_map 275 return transformed 276 277 278# 279# Generated code support 280# 281 282 283def autograph_artifact(entity, extras=None): 284 if inspect.ismethod(entity): 285 setattr(entity.__func__, 'autograph_info__', extras) 286 else: 287 setattr(entity, 'autograph_info__', extras) 288 return entity 289 290 291def is_autograph_artifact(entity): 292 return hasattr(entity, 'autograph_info__') 293 294 295def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): 296 """Converts a function call inline. 297 298 For internal use only. 299 300 Note: The argument list is optimized for readability of generated code, which 301 may look like this: 302 303 ag__.converted_call(f, (arg1, arg2), None, fscope) 304 ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope) 305 ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope) 306 307 Args: 308 f: The function to convert. 309 args: Tuple, the original positional arguments of f 310 kwargs: Optional[Dict], the original keyword arguments of f 311 caller_fn_scope: Optional[function_wrappers.FunctionScope], the function 312 scope of the converted function in which this call was originally made. 313 options: Optional[converter.ConversionOptions], conversion options. If not 314 specified, the value of caller_fn_scope.callopts is used. Either options 315 or caller_fn_scope must be present. 316 317 Returns: 318 Any, the result of executing a possibly-converted `f` with the given 319 arguments. 320 """ 321 logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args, 322 kwargs) 323 324 if options is None: 325 if caller_fn_scope is None: 326 raise ValueError('either caller_fn_scope or options must have a value') 327 options = caller_fn_scope.callopts 328 329 if conversion.is_in_allowlist_cache(f, options): 330 logging.log(2, 'Allowlisted %s: from cache', f) 331 return _call_unconverted(f, args, kwargs, options, False) 332 333 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: 334 logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f) 335 return _call_unconverted(f, args, kwargs, options, False) 336 337 if is_autograph_artifact(f): 338 logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f) 339 return _call_unconverted(f, args, kwargs, options) 340 341 # If this is a partial, unwrap it and redo all the checks. 342 if isinstance(f, functools.partial): 343 new_kwargs = {} 344 if f.keywords is not None: 345 # Use copy to avoid mutating the underlying keywords. 346 new_kwargs = f.keywords.copy() 347 if kwargs is not None: 348 new_kwargs.update(kwargs) 349 new_args = f.args + args 350 logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args, 351 new_kwargs) 352 return converted_call( 353 f.func, 354 new_args, 355 new_kwargs, 356 caller_fn_scope=caller_fn_scope, 357 options=options) 358 359 if inspect_utils.isbuiltin(f): 360 if f is eval: 361 return py_builtins.eval_in_original_context(f, args, caller_fn_scope) 362 if f is super: 363 return py_builtins.super_in_original_context(f, args, caller_fn_scope) 364 if f is globals: 365 return py_builtins.globals_in_original_context(caller_fn_scope) 366 if f is locals: 367 return py_builtins.locals_in_original_context(caller_fn_scope) 368 if kwargs: 369 return py_builtins.overload_of(f)(*args, **kwargs) 370 else: 371 return py_builtins.overload_of(f)(*args) 372 373 if conversion.is_unsupported(f): 374 return _call_unconverted(f, args, kwargs, options) 375 376 if not options.user_requested and conversion.is_allowlisted(f): 377 return _call_unconverted(f, args, kwargs, options) 378 379 # internal_convert_user_code is for example turned off when issuing a dynamic 380 # call conversion from generated code while in nonrecursive mode. In that 381 # case we evidently don't want to recurse, but we still have to convert 382 # things like builtins. 383 if not options.internal_convert_user_code: 384 return _call_unconverted(f, args, kwargs, options) 385 386 try: 387 if inspect.ismethod(f) or inspect.isfunction(f): 388 target_entity = f 389 effective_args = args 390 391 f_self = getattr(f, '__self__', None) 392 if f_self is not None: 393 if isinstance(f_self, function.TfMethodTarget): 394 f_self = f_self.target 395 effective_args = (f_self,) + effective_args 396 397 elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'): 398 # Callable objects. Dunder methods have special lookup rules, see: 399 # https://docs.python.org/3/reference/datamodel.html#specialnames 400 # TODO(mdan): Recurse into converted_call to simplify other verifications. 401 # This should be handled in the same way as partials. 402 target_entity = f.__class__.__call__ 403 effective_args = (f,) + args 404 405 else: 406 target_entity = f 407 raise NotImplementedError('unknown callable type "%s"' % type(f)) 408 409 except Exception as e: # pylint:disable=broad-except 410 logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) 411 if is_autograph_strict_conversion_mode(): 412 raise 413 return _fall_back_unconverted(f, args, kwargs, options, e) 414 415 if not hasattr(target_entity, '__code__'): 416 logging.log(2, 'Permanently allowed: %s: native binding', target_entity) 417 return _call_unconverted(f, args, kwargs, options) 418 elif (hasattr(target_entity.__code__, 'co_filename') and 419 target_entity.__code__.co_filename == '<string>'): 420 # TODO(mdan): __globals__['txt'] might work in Py3. 421 logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)', 422 target_entity) 423 return _call_unconverted(f, args, kwargs, options) 424 425 try: 426 program_ctx = converter.ProgramContext(options=options) 427 converted_f = _convert_actual(target_entity, program_ctx) 428 if logging.has_verbosity(2): 429 _log_callargs(converted_f, effective_args, kwargs) 430 except Exception as e: # pylint:disable=broad-except 431 logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) 432 if is_autograph_strict_conversion_mode(): 433 raise 434 return _fall_back_unconverted(f, args, kwargs, options, e) 435 436 with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter(): 437 try: 438 if kwargs is not None: 439 result = converted_f(*effective_args, **kwargs) 440 else: 441 result = converted_f(*effective_args) 442 except Exception as e: 443 _attach_error_metadata(e, converted_f) 444 raise 445 446 return result 447 448 449def _call_unconverted(f, args, kwargs, options, update_cache=True): 450 """Calls the original function without converting with AutoGraph.""" 451 if update_cache: 452 conversion.cache_allowlisted(f, options) 453 454 if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget): 455 return f.__self__.call(args, kwargs) 456 457 if kwargs is not None: 458 return f(*args, **kwargs) 459 return f(*args) 460 461 462def _fall_back_unconverted(f, args, kwargs, options, exc): 463 """Falls back to calling the function unconverted, in case of error.""" 464 # TODO(mdan): Consider adding an internal metric. 465 warning_template = ( 466 'AutoGraph could not transform %s and will run it as-is.\n' 467 '%s' 468 'Cause: %s\n' 469 'To silence this warning, decorate the function with' 470 ' @tf.autograph.experimental.do_not_convert') 471 if isinstance(exc, errors.InaccessibleSourceCodeError): 472 if ag_ctx.INSPECT_SOURCE_SUPPORTED: 473 logging.warning(warning_template, f, '', exc) 474 elif isinstance(exc, errors.UnsupportedLanguageElementError): 475 if not conversion.is_in_allowlist_cache(f, options): 476 logging.warning(warning_template, f, '', exc) 477 else: 478 file_bug_message = ( 479 'Please report this to the TensorFlow team. When filing the bug, set' 480 ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and' 481 ' attach the full output.\n') 482 logging.warning(warning_template, f, file_bug_message, exc) 483 484 return _call_unconverted(f, args, kwargs, options) 485 486 487# 488# TensorFlow integration 489# 490 491 492@tf_export('__internal__.autograph.tf_convert', v1=[]) 493def tf_convert(f, ctx, convert_by_default=True, user_requested=False): 494 """Decorator that applies AutoGraph to a function. 495 496 Use in internal APIs. 497 498 This API is suitable for high order functions internal to the TensorFlow API, 499 and more generally any function to which AutoGraph is not applied. 500 501 Guidance: `convert` was a decorator meant for use directly by developers, but 502 most of today's uses go through `tf.function`. `tf_convert` is to be called 503 from high order functions internal to TF. By default, all the internal 504 TensorFlow functions are skipped when AutoGraph processes the code. This may 505 lead to user-supplied functions to be incorrectly skipped as well. 506 `tf_convert` helps avoid that. See the following example for more details. 507 508 ``` 509 =====tf_internal_module.py===== 510 511 def unconverted(input_fn): 512 return input_fn() 513 514 def converted(input_fn): 515 return tf.__internal__.autograph.tf_convert( 516 input_fn, ctx=tf.__internal__.autograph.control_status_ctx())() 517 518 ======user_module.py====== 519 520 @tf.function 521 def foo(input_fn) 522 return unconverted(input_fn) 523 524 @tf.function 525 def bar(input_fn) 526 return converted(input_fn) 527 528 @tf.function(autograph=False) 529 def baz(input_fn) 530 return converted(input_fn) 531 ``` 532 533 The `foo` method above will execute the `input_fn` without autograph 534 conversion, while the `bar` method will run an autographed `input_fn`. The 535 `baz` method will run an unconverted `input_fn`, since `tf_convert` respect 536 the control status context. 537 538 Note that both methods in `tf_internal_module` are skipped by autograph when 539 tracing the `tf.function`. The configuration of whether a module/package 540 should be skipped by autograph is controlled in 541 tensorflow/python/autograph/core/config.py. 542 543 Args: 544 f: Callable. 545 ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used. 546 convert_by_default: bool, whether to use AutoGraph when the context doesn't 547 specify. 548 user_requested: bool, whether to ignore the conversion allowlist. See 549 ConversionOptions.user_requested. 550 551 Returns: 552 Either `f or the converted version of `f`. 553 """ 554 555 if is_autograph_artifact(f): 556 return f 557 f_wrapper = f 558 decorators, f = tf_decorator.unwrap(f) 559 560 # TODO(mdan): Grab features from context. 561 # Note: we pass the original context through to convert to properly handle the 562 # following scenario, which can be used inside TF implementations: 563 # 564 # ctx = ag_ctx.control_status_ctx() 565 # @function(autograph=False) # Low-level graph code 566 # def inner_fn(): 567 # # The context is disabled here, but should be enabled in user user_fn 568 # tf_convert(user_fn, ctx=ctx) 569 if ctx.status == ag_ctx.Status.ENABLED: 570 wrapper_factory = convert( 571 recursive=True, user_requested=user_requested, conversion_ctx=ctx) 572 elif ctx.status == ag_ctx.Status.DISABLED: 573 wrapper_factory = do_not_convert 574 elif ctx.status == ag_ctx.Status.UNSPECIFIED: 575 if convert_by_default: 576 wrapper_factory = convert( 577 recursive=True, user_requested=user_requested, conversion_ctx=ctx) 578 else: 579 wrapper_factory = call_with_unspecified_conversion_status 580 else: 581 assert False, 'This switch contains all possible cases!' 582 wrapper = wrapper_factory(f) 583 584 if decorators: 585 wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper) 586 587 return autograph_artifact(wrapper) 588 589 590def call_with_unspecified_conversion_status(func): 591 """Decorator that resets the conversion context to the unspecified status.""" 592 593 def wrapper(*args, **kwargs): 594 with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED): 595 return func(*args, **kwargs) 596 597 if inspect.isfunction(func) or inspect.ismethod(func): 598 wrapper = functools.update_wrapper(wrapper, func) 599 600 return autograph_artifact(wrapper) 601 602 603def _log_callargs(f, args, kwargs): 604 """Logging helper.""" 605 logging.log(2, 'Defaults of %s : %s', f, f.__defaults__) 606 logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__) 607 608 if kwargs is not None: 609 callargs = tf_inspect.getcallargs(f, *args, **kwargs) 610 else: 611 callargs = tf_inspect.getcallargs(f, *args) 612 613 formatted_callargs = '\n'.join( 614 ' {}: {}'.format(k, v) for k, v in callargs.items()) 615 logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs) 616 617 618# 619# Public API 620# 621 622 623@tf_export('autograph.experimental.do_not_convert') 624def do_not_convert(func=None): 625 """Decorator that suppresses the conversion of a function. 626 627 Args: 628 func: function to decorate. 629 630 Returns: 631 If `func` is not None, returns a `Callable` which is equivalent to 632 `func`, but is not converted by AutoGraph. 633 If `func` is None, returns a decorator that, when invoked with a 634 single `func` argument, returns a `Callable` equivalent to the 635 above case. 636 """ 637 if func is None: 638 return do_not_convert 639 640 def wrapper(*args, **kwargs): 641 with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED): 642 return func(*args, **kwargs) 643 644 if inspect.isfunction(func) or inspect.ismethod(func): 645 wrapper = functools.update_wrapper(wrapper, func) 646 647 return autograph_artifact(wrapper) 648 649 650# TODO(mdan): Make private. 651def convert(recursive=False, 652 optional_features=None, 653 user_requested=True, 654 conversion_ctx=ag_ctx.NullCtx()): 655 """Decorator that compiles a function to use TensorFlow ops. 656 657 The decorator is dynamic - it recompiles the target whenever the decorated 658 function is called. This means the parameter values are known at conversion. 659 It also means that repeated calls with different types of parameters will be 660 correctly processed. 661 662 Args: 663 recursive: bool, whether to recursively convert any functions or classes 664 that the converted function may use. 665 optional_features: converted.Feature, allows toggling optional or 666 experimental features. When set to None, only the core features are 667 enabled. 668 user_requested: bool, whether this is a function that the user explicitly 669 asked to be converted. See ConversionOptions.user_requested. 670 conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in 671 which `f` is used. 672 673 Returns: 674 Callable, a decorator that converts the given function into an equivalent 675 function that uses TensorFlow ops. 676 """ 677 678 def decorator(f): 679 """Decorator implementation.""" 680 681 def wrapper(*args, **kwargs): 682 """Wrapper that calls the converted version of f.""" 683 options = converter.ConversionOptions( 684 recursive=recursive, 685 user_requested=user_requested, 686 optional_features=optional_features) 687 try: 688 with conversion_ctx: 689 return converted_call(f, args, kwargs, options=options) 690 except Exception as e: # pylint:disable=broad-except 691 if hasattr(e, 'ag_error_metadata'): 692 raise e.ag_error_metadata.to_exception(e) 693 else: 694 raise 695 696 if inspect.isfunction(f) or inspect.ismethod(f): 697 wrapper = functools.update_wrapper(wrapper, f) 698 699 decorated_wrapper = tf_decorator.make_decorator(f, wrapper) 700 return autograph_artifact(decorated_wrapper) 701 702 return decorator 703 704 705# pylint:disable=line-too-long 706@tf_export('autograph.to_graph', v1=[]) 707def to_graph(entity, recursive=True, experimental_optional_features=None): 708 """Converts a Python entity into a TensorFlow graph. 709 710 Also see: `tf.autograph.to_code`, `tf.function`. 711 712 Unlike `tf.function`, `to_graph` is a low-level transpiler that converts 713 Python code to TensorFlow graph code. It does not implement any caching, 714 variable management or create any actual ops, and is best used where greater 715 control over the generated TensorFlow graph is desired. Another difference 716 from `tf.function` is that `to_graph` will not wrap the graph into a 717 TensorFlow function or a Python callable. Internally, `tf.function` uses 718 `to_graph`. 719 720 Example usage: 721 722 >>> def f(x): 723 ... if x > 0: 724 ... y = x * x 725 ... else: 726 ... y = -x 727 ... return y 728 ... 729 >>> converted_f = to_graph(f) 730 >>> x = tf.constant(2) 731 >>> converted_f(x) # converted_foo is like a TensorFlow Op. 732 <tf.Tensor: shape=(), dtype=int32, numpy=4> 733 734 Supported Python entities include: 735 * functions 736 * classes 737 * object methods 738 739 Functions are converted into new functions with converted code. 740 741 Classes are converted by generating a new class whose methods use converted 742 code. 743 744 Methods are converted into unbound function that have an additional first 745 argument called `self`. 746 747 For a tutorial, see the 748 [tf.function and AutoGraph guide](https://www.tensorflow.org/guide/function). 749 For more detailed information, see the 750 [AutoGraph reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md). 751 752 Args: 753 entity: Python callable or class to convert. 754 recursive: Whether to recursively convert any functions that the converted 755 function may call. 756 experimental_optional_features: `None`, a tuple of, or a single 757 `tf.autograph.experimental.Feature` value. 758 759 Returns: 760 Same as `entity`, the converted Python function or class. 761 762 Raises: 763 ValueError: If the entity could not be converted. 764 """ 765 try: 766 program_ctx = converter.ProgramContext( 767 options=converter.ConversionOptions( 768 recursive=recursive, 769 user_requested=True, 770 optional_features=experimental_optional_features)) 771 return autograph_artifact(_convert_actual(entity, program_ctx)) 772 except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: 773 logging.error(1, 'Error converting %s', entity, exc_info=True) 774 raise ConversionError('converting {}: {}: {}'.format( 775 entity, e.__class__.__name__, str(e))) 776 777 778@tf_export(v1=['autograph.to_graph']) 779def to_graph_v1(entity, 780 recursive=True, 781 arg_values=None, 782 arg_types=None, 783 experimental_optional_features=None): 784 """Converts a Python entity into a TensorFlow graph. 785 786 Also see: `tf.autograph.to_code`, `tf.function`. 787 788 Unlike `tf.function`, `to_graph` is a low-level transpiler that converts 789 Python code to TensorFlow graph code. It does not implement any caching, 790 variable management or create any actual ops, and is best used where greater 791 control over the generated TensorFlow graph is desired. Another difference 792 from `tf.function` is that `to_graph` will not wrap the graph into a 793 TensorFlow function or a Python callable. Internally, `tf.function` uses 794 `to_graph`. 795 796 _Example Usage_ 797 798 ```python 799 def foo(x): 800 if x > 0: 801 y = x * x 802 else: 803 y = -x 804 return y 805 806 converted_foo = to_graph(foo) 807 808 x = tf.constant(1) 809 y = converted_foo(x) # converted_foo is a TensorFlow Op-like. 810 assert is_tensor(y) 811 ``` 812 813 Supported Python entities include: 814 * functions 815 * classes 816 * object methods 817 818 Functions are converted into new functions with converted code. 819 820 Classes are converted by generating a new class whose methods use converted 821 code. 822 823 Methods are converted into unbound function that have an additional first 824 argument called `self`. 825 826 Args: 827 entity: Python callable or class to convert. 828 recursive: Whether to recursively convert any functions that the converted 829 function may call. 830 arg_values: Deprecated. 831 arg_types: Deprecated. 832 experimental_optional_features: `None`, a tuple of, or a single 833 `tf.autograph.experimental.Feature` value. 834 835 Returns: 836 Same as `entity`, the converted Python function or class. 837 838 Raises: 839 ValueError: If the entity could not be converted. 840 """ 841 del arg_types 842 del arg_values 843 return to_graph( 844 entity, 845 recursive=recursive, 846 experimental_optional_features=experimental_optional_features) 847 848 849@tf_export(v1=['autograph.to_code']) 850def to_code_v1(entity, 851 recursive=True, 852 arg_values=None, 853 arg_types=None, 854 indentation=' ', 855 experimental_optional_features=None): 856 """Returns the source code generated by AutoGraph, as a string. 857 858 Example usage: 859 860 >>> def f(x): 861 ... if x < 0: 862 ... x = -x 863 ... return x 864 >>> tf.autograph.to_code(f) 865 "...def tf__f(x):..." 866 867 Also see: `tf.autograph.to_graph`. 868 869 Note: If a function has been decorated with `tf.function`, pass its 870 underlying Python function, rather than the callable that `tf.function 871 creates: 872 873 >>> @tf.function 874 ... def f(x): 875 ... if x < 0: 876 ... x = -x 877 ... return x 878 >>> tf.autograph.to_code(f.python_function) 879 "...def tf__f(x):..." 880 881 Args: 882 entity: Python callable or class. 883 recursive: Whether to recursively convert any functions that the converted 884 function may call. 885 arg_values: Deprecated. 886 arg_types: Deprecated. 887 indentation: Deprecated. 888 experimental_optional_features: `None`, a tuple of, or a single 889 `tf.autograph.experimental.Feature` value. 890 891 Returns: 892 The converted code as string. 893 """ 894 del arg_values 895 del arg_types 896 del indentation 897 return to_code( 898 entity, 899 recursive=recursive, 900 experimental_optional_features=experimental_optional_features) 901 902 903@tf_export('autograph.to_code', v1=[]) 904def to_code(entity, recursive=True, experimental_optional_features=None): 905 """Returns the source code generated by AutoGraph, as a string. 906 907 Example usage: 908 909 >>> def f(x): 910 ... if x < 0: 911 ... x = -x 912 ... return x 913 >>> tf.autograph.to_code(f) 914 "...def tf__f(x):..." 915 916 Also see: `tf.autograph.to_graph`. 917 918 Note: If a function has been decorated with `tf.function`, pass its 919 underlying Python function, rather than the callable that `tf.function 920 creates: 921 922 >>> @tf.function 923 ... def f(x): 924 ... if x < 0: 925 ... x = -x 926 ... return x 927 >>> tf.autograph.to_code(f.python_function) 928 "...def tf__f(x):..." 929 930 Args: 931 entity: Python callable or class to convert. 932 recursive: Whether to recursively convert any functions that the converted 933 function may call. 934 experimental_optional_features: `None`, a tuple of, or a single 935 `tf.autograph.experimental.Feature` value. 936 937 Returns: 938 The converted code as string. 939 """ 940 source = tf_inspect.getsource( 941 to_graph( 942 entity, 943 recursive=recursive, 944 experimental_optional_features=experimental_optional_features)) 945 return textwrap.dedent(source) 946 947 948_TRANSPILER = PyToTF() 949