1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles control flow statements: while, for, if.
16
17Python 2 compatibility version. Not maintained.
18"""
19
20import gast
21
22from tensorflow.python.autograph.core import converter
23from tensorflow.python.autograph.lang import directives
24from tensorflow.python.autograph.pyct import anno
25from tensorflow.python.autograph.pyct import ast_util
26from tensorflow.python.autograph.pyct import cfg
27from tensorflow.python.autograph.pyct import parser
28from tensorflow.python.autograph.pyct import qual_names
29from tensorflow.python.autograph.pyct import templates
30from tensorflow.python.autograph.pyct.static_analysis import activity
31from tensorflow.python.autograph.pyct.static_analysis import annos
32from tensorflow.python.autograph.pyct.static_analysis import liveness
33from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
34from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
35
36
37# TODO(mdan): Refactor functions to make them smaller.
38
39
40class ControlFlowTransformer(converter.Base):
41  """Transforms control flow structures like loops an conditionals."""
42
43  def _create_cond_branch(self, body_name, aliased_orig_names,
44                          aliased_new_names, body, returns):
45    if len(returns) == 1:
46      template = """
47        return retval
48      """
49      return_stmt = templates.replace(template, retval=returns[0])
50    else:
51      template = """
52        return (retvals,)
53      """
54      return_stmt = templates.replace(template, retvals=returns)
55
56    if aliased_orig_names:
57      alias_declarations = []
58      for new_name, old_name in zip(aliased_new_names, aliased_orig_names):
59        template = """
60          try:
61            aliased_new_name = aliased_orig_name
62          except NameError:
63            aliased_new_name = ag__.Undefined(symbol_name)
64        """
65
66        alias_declarations.extend(
67            templates.replace(
68                template,
69                aliased_new_name=new_name,
70                aliased_orig_name=old_name,
71                symbol_name=gast.Constant(str(old_name), kind=None)))
72
73      template = """
74        def body_name():
75          alias_declarations
76          body
77          return_stmt
78      """
79      return templates.replace(
80          template,
81          alias_declarations=alias_declarations,
82          body_name=body_name,
83          body=body,
84          return_stmt=return_stmt)
85    else:
86      template = """
87        def body_name():
88          body
89          return_stmt
90      """
91      return templates.replace(
92          template, body_name=body_name, body=body, return_stmt=return_stmt)
93
94  def _create_cond_expr(self, results, test, body_name, orelse_name,
95                        state_getter_name, state_setter_name,
96                        basic_symbol_names, composite_symbol_names):
97    if results is not None:
98      template = """
99        results = ag__.if_stmt(test, body_name, orelse_name,
100                               state_getter_name, state_setter_name,
101                               (basic_symbol_names,),
102                               (composite_symbol_names,))
103      """
104      return templates.replace(
105          template,
106          test=test,
107          results=results,
108          body_name=body_name,
109          orelse_name=orelse_name,
110          state_getter_name=state_getter_name,
111          state_setter_name=state_setter_name,
112          basic_symbol_names=basic_symbol_names,
113          composite_symbol_names=composite_symbol_names)
114    else:
115      template = """
116        ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name,
117                     (basic_symbol_names,), (composite_symbol_names,))
118      """
119      return templates.replace(
120          template,
121          test=test,
122          body_name=body_name,
123          orelse_name=orelse_name,
124          getter_name=state_getter_name,
125          setter_name=state_setter_name,
126          basic_symbol_names=basic_symbol_names,
127          composite_symbol_names=composite_symbol_names)
128
129  def _fmt_symbols(self, symbol_set):
130    if not symbol_set:
131      return 'no variables'
132    return ', '.join(map(str, symbol_set))
133
134  def _determine_aliased_symbols(self, scope, node_defined_in):
135    modified_live = scope.modified & node_defined_in
136    # Composite symbols are handled elsewhere see _create_state_functions
137    return {s for s in modified_live if not s.is_composite()}
138
139  def _create_state_functions(self, composites, state_getter_name,
140                              state_setter_name):
141
142    if composites:
143      composite_tuple = tuple(composites)
144
145      template = """
146        def state_getter_name():
147          return composite_tuple,
148        def state_setter_name(vals):
149          composite_tuple, = vals
150      """
151      node = templates.replace(
152          template,
153          state_getter_name=state_getter_name,
154          state_setter_name=state_setter_name,
155          composite_tuple=composite_tuple)
156    else:
157      template = """
158        def state_getter_name():
159          return ()
160        def state_setter_name(_):
161          pass
162        """
163      node = templates.replace(
164          template,
165          state_getter_name=state_getter_name,
166          state_setter_name=state_setter_name)
167
168    return node
169
170  def _create_loop_options(self, node):
171    if not anno.hasanno(node, anno.Basic.DIRECTIVES):
172      return gast.Dict([], [])
173
174    loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES)
175    if directives.set_loop_options not in loop_directives:
176      return gast.Dict([], [])
177
178    opts_dict = loop_directives[directives.set_loop_options]
179    str_keys, values = zip(*opts_dict.items())
180    keys = [gast.Constant(s, kind=None) for s in str_keys]
181    values = list(values)  # ast and gast don't play well with tuples.
182    return gast.Dict(keys, values)
183
184  def _create_undefined_assigns(self, undefined_symbols):
185    assignments = []
186    for s in undefined_symbols:
187      template = '''
188        var = ag__.Undefined(symbol_name)
189      '''
190      assignments += templates.replace(
191          template,
192          var=s,
193          symbol_name=gast.Constant(s.ssf(), kind=None))
194    return assignments
195
196  def visit_If(self, node):
197    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
198    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
199    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
200    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
201
202    # Note: this information needs to be extracted before the body conversion
203    # that happens in the call to generic_visit below, because the conversion
204    # generates nodes that lack static analysis annotations.
205    need_alias_in_body = self._determine_aliased_symbols(
206        body_scope, defined_in)
207    need_alias_in_orelse = self._determine_aliased_symbols(
208        orelse_scope, defined_in)
209
210    node = self.generic_visit(node)
211
212    modified_in_cond = body_scope.modified | orelse_scope.modified
213    returned_from_cond = set()
214    composites = set()
215    for s in modified_in_cond:
216      if s in live_out and not s.is_composite():
217        returned_from_cond.add(s)
218      if s.is_composite():
219        # Special treatment for compound objects, always return them.
220        # This allows special handling within the if_stmt itself.
221        # For example, in TensorFlow we need to restore the state of composite
222        # symbols to ensure that only effects from the executed branch are seen.
223        composites.add(s)
224
225    created_in_body = body_scope.modified & returned_from_cond - defined_in
226    created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in
227
228    basic_created_in_body = tuple(
229        s for s in created_in_body if not s.is_composite())
230    basic_created_in_orelse = tuple(
231        s for s in created_in_orelse if not s.is_composite())
232
233    # These variables are defined only in a single branch. This is fine in
234    # Python so we pass them through. Another backend, e.g. Tensorflow, may need
235    # to handle these cases specially or throw an Error.
236    possibly_undefined = (set(basic_created_in_body) ^
237                          set(basic_created_in_orelse))
238
239    # Alias the closure variables inside the conditional functions, to allow
240    # the functions access to the respective variables.
241    # We will alias variables independently for body and orelse scope,
242    # because different branches might write different variables.
243    aliased_body_orig_names = tuple(need_alias_in_body)
244    aliased_orelse_orig_names = tuple(need_alias_in_orelse)
245    aliased_body_new_names = tuple(
246        self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
247        for s in aliased_body_orig_names)
248    aliased_orelse_new_names = tuple(
249        self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
250        for s in aliased_orelse_orig_names)
251
252    alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
253    alias_orelse_map = dict(
254        zip(aliased_orelse_orig_names, aliased_orelse_new_names))
255
256    node_body = ast_util.rename_symbols(node.body, alias_body_map)
257    node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)
258
259    cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced)
260    body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
261    orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
262    all_referenced = body_scope.referenced | orelse_scope.referenced
263    state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced)
264    state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced)
265
266    returned_from_cond = tuple(returned_from_cond)
267    composites = tuple(composites)
268
269    if returned_from_cond:
270      if len(returned_from_cond) == 1:
271        cond_results = returned_from_cond[0]
272      else:
273        cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)
274
275      returned_from_body = tuple(
276          alias_body_map[s] if s in need_alias_in_body else s
277          for s in returned_from_cond)
278      returned_from_orelse = tuple(
279          alias_orelse_map[s] if s in need_alias_in_orelse else s
280          for s in returned_from_cond)
281
282    else:
283      # When the cond would return no value, we leave the cond called without
284      # results. That in turn should trigger the side effect guards. The
285      # branch functions will return a dummy value that ensures cond
286      # actually has some return value as well.
287      cond_results = None
288      # TODO(mdan): Replace with None once side_effect_guards is retired.
289      returned_from_body = (templates.replace_as_expression(
290          'ag__.match_staging_level(1, cond_var_name)',
291          cond_var_name=cond_var_name),)
292      returned_from_orelse = (templates.replace_as_expression(
293          'ag__.match_staging_level(1, cond_var_name)',
294          cond_var_name=cond_var_name),)
295
296    cond_assign = self.create_assignment(cond_var_name, node.test)
297    body_def = self._create_cond_branch(
298        body_name,
299        aliased_orig_names=aliased_body_orig_names,
300        aliased_new_names=aliased_body_new_names,
301        body=node_body,
302        returns=returned_from_body)
303    orelse_def = self._create_cond_branch(
304        orelse_name,
305        aliased_orig_names=aliased_orelse_orig_names,
306        aliased_new_names=aliased_orelse_new_names,
307        body=node_orelse,
308        returns=returned_from_orelse)
309    undefined_assigns = self._create_undefined_assigns(possibly_undefined)
310    composite_defs = self._create_state_functions(
311        composites, state_getter_name, state_setter_name)
312
313    basic_symbol_names = tuple(
314        gast.Constant(str(symbol), kind=None) for symbol in returned_from_cond)
315    composite_symbol_names = tuple(
316        gast.Constant(str(symbol), kind=None) for symbol in composites)
317
318    cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name,
319                                       orelse_name, state_getter_name,
320                                       state_setter_name, basic_symbol_names,
321                                       composite_symbol_names)
322
323    if_ast = (
324        undefined_assigns + composite_defs + body_def + orelse_def +
325        cond_assign + cond_expr)
326    return if_ast
327
328  def _get_basic_loop_vars(self, modified_symbols, live_in, live_out):
329    # The loop variables corresponding to simple symbols (e.g. `x`).
330    basic_loop_vars = []
331    for s in modified_symbols:
332      if s.is_composite():
333        # TODO(mdan): Raise an error when this happens for a TF loop.
334        continue
335      # Variables not live into or out of the loop are considered local to the
336      # loop.
337      if s not in live_in and s not in live_out:
338        continue
339      basic_loop_vars.append(s)
340    return frozenset(basic_loop_vars)
341
342  def _get_composite_loop_vars(self, modified_symbols, live_in):
343    # The loop variables corresponding to composite symbols (e.g. `self.x`).
344    composite_loop_vars = []
345    for s in modified_symbols:
346      if not s.is_composite():
347        continue
348      # Mutations made to objects created inside the loop will appear as writes
349      # to composite symbols. Because these mutations appear as modifications
350      # made to composite symbols, we check whether the composite's parent is
351      # actually live into the loop.
352      # Example:
353      #   while cond:
354      #     x = Foo()
355      #     x.foo = 2 * x.foo  # x.foo is live into the loop, but x is not.
356      #
357      # Note that some parents might not be symbols - for example, in x['foo'],
358      # 'foo' is a parent, but it's a literal, not a symbol. We don't check the
359      # liveness of literals.
360      support_set_symbols = tuple(
361          sss for sss in s.support_set if sss.is_symbol())
362      if not all(sss in live_in for sss in support_set_symbols):
363        continue
364      composite_loop_vars.append(s)
365    return frozenset(composite_loop_vars)
366
367  def _get_loop_vars(self, node, modified_symbols):
368    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
369    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
370    live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
371    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
372    reserved_symbols = body_scope.referenced
373
374    basic_loop_vars = self._get_basic_loop_vars(
375        modified_symbols, live_in, live_out)
376    composite_loop_vars = self._get_composite_loop_vars(
377        modified_symbols, live_in)
378
379    # Variable that are used or defined inside the loop, but not defined
380    # before entering the loop. Only simple variables must be defined. The
381    # composite ones will be implicitly checked at runtime.
382    undefined_lives = basic_loop_vars - defined_in
383
384    return (basic_loop_vars, composite_loop_vars, reserved_symbols,
385            undefined_lives)
386
387  def _loop_var_constructs(self, basic_loop_vars):
388    loop_vars = tuple(basic_loop_vars)
389    loop_vars_ast_tuple = gast.Tuple([n.ast() for n in loop_vars], None)
390
391    if len(loop_vars) == 1:
392      loop_vars = loop_vars[0]
393
394    return loop_vars, loop_vars_ast_tuple
395
396  def visit_While(self, node):
397    node = self.generic_visit(node)
398
399    (basic_loop_vars, composite_loop_vars, reserved_symbols,
400     possibly_undefs) = self._get_loop_vars(
401         node,
402         anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified)
403    loop_vars, loop_vars_ast_tuple = self._loop_var_constructs(
404        basic_loop_vars)
405
406    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols)
407    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols)
408    state_functions = self._create_state_functions(
409        composite_loop_vars, state_getter_name, state_setter_name)
410
411    basic_symbol_names = tuple(
412        gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars)
413    composite_symbol_names = tuple(
414        gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars)
415
416    opts = self._create_loop_options(node)
417
418    # TODO(mdan): Use a single template.
419    # If the body and test functions took a single tuple for loop_vars, instead
420    # of *loop_vars, then a single template could be used.
421    if loop_vars:
422      template = """
423        state_functions
424        def body_name(loop_vars):
425          body
426          return loop_vars,
427        def test_name(loop_vars):
428          return test
429        loop_vars_ast_tuple = ag__.while_stmt(
430            test_name,
431            body_name,
432            state_getter_name,
433            state_setter_name,
434            (loop_vars,),
435            (basic_symbol_names,),
436            (composite_symbol_names,),
437            opts)
438      """
439      node = templates.replace(
440          template,
441          loop_vars=loop_vars,
442          loop_vars_ast_tuple=loop_vars_ast_tuple,
443          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
444          test=node.test,
445          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
446          body=node.body,
447          state_functions=state_functions,
448          state_getter_name=state_getter_name,
449          state_setter_name=state_setter_name,
450          basic_symbol_names=basic_symbol_names,
451          composite_symbol_names=composite_symbol_names,
452          opts=opts)
453    else:
454      template = """
455        state_functions
456        def body_name():
457          body
458          return ()
459        def test_name():
460          return test
461        ag__.while_stmt(
462            test_name,
463            body_name,
464            state_getter_name,
465            state_setter_name,
466            (),
467            (),
468            (composite_symbol_names,),
469            opts)
470      """
471      node = templates.replace(
472          template,
473          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
474          test=node.test,
475          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
476          body=node.body,
477          state_functions=state_functions,
478          state_getter_name=state_getter_name,
479          state_setter_name=state_setter_name,
480          composite_symbol_names=composite_symbol_names,
481          opts=opts)
482
483    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
484    return undefined_assigns + node
485
486  def visit_For(self, node):
487    node = self.generic_visit(node)
488
489    (basic_loop_vars, composite_loop_vars,
490     reserved_symbols, possibly_undefs) = self._get_loop_vars(
491         node, (anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified
492                | anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE).modified))
493    loop_vars, loop_vars_ast_tuple = self._loop_var_constructs(
494        basic_loop_vars)
495    body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols)
496
497    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols)
498    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols)
499    state_functions = self._create_state_functions(
500        composite_loop_vars, state_getter_name, state_setter_name)
501
502    if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
503      extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
504      extra_test_name = self.ctx.namer.new_symbol(
505          'extra_test', reserved_symbols)
506      template = """
507        def extra_test_name(loop_vars):
508          return extra_test_expr
509      """
510      extra_test_function = templates.replace(
511          template,
512          extra_test_name=extra_test_name,
513          loop_vars=loop_vars,
514          extra_test_expr=extra_test)
515    else:
516      extra_test_name = parser.parse_expression('None')
517      extra_test_function = []
518
519    # Workaround for PEP-3113
520    # iterates_var holds a single variable with the iterates, which may be a
521    # tuple.
522    iterates_var_name = self.ctx.namer.new_symbol(
523        'iterates', reserved_symbols)
524    template = """
525      iterates = iterates_var_name
526    """
527    iterate_expansion = templates.replace(
528        template,
529        iterates=node.target,
530        iterates_var_name=iterates_var_name)
531
532    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
533
534    basic_symbol_names = tuple(
535        gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars)
536    composite_symbol_names = tuple(
537        gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars)
538
539    opts = self._create_loop_options(node)
540
541    # TODO(mdan): Use a single template.
542    # If the body and test functions took a single tuple for loop_vars, instead
543    # of *loop_vars, then a single template could be used.
544    if loop_vars:
545      template = """
546        undefined_assigns
547        state_functions
548        def body_name(iterates_var_name, loop_vars):
549          iterate_expansion
550          body
551          return loop_vars,
552        extra_test_function
553        loop_vars_ast_tuple = ag__.for_stmt(
554            iter_,
555            extra_test_name,
556            body_name,
557            state_getter_name,
558            state_setter_name,
559            (loop_vars,),
560            (basic_symbol_names,),
561            (composite_symbol_names,),
562            opts)
563      """
564      return templates.replace(
565          template,
566          undefined_assigns=undefined_assigns,
567          loop_vars=loop_vars,
568          loop_vars_ast_tuple=loop_vars_ast_tuple,
569          iter_=node.iter,
570          iterate_expansion=iterate_expansion,
571          iterates_var_name=iterates_var_name,
572          extra_test_name=extra_test_name,
573          extra_test_function=extra_test_function,
574          body_name=body_name,
575          body=node.body,
576          state_functions=state_functions,
577          state_getter_name=state_getter_name,
578          state_setter_name=state_setter_name,
579          basic_symbol_names=basic_symbol_names,
580          composite_symbol_names=composite_symbol_names,
581          opts=opts)
582    else:
583      template = """
584        undefined_assigns
585        state_functions
586        def body_name(iterates_var_name):
587          iterate_expansion
588          body
589          return ()
590        extra_test_function
591        ag__.for_stmt(
592            iter_,
593            extra_test_name,
594            body_name,
595            state_getter_name,
596            state_setter_name,
597            (),
598            (),
599            (composite_symbol_names,),
600            opts)
601      """
602      return templates.replace(
603          template,
604          undefined_assigns=undefined_assigns,
605          iter_=node.iter,
606          iterate_expansion=iterate_expansion,
607          iterates_var_name=iterates_var_name,
608          extra_test_name=extra_test_name,
609          extra_test_function=extra_test_function,
610          body_name=body_name,
611          body=node.body,
612          state_functions=state_functions,
613          state_getter_name=state_getter_name,
614          state_setter_name=state_setter_name,
615          composite_symbol_names=composite_symbol_names,
616          opts=opts)
617
618
619class AnnotatedDef(reaching_definitions.Definition):
620
621  def __init__(self):
622    super(AnnotatedDef, self).__init__()
623    self.directives = {}
624
625
626def transform(node, ctx):
627  graphs = cfg.build(node)
628  node = qual_names.resolve(node)
629  node = activity.resolve(node, ctx, None)
630  node = reaching_definitions.resolve(node, ctx, graphs)
631  node = reaching_fndefs.resolve(node, ctx, graphs)
632  node = liveness.resolve(node, ctx, graphs)
633
634  node = ControlFlowTransformer(ctx).visit(node)
635  return node
636