1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Handles control flow statements: while, for, if. 16 17Python 2 compatibility version. Not maintained. 18""" 19 20import gast 21 22from tensorflow.python.autograph.core import converter 23from tensorflow.python.autograph.lang import directives 24from tensorflow.python.autograph.pyct import anno 25from tensorflow.python.autograph.pyct import ast_util 26from tensorflow.python.autograph.pyct import cfg 27from tensorflow.python.autograph.pyct import parser 28from tensorflow.python.autograph.pyct import qual_names 29from tensorflow.python.autograph.pyct import templates 30from tensorflow.python.autograph.pyct.static_analysis import activity 31from tensorflow.python.autograph.pyct.static_analysis import annos 32from tensorflow.python.autograph.pyct.static_analysis import liveness 33from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions 34from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs 35 36 37# TODO(mdan): Refactor functions to make them smaller. 38 39 40class ControlFlowTransformer(converter.Base): 41 """Transforms control flow structures like loops an conditionals.""" 42 43 def _create_cond_branch(self, body_name, aliased_orig_names, 44 aliased_new_names, body, returns): 45 if len(returns) == 1: 46 template = """ 47 return retval 48 """ 49 return_stmt = templates.replace(template, retval=returns[0]) 50 else: 51 template = """ 52 return (retvals,) 53 """ 54 return_stmt = templates.replace(template, retvals=returns) 55 56 if aliased_orig_names: 57 alias_declarations = [] 58 for new_name, old_name in zip(aliased_new_names, aliased_orig_names): 59 template = """ 60 try: 61 aliased_new_name = aliased_orig_name 62 except NameError: 63 aliased_new_name = ag__.Undefined(symbol_name) 64 """ 65 66 alias_declarations.extend( 67 templates.replace( 68 template, 69 aliased_new_name=new_name, 70 aliased_orig_name=old_name, 71 symbol_name=gast.Constant(str(old_name), kind=None))) 72 73 template = """ 74 def body_name(): 75 alias_declarations 76 body 77 return_stmt 78 """ 79 return templates.replace( 80 template, 81 alias_declarations=alias_declarations, 82 body_name=body_name, 83 body=body, 84 return_stmt=return_stmt) 85 else: 86 template = """ 87 def body_name(): 88 body 89 return_stmt 90 """ 91 return templates.replace( 92 template, body_name=body_name, body=body, return_stmt=return_stmt) 93 94 def _create_cond_expr(self, results, test, body_name, orelse_name, 95 state_getter_name, state_setter_name, 96 basic_symbol_names, composite_symbol_names): 97 if results is not None: 98 template = """ 99 results = ag__.if_stmt(test, body_name, orelse_name, 100 state_getter_name, state_setter_name, 101 (basic_symbol_names,), 102 (composite_symbol_names,)) 103 """ 104 return templates.replace( 105 template, 106 test=test, 107 results=results, 108 body_name=body_name, 109 orelse_name=orelse_name, 110 state_getter_name=state_getter_name, 111 state_setter_name=state_setter_name, 112 basic_symbol_names=basic_symbol_names, 113 composite_symbol_names=composite_symbol_names) 114 else: 115 template = """ 116 ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name, 117 (basic_symbol_names,), (composite_symbol_names,)) 118 """ 119 return templates.replace( 120 template, 121 test=test, 122 body_name=body_name, 123 orelse_name=orelse_name, 124 getter_name=state_getter_name, 125 setter_name=state_setter_name, 126 basic_symbol_names=basic_symbol_names, 127 composite_symbol_names=composite_symbol_names) 128 129 def _fmt_symbols(self, symbol_set): 130 if not symbol_set: 131 return 'no variables' 132 return ', '.join(map(str, symbol_set)) 133 134 def _determine_aliased_symbols(self, scope, node_defined_in): 135 modified_live = scope.modified & node_defined_in 136 # Composite symbols are handled elsewhere see _create_state_functions 137 return {s for s in modified_live if not s.is_composite()} 138 139 def _create_state_functions(self, composites, state_getter_name, 140 state_setter_name): 141 142 if composites: 143 composite_tuple = tuple(composites) 144 145 template = """ 146 def state_getter_name(): 147 return composite_tuple, 148 def state_setter_name(vals): 149 composite_tuple, = vals 150 """ 151 node = templates.replace( 152 template, 153 state_getter_name=state_getter_name, 154 state_setter_name=state_setter_name, 155 composite_tuple=composite_tuple) 156 else: 157 template = """ 158 def state_getter_name(): 159 return () 160 def state_setter_name(_): 161 pass 162 """ 163 node = templates.replace( 164 template, 165 state_getter_name=state_getter_name, 166 state_setter_name=state_setter_name) 167 168 return node 169 170 def _create_loop_options(self, node): 171 if not anno.hasanno(node, anno.Basic.DIRECTIVES): 172 return gast.Dict([], []) 173 174 loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES) 175 if directives.set_loop_options not in loop_directives: 176 return gast.Dict([], []) 177 178 opts_dict = loop_directives[directives.set_loop_options] 179 str_keys, values = zip(*opts_dict.items()) 180 keys = [gast.Constant(s, kind=None) for s in str_keys] 181 values = list(values) # ast and gast don't play well with tuples. 182 return gast.Dict(keys, values) 183 184 def _create_undefined_assigns(self, undefined_symbols): 185 assignments = [] 186 for s in undefined_symbols: 187 template = ''' 188 var = ag__.Undefined(symbol_name) 189 ''' 190 assignments += templates.replace( 191 template, 192 var=s, 193 symbol_name=gast.Constant(s.ssf(), kind=None)) 194 return assignments 195 196 def visit_If(self, node): 197 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 198 orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) 199 defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) 200 live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) 201 202 # Note: this information needs to be extracted before the body conversion 203 # that happens in the call to generic_visit below, because the conversion 204 # generates nodes that lack static analysis annotations. 205 need_alias_in_body = self._determine_aliased_symbols( 206 body_scope, defined_in) 207 need_alias_in_orelse = self._determine_aliased_symbols( 208 orelse_scope, defined_in) 209 210 node = self.generic_visit(node) 211 212 modified_in_cond = body_scope.modified | orelse_scope.modified 213 returned_from_cond = set() 214 composites = set() 215 for s in modified_in_cond: 216 if s in live_out and not s.is_composite(): 217 returned_from_cond.add(s) 218 if s.is_composite(): 219 # Special treatment for compound objects, always return them. 220 # This allows special handling within the if_stmt itself. 221 # For example, in TensorFlow we need to restore the state of composite 222 # symbols to ensure that only effects from the executed branch are seen. 223 composites.add(s) 224 225 created_in_body = body_scope.modified & returned_from_cond - defined_in 226 created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in 227 228 basic_created_in_body = tuple( 229 s for s in created_in_body if not s.is_composite()) 230 basic_created_in_orelse = tuple( 231 s for s in created_in_orelse if not s.is_composite()) 232 233 # These variables are defined only in a single branch. This is fine in 234 # Python so we pass them through. Another backend, e.g. Tensorflow, may need 235 # to handle these cases specially or throw an Error. 236 possibly_undefined = (set(basic_created_in_body) ^ 237 set(basic_created_in_orelse)) 238 239 # Alias the closure variables inside the conditional functions, to allow 240 # the functions access to the respective variables. 241 # We will alias variables independently for body and orelse scope, 242 # because different branches might write different variables. 243 aliased_body_orig_names = tuple(need_alias_in_body) 244 aliased_orelse_orig_names = tuple(need_alias_in_orelse) 245 aliased_body_new_names = tuple( 246 self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) 247 for s in aliased_body_orig_names) 248 aliased_orelse_new_names = tuple( 249 self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) 250 for s in aliased_orelse_orig_names) 251 252 alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) 253 alias_orelse_map = dict( 254 zip(aliased_orelse_orig_names, aliased_orelse_new_names)) 255 256 node_body = ast_util.rename_symbols(node.body, alias_body_map) 257 node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) 258 259 cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) 260 body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) 261 orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) 262 all_referenced = body_scope.referenced | orelse_scope.referenced 263 state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced) 264 state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced) 265 266 returned_from_cond = tuple(returned_from_cond) 267 composites = tuple(composites) 268 269 if returned_from_cond: 270 if len(returned_from_cond) == 1: 271 cond_results = returned_from_cond[0] 272 else: 273 cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) 274 275 returned_from_body = tuple( 276 alias_body_map[s] if s in need_alias_in_body else s 277 for s in returned_from_cond) 278 returned_from_orelse = tuple( 279 alias_orelse_map[s] if s in need_alias_in_orelse else s 280 for s in returned_from_cond) 281 282 else: 283 # When the cond would return no value, we leave the cond called without 284 # results. That in turn should trigger the side effect guards. The 285 # branch functions will return a dummy value that ensures cond 286 # actually has some return value as well. 287 cond_results = None 288 # TODO(mdan): Replace with None once side_effect_guards is retired. 289 returned_from_body = (templates.replace_as_expression( 290 'ag__.match_staging_level(1, cond_var_name)', 291 cond_var_name=cond_var_name),) 292 returned_from_orelse = (templates.replace_as_expression( 293 'ag__.match_staging_level(1, cond_var_name)', 294 cond_var_name=cond_var_name),) 295 296 cond_assign = self.create_assignment(cond_var_name, node.test) 297 body_def = self._create_cond_branch( 298 body_name, 299 aliased_orig_names=aliased_body_orig_names, 300 aliased_new_names=aliased_body_new_names, 301 body=node_body, 302 returns=returned_from_body) 303 orelse_def = self._create_cond_branch( 304 orelse_name, 305 aliased_orig_names=aliased_orelse_orig_names, 306 aliased_new_names=aliased_orelse_new_names, 307 body=node_orelse, 308 returns=returned_from_orelse) 309 undefined_assigns = self._create_undefined_assigns(possibly_undefined) 310 composite_defs = self._create_state_functions( 311 composites, state_getter_name, state_setter_name) 312 313 basic_symbol_names = tuple( 314 gast.Constant(str(symbol), kind=None) for symbol in returned_from_cond) 315 composite_symbol_names = tuple( 316 gast.Constant(str(symbol), kind=None) for symbol in composites) 317 318 cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, 319 orelse_name, state_getter_name, 320 state_setter_name, basic_symbol_names, 321 composite_symbol_names) 322 323 if_ast = ( 324 undefined_assigns + composite_defs + body_def + orelse_def + 325 cond_assign + cond_expr) 326 return if_ast 327 328 def _get_basic_loop_vars(self, modified_symbols, live_in, live_out): 329 # The loop variables corresponding to simple symbols (e.g. `x`). 330 basic_loop_vars = [] 331 for s in modified_symbols: 332 if s.is_composite(): 333 # TODO(mdan): Raise an error when this happens for a TF loop. 334 continue 335 # Variables not live into or out of the loop are considered local to the 336 # loop. 337 if s not in live_in and s not in live_out: 338 continue 339 basic_loop_vars.append(s) 340 return frozenset(basic_loop_vars) 341 342 def _get_composite_loop_vars(self, modified_symbols, live_in): 343 # The loop variables corresponding to composite symbols (e.g. `self.x`). 344 composite_loop_vars = [] 345 for s in modified_symbols: 346 if not s.is_composite(): 347 continue 348 # Mutations made to objects created inside the loop will appear as writes 349 # to composite symbols. Because these mutations appear as modifications 350 # made to composite symbols, we check whether the composite's parent is 351 # actually live into the loop. 352 # Example: 353 # while cond: 354 # x = Foo() 355 # x.foo = 2 * x.foo # x.foo is live into the loop, but x is not. 356 # 357 # Note that some parents might not be symbols - for example, in x['foo'], 358 # 'foo' is a parent, but it's a literal, not a symbol. We don't check the 359 # liveness of literals. 360 support_set_symbols = tuple( 361 sss for sss in s.support_set if sss.is_symbol()) 362 if not all(sss in live_in for sss in support_set_symbols): 363 continue 364 composite_loop_vars.append(s) 365 return frozenset(composite_loop_vars) 366 367 def _get_loop_vars(self, node, modified_symbols): 368 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 369 defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) 370 live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) 371 live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) 372 reserved_symbols = body_scope.referenced 373 374 basic_loop_vars = self._get_basic_loop_vars( 375 modified_symbols, live_in, live_out) 376 composite_loop_vars = self._get_composite_loop_vars( 377 modified_symbols, live_in) 378 379 # Variable that are used or defined inside the loop, but not defined 380 # before entering the loop. Only simple variables must be defined. The 381 # composite ones will be implicitly checked at runtime. 382 undefined_lives = basic_loop_vars - defined_in 383 384 return (basic_loop_vars, composite_loop_vars, reserved_symbols, 385 undefined_lives) 386 387 def _loop_var_constructs(self, basic_loop_vars): 388 loop_vars = tuple(basic_loop_vars) 389 loop_vars_ast_tuple = gast.Tuple([n.ast() for n in loop_vars], None) 390 391 if len(loop_vars) == 1: 392 loop_vars = loop_vars[0] 393 394 return loop_vars, loop_vars_ast_tuple 395 396 def visit_While(self, node): 397 node = self.generic_visit(node) 398 399 (basic_loop_vars, composite_loop_vars, reserved_symbols, 400 possibly_undefs) = self._get_loop_vars( 401 node, 402 anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified) 403 loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( 404 basic_loop_vars) 405 406 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) 407 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) 408 state_functions = self._create_state_functions( 409 composite_loop_vars, state_getter_name, state_setter_name) 410 411 basic_symbol_names = tuple( 412 gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars) 413 composite_symbol_names = tuple( 414 gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars) 415 416 opts = self._create_loop_options(node) 417 418 # TODO(mdan): Use a single template. 419 # If the body and test functions took a single tuple for loop_vars, instead 420 # of *loop_vars, then a single template could be used. 421 if loop_vars: 422 template = """ 423 state_functions 424 def body_name(loop_vars): 425 body 426 return loop_vars, 427 def test_name(loop_vars): 428 return test 429 loop_vars_ast_tuple = ag__.while_stmt( 430 test_name, 431 body_name, 432 state_getter_name, 433 state_setter_name, 434 (loop_vars,), 435 (basic_symbol_names,), 436 (composite_symbol_names,), 437 opts) 438 """ 439 node = templates.replace( 440 template, 441 loop_vars=loop_vars, 442 loop_vars_ast_tuple=loop_vars_ast_tuple, 443 test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), 444 test=node.test, 445 body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), 446 body=node.body, 447 state_functions=state_functions, 448 state_getter_name=state_getter_name, 449 state_setter_name=state_setter_name, 450 basic_symbol_names=basic_symbol_names, 451 composite_symbol_names=composite_symbol_names, 452 opts=opts) 453 else: 454 template = """ 455 state_functions 456 def body_name(): 457 body 458 return () 459 def test_name(): 460 return test 461 ag__.while_stmt( 462 test_name, 463 body_name, 464 state_getter_name, 465 state_setter_name, 466 (), 467 (), 468 (composite_symbol_names,), 469 opts) 470 """ 471 node = templates.replace( 472 template, 473 test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), 474 test=node.test, 475 body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), 476 body=node.body, 477 state_functions=state_functions, 478 state_getter_name=state_getter_name, 479 state_setter_name=state_setter_name, 480 composite_symbol_names=composite_symbol_names, 481 opts=opts) 482 483 undefined_assigns = self._create_undefined_assigns(possibly_undefs) 484 return undefined_assigns + node 485 486 def visit_For(self, node): 487 node = self.generic_visit(node) 488 489 (basic_loop_vars, composite_loop_vars, 490 reserved_symbols, possibly_undefs) = self._get_loop_vars( 491 node, (anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified 492 | anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE).modified)) 493 loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( 494 basic_loop_vars) 495 body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols) 496 497 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) 498 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) 499 state_functions = self._create_state_functions( 500 composite_loop_vars, state_getter_name, state_setter_name) 501 502 if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): 503 extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) 504 extra_test_name = self.ctx.namer.new_symbol( 505 'extra_test', reserved_symbols) 506 template = """ 507 def extra_test_name(loop_vars): 508 return extra_test_expr 509 """ 510 extra_test_function = templates.replace( 511 template, 512 extra_test_name=extra_test_name, 513 loop_vars=loop_vars, 514 extra_test_expr=extra_test) 515 else: 516 extra_test_name = parser.parse_expression('None') 517 extra_test_function = [] 518 519 # Workaround for PEP-3113 520 # iterates_var holds a single variable with the iterates, which may be a 521 # tuple. 522 iterates_var_name = self.ctx.namer.new_symbol( 523 'iterates', reserved_symbols) 524 template = """ 525 iterates = iterates_var_name 526 """ 527 iterate_expansion = templates.replace( 528 template, 529 iterates=node.target, 530 iterates_var_name=iterates_var_name) 531 532 undefined_assigns = self._create_undefined_assigns(possibly_undefs) 533 534 basic_symbol_names = tuple( 535 gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars) 536 composite_symbol_names = tuple( 537 gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars) 538 539 opts = self._create_loop_options(node) 540 541 # TODO(mdan): Use a single template. 542 # If the body and test functions took a single tuple for loop_vars, instead 543 # of *loop_vars, then a single template could be used. 544 if loop_vars: 545 template = """ 546 undefined_assigns 547 state_functions 548 def body_name(iterates_var_name, loop_vars): 549 iterate_expansion 550 body 551 return loop_vars, 552 extra_test_function 553 loop_vars_ast_tuple = ag__.for_stmt( 554 iter_, 555 extra_test_name, 556 body_name, 557 state_getter_name, 558 state_setter_name, 559 (loop_vars,), 560 (basic_symbol_names,), 561 (composite_symbol_names,), 562 opts) 563 """ 564 return templates.replace( 565 template, 566 undefined_assigns=undefined_assigns, 567 loop_vars=loop_vars, 568 loop_vars_ast_tuple=loop_vars_ast_tuple, 569 iter_=node.iter, 570 iterate_expansion=iterate_expansion, 571 iterates_var_name=iterates_var_name, 572 extra_test_name=extra_test_name, 573 extra_test_function=extra_test_function, 574 body_name=body_name, 575 body=node.body, 576 state_functions=state_functions, 577 state_getter_name=state_getter_name, 578 state_setter_name=state_setter_name, 579 basic_symbol_names=basic_symbol_names, 580 composite_symbol_names=composite_symbol_names, 581 opts=opts) 582 else: 583 template = """ 584 undefined_assigns 585 state_functions 586 def body_name(iterates_var_name): 587 iterate_expansion 588 body 589 return () 590 extra_test_function 591 ag__.for_stmt( 592 iter_, 593 extra_test_name, 594 body_name, 595 state_getter_name, 596 state_setter_name, 597 (), 598 (), 599 (composite_symbol_names,), 600 opts) 601 """ 602 return templates.replace( 603 template, 604 undefined_assigns=undefined_assigns, 605 iter_=node.iter, 606 iterate_expansion=iterate_expansion, 607 iterates_var_name=iterates_var_name, 608 extra_test_name=extra_test_name, 609 extra_test_function=extra_test_function, 610 body_name=body_name, 611 body=node.body, 612 state_functions=state_functions, 613 state_getter_name=state_getter_name, 614 state_setter_name=state_setter_name, 615 composite_symbol_names=composite_symbol_names, 616 opts=opts) 617 618 619class AnnotatedDef(reaching_definitions.Definition): 620 621 def __init__(self): 622 super(AnnotatedDef, self).__init__() 623 self.directives = {} 624 625 626def transform(node, ctx): 627 graphs = cfg.build(node) 628 node = qual_names.resolve(node) 629 node = activity.resolve(node, ctx, None) 630 node = reaching_definitions.resolve(node, ctx, graphs) 631 node = reaching_fndefs.resolve(node, ctx, graphs) 632 node = liveness.resolve(node, ctx, graphs) 633 634 node = ControlFlowTransformer(ctx).visit(node) 635 return node 636