xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/pyct/ast_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AST manipulation utilities."""
16
17import ast
18
19import gast
20
21from tensorflow.python.autograph.pyct import anno
22from tensorflow.python.autograph.pyct import parser
23from tensorflow.python.autograph.pyct import qual_names
24
25
26class CleanCopier(object):
27  """NodeTransformer-like visitor that copies an AST."""
28
29  def __init__(self, preserve_annos):
30    super(CleanCopier, self).__init__()
31    self.preserve_annos = preserve_annos
32
33  def copy(self, node):
34    """Returns a deep copy of node (excluding some fields, see copy_clean)."""
35
36    if isinstance(node, list):
37      return [self.copy(n) for n in node]
38    elif isinstance(node, tuple):
39      return tuple(self.copy(n) for n in node)
40    elif not isinstance(node, (gast.AST, ast.AST)):
41      # Assuming everything that's not an AST, list or tuple is a value type
42      # and may simply be assigned.
43      return node
44
45    assert isinstance(node, (gast.AST, ast.AST))
46
47    new_fields = {}
48    for f in node._fields:
49      if not f.startswith('__') and hasattr(node, f):
50        new_fields[f] = self.copy(getattr(node, f))
51    new_node = type(node)(**new_fields)
52
53    if self.preserve_annos:
54      for k in self.preserve_annos:
55        anno.copyanno(node, new_node, k)
56    return new_node
57
58
59def copy_clean(node, preserve_annos=None):
60  """Creates a deep copy of an AST.
61
62  The copy will not include fields that are prefixed by '__', with the
63  exception of user-specified annotations.
64
65  Args:
66    node: ast.AST
67    preserve_annos: Optional[Set[Hashable]], annotation keys to include in the
68        copy
69  Returns:
70    ast.AST
71  """
72  return CleanCopier(preserve_annos).copy(node)
73
74
75class SymbolRenamer(gast.NodeTransformer):
76  """Transformer that can rename symbols to a simple names."""
77
78  def __init__(self, name_map):
79    self.name_map = name_map
80
81  def _process_name_node(self, node):
82    qn = anno.getanno(node, anno.Basic.QN)
83    if qn in self.name_map:
84      new_node = gast.Name(
85          str(self.name_map[qn]),
86          ctx=node.ctx,
87          annotation=None,
88          type_comment=None)
89      # All annotations get carried over.
90      for k in anno.keys(node):
91        anno.copyanno(node, new_node, k)
92      return new_node
93    return self.generic_visit(node)
94
95  def _process_list_of_strings(self, names):
96    for i in range(len(names)):
97      qn = qual_names.QN(names[i])
98      if qn in self.name_map:
99        names[i] = str(self.name_map[qn])
100    return names
101
102  def visit_Nonlocal(self, node):
103    node.names = self._process_list_of_strings(node.names)
104    return node
105
106  def visit_Global(self, node):
107    node.names = self._process_list_of_strings(node.names)
108    return node
109
110  def visit_Name(self, node):
111    return self._process_name_node(node)
112
113  def visit_Attribute(self, node):
114    if anno.hasanno(node, anno.Basic.QN):
115      return self._process_name_node(node)
116    # Renaming attributes is not supported.
117    return self.generic_visit(node)
118
119  def visit_FunctionDef(self, node):
120    qn = qual_names.QN(node.name)
121    if qn in self.name_map:
122      node.name = str(self.name_map[qn])
123    return self.generic_visit(node)
124
125
126def rename_symbols(node, name_map):
127  """Renames symbols in an AST. Requires qual_names annotations."""
128  renamer = SymbolRenamer(name_map)
129  if isinstance(node, list):
130    return [renamer.visit(n) for n in node]
131  elif isinstance(node, tuple):
132    return tuple(renamer.visit(n) for n in node)
133  return renamer.visit(node)
134
135
136def keywords_to_dict(keywords):
137  """Converts a list of ast.keyword objects to a dict."""
138  keys = []
139  values = []
140  for kw in keywords:
141    keys.append(gast.Constant(kw.arg, kind=None))
142    values.append(kw.value)
143  return gast.Dict(keys=keys, values=values)
144
145
146class PatternMatcher(gast.NodeVisitor):
147  """Matches a node against a pattern represented by a node."""
148
149  def __init__(self, pattern):
150    self.pattern = pattern
151    self.pattern_stack = []
152    self.matches = True
153
154  def compare_and_visit(self, node, pattern):
155    self.pattern_stack.append(self.pattern)
156    self.pattern = pattern
157    self.generic_visit(node)
158    self.pattern = self.pattern_stack.pop()
159
160  def no_match(self):
161    self.matches = False
162    return False
163
164  def is_wildcard(self, p):
165    if isinstance(p, (list, tuple)) and len(p) == 1:
166      p, = p
167    if isinstance(p, gast.Name) and p.id == '_':
168      return True
169    if p == '_':
170      return True
171    return False
172
173  def generic_visit(self, node):
174    if not self.matches:
175      return
176
177    pattern = self.pattern
178    for f in node._fields:
179      if f.startswith('__'):
180        continue
181
182      if not hasattr(node, f):
183        if hasattr(pattern, f) and getattr(pattern, f):
184          return self.no_match()
185        else:
186          continue
187      if not hasattr(pattern, f):
188        return self.no_match()
189
190      v = getattr(node, f)
191      p = getattr(pattern, f)
192
193      if self.is_wildcard(p):
194        continue
195      if isinstance(v, (list, tuple)):
196        if not isinstance(p, (list, tuple)) or len(v) != len(p):
197          return self.no_match()
198        for v_item, p_item in zip(v, p):
199          self.compare_and_visit(v_item, p_item)
200      elif isinstance(v, (gast.AST, ast.AST)):
201        if not isinstance(v, type(p)) and not isinstance(p, type(v)):
202          return self.no_match()
203        self.compare_and_visit(v, p)
204      else:
205        # Assume everything else is a value type.
206        if v != p:
207          return self.no_match()
208
209
210def matches(node, pattern):
211  """Basic pattern matcher for AST.
212
213  The pattern may contain wildcards represented by the symbol '_'. A node
214  matches a pattern if for every node in the tree, either there is a node of
215  the same type in pattern, or a Name node with id='_'.
216
217  Args:
218    node: ast.AST
219    pattern: ast.AST
220  Returns:
221    bool
222  """
223  if isinstance(pattern, str):
224    pattern = parser.parse_str(pattern)
225
226  matcher = PatternMatcher(pattern)
227  matcher.visit(node)
228  return matcher.matches
229
230
231# TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
232def apply_to_single_assignments(targets, values, apply_fn):
233  """Applies a function to each individual assignment.
234
235  This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
236  It tries to break down the unpacking if possible. In effect, it has the same
237  effect as passing the assigned values in SSA form to apply_fn.
238
239  Examples:
240
241  The following will result in apply_fn(a, c), apply_fn(b, d):
242
243      a, b = c, d
244
245  The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
246
247      a, b = c
248
249  The following will result in apply_fn(a, (b, c)):
250
251      a = b, c
252
253  It uses the visitor pattern to allow subclasses to process single
254  assignments individually.
255
256  Args:
257    targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be
258        used with the targets field of an ast.Assign node
259    values: ast.AST
260    apply_fn: Callable[[ast.AST, ast.AST], None], called with the
261        respective nodes of each single assignment
262  """
263  if not isinstance(targets, (list, tuple)):
264    targets = (targets,)
265  for target in targets:
266    if isinstance(target, (gast.Tuple, gast.List)):
267      for i in range(len(target.elts)):
268        target_el = target.elts[i]
269        if isinstance(values, (gast.Tuple, gast.List)):
270          value_el = values.elts[i]
271        else:
272          idx = parser.parse_expression(str(i))
273          value_el = gast.Subscript(values, idx, ctx=gast.Load())
274        apply_to_single_assignments(target_el, value_el, apply_fn)
275    else:
276      apply_fn(target, values)
277
278
279def parallel_walk(node, other):
280  """Walks two ASTs in parallel.
281
282  The two trees must have identical structure.
283
284  Args:
285    node: Union[ast.AST, Iterable[ast.AST]]
286    other: Union[ast.AST, Iterable[ast.AST]]
287  Yields:
288    Tuple[ast.AST, ast.AST]
289  Raises:
290    ValueError: if the two trees don't have identical structure.
291  """
292  if isinstance(node, (list, tuple)):
293    node_stack = list(node)
294  else:
295    node_stack = [node]
296
297  if isinstance(other, (list, tuple)):
298    other_stack = list(other)
299  else:
300    other_stack = [other]
301
302  while node_stack and other_stack:
303    assert len(node_stack) == len(other_stack)
304    n = node_stack.pop()
305    o = other_stack.pop()
306
307    if ((not isinstance(n, (ast.AST, gast.AST, str)) and n is not None) or
308        (not isinstance(o, (ast.AST, gast.AST, str)) and n is not None) or
309        n.__class__.__name__ != o.__class__.__name__):
310      raise ValueError('inconsistent nodes: {} ({}) and {} ({})'.format(
311          n, n.__class__.__name__, o, o.__class__.__name__))
312
313    yield n, o
314
315    if isinstance(n, str):
316      assert isinstance(o, str), 'The check above should have ensured this'
317      continue
318    if n is None:
319      assert o is None, 'The check above should have ensured this'
320      continue
321
322    for f in n._fields:
323      n_child = getattr(n, f, None)
324      o_child = getattr(o, f, None)
325      if f.startswith('__') or n_child is None or o_child is None:
326        continue
327
328      if isinstance(n_child, (list, tuple)):
329        if (not isinstance(o_child, (list, tuple)) or
330            len(n_child) != len(o_child)):
331          raise ValueError(
332              'inconsistent values for field {}: {} and {}'.format(
333                  f, n_child, o_child))
334        node_stack.extend(n_child)
335        other_stack.extend(o_child)
336
337      elif isinstance(n_child, (gast.AST, ast.AST)):
338        node_stack.append(n_child)
339        other_stack.append(o_child)
340
341      elif n_child != o_child:
342        raise ValueError(
343            'inconsistent values for field {}: {} and {}'.format(
344                f, n_child, o_child))
345