1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""AST manipulation utilities.""" 16 17import ast 18 19import gast 20 21from tensorflow.python.autograph.pyct import anno 22from tensorflow.python.autograph.pyct import parser 23from tensorflow.python.autograph.pyct import qual_names 24 25 26class CleanCopier(object): 27 """NodeTransformer-like visitor that copies an AST.""" 28 29 def __init__(self, preserve_annos): 30 super(CleanCopier, self).__init__() 31 self.preserve_annos = preserve_annos 32 33 def copy(self, node): 34 """Returns a deep copy of node (excluding some fields, see copy_clean).""" 35 36 if isinstance(node, list): 37 return [self.copy(n) for n in node] 38 elif isinstance(node, tuple): 39 return tuple(self.copy(n) for n in node) 40 elif not isinstance(node, (gast.AST, ast.AST)): 41 # Assuming everything that's not an AST, list or tuple is a value type 42 # and may simply be assigned. 43 return node 44 45 assert isinstance(node, (gast.AST, ast.AST)) 46 47 new_fields = {} 48 for f in node._fields: 49 if not f.startswith('__') and hasattr(node, f): 50 new_fields[f] = self.copy(getattr(node, f)) 51 new_node = type(node)(**new_fields) 52 53 if self.preserve_annos: 54 for k in self.preserve_annos: 55 anno.copyanno(node, new_node, k) 56 return new_node 57 58 59def copy_clean(node, preserve_annos=None): 60 """Creates a deep copy of an AST. 61 62 The copy will not include fields that are prefixed by '__', with the 63 exception of user-specified annotations. 64 65 Args: 66 node: ast.AST 67 preserve_annos: Optional[Set[Hashable]], annotation keys to include in the 68 copy 69 Returns: 70 ast.AST 71 """ 72 return CleanCopier(preserve_annos).copy(node) 73 74 75class SymbolRenamer(gast.NodeTransformer): 76 """Transformer that can rename symbols to a simple names.""" 77 78 def __init__(self, name_map): 79 self.name_map = name_map 80 81 def _process_name_node(self, node): 82 qn = anno.getanno(node, anno.Basic.QN) 83 if qn in self.name_map: 84 new_node = gast.Name( 85 str(self.name_map[qn]), 86 ctx=node.ctx, 87 annotation=None, 88 type_comment=None) 89 # All annotations get carried over. 90 for k in anno.keys(node): 91 anno.copyanno(node, new_node, k) 92 return new_node 93 return self.generic_visit(node) 94 95 def _process_list_of_strings(self, names): 96 for i in range(len(names)): 97 qn = qual_names.QN(names[i]) 98 if qn in self.name_map: 99 names[i] = str(self.name_map[qn]) 100 return names 101 102 def visit_Nonlocal(self, node): 103 node.names = self._process_list_of_strings(node.names) 104 return node 105 106 def visit_Global(self, node): 107 node.names = self._process_list_of_strings(node.names) 108 return node 109 110 def visit_Name(self, node): 111 return self._process_name_node(node) 112 113 def visit_Attribute(self, node): 114 if anno.hasanno(node, anno.Basic.QN): 115 return self._process_name_node(node) 116 # Renaming attributes is not supported. 117 return self.generic_visit(node) 118 119 def visit_FunctionDef(self, node): 120 qn = qual_names.QN(node.name) 121 if qn in self.name_map: 122 node.name = str(self.name_map[qn]) 123 return self.generic_visit(node) 124 125 126def rename_symbols(node, name_map): 127 """Renames symbols in an AST. Requires qual_names annotations.""" 128 renamer = SymbolRenamer(name_map) 129 if isinstance(node, list): 130 return [renamer.visit(n) for n in node] 131 elif isinstance(node, tuple): 132 return tuple(renamer.visit(n) for n in node) 133 return renamer.visit(node) 134 135 136def keywords_to_dict(keywords): 137 """Converts a list of ast.keyword objects to a dict.""" 138 keys = [] 139 values = [] 140 for kw in keywords: 141 keys.append(gast.Constant(kw.arg, kind=None)) 142 values.append(kw.value) 143 return gast.Dict(keys=keys, values=values) 144 145 146class PatternMatcher(gast.NodeVisitor): 147 """Matches a node against a pattern represented by a node.""" 148 149 def __init__(self, pattern): 150 self.pattern = pattern 151 self.pattern_stack = [] 152 self.matches = True 153 154 def compare_and_visit(self, node, pattern): 155 self.pattern_stack.append(self.pattern) 156 self.pattern = pattern 157 self.generic_visit(node) 158 self.pattern = self.pattern_stack.pop() 159 160 def no_match(self): 161 self.matches = False 162 return False 163 164 def is_wildcard(self, p): 165 if isinstance(p, (list, tuple)) and len(p) == 1: 166 p, = p 167 if isinstance(p, gast.Name) and p.id == '_': 168 return True 169 if p == '_': 170 return True 171 return False 172 173 def generic_visit(self, node): 174 if not self.matches: 175 return 176 177 pattern = self.pattern 178 for f in node._fields: 179 if f.startswith('__'): 180 continue 181 182 if not hasattr(node, f): 183 if hasattr(pattern, f) and getattr(pattern, f): 184 return self.no_match() 185 else: 186 continue 187 if not hasattr(pattern, f): 188 return self.no_match() 189 190 v = getattr(node, f) 191 p = getattr(pattern, f) 192 193 if self.is_wildcard(p): 194 continue 195 if isinstance(v, (list, tuple)): 196 if not isinstance(p, (list, tuple)) or len(v) != len(p): 197 return self.no_match() 198 for v_item, p_item in zip(v, p): 199 self.compare_and_visit(v_item, p_item) 200 elif isinstance(v, (gast.AST, ast.AST)): 201 if not isinstance(v, type(p)) and not isinstance(p, type(v)): 202 return self.no_match() 203 self.compare_and_visit(v, p) 204 else: 205 # Assume everything else is a value type. 206 if v != p: 207 return self.no_match() 208 209 210def matches(node, pattern): 211 """Basic pattern matcher for AST. 212 213 The pattern may contain wildcards represented by the symbol '_'. A node 214 matches a pattern if for every node in the tree, either there is a node of 215 the same type in pattern, or a Name node with id='_'. 216 217 Args: 218 node: ast.AST 219 pattern: ast.AST 220 Returns: 221 bool 222 """ 223 if isinstance(pattern, str): 224 pattern = parser.parse_str(pattern) 225 226 matcher = PatternMatcher(pattern) 227 matcher.visit(node) 228 return matcher.matches 229 230 231# TODO(mdan): Once we have error tracing, we may be able to just go to SSA. 232def apply_to_single_assignments(targets, values, apply_fn): 233 """Applies a function to each individual assignment. 234 235 This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. 236 It tries to break down the unpacking if possible. In effect, it has the same 237 effect as passing the assigned values in SSA form to apply_fn. 238 239 Examples: 240 241 The following will result in apply_fn(a, c), apply_fn(b, d): 242 243 a, b = c, d 244 245 The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): 246 247 a, b = c 248 249 The following will result in apply_fn(a, (b, c)): 250 251 a = b, c 252 253 It uses the visitor pattern to allow subclasses to process single 254 assignments individually. 255 256 Args: 257 targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be 258 used with the targets field of an ast.Assign node 259 values: ast.AST 260 apply_fn: Callable[[ast.AST, ast.AST], None], called with the 261 respective nodes of each single assignment 262 """ 263 if not isinstance(targets, (list, tuple)): 264 targets = (targets,) 265 for target in targets: 266 if isinstance(target, (gast.Tuple, gast.List)): 267 for i in range(len(target.elts)): 268 target_el = target.elts[i] 269 if isinstance(values, (gast.Tuple, gast.List)): 270 value_el = values.elts[i] 271 else: 272 idx = parser.parse_expression(str(i)) 273 value_el = gast.Subscript(values, idx, ctx=gast.Load()) 274 apply_to_single_assignments(target_el, value_el, apply_fn) 275 else: 276 apply_fn(target, values) 277 278 279def parallel_walk(node, other): 280 """Walks two ASTs in parallel. 281 282 The two trees must have identical structure. 283 284 Args: 285 node: Union[ast.AST, Iterable[ast.AST]] 286 other: Union[ast.AST, Iterable[ast.AST]] 287 Yields: 288 Tuple[ast.AST, ast.AST] 289 Raises: 290 ValueError: if the two trees don't have identical structure. 291 """ 292 if isinstance(node, (list, tuple)): 293 node_stack = list(node) 294 else: 295 node_stack = [node] 296 297 if isinstance(other, (list, tuple)): 298 other_stack = list(other) 299 else: 300 other_stack = [other] 301 302 while node_stack and other_stack: 303 assert len(node_stack) == len(other_stack) 304 n = node_stack.pop() 305 o = other_stack.pop() 306 307 if ((not isinstance(n, (ast.AST, gast.AST, str)) and n is not None) or 308 (not isinstance(o, (ast.AST, gast.AST, str)) and n is not None) or 309 n.__class__.__name__ != o.__class__.__name__): 310 raise ValueError('inconsistent nodes: {} ({}) and {} ({})'.format( 311 n, n.__class__.__name__, o, o.__class__.__name__)) 312 313 yield n, o 314 315 if isinstance(n, str): 316 assert isinstance(o, str), 'The check above should have ensured this' 317 continue 318 if n is None: 319 assert o is None, 'The check above should have ensured this' 320 continue 321 322 for f in n._fields: 323 n_child = getattr(n, f, None) 324 o_child = getattr(o, f, None) 325 if f.startswith('__') or n_child is None or o_child is None: 326 continue 327 328 if isinstance(n_child, (list, tuple)): 329 if (not isinstance(o_child, (list, tuple)) or 330 len(n_child) != len(o_child)): 331 raise ValueError( 332 'inconsistent values for field {}: {} and {}'.format( 333 f, n_child, o_child)) 334 node_stack.extend(n_child) 335 other_stack.extend(o_child) 336 337 elif isinstance(n_child, (gast.AST, ast.AST)): 338 node_stack.append(n_child) 339 other_stack.append(o_child) 340 341 elif n_child != o_child: 342 raise ValueError( 343 'inconsistent values for field {}: {} and {}'.format( 344 f, n_child, o_child)) 345