xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/pyct/transpiler.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Generic source code transformation infrastructure."""
16
17import inspect
18import threading
19import types
20
21import gast
22
23from tensorflow.python.autograph.pyct import cache
24from tensorflow.python.autograph.pyct import inspect_utils
25from tensorflow.python.autograph.pyct import loader
26from tensorflow.python.autograph.pyct import naming
27from tensorflow.python.autograph.pyct import origin_info
28from tensorflow.python.autograph.pyct import parser
29from tensorflow.python.autograph.pyct import templates
30from tensorflow.python.autograph.pyct import transformer
31from tensorflow.python.autograph.utils import ag_logging as logging
32
33
34def _wrap_into_factory(nodes, entity_name, inner_factory_name,
35                       outer_factory_name, closure_vars, factory_args,
36                       future_features):
37  """Wraps an AST into the body of a factory with consistent lexical context.
38
39  The AST is expected to define some symbol with a name given by `entity_name`.
40
41  This mechanism ensures that the resulting transformed entity has lexical
42  scoping identical to that of the source entity, while allowing extra
43  parametrization.
44
45  Two nested factories achieve the following:
46
47   1. The inner factory dynamically creates the entity represented by `nodes`.
48   2. The inner factory is parametrized by a custom set of arguments.
49   3. The inner factory has a closure identical to that of the transformed
50       entity.
51   4. The inner factory has local variables named like `args`, which `nodes` may
52       use as additional parameters.
53   5. The inner factory returns the variables given by `entity_name`.
54   6. The outer factory is niladic.
55   7. The outer factory has no closure.
56   8. The outer factory creates the necessary lexical scope for the inner
57       factory, so that the loaded code has the given configuration for
58       closure/globals.
59   9. The outer factory returns the inner factory.
60
61  Roughly speaking, the following code is generated:
62
63      from __future__ import future_feature_1
64      from __future__ import future_feature_2
65      ...
66
67      def outer_factory():
68        closure_var_1 = None
69        closure_var_2 = None
70        ...
71
72        def inner_factory(arg_1, arg_2, ...):
73          <<nodes>>
74          return entity
75
76        return inner_factory
77
78  The lexical scoping is created using dummy symbol declarations which create
79  local variables in the body of the outer factory, so that the Python parser
80  correctly marks them as free non-global variables upon load (that is, it
81  creates cell slots for each symbol. These symbols are initialized with None,
82  but their values are not expected to be used; instead, the caller is expected
83  to replace them with the cells of the source entity. For more details, see:
84  https://docs.python.org/3/reference/executionmodel.html#binding-of-names
85
86  Args:
87    nodes: Tuple[ast.AST], the source code to wrap.
88    entity_name: Union[Text, ast.AST], the name of the principal entity that
89      `nodes` define.
90    inner_factory_name: Text, the name of the inner factory.
91    outer_factory_name: Text, the name of the outer factory.
92    closure_vars: Iterable[Text], names of the closure variables for the inner
93      factory.
94    factory_args: Iterable[Text], names of additional arguments for the
95      inner factory. Useful to configure variables that the converted code can
96      use. Typically, these are modules.
97    future_features: Iterable[Text], names of future statements to associate the
98      code with.
99
100  Returns:
101    ast.AST
102  """
103  dummy_closure_defs = []
104  for var_name in closure_vars:
105    template = """
106      var_name = None
107    """
108    dummy_closure_defs.extend(templates.replace(template, var_name=var_name))
109
110  if future_features:
111    future_imports = gast.ImportFrom(
112        module='__future__',
113        names=[gast.alias(name=name, asname=None) for name in future_features],
114        level=0)
115  else:
116    future_imports = []
117
118  factory_args = [
119      gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None)
120      for name in factory_args
121  ]
122
123  template = """
124    future_imports
125    def outer_factory_name():
126      dummy_closure_defs
127      def inner_factory_name(factory_args):
128        entity_defs
129        return entity_name
130      return inner_factory_name
131  """
132  return templates.replace(
133      template,
134      dummy_closure_defs=dummy_closure_defs,
135      entity_defs=nodes,
136      entity_name=entity_name,
137      factory_args=factory_args,
138      future_imports=future_imports,
139      inner_factory_name=inner_factory_name,
140      outer_factory_name=outer_factory_name)
141
142
143class _PythonFnFactory(object):
144  """Helper object that wraps a Python function factory."""
145
146  def __init__(self, name, freevars, extra_locals):
147    """Creates a new factory for a Python function.
148
149    Args:
150      name: The function name.
151      freevars: The list of non-global free variables for the function.
152      extra_locals: Dict[Text, Any], names and values for custom variables that
153        are accessible to the generated code as local variables.
154    """
155    self._name = name
156    self._freevars = freevars
157    self._extra_locals = extra_locals
158
159    self._unbound_factory = None
160    self.module = None
161    self.source_map = None
162
163  def create(self,
164             nodes,
165             namer,
166             inner_factory_name='inner_factory',
167             outer_factory_name='outer_factory',
168             future_features=()):
169    """Initializes a function."""
170    if self._unbound_factory is not None:
171      raise ValueError('double initialization; create a new object instead')
172
173    inner_factory_name = namer.new_symbol(inner_factory_name, ())
174    outer_factory_name = namer.new_symbol(outer_factory_name, ())
175    nodes = _wrap_into_factory(nodes, self._name, inner_factory_name,
176                               outer_factory_name, self._freevars,
177                               self._extra_locals.keys(), future_features)
178
179    module, _, source_map = loader.load_ast(
180        nodes, include_source_map=True)
181    outer_factory = getattr(module, outer_factory_name)
182    self._unbound_factory = outer_factory()
183    self.module = module
184    self.source_map = source_map
185
186  def instantiate(self,
187                  globals_,
188                  closure,
189                  defaults=None,
190                  kwdefaults=None):
191    """Creates a new function instance."""
192    if self._unbound_factory is None:
193      raise ValueError('call create first')
194
195    factory_code = self._unbound_factory.__code__
196    factory_freevars = factory_code.co_freevars
197    closure_map = dict(zip(self._freevars, closure))
198    factory_closure = tuple(
199        closure_map[name] for name in factory_code.co_freevars)
200    if len(factory_closure) != len(closure):
201      raise ValueError(
202          'closure mismatch, requested {}, but source function had {}'.format(
203              self._freevars, factory_freevars))
204
205    bound_factory = types.FunctionType(
206        code=factory_code,
207        globals=globals_,
208        name=self._name,
209        argdefs=(),
210        closure=factory_closure)
211
212    # The lint override is a false positive.
213    new_fn = bound_factory(**self._extra_locals)  # pylint:disable=not-callable
214
215    if defaults:
216      new_fn.__defaults__ = defaults
217    if kwdefaults:
218      new_fn.__kwdefaults__ = kwdefaults
219
220    return new_fn
221
222
223class GenericTranspiler(object):
224  """A generic transpiler for Python functions.
225
226  Its interface is the `transform` API, which can process Python function
227  objects. Internally, it handles parsing.
228
229  Users typically subclass this, customizing the `transform_ast` method. The
230  output of transformed_ast is returned directly by `transform`. Existing
231  methods like `transform_function` may also be overloaded.
232
233  Example:
234
235      class MyTransformer(GenericTranspiler):
236
237        def transform_ast(self, node, ctx):
238          result = <<transform node>>
239          return result
240
241      transformer = MyTransfomer()
242
243      result = transformer.transform(f, ...)
244      # result is the output
245  """
246
247  def get_transformed_name(self, node):
248    """Returns a name for the output function. Subclasses may override this."""
249    if isinstance(node, gast.Lambda):
250      return 'lam'
251    elif isinstance(node, gast.FunctionDef):
252      return node.name
253    raise ValueError('Unknown node type {}'.format(node))
254
255  def transform_ast(self, node, ctx):
256    """Performs an actual transformation of a function's AST.
257
258    Subclasses must implement this method, and do not usually call it.
259
260    Args:
261      node: One or more ast.AST nodes representing the AST to be transformed.
262      ctx: transformer.Context.
263    """
264    raise NotImplementedError('subclasses must override this')
265
266  def transform(self, obj, user_context):
267    """Transforms a Python object.
268
269    Users typically call this method.
270
271    Args:
272      obj: A Python object, function, type, etc.
273      user_context: An opaque object (may be None) that is forwarded to
274        transform_ast, through the ctx.user_context argument.
275    Returns:
276      The result of calling transform_function.
277
278    Raises:
279      NotImplementedError: if the type of obj is not handled.
280    """
281    if inspect.isfunction(obj) or inspect.ismethod(obj):
282      return self.transform_function(obj, user_context)
283
284    raise NotImplementedError('Non-function: {}'.format(type(obj)))
285
286  def _erase_arg_defaults(self, node):
287    """Erase arg default expressions, which would otherwise be unbound."""
288    args = node.args
289    for i in range(len(args.defaults)):
290      args.defaults[i] = parser.parse_expression('None')
291    for i, d in enumerate(args.kw_defaults):
292      if d is not None:
293        args.kw_defaults[i] = parser.parse_expression('None')
294    return node
295
296  def transform_module(self, mod, user_context):
297    """Transforms a module.
298
299    Subclasses may override this method. The return value is opaque.
300
301    The method receives the original AST. The result is passed as-is to the
302    output of `transform`.
303
304    Args:
305      mod: A Python module.
306      user_context: An opaque object (may be None) that is forwarded to
307        transform_ast, through the ctx.user_context argument.
308    Returns:
309      List[Tuple[Any, Any]]. By default it returns the output of transform_ast,
310      evaluated on each supported member, other than modules, together with a
311      `transformer.Context` containing information about the transformation
312      process.
313    """
314    result = []
315    for member in mod.__dict__.values():
316      if inspect.ismodule(member):
317        continue  # Not transforming modules recursively.
318      try:
319        result.append(self.transform(member, user_context))
320      except NotImplementedError:
321        pass  # Skip unsupported elements.
322    return result
323
324  def transform_function(self, fn, user_context):
325    """Transforms a function.
326
327    Subclasses may override this method. The return value is opaque.
328
329    The method receives the original AST. The result is passed as-is to the
330    output of `transform`.
331
332    Args:
333      fn: A function or lambda.
334      user_context: An opaque object (may be None) that is forwarded to
335        transform_ast, through the ctx.user_context argument.
336    Returns:
337      Tuple[Any, Any]. By default it returns the output of transform_ast,
338      together with a `transformer.Context` containing information about the
339      transformation process.
340    """
341    future_features = inspect_utils.getfutureimports(fn)
342    node, source = parser.parse_entity(fn, future_features=future_features)
343    logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)
344
345    origin_info.resolve_entity(node, source, fn)
346
347    namespace = inspect_utils.getnamespace(fn)
348    namer = naming.Namer(namespace)
349    new_name = namer.new_symbol(self.get_transformed_name(node), ())
350    entity_info = transformer.EntityInfo(
351        name=new_name,
352        source_code=source,
353        source_file='<fragment>',
354        future_features=future_features,
355        namespace=namespace)
356    context = transformer.Context(entity_info, namer, user_context)
357
358    node = self._erase_arg_defaults(node)
359    result = self.transform_ast(node, context)
360
361    return result, context
362
363
364class PyToPy(GenericTranspiler):
365  """A generic Python-to-Python transpiler.
366
367  Its `transform` method offers a function-in, function-out interface.
368  Internally, it takes care of parsing, caching and loading of the translated
369  code.
370
371  Users typically subclass this, overriding `transform_ast`.
372
373  Usually, instances of this class are singletons, since each instance manages
374  its own cache. The caching can be controlled by overriding `get_caching_key`.
375
376  Example:
377
378      class MyTransformer(PyToPy):
379
380        def transform_ast(self, node, ctx):
381          node = <<transform node, usually using ast.NodeTransformer classes>>
382          return node
383
384      transformer = MyTransfomer()
385
386      new_f, module, source_map = transformer.transform_function(f, ...)
387      # new_f is a function with signature identical to f
388
389  The transformed function has access to the same namespace as the original
390  function. To allow access to internal APIs, users may inject additional
391  symbols by overriding `get_extra_locals`.
392  """
393
394  def __init__(self):
395    self._cache_lock = threading.RLock()
396    self._cache = cache.CodeObjectCache()
397
398  def get_extra_locals(self):
399    """Returns extra static local variables to be made to transformed code.
400
401    Subclasses must override this.
402
403    Returns:
404      extra_locals: A Dict[Text, Any] containing additional variables to make
405        available to the transformed code.
406    """
407    raise NotImplementedError('subclasses must override this')
408
409  def get_caching_key(self, user_context):
410    """Returns a unique key to use for caching.
411
412    Subclasses must override this.
413
414    Calls made to `transform_function` with functions that have the same code
415    object and caching key will return a cached instance on subsequent
416    invocations.
417
418    Args:
419      user_context: The context object which was passed to `transform`.
420
421    Returns:
422      extra_locals: A hashable.
423    """
424    raise NotImplementedError('subclasses must override this')
425
426  def _cached_factory(self, fn, cache_subkey):
427    cached_factory = self._cache[fn][cache_subkey]
428    logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey,
429                cached_factory)
430    return cached_factory
431
432  def transform_function(self, fn, user_context):
433    """Transforms a function. See GenericTranspiler.trasnform_function.
434
435    This overload wraps the parent's `transform_function`, adding caching and
436    facilities to instantiate the output as a Python object. It also
437    adds facilities to make new symbols available to the generated Python code,
438    visible as local variables - see `get_extra_locals`.
439
440    Args:
441      fn: A function or lambda.
442      user_context: An opaque object (may be None) that is forwarded to
443        transform_ast, through the ctx.user_context argument.
444    Returns:
445      A tuple:
446        * A function or lambda with the same signature and closure as `fn`
447        * The temporary module into which the transformed function was loaded
448        * The source map as a
449            Dict[origin_info.LineLocation, origin_info.OriginInfo]
450    """
451    cache_subkey = self.get_caching_key(user_context)
452
453    if self._cache.has(fn, cache_subkey):
454      # Fast path: use a lock-free check.
455      factory = self._cached_factory(fn, cache_subkey)
456
457    else:
458      with self._cache_lock:
459        # Check again under lock.
460        if self._cache.has(fn, cache_subkey):
461          factory = self._cached_factory(fn, cache_subkey)
462
463        else:
464          logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey)
465          # TODO(mdan): Confusing overloading pattern. Fix.
466          nodes, ctx = super(PyToPy, self).transform_function(fn, user_context)
467
468          if isinstance(nodes, gast.Lambda):
469            nodes = gast.Assign(
470                targets=[
471                    gast.Name(
472                        ctx.info.name,
473                        ctx=gast.Store(),
474                        annotation=None,
475                        type_comment=None)
476                ],
477                value=nodes)
478          else:
479            nodes.name = ctx.info.name
480
481          if logging.has_verbosity(2):
482            logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes))
483
484          factory = _PythonFnFactory(
485              ctx.info.name, fn.__code__.co_freevars, self.get_extra_locals())
486          factory.create(
487              nodes, ctx.namer, future_features=ctx.info.future_features)
488          self._cache[fn][cache_subkey] = factory
489
490    transformed_fn = factory.instantiate(
491        globals_=fn.__globals__,
492        closure=fn.__closure__ or (),
493        defaults=fn.__defaults__,
494        kwdefaults=getattr(fn, '__kwdefaults__', None))
495    return transformed_fn, factory.module, factory.source_map
496