1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Reaching definition analysis. 16 17This analysis attaches a set of a Definition objects to each symbol, one 18for each distinct definition that may reach it. The Definition objects are 19mutable and may be used by subsequent analyses to further annotate data like 20static type and value information. 21The analysis also attaches the set of the symbols defined at the entry of 22control flow statements. 23 24Requires activity analysis. 25""" 26 27import weakref 28 29import gast 30 31from tensorflow.python.autograph.pyct import anno 32from tensorflow.python.autograph.pyct import cfg 33from tensorflow.python.autograph.pyct import transformer 34 35 36class Definition(object): 37 """Definition objects describe a unique definition of a variable. 38 39 Subclasses of this may be used by passing an appropriate factory function to 40 resolve. 41 42 Attributes: 43 param_of: Optional[ast.AST] 44 directives: Dict, optional definition annotations 45 """ 46 47 def __init__(self): 48 self.param_of = None 49 self.directives = {} 50 51 def __repr__(self): 52 return '%s[%d]' % (self.__class__.__name__, id(self)) 53 54 55class _NodeState(object): 56 """Abstraction for the state of the CFG walk for reaching definition analysis. 57 58 This is a value type. Only implements the strictly necessary operators. 59 60 Attributes: 61 value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and 62 their possible definitions 63 """ 64 65 def __init__(self, init_from=None): 66 if init_from: 67 if isinstance(init_from, _NodeState): 68 self.value = { 69 s: set(other_infos) for s, other_infos in init_from.value.items() 70 } 71 elif isinstance(init_from, dict): 72 self.value = {s: set((init_from[s],)) for s in init_from} 73 else: 74 assert False, init_from 75 else: 76 self.value = {} 77 78 def __eq__(self, other): 79 if frozenset(self.value.keys()) != frozenset(other.value.keys()): 80 return False 81 ret = all(self.value[s] == other.value[s] for s in self.value) 82 return ret 83 84 def __ne__(self, other): 85 return not self.__eq__(other) 86 87 def __or__(self, other): 88 assert isinstance(other, _NodeState) 89 result = _NodeState(self) 90 for s, other_infos in other.value.items(): 91 if s in result.value: 92 result.value[s].update(other_infos) 93 else: 94 result.value[s] = set(other_infos) 95 return result 96 97 def __sub__(self, other): 98 assert isinstance(other, set) 99 result = _NodeState(self) 100 for s in other: 101 result.value.pop(s, None) 102 return result 103 104 def __repr__(self): 105 return 'NodeState[%s]=%s' % (id(self), repr(self.value)) 106 107 108class Analyzer(cfg.GraphVisitor): 109 """CFG visitor that determines reaching definitions at statement level.""" 110 111 def __init__(self, graph, definition_factory): 112 self._definition_factory = definition_factory 113 super(Analyzer, self).__init__(graph) 114 self.gen_map = {} 115 116 def init_state(self, _): 117 return _NodeState() 118 119 def visit_node(self, node): 120 prev_defs_out = self.out[node] 121 122 defs_in = _NodeState() 123 for n in node.prev: 124 defs_in |= self.out[n] 125 126 if anno.hasanno(node.ast_node, anno.Static.SCOPE): 127 node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) 128 # The definition objects created by each node must be singletons because 129 # their ids are used in equality checks. 130 if node not in self.gen_map: 131 node_symbols = {} 132 # Every binding operation (assign, nonlocal, global, etc.) counts as a 133 # definition, with the exception of del, which only deletes without 134 # creating a new variable. 135 newly_defined = ((node_scope.bound | node_scope.globals) - 136 node_scope.deleted) 137 for s in newly_defined: 138 def_ = self._definition_factory() 139 node_symbols[s] = def_ 140 # Every param receives a definition. Params are not necessarily 141 # considered as "modified". 142 for s, p in node_scope.params.items(): 143 def_ = self._definition_factory() 144 def_.param_of = weakref.ref(p) 145 node_symbols[s] = def_ 146 self.gen_map[node] = _NodeState(node_symbols) 147 148 gen = self.gen_map[node] 149 kill = node_scope.modified | node_scope.deleted 150 defs_out = gen | (defs_in - kill) 151 152 gen = self.gen_map[node] 153 defs_out = gen | (defs_in - kill) 154 155 else: 156 assert self.can_ignore(node), (node.ast_node, node) 157 defs_out = defs_in 158 159 self.in_[node] = defs_in 160 self.out[node] = defs_out 161 162 return prev_defs_out != defs_out 163 164 165class TreeAnnotator(transformer.Base): 166 """AST visitor that annotates each symbol name with its reaching definitions. 167 168 Simultaneously, the visitor runs the dataflow analysis on each function node, 169 accounting for the effect of closures. For example: 170 171 def foo(): 172 bar = 1 173 def baz(): 174 # bar = 1 reaches here 175 """ 176 177 def __init__(self, source_info, graphs, definition_factory): 178 super(TreeAnnotator, self).__init__(source_info) 179 self.allow_skips = False 180 self.definition_factory = definition_factory 181 self.graphs = graphs 182 self.current_analyzer = None 183 self.current_cfg_node = None 184 185 def visit_FunctionDef(self, node): 186 parent_analyzer = self.current_analyzer 187 subgraph = self.graphs[node] 188 189 analyzer = Analyzer(subgraph, self.definition_factory) 190 analyzer.visit_forward() 191 192 # Recursively process any remaining subfunctions. 193 self.current_analyzer = analyzer 194 node.args = self.visit(node.args) 195 node.body = self.visit_block(node.body) 196 self.current_analyzer = parent_analyzer 197 198 return node 199 200 def visit_Name(self, node): 201 if self.current_analyzer is None: 202 # Names may appear outside function defs - for example in class 203 # definitions. 204 return node 205 206 analyzer = self.current_analyzer 207 cfg_node = self.current_cfg_node 208 209 assert cfg_node is not None, ('name node, %s, outside of any statement?' 210 % node.id) 211 212 qn = anno.getanno(node, anno.Basic.QN) 213 if isinstance(node.ctx, gast.Load): 214 anno.setanno(node, anno.Static.DEFINITIONS, 215 tuple(analyzer.in_[cfg_node].value.get(qn, ()))) 216 else: 217 anno.setanno(node, anno.Static.DEFINITIONS, 218 tuple(analyzer.out[cfg_node].value.get(qn, ()))) 219 220 return node 221 222 def _aggregate_predecessors_defined_in(self, node): 223 preds = self.current_analyzer.graph.stmt_prev[node] 224 node_defined_in = set() 225 for p in preds: 226 node_defined_in |= set(self.current_analyzer.out[p].value.keys()) 227 anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in)) 228 229 def visit_If(self, node): 230 self._aggregate_predecessors_defined_in(node) 231 return self.generic_visit(node) 232 233 def visit_For(self, node): 234 self._aggregate_predecessors_defined_in(node) 235 236 # Manually accounting for the shortcoming described in 237 # cfg.AstToCfg.visit_For. 238 parent = self.current_cfg_node 239 self.current_cfg_node = self.current_analyzer.graph.index[node.iter] 240 node.target = self.visit(node.target) 241 self.current_cfg_node = parent 242 243 node.iter = self.visit(node.iter) 244 node.body = self.visit_block(node.body) 245 node.orelse = self.visit_block(node.orelse) 246 247 return node 248 249 def visit_While(self, node): 250 self._aggregate_predecessors_defined_in(node) 251 return self.generic_visit(node) 252 253 def visit_Try(self, node): 254 self._aggregate_predecessors_defined_in(node) 255 return self.generic_visit(node) 256 257 def visit_ExceptHandler(self, node): 258 self._aggregate_predecessors_defined_in(node) 259 # TODO(mdan): Also track the exception type / name symbols. 260 node.body = self.visit_block(node.body) 261 return node 262 263 def visit(self, node): 264 parent = self.current_cfg_node 265 266 if (self.current_analyzer is not None and 267 node in self.current_analyzer.graph.index): 268 self.current_cfg_node = self.current_analyzer.graph.index[node] 269 node = super(TreeAnnotator, self).visit(node) 270 271 self.current_cfg_node = parent 272 return node 273 274 275def resolve(node, source_info, graphs, definition_factory=Definition): 276 """Resolves reaching definitions for each symbol. 277 278 Args: 279 node: ast.AST 280 source_info: transformer.SourceInfo 281 graphs: Dict[ast.FunctionDef, cfg.Graph] 282 definition_factory: Callable[[], Definition] 283 Returns: 284 ast.AST 285 """ 286 visitor = TreeAnnotator(source_info, graphs, definition_factory) 287 node = visitor.visit(node) 288 return node 289