xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/impl/api.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This module contains the user- and codegen-facing API for AutoGraph."""
16
17import functools
18import importlib
19import inspect
20import os
21import sys
22import textwrap
23import traceback
24
25from tensorflow.python.autograph import operators
26from tensorflow.python.autograph import utils
27from tensorflow.python.autograph.converters import asserts
28from tensorflow.python.autograph.converters import break_statements
29from tensorflow.python.autograph.converters import call_trees
30from tensorflow.python.autograph.converters import conditional_expressions
31from tensorflow.python.autograph.converters import continue_statements
32from tensorflow.python.autograph.converters import control_flow
33from tensorflow.python.autograph.converters import directives
34from tensorflow.python.autograph.converters import functions
35from tensorflow.python.autograph.converters import lists
36from tensorflow.python.autograph.converters import logical_expressions
37from tensorflow.python.autograph.converters import return_statements
38from tensorflow.python.autograph.converters import slices
39from tensorflow.python.autograph.converters import variables
40from tensorflow.python.autograph.core import ag_ctx
41from tensorflow.python.autograph.core import converter
42from tensorflow.python.autograph.core import function_wrappers
43from tensorflow.python.autograph.core import unsupported_features_checker
44from tensorflow.python.autograph.impl import conversion
45from tensorflow.python.autograph.lang import special_functions
46from tensorflow.python.autograph.operators import py_builtins
47from tensorflow.python.autograph.pyct import anno
48from tensorflow.python.autograph.pyct import cfg
49from tensorflow.python.autograph.pyct import error_utils
50from tensorflow.python.autograph.pyct import errors
51from tensorflow.python.autograph.pyct import inspect_utils
52from tensorflow.python.autograph.pyct import origin_info
53from tensorflow.python.autograph.pyct import qual_names
54from tensorflow.python.autograph.pyct import transpiler
55from tensorflow.python.autograph.pyct.static_analysis import activity
56from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
57from tensorflow.python.autograph.utils import ag_logging as logging
58from tensorflow.python.eager import function
59from tensorflow.python.framework import errors_impl
60from tensorflow.python.util import tf_decorator
61from tensorflow.python.util import tf_inspect
62from tensorflow.python.util import tf_stack
63from tensorflow.python.util.tf_export import tf_export
64
65
66def is_autograph_strict_conversion_mode():
67  return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0
68
69
70#
71# Error handling
72#
73
74
75# TODO(mdan): Export this symbol.
76class AutoGraphError(errors.PyCTError):
77  """Base class for all AutoGraph exceptions."""
78  pass
79
80
81class ConversionError(AutoGraphError):
82  """Raised during the conversion process."""
83  pass
84
85
86class StagingError(AutoGraphError):
87  """Raised during the staging (i.e. Python execution) of converted code."""
88  pass
89
90
91class _ErrorMetadata(error_utils.ErrorMetadataBase):
92  """AutoGraph-specific error metadata. See base class."""
93
94  def create_exception(self, source_error):
95    preferred_type = type(source_error)
96    if issubclass(preferred_type, errors_impl.OpError):
97      # Best-effort unpacking of OpError exceptions.
98      # TODO(mdan): Use a mechanism that is more future-proof.
99      init_argspec = tf_inspect.getfullargspec(preferred_type.__init__)
100      message = self.get_message()
101      init_args = tuple(init_argspec.args)
102      # At the time of this writing, TF errors either take 3 or 4 arguments,
103      # the argument '*args' may or may not be used.
104      if init_args == ('self', 'node_def', 'op', 'message'):
105        return preferred_type(source_error.node_def, source_error.op, message,
106                              source_error.experimental_payloads)
107
108    elif preferred_type in (errors.PyCTError, AutoGraphError, ConversionError,
109                            StagingError, errors_impl.InaccessibleTensorError,
110                            errors_impl.OperatorNotAllowedInGraphError):
111      return preferred_type(self.get_message())
112
113    exc = super(_ErrorMetadata, self).create_exception(source_error)
114    if exc is not None:
115      return exc
116
117    # Note: While changing an error's message property to change the message it
118    # displays will probably work a lot of times, there is no standard way in
119    # Python to do that. The safest way is therefore to create a new exception.
120    # For user defined exceptions, we could define an interface that allowed
121    # them to work under this mechanism.
122    return StagingError(self.get_message())
123
124
125def _attach_error_metadata(e, f):
126  """Augments an error with the metadata necessary for rewrite."""
127  if hasattr(e, 'ag_pass_through'):
128    return
129
130  metadata = getattr(e, 'ag_error_metadata', None)
131  source_map = f.ag_source_map
132
133  if metadata is None:
134    logging.log(1, 'Caught error in user callable %s', f, exc_info=True)
135    message = '{}: {}'.format(e.__class__.__name__, e)
136  else:
137    message = None
138
139  cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:]
140
141  e.ag_error_metadata = _ErrorMetadata(cause_tb, metadata, message, source_map,
142                                       __file__)
143
144
145class StackTraceMapper(tf_stack.StackTraceMapper):
146  """Remaps generated code to code it originated from."""
147
148  def __init__(self, converted_fn):
149    super().__init__()
150    self._source_map = converted_fn.ag_source_map
151    # This may be called repeatedly: once on entry, by the superclass, then by
152    # each child context manager.
153    self._cached_map = None
154
155  def get_effective_source_map(self):
156    if self._cached_map is not None:
157      return self._cached_map
158
159    parent_map = self.parent.get_effective_source_map()
160
161    effective_source_map = {}
162    for loc, origin in self._source_map.items():
163      effective_source_map[(loc.filename, loc.lineno)] = (origin.loc.filename,
164                                                          origin.loc.lineno,
165                                                          origin.function_name)
166
167    for key, value in parent_map.items():
168      filename, lineno, _ = value
169      value_loc = origin_info.LineLocation(filename=filename, lineno=lineno)
170      if value_loc in self._source_map:
171        origin = self._source_map[value_loc]
172        effective_source_map[key] = (origin.loc.filename, origin.loc.lineno,
173                                     origin.function_name)
174      else:
175        effective_source_map[key] = value
176
177    self._cached_map = effective_source_map
178    return effective_source_map
179
180
181#
182# Actual source code transformation
183#
184
185
186class PyToTF(transpiler.PyToPy):
187  """The TensorFlow AutoGraph transformer."""
188
189  def __init__(self):
190    super(PyToTF, self).__init__()
191    self._extra_locals = None
192
193  def get_transformed_name(self, node):
194    return 'tf__' + super(PyToTF, self).get_transformed_name(node)
195
196  def get_extra_locals(self):
197    if self._extra_locals is None:
198      # TODO(mdan): Move into core or replace with an actual importable module.
199      # Craft a module that exposes the external API as well as certain
200      # internal modules.
201      module_spec = importlib.machinery.ModuleSpec('autograph', None)
202      ag_internal = importlib.util.module_from_spec(module_spec)
203      ag_internal.__dict__.update(inspect.getmodule(PyToTF).__dict__)
204      ag_internal.ConversionOptions = converter.ConversionOptions
205      ag_internal.STD = converter.STANDARD_OPTIONS
206      ag_internal.Feature = converter.Feature
207      ag_internal.utils = utils
208      ag_internal.FunctionScope = function_wrappers.FunctionScope
209      ag_internal.with_function_scope = function_wrappers.with_function_scope
210      # TODO(mdan): Add safeguards against name clashes.
211      # We don't want to create a submodule because we want the operators to be
212      # accessible as ag__.<operator>
213      ag_internal.__dict__.update(special_functions.__dict__)
214      ag_internal.__dict__.update(operators.__dict__)
215
216      self._extra_locals = {'ag__': ag_internal}
217    return self._extra_locals
218
219  def get_caching_key(self, ctx):
220    return ctx.options
221
222  def initial_analysis(self, node, ctx):
223    graphs = cfg.build(node)
224    node = qual_names.resolve(node)
225    node = activity.resolve(node, ctx, None)
226    node = reaching_definitions.resolve(node, ctx, graphs)
227    anno.dup(
228        node,
229        {
230            anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
231        },
232    )
233    return node
234
235  def transform_ast(self, node, ctx):
236    unsupported_features_checker.verify(node)
237    node = self.initial_analysis(node, ctx)
238
239    node = functions.transform(node, ctx)
240    node = directives.transform(node, ctx)
241    node = break_statements.transform(node, ctx)
242    if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
243      node = asserts.transform(node, ctx)
244    # Note: sequencing continue canonicalization before for loop one avoids
245    # dealing with the extra loop increment operation that the for
246    # canonicalization creates.
247    node = continue_statements.transform(node, ctx)
248    node = return_statements.transform(node, ctx)
249    if ctx.user.options.uses(converter.Feature.LISTS):
250      node = lists.transform(node, ctx)
251      node = slices.transform(node, ctx)
252    node = call_trees.transform(node, ctx)
253    node = control_flow.transform(node, ctx)
254    node = conditional_expressions.transform(node, ctx)
255    node = logical_expressions.transform(node, ctx)
256    node = variables.transform(node, ctx)
257    return node
258
259
260def _convert_actual(entity, program_ctx):
261  """Applies AutoGraph to entity."""
262
263  # TODO(mdan): Put these extra fields inside __autograph_info__.
264  if not hasattr(entity, '__code__'):
265    raise ValueError('Cannot apply autograph to a function that doesn\'t '
266                     'expose a __code__ object. If this is a @tf.function,'
267                     ' try passing f.python_function instead.')
268
269  transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx)
270
271  assert not hasattr(transformed, 'ag_module')
272  assert not hasattr(transformed, 'ag_source_map')
273  transformed.ag_module = module
274  transformed.ag_source_map = source_map
275  return transformed
276
277
278#
279# Generated code support
280#
281
282
283def autograph_artifact(entity, extras=None):
284  if inspect.ismethod(entity):
285    setattr(entity.__func__, 'autograph_info__', extras)
286  else:
287    setattr(entity, 'autograph_info__', extras)
288  return entity
289
290
291def is_autograph_artifact(entity):
292  return hasattr(entity, 'autograph_info__')
293
294
295def converted_call(f, args, kwargs, caller_fn_scope=None, options=None):
296  """Converts a function call inline.
297
298  For internal use only.
299
300  Note: The argument list is optimized for readability of generated code, which
301  may look like this:
302
303    ag__.converted_call(f, (arg1, arg2), None, fscope)
304    ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)
305    ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope)
306
307  Args:
308    f: The function to convert.
309    args: Tuple, the original positional arguments of f
310    kwargs: Optional[Dict], the original keyword arguments of f
311    caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
312      scope of the converted function in which this call was originally made.
313    options: Optional[converter.ConversionOptions], conversion options. If not
314      specified, the value of caller_fn_scope.callopts is used. Either options
315      or caller_fn_scope must be present.
316
317  Returns:
318    Any, the result of executing a possibly-converted `f` with the given
319      arguments.
320  """
321  logging.log(1, 'Converted call: %s\n    args: %s\n    kwargs: %s\n', f, args,
322              kwargs)
323
324  if options is None:
325    if caller_fn_scope is None:
326      raise ValueError('either caller_fn_scope or options must have a value')
327    options = caller_fn_scope.callopts
328
329  if conversion.is_in_allowlist_cache(f, options):
330    logging.log(2, 'Allowlisted %s: from cache', f)
331    return _call_unconverted(f, args, kwargs, options, False)
332
333  if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
334    logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f)
335    return _call_unconverted(f, args, kwargs, options, False)
336
337  if is_autograph_artifact(f):
338    logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f)
339    return _call_unconverted(f, args, kwargs, options)
340
341  # If this is a partial, unwrap it and redo all the checks.
342  if isinstance(f, functools.partial):
343    new_kwargs = {}
344    if f.keywords is not None:
345      # Use copy to avoid mutating the underlying keywords.
346      new_kwargs = f.keywords.copy()
347    if kwargs is not None:
348      new_kwargs.update(kwargs)
349    new_args = f.args + args
350    logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args,
351                new_kwargs)
352    return converted_call(
353        f.func,
354        new_args,
355        new_kwargs,
356        caller_fn_scope=caller_fn_scope,
357        options=options)
358
359  if inspect_utils.isbuiltin(f):
360    if f is eval:
361      return py_builtins.eval_in_original_context(f, args, caller_fn_scope)
362    if f is super:
363      return py_builtins.super_in_original_context(f, args, caller_fn_scope)
364    if f is globals:
365      return py_builtins.globals_in_original_context(caller_fn_scope)
366    if f is locals:
367      return py_builtins.locals_in_original_context(caller_fn_scope)
368    if kwargs:
369      return py_builtins.overload_of(f)(*args, **kwargs)
370    else:
371      return py_builtins.overload_of(f)(*args)
372
373  if conversion.is_unsupported(f):
374    return _call_unconverted(f, args, kwargs, options)
375
376  if not options.user_requested and conversion.is_allowlisted(f):
377    return _call_unconverted(f, args, kwargs, options)
378
379  # internal_convert_user_code is for example turned off when issuing a dynamic
380  # call conversion from generated code while in nonrecursive mode. In that
381  # case we evidently don't want to recurse, but we still have to convert
382  # things like builtins.
383  if not options.internal_convert_user_code:
384    return _call_unconverted(f, args, kwargs, options)
385
386  try:
387    if inspect.ismethod(f) or inspect.isfunction(f):
388      target_entity = f
389      effective_args = args
390
391      f_self = getattr(f, '__self__', None)
392      if f_self is not None:
393        if isinstance(f_self, function.TfMethodTarget):
394          f_self = f_self.target
395        effective_args = (f_self,) + effective_args
396
397    elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'):
398      # Callable objects. Dunder methods have special lookup rules, see:
399      # https://docs.python.org/3/reference/datamodel.html#specialnames
400      # TODO(mdan): Recurse into converted_call to simplify other verifications.
401      # This should be handled in the same way as partials.
402      target_entity = f.__class__.__call__
403      effective_args = (f,) + args
404
405    else:
406      target_entity = f
407      raise NotImplementedError('unknown callable type "%s"' % type(f))
408
409  except Exception as e:  # pylint:disable=broad-except
410    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
411    if is_autograph_strict_conversion_mode():
412      raise
413    return _fall_back_unconverted(f, args, kwargs, options, e)
414
415  if not hasattr(target_entity, '__code__'):
416    logging.log(2, 'Permanently allowed: %s: native binding', target_entity)
417    return _call_unconverted(f, args, kwargs, options)
418  elif (hasattr(target_entity.__code__, 'co_filename') and
419        target_entity.__code__.co_filename == '<string>'):
420    # TODO(mdan): __globals__['txt'] might work in Py3.
421    logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)',
422                target_entity)
423    return _call_unconverted(f, args, kwargs, options)
424
425  try:
426    program_ctx = converter.ProgramContext(options=options)
427    converted_f = _convert_actual(target_entity, program_ctx)
428    if logging.has_verbosity(2):
429      _log_callargs(converted_f, effective_args, kwargs)
430  except Exception as e:  # pylint:disable=broad-except
431    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
432    if is_autograph_strict_conversion_mode():
433      raise
434    return _fall_back_unconverted(f, args, kwargs, options, e)
435
436  with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
437    try:
438      if kwargs is not None:
439        result = converted_f(*effective_args, **kwargs)
440      else:
441        result = converted_f(*effective_args)
442    except Exception as e:
443      _attach_error_metadata(e, converted_f)
444      raise
445
446  return result
447
448
449def _call_unconverted(f, args, kwargs, options, update_cache=True):
450  """Calls the original function without converting with AutoGraph."""
451  if update_cache:
452    conversion.cache_allowlisted(f, options)
453
454  if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget):
455    return f.__self__.call(args, kwargs)
456
457  if kwargs is not None:
458    return f(*args, **kwargs)
459  return f(*args)
460
461
462def _fall_back_unconverted(f, args, kwargs, options, exc):
463  """Falls back to calling the function unconverted, in case of error."""
464  # TODO(mdan): Consider adding an internal metric.
465  warning_template = (
466      'AutoGraph could not transform %s and will run it as-is.\n'
467      '%s'
468      'Cause: %s\n'
469      'To silence this warning, decorate the function with'
470      ' @tf.autograph.experimental.do_not_convert')
471  if isinstance(exc, errors.InaccessibleSourceCodeError):
472    if ag_ctx.INSPECT_SOURCE_SUPPORTED:
473      logging.warning(warning_template, f, '', exc)
474  elif isinstance(exc, errors.UnsupportedLanguageElementError):
475    if not conversion.is_in_allowlist_cache(f, options):
476      logging.warning(warning_template, f, '', exc)
477  else:
478    file_bug_message = (
479        'Please report this to the TensorFlow team. When filing the bug, set'
480        ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
481        ' attach the full output.\n')
482    logging.warning(warning_template, f, file_bug_message, exc)
483
484  return _call_unconverted(f, args, kwargs, options)
485
486
487#
488# TensorFlow integration
489#
490
491
492@tf_export('__internal__.autograph.tf_convert', v1=[])
493def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
494  """Decorator that applies AutoGraph to a function.
495
496  Use in internal APIs.
497
498  This API is suitable for high order functions internal to the TensorFlow API,
499  and more generally any function to which AutoGraph is not applied.
500
501  Guidance: `convert` was a decorator meant for use directly by developers, but
502  most of today's uses go through `tf.function`. `tf_convert` is to be called
503  from high order functions internal to TF. By default, all the internal
504  TensorFlow functions are skipped when AutoGraph processes the code. This may
505  lead to user-supplied functions to be incorrectly skipped as well.
506  `tf_convert` helps avoid that. See the following example for more details.
507
508  ```
509  =====tf_internal_module.py=====
510
511  def unconverted(input_fn):
512    return input_fn()
513
514  def converted(input_fn):
515    return tf.__internal__.autograph.tf_convert(
516       input_fn, ctx=tf.__internal__.autograph.control_status_ctx())()
517
518  ======user_module.py======
519
520  @tf.function
521  def foo(input_fn)
522    return unconverted(input_fn)
523
524  @tf.function
525  def bar(input_fn)
526    return converted(input_fn)
527
528  @tf.function(autograph=False)
529  def baz(input_fn)
530    return converted(input_fn)
531  ```
532
533  The `foo` method above will execute the `input_fn` without autograph
534  conversion, while the `bar` method will run an autographed `input_fn`. The
535  `baz` method will run an unconverted `input_fn`, since `tf_convert` respect
536  the control status context.
537
538  Note that both methods in `tf_internal_module` are skipped by autograph when
539  tracing the `tf.function`. The configuration of whether a module/package
540  should be skipped by autograph is controlled in
541  tensorflow/python/autograph/core/config.py.
542
543  Args:
544    f: Callable.
545    ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
546    convert_by_default: bool, whether to use AutoGraph when the context doesn't
547      specify.
548    user_requested: bool, whether to ignore the conversion allowlist. See
549      ConversionOptions.user_requested.
550
551  Returns:
552    Either `f or the converted version of `f`.
553  """
554
555  if is_autograph_artifact(f):
556    return f
557  f_wrapper = f
558  decorators, f = tf_decorator.unwrap(f)
559
560  # TODO(mdan): Grab features from context.
561  # Note: we pass the original context through to convert to properly handle the
562  # following scenario, which can be used inside TF implementations:
563  #
564  #   ctx = ag_ctx.control_status_ctx()
565  #   @function(autograph=False)  # Low-level graph code
566  #   def inner_fn():
567  #     # The context is disabled here, but should be enabled in user user_fn
568  #     tf_convert(user_fn, ctx=ctx)
569  if ctx.status == ag_ctx.Status.ENABLED:
570    wrapper_factory = convert(
571        recursive=True, user_requested=user_requested, conversion_ctx=ctx)
572  elif ctx.status == ag_ctx.Status.DISABLED:
573    wrapper_factory = do_not_convert
574  elif ctx.status == ag_ctx.Status.UNSPECIFIED:
575    if convert_by_default:
576      wrapper_factory = convert(
577          recursive=True, user_requested=user_requested, conversion_ctx=ctx)
578    else:
579      wrapper_factory = call_with_unspecified_conversion_status
580  else:
581    assert False, 'This switch contains all possible cases!'
582  wrapper = wrapper_factory(f)
583
584  if decorators:
585    wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)
586
587  return autograph_artifact(wrapper)
588
589
590def call_with_unspecified_conversion_status(func):
591  """Decorator that resets the conversion context to the unspecified status."""
592
593  def wrapper(*args, **kwargs):
594    with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
595      return func(*args, **kwargs)
596
597  if inspect.isfunction(func) or inspect.ismethod(func):
598    wrapper = functools.update_wrapper(wrapper, func)
599
600  return autograph_artifact(wrapper)
601
602
603def _log_callargs(f, args, kwargs):
604  """Logging helper."""
605  logging.log(2, 'Defaults of %s : %s', f, f.__defaults__)
606  logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__)
607
608  if kwargs is not None:
609    callargs = tf_inspect.getcallargs(f, *args, **kwargs)
610  else:
611    callargs = tf_inspect.getcallargs(f, *args)
612
613  formatted_callargs = '\n'.join(
614      '    {}: {}'.format(k, v) for k, v in callargs.items())
615  logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs)
616
617
618#
619# Public API
620#
621
622
623@tf_export('autograph.experimental.do_not_convert')
624def do_not_convert(func=None):
625  """Decorator that suppresses the conversion of a function.
626
627  Args:
628    func: function to decorate.
629
630  Returns:
631    If `func` is not None, returns a `Callable` which is equivalent to
632    `func`, but is not converted by AutoGraph.
633    If `func` is None, returns a decorator that, when invoked with a
634    single `func` argument, returns a `Callable` equivalent to the
635    above case.
636  """
637  if func is None:
638    return do_not_convert
639
640  def wrapper(*args, **kwargs):
641    with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
642      return func(*args, **kwargs)
643
644  if inspect.isfunction(func) or inspect.ismethod(func):
645    wrapper = functools.update_wrapper(wrapper, func)
646
647  return autograph_artifact(wrapper)
648
649
650# TODO(mdan): Make private.
651def convert(recursive=False,
652            optional_features=None,
653            user_requested=True,
654            conversion_ctx=ag_ctx.NullCtx()):
655  """Decorator that compiles a function to use TensorFlow ops.
656
657  The decorator is dynamic - it recompiles the target whenever the decorated
658  function is called. This means the parameter values are known at conversion.
659  It also means that repeated calls with different types of parameters will be
660  correctly processed.
661
662  Args:
663    recursive: bool, whether to recursively convert any functions or classes
664      that the converted function may use.
665    optional_features: converted.Feature, allows toggling optional or
666      experimental features. When set to None, only the core features are
667      enabled.
668    user_requested: bool, whether this is a function that the user explicitly
669      asked to be converted. See ConversionOptions.user_requested.
670    conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in
671      which `f` is used.
672
673  Returns:
674    Callable, a decorator that converts the given function into an equivalent
675    function that uses TensorFlow ops.
676  """
677
678  def decorator(f):
679    """Decorator implementation."""
680
681    def wrapper(*args, **kwargs):
682      """Wrapper that calls the converted version of f."""
683      options = converter.ConversionOptions(
684          recursive=recursive,
685          user_requested=user_requested,
686          optional_features=optional_features)
687      try:
688        with conversion_ctx:
689          return converted_call(f, args, kwargs, options=options)
690      except Exception as e:  # pylint:disable=broad-except
691        if hasattr(e, 'ag_error_metadata'):
692          raise e.ag_error_metadata.to_exception(e)
693        else:
694          raise
695
696    if inspect.isfunction(f) or inspect.ismethod(f):
697      wrapper = functools.update_wrapper(wrapper, f)
698
699    decorated_wrapper = tf_decorator.make_decorator(f, wrapper)
700    return autograph_artifact(decorated_wrapper)
701
702  return decorator
703
704
705# pylint:disable=line-too-long
706@tf_export('autograph.to_graph', v1=[])
707def to_graph(entity, recursive=True, experimental_optional_features=None):
708  """Converts a Python entity into a TensorFlow graph.
709
710  Also see: `tf.autograph.to_code`, `tf.function`.
711
712  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
713  Python code to TensorFlow graph code. It does not implement any caching,
714  variable management or create any actual ops, and is best used where greater
715  control over the generated TensorFlow graph is desired. Another difference
716  from `tf.function` is that `to_graph` will not wrap the graph into a
717  TensorFlow function or a Python callable. Internally, `tf.function` uses
718  `to_graph`.
719
720  Example usage:
721
722  >>> def f(x):
723  ...   if x > 0:
724  ...     y = x * x
725  ...   else:
726  ...     y = -x
727  ...   return y
728  ...
729  >>> converted_f = to_graph(f)
730  >>> x = tf.constant(2)
731  >>> converted_f(x)  # converted_foo is like a TensorFlow Op.
732  <tf.Tensor: shape=(), dtype=int32, numpy=4>
733
734  Supported Python entities include:
735    * functions
736    * classes
737    * object methods
738
739  Functions are converted into new functions with converted code.
740
741  Classes are converted by generating a new class whose methods use converted
742  code.
743
744  Methods are converted into unbound function that have an additional first
745  argument called `self`.
746
747  For a tutorial, see the
748  [tf.function and AutoGraph guide](https://www.tensorflow.org/guide/function).
749  For more detailed information, see the
750  [AutoGraph reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md).
751
752  Args:
753    entity: Python callable or class to convert.
754    recursive: Whether to recursively convert any functions that the converted
755      function may call.
756    experimental_optional_features: `None`, a tuple of, or a single
757      `tf.autograph.experimental.Feature` value.
758
759  Returns:
760    Same as `entity`, the converted Python function or class.
761
762  Raises:
763    ValueError: If the entity could not be converted.
764  """
765  try:
766    program_ctx = converter.ProgramContext(
767        options=converter.ConversionOptions(
768            recursive=recursive,
769            user_requested=True,
770            optional_features=experimental_optional_features))
771    return autograph_artifact(_convert_actual(entity, program_ctx))
772  except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
773    logging.error(1, 'Error converting %s', entity, exc_info=True)
774    raise ConversionError('converting {}: {}: {}'.format(
775        entity, e.__class__.__name__, str(e)))
776
777
778@tf_export(v1=['autograph.to_graph'])
779def to_graph_v1(entity,
780                recursive=True,
781                arg_values=None,
782                arg_types=None,
783                experimental_optional_features=None):
784  """Converts a Python entity into a TensorFlow graph.
785
786  Also see: `tf.autograph.to_code`, `tf.function`.
787
788  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
789  Python code to TensorFlow graph code. It does not implement any caching,
790  variable management or create any actual ops, and is best used where greater
791  control over the generated TensorFlow graph is desired. Another difference
792  from `tf.function` is that `to_graph` will not wrap the graph into a
793  TensorFlow function or a Python callable. Internally, `tf.function` uses
794  `to_graph`.
795
796  _Example Usage_
797
798  ```python
799    def foo(x):
800      if x > 0:
801        y = x * x
802      else:
803        y = -x
804      return y
805
806    converted_foo = to_graph(foo)
807
808    x = tf.constant(1)
809    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
810    assert is_tensor(y)
811  ```
812
813  Supported Python entities include:
814    * functions
815    * classes
816    * object methods
817
818  Functions are converted into new functions with converted code.
819
820  Classes are converted by generating a new class whose methods use converted
821  code.
822
823  Methods are converted into unbound function that have an additional first
824  argument called `self`.
825
826  Args:
827    entity: Python callable or class to convert.
828    recursive: Whether to recursively convert any functions that the converted
829      function may call.
830    arg_values: Deprecated.
831    arg_types: Deprecated.
832    experimental_optional_features: `None`, a tuple of, or a single
833      `tf.autograph.experimental.Feature` value.
834
835  Returns:
836    Same as `entity`, the converted Python function or class.
837
838  Raises:
839    ValueError: If the entity could not be converted.
840  """
841  del arg_types
842  del arg_values
843  return to_graph(
844      entity,
845      recursive=recursive,
846      experimental_optional_features=experimental_optional_features)
847
848
849@tf_export(v1=['autograph.to_code'])
850def to_code_v1(entity,
851               recursive=True,
852               arg_values=None,
853               arg_types=None,
854               indentation='  ',
855               experimental_optional_features=None):
856  """Returns the source code generated by AutoGraph, as a string.
857
858  Example usage:
859
860  >>> def f(x):
861  ...   if x < 0:
862  ...     x = -x
863  ...   return x
864  >>> tf.autograph.to_code(f)
865  "...def tf__f(x):..."
866
867  Also see: `tf.autograph.to_graph`.
868
869  Note: If a function has been decorated with `tf.function`, pass its
870  underlying Python function, rather than the callable that `tf.function
871  creates:
872
873  >>> @tf.function
874  ... def f(x):
875  ...   if x < 0:
876  ...     x = -x
877  ...   return x
878  >>> tf.autograph.to_code(f.python_function)
879  "...def tf__f(x):..."
880
881  Args:
882    entity: Python callable or class.
883    recursive: Whether to recursively convert any functions that the converted
884      function may call.
885    arg_values: Deprecated.
886    arg_types: Deprecated.
887    indentation: Deprecated.
888    experimental_optional_features: `None`, a tuple of, or a single
889      `tf.autograph.experimental.Feature` value.
890
891  Returns:
892    The converted code as string.
893  """
894  del arg_values
895  del arg_types
896  del indentation
897  return to_code(
898      entity,
899      recursive=recursive,
900      experimental_optional_features=experimental_optional_features)
901
902
903@tf_export('autograph.to_code', v1=[])
904def to_code(entity, recursive=True, experimental_optional_features=None):
905  """Returns the source code generated by AutoGraph, as a string.
906
907  Example usage:
908
909  >>> def f(x):
910  ...   if x < 0:
911  ...     x = -x
912  ...   return x
913  >>> tf.autograph.to_code(f)
914  "...def tf__f(x):..."
915
916  Also see: `tf.autograph.to_graph`.
917
918  Note: If a function has been decorated with `tf.function`, pass its
919  underlying Python function, rather than the callable that `tf.function
920  creates:
921
922  >>> @tf.function
923  ... def f(x):
924  ...   if x < 0:
925  ...     x = -x
926  ...   return x
927  >>> tf.autograph.to_code(f.python_function)
928  "...def tf__f(x):..."
929
930  Args:
931    entity: Python callable or class to convert.
932    recursive: Whether to recursively convert any functions that the converted
933      function may call.
934    experimental_optional_features: `None`, a tuple of, or a single
935      `tf.autograph.experimental.Feature` value.
936
937  Returns:
938    The converted code as string.
939  """
940  source = tf_inspect.getsource(
941      to_graph(
942          entity,
943          recursive=recursive,
944          experimental_optional_features=experimental_optional_features))
945  return textwrap.dedent(source)
946
947
948_TRANSPILER = PyToTF()
949