xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/pyct/transformer.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A node transformer that includes utilities for SCT."""
16
17import collections
18import enum
19
20import gast
21
22from tensorflow.python.autograph.pyct import anno
23from tensorflow.python.autograph.pyct import parser
24from tensorflow.python.autograph.pyct import pretty_printer
25from tensorflow.python.autograph.pyct import templates
26
27
28class AnalysisLevel(enum.IntEnum):
29
30  NONE = 0
31  ACTIVITY = 1
32  DEFINEDNESS = 2
33  LIVENESS = 3
34
35
36# TODO(znado): Use namedtuple.
37class Context(object):
38  """Contains information about a source code transformation.
39
40  This object is mutable, and is updated during conversion. Not thread safe.
41
42  Attributes:
43    info: EntityInfo, immutable.
44    namer: naming.Namer.
45    current_origin: origin_info.OriginInfo, holds the OriginInfo of the last
46      AST node to be processed successfully. Useful for error handling.
47    user: An user-supplied context object. The object is opaque to the
48      infrastructure, but will pe passed through to all custom transformations.
49  """
50
51  def __init__(self, info, namer, user_context):
52    self.info = info
53    self.namer = namer
54    self.current_origin = None
55    self.user = user_context
56
57
58# TODO(mdan): Move to a standalone file.
59class EntityInfo(
60    collections.namedtuple(
61        'EntityInfo',
62        ('name', 'source_code', 'source_file', 'future_features', 'namespace'))
63):
64  """Contains information about a Python entity.
65
66  Immutable.
67
68  Examples of entities include functions and classes.
69
70  Attributes:
71    name: The name that identifies this entity.
72    source_code: The entity's source code.
73    source_file: The entity's source file.
74    future_features: Tuple[Text], the future features that this entity was
75      compiled with. See
76      https://docs.python.org/2/reference/simple_stmts.html#future.
77    namespace: Dict[str, ], containing symbols visible to the entity (excluding
78      parameters).
79  """
80  pass
81
82
83class _StateStack(object):
84  """Templated context manager.
85
86  This class provides syntactic sugar for a stack of objects of known
87  type. It allows accessing attributes of the object at the top of the stack
88  directly against this object, which allows for very terse syntax.
89
90  For example, this code:
91
92    stack = _StateStack(Foo)
93    stack.enter()
94    stack.bar
95
96  Is equivalent to:
97
98    stack = []
99    stack.append(Foo())
100    foo = stack[-1]
101    foo.bar
102
103  See _State for more on how this is used.
104
105  Attributes:
106    type: Any, the type of objects that this stack holds
107    level: int, the current stack depth
108    stack: List[Any], the actual stack
109    value: Any, the instance of the object at the top of the stack
110  """
111
112  def __init__(self, type_):
113    # Because we override __setattr__, we need to attach these attributes using
114    # the superclass' setattr.
115    object.__setattr__(self, 'type', type_)
116    object.__setattr__(self, '_stack', [])
117    if not hasattr(type_, 'no_root'):
118      self.enter()
119
120  def __enter__(self):
121    self.enter()
122    return self
123
124  def __exit__(self, exc_type, exc_value, traceback):
125    self.exit()
126
127  def enter(self):
128    self._stack.append(self.type())
129
130  def exit(self):
131    self._stack.pop()
132
133  @property
134  def stack(self):
135    return self._stack
136
137  @property
138  def level(self):
139    return len(self._stack)
140
141  @property
142  def value(self):
143    return self._stack[-1]
144
145  def __iter__(self):
146    return iter(self._stack)
147
148  def __getattr__(self, key):
149    return getattr(self._stack[-1], key)
150
151  def __setattr__(self, key, value):
152    setattr(self._stack[-1], key, value)
153
154
155class _State(object):
156  """Syntactic sugar for accessing an instance of a StateStack context manager.
157
158  This structure offers syntactic sugar over a dict of stacks of objects
159  of known type. These structures are useful to keep state during AST walks.
160  Multiple different scopes can be tracked in parallel. For example:
161
162    s = _State()
163
164    s[foo].enter()
165    s[bar].enter()  # this will not affect s[foo]
166
167  Element access has special semantics:
168    * keys are a data type
169    * element values are _StateStack(type=key) objects
170    * missing elements are automatically added, similarly to defaultdict
171
172  For example, the following block :
173
174    _State s
175    s[Foo]
176
177  Is equivalent to:
178
179    s = {}
180    if Foo not in s:
181      s[Foo] = Foo()
182    s[Foo]
183
184  See Base for how it's used.
185  """
186
187  def __init__(self):
188    self._value = {}
189
190  def __getitem__(self, key):
191    if key not in self._value:
192      self._value[key] = _StateStack(key)
193    return self._value[key]
194
195
196class NodeStateTracker(object):
197  """Base class for general-purpose Python code transformation.
198
199  This abstract class provides helpful functions, like state tracking within
200  the scope of arbitrary node, helpers for processing code blocks, debugging,
201  mapping of transformed code to original code, and others.
202
203  Scope-local state tracking: to keep state across nodes, at the level of
204  (possibly nested) scopes, use enter/exit_local_scope and set/get_local.
205  You must call enter/exit_local_scope manually, but the transformer detects
206  when they are not properly paired.
207
208  The transformer allows keeping state across calls that is local
209  to arbitrary nodes and their descendants, using the self.state attribute.
210  Multiple independent scopes are allowed and automatically constructed.
211
212  For example, to keep track of the `If` node that encloses any `Name` node,
213  one can write:
214
215  ```
216    class FooType(object):
217
218      def __init__(self):
219        self.foo_property = None
220
221    class DummyTransformer(NodeStateTracker, ast.NodeTransformer):
222
223      def visit_If(self, node):
224        self.state[FooType].enter()
225        self.state[FooType].foo_property = node
226        node = self.veneric_visit(node)
227        self.state[FooType].exit()
228        return node
229
230      def visit_Name(self, node):
231        self.state[FooType].foo_property  # will hold the innermost enclosing if
232  ```
233
234  Alternatively, the `enter()`/`exit()` calls can be managed by a `with`
235  statement:
236
237  ```
238      def visit_If(self, node):
239        with self.state[FooType] as foo:
240          foo.foo_property = node
241          return self.generic_visit(node)
242  ```
243  """
244
245  # TODO(mdan): Document all extra features.
246
247  def __init__(self, ctx):
248    """Initialize the transformer.
249
250    Subclasses should call this.
251
252    Args:
253      ctx: A Context object.
254    """
255    self._lineno = 0
256    self._col_offset = 0
257    self.ctx = ctx
258
259    # Allows scoping of local variables to keep state across calls to visit_*
260    # methods. Multiple scope hierarchies may exist and are keyed by tag. A
261    # scope is valid at one or more nodes and all its children. Scopes created
262    # in child nodes supersede their parent. Scopes are isolated from one
263    # another.
264    self.state = _State()
265
266  def debug_print(self, node):
267    """Helper method useful for debugging. Prints the AST."""
268    if __debug__:
269      print(pretty_printer.fmt(node))
270    return node
271
272  def debug_print_src(self, node):
273    """Helper method useful for debugging. Prints the AST as code."""
274    if __debug__:
275      print(parser.unparse(node))
276    return node
277
278  def visit_block(self, nodes, before_visit=None, after_visit=None):
279    """A more powerful version of generic_visit for statement blocks.
280
281    An example of a block is the body of an if statement.
282
283    This function allows specifying a postprocessing callback (the
284    after_visit argument) argument which can be used to move nodes to a new
285    destination. This is done by after_visit by returning a non-null
286    second return value, e.g. return new_node, new_destination.
287
288    For example, a transformer could perform the following move:
289
290        foo()
291        bar()
292        baz()
293
294        foo()
295        if cond:
296          bar()
297          baz()
298
299    The above could be done with a postprocessor of this kind:
300
301        def after_visit(node):
302          if node_is_function_call(bar):
303            new_container_node = build_cond()
304            new_container_node.body.append(node)
305            return new_container_node, new_container_node.body
306          else:
307            # Once we set a new destination, all subsequent items will be
308            # moved to it, so we don't need to explicitly handle baz.
309            return node, None
310
311    Args:
312      nodes: enumerable of AST node objects. If None, the function returns None.
313      before_visit: optional callable that is called before visiting each item
314        in nodes
315      after_visit: optional callable that takes in an AST node and returns a
316        tuple (new_node, new_destination). It is called after visiting each item
317        in nodes. Is used in the same was as the
318          visit_* methods: new_node will replace the node; if not None,
319            new_destination must be a list, and subsequent nodes will be placed
320            in this list instead of the list returned by visit_block.
321
322    Returns:
323      A list of AST node objects containing the transformed items fron nodes,
324      except those nodes that have been relocated using after_visit.
325    """
326    if nodes is None:
327      return None
328
329    results = []
330    node_destination = results
331    for node in nodes:
332      if before_visit:
333        # TODO(mdan): We can modify node here too, if ever needed.
334        before_visit()
335
336      replacement = self.visit(node)
337
338      if after_visit and replacement:
339        replacement, new_destination = after_visit(replacement)
340      else:
341        new_destination = None
342
343      if replacement:
344        if isinstance(replacement, (list, tuple)):
345          node_destination.extend(replacement)
346        else:
347          node_destination.append(replacement)
348
349      # Allow the postprocessor to reroute the remaining nodes to a new list.
350      if new_destination is not None:
351        node_destination = new_destination
352    return results
353
354
355# TODO(mdan): Rename to PythonCodeTransformer.
356class Base(NodeStateTracker, gast.NodeTransformer):
357  """Base class for general-purpose Python-to-Python code transformation.
358
359  This is an extension of ast.NodeTransformer that provides the additional
360  functions offered by NodeStateTracker.
361  """
362
363  def create_assignment(self, target, expression):
364    template = """
365      target = expression
366    """
367    return templates.replace(template, target=target, expression=expression)
368
369  # TODO(mdan): Remove.
370  def apply_to_single_assignments(self, targets, values, apply_fn):
371    """Applies a function to each individual assignment.
372
373    This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
374    It tries to break down the unpacking if possible. In effect, it has the same
375    effect as passing the assigned values in SSA form to apply_fn.
376
377    Examples:
378
379    The following will result in apply_fn(a, c), apply_fn(b, d):
380
381        a, b = c, d
382
383    The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
384
385        a, b = c
386
387    The following will result in apply_fn(a, (b, c)):
388
389        a = b, c
390
391    It uses the visitor pattern to allow subclasses to process single
392    assignments individually.
393
394    Args:
395      targets: list, tuple of or individual AST node. Should be used with the
396        targets field of an ast.Assign node.
397      values: an AST node.
398      apply_fn: a function of a single argument, which will be called with the
399        respective nodes of each single assignment. The signature is
400        apply_fn(target, value), no return value.
401    """
402    if not isinstance(targets, (list, tuple)):
403      targets = (targets,)
404    for target in targets:
405      if isinstance(target, (gast.Tuple, gast.List)):
406        for i in range(len(target.elts)):
407          target_el = target.elts[i]
408          if isinstance(values, (gast.Tuple, gast.List)):
409            value_el = values.elts[i]
410          else:
411            value_el = gast.Subscript(values, i, ctx=gast.Store())
412          self.apply_to_single_assignments(target_el, value_el, apply_fn)
413      else:
414        # TODO(mdan): Look into allowing to rewrite the AST here.
415        apply_fn(target, values)
416
417  def visit(self, node):
418    if not isinstance(node, gast.AST):
419      # This is not that uncommon a mistake: various node bodies are lists, for
420      # example, posing a land mine for transformers that need to recursively
421      # call `visit`.  The error needs to be raised before the exception handler
422      # below is installed, because said handler will mess up if `node` is not,
423      # in fact, a node.
424      msg = ('invalid value for "node": expected "ast.AST", got "{}"; to'
425             ' visit lists of nodes, use "visit_block" instead').format(
426                 type(node))
427      raise ValueError(msg)
428
429    if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
430      return node
431
432    parent_origin = self.ctx.current_origin
433    if anno.hasanno(node, anno.Basic.ORIGIN):
434      self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)
435
436    try:
437      processing_expr_node = isinstance(node, gast.Expr)
438      if processing_expr_node:
439        entry_expr_value = node.value
440
441      result = super(Base, self).visit(node)
442
443      # Adjust for consistency: replacing the value of an Expr with
444      # an Assign node removes the need for the Expr node.
445      if (processing_expr_node and isinstance(result, gast.Expr) and
446          (result.value is not entry_expr_value)):
447        # When the replacement is a list, it is assumed that the list came
448        # from a template that contained a number of statements, which
449        # themselves are standalone and don't require an enclosing Expr.
450        if isinstance(result.value,
451                      (list, tuple, gast.Assign, gast.AugAssign)):
452          result = result.value
453
454      # By default, all replacements receive the origin info of the replaced
455      # node.
456      if result is not node and result is not None:
457        inherited_origin = anno.getanno(
458            node, anno.Basic.ORIGIN, default=parent_origin)
459        if inherited_origin is not None:
460          nodes_to_adjust = result
461          if isinstance(result, (list, tuple)):
462            nodes_to_adjust = result
463          else:
464            nodes_to_adjust = (result,)
465          for n in nodes_to_adjust:
466            if not anno.hasanno(n, anno.Basic.ORIGIN):
467              anno.setanno(n, anno.Basic.ORIGIN, inherited_origin)
468    finally:
469      self.ctx.current_origin = parent_origin
470
471    return result
472
473
474class CodeGenerator(NodeStateTracker, gast.NodeVisitor):
475  """Base class for general-purpose Python-to-string code transformation.
476
477  Similar to Base, but outputs arbitrary strings instead of a Python AST.
478
479  This uses the same visitor mechanism that the standard NodeVisitor uses,
480  meaning that subclasses write handlers for the different kinds of nodes.
481  New code is generated using the emit method, which appends to a code buffer
482  that can be afterwards obtained from code_buffer.
483
484  Example:
485
486    class SimpleCodeGen(CodeGenerator):
487
488      def visitIf(self, node):
489        self.emit('if ')
490        self.visit(node.test)
491        self.emit(' { ')
492        self.visit(node.body)
493        self.emit(' } else { ')
494        self.visit(node.orelse)
495        self.emit(' } ')
496
497    node = ast.parse(...)
498    gen = SimpleCodeGen()
499    gen.visit(node)
500    # gen.code_buffer contains the resulting code
501  """
502
503  def __init__(self, ctx):
504    super(CodeGenerator, self).__init__(ctx)
505
506    self._output_code = ''
507    self.source_map = {}
508
509  def emit(self, code):
510    self._output_code += code
511
512  @property
513  def code_buffer(self):
514    return self._output_code
515
516  def visit(self, node):
517    if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
518      return
519
520    parent_origin = self.ctx.current_origin
521    eof_before = len(self._output_code)
522    if anno.hasanno(node, anno.Basic.ORIGIN):
523      self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)
524
525    try:
526      ret = super(CodeGenerator, self).visit(node)
527
528      # By default, all replacements receive the origin info of the replaced
529      # node.
530      eof_after = len(self._output_code)
531      if eof_before - eof_after:
532        inherited_origin = anno.getanno(
533            node, anno.Basic.ORIGIN, default=parent_origin)
534        if inherited_origin is not None:
535          self.source_map[(eof_before, eof_after)] = inherited_origin
536      return ret
537    finally:
538      self.ctx.current_origin = parent_origin
539