1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Reaching definition analysis.
16
17This analysis attaches a set of a Definition objects to each symbol, one
18for each distinct definition that may reach it. The Definition objects are
19mutable and may be used by subsequent analyses to further annotate data like
20static type and value information.
21The analysis also attaches the set of the symbols defined at the entry of
22control flow statements.
23
24Requires activity analysis.
25"""
26
27import weakref
28
29import gast
30
31from tensorflow.python.autograph.pyct import anno
32from tensorflow.python.autograph.pyct import cfg
33from tensorflow.python.autograph.pyct import transformer
34
35
36class Definition(object):
37  """Definition objects describe a unique definition of a variable.
38
39  Subclasses of this may be used by passing an appropriate factory function to
40  resolve.
41
42  Attributes:
43    param_of: Optional[ast.AST]
44    directives: Dict, optional definition annotations
45  """
46
47  def __init__(self):
48    self.param_of = None
49    self.directives = {}
50
51  def __repr__(self):
52    return '%s[%d]' % (self.__class__.__name__, id(self))
53
54
55class _NodeState(object):
56  """Abstraction for the state of the CFG walk for reaching definition analysis.
57
58  This is a value type. Only implements the strictly necessary operators.
59
60  Attributes:
61    value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
62        their possible definitions
63  """
64
65  def __init__(self, init_from=None):
66    if init_from:
67      if isinstance(init_from, _NodeState):
68        self.value = {
69            s: set(other_infos) for s, other_infos in init_from.value.items()
70        }
71      elif isinstance(init_from, dict):
72        self.value = {s: set((init_from[s],)) for s in init_from}
73      else:
74        assert False, init_from
75    else:
76      self.value = {}
77
78  def __eq__(self, other):
79    if frozenset(self.value.keys()) != frozenset(other.value.keys()):
80      return False
81    ret = all(self.value[s] == other.value[s] for s in self.value)
82    return ret
83
84  def __ne__(self, other):
85    return not self.__eq__(other)
86
87  def __or__(self, other):
88    assert isinstance(other, _NodeState)
89    result = _NodeState(self)
90    for s, other_infos in other.value.items():
91      if s in result.value:
92        result.value[s].update(other_infos)
93      else:
94        result.value[s] = set(other_infos)
95    return result
96
97  def __sub__(self, other):
98    assert isinstance(other, set)
99    result = _NodeState(self)
100    for s in other:
101      result.value.pop(s, None)
102    return result
103
104  def __repr__(self):
105    return 'NodeState[%s]=%s' % (id(self), repr(self.value))
106
107
108class Analyzer(cfg.GraphVisitor):
109  """CFG visitor that determines reaching definitions at statement level."""
110
111  def __init__(self, graph, definition_factory):
112    self._definition_factory = definition_factory
113    super(Analyzer, self).__init__(graph)
114    self.gen_map = {}
115
116  def init_state(self, _):
117    return _NodeState()
118
119  def visit_node(self, node):
120    prev_defs_out = self.out[node]
121
122    defs_in = _NodeState()
123    for n in node.prev:
124      defs_in |= self.out[n]
125
126    if anno.hasanno(node.ast_node, anno.Static.SCOPE):
127      node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
128      # The definition objects created by each node must be singletons because
129      # their ids are used in equality checks.
130      if node not in self.gen_map:
131        node_symbols = {}
132        # Every binding operation (assign, nonlocal, global, etc.) counts as a
133        # definition, with the exception of del, which only deletes without
134        # creating a new variable.
135        newly_defined = ((node_scope.bound | node_scope.globals) -
136                         node_scope.deleted)
137        for s in newly_defined:
138          def_ = self._definition_factory()
139          node_symbols[s] = def_
140        # Every param receives a definition. Params are not necessarily
141        # considered as "modified".
142        for s, p in node_scope.params.items():
143          def_ = self._definition_factory()
144          def_.param_of = weakref.ref(p)
145          node_symbols[s] = def_
146        self.gen_map[node] = _NodeState(node_symbols)
147
148      gen = self.gen_map[node]
149      kill = node_scope.modified | node_scope.deleted
150      defs_out = gen | (defs_in - kill)
151
152      gen = self.gen_map[node]
153      defs_out = gen | (defs_in - kill)
154
155    else:
156      assert self.can_ignore(node), (node.ast_node, node)
157      defs_out = defs_in
158
159    self.in_[node] = defs_in
160    self.out[node] = defs_out
161
162    return prev_defs_out != defs_out
163
164
165class TreeAnnotator(transformer.Base):
166  """AST visitor that annotates each symbol name with its reaching definitions.
167
168  Simultaneously, the visitor runs the dataflow analysis on each function node,
169  accounting for the effect of closures. For example:
170
171    def foo():
172      bar = 1
173      def baz():
174        # bar = 1 reaches here
175  """
176
177  def __init__(self, source_info, graphs, definition_factory):
178    super(TreeAnnotator, self).__init__(source_info)
179    self.allow_skips = False
180    self.definition_factory = definition_factory
181    self.graphs = graphs
182    self.current_analyzer = None
183    self.current_cfg_node = None
184
185  def visit_FunctionDef(self, node):
186    parent_analyzer = self.current_analyzer
187    subgraph = self.graphs[node]
188
189    analyzer = Analyzer(subgraph, self.definition_factory)
190    analyzer.visit_forward()
191
192    # Recursively process any remaining subfunctions.
193    self.current_analyzer = analyzer
194    node.args = self.visit(node.args)
195    node.body = self.visit_block(node.body)
196    self.current_analyzer = parent_analyzer
197
198    return node
199
200  def visit_Name(self, node):
201    if self.current_analyzer is None:
202      # Names may appear outside function defs - for example in class
203      # definitions.
204      return node
205
206    analyzer = self.current_analyzer
207    cfg_node = self.current_cfg_node
208
209    assert cfg_node is not None, ('name node, %s, outside of any statement?'
210                                  % node.id)
211
212    qn = anno.getanno(node, anno.Basic.QN)
213    if isinstance(node.ctx, gast.Load):
214      anno.setanno(node, anno.Static.DEFINITIONS,
215                   tuple(analyzer.in_[cfg_node].value.get(qn, ())))
216    else:
217      anno.setanno(node, anno.Static.DEFINITIONS,
218                   tuple(analyzer.out[cfg_node].value.get(qn, ())))
219
220    return node
221
222  def _aggregate_predecessors_defined_in(self, node):
223    preds = self.current_analyzer.graph.stmt_prev[node]
224    node_defined_in = set()
225    for p in preds:
226      node_defined_in |= set(self.current_analyzer.out[p].value.keys())
227    anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in))
228
229  def visit_If(self, node):
230    self._aggregate_predecessors_defined_in(node)
231    return self.generic_visit(node)
232
233  def visit_For(self, node):
234    self._aggregate_predecessors_defined_in(node)
235
236    # Manually accounting for the shortcoming described in
237    # cfg.AstToCfg.visit_For.
238    parent = self.current_cfg_node
239    self.current_cfg_node = self.current_analyzer.graph.index[node.iter]
240    node.target = self.visit(node.target)
241    self.current_cfg_node = parent
242
243    node.iter = self.visit(node.iter)
244    node.body = self.visit_block(node.body)
245    node.orelse = self.visit_block(node.orelse)
246
247    return node
248
249  def visit_While(self, node):
250    self._aggregate_predecessors_defined_in(node)
251    return self.generic_visit(node)
252
253  def visit_Try(self, node):
254    self._aggregate_predecessors_defined_in(node)
255    return self.generic_visit(node)
256
257  def visit_ExceptHandler(self, node):
258    self._aggregate_predecessors_defined_in(node)
259    # TODO(mdan): Also track the exception type / name symbols.
260    node.body = self.visit_block(node.body)
261    return node
262
263  def visit(self, node):
264    parent = self.current_cfg_node
265
266    if (self.current_analyzer is not None and
267        node in self.current_analyzer.graph.index):
268      self.current_cfg_node = self.current_analyzer.graph.index[node]
269    node = super(TreeAnnotator, self).visit(node)
270
271    self.current_cfg_node = parent
272    return node
273
274
275def resolve(node, source_info, graphs, definition_factory=Definition):
276  """Resolves reaching definitions for each symbol.
277
278  Args:
279    node: ast.AST
280    source_info: transformer.SourceInfo
281    graphs: Dict[ast.FunctionDef, cfg.Graph]
282    definition_factory: Callable[[], Definition]
283  Returns:
284    ast.AST
285  """
286  visitor = TreeAnnotator(source_info, graphs, definition_factory)
287  node = visitor.visit(node)
288  return node
289