1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Generic source code transformation infrastructure.""" 16 17import inspect 18import threading 19import types 20 21import gast 22 23from tensorflow.python.autograph.pyct import cache 24from tensorflow.python.autograph.pyct import inspect_utils 25from tensorflow.python.autograph.pyct import loader 26from tensorflow.python.autograph.pyct import naming 27from tensorflow.python.autograph.pyct import origin_info 28from tensorflow.python.autograph.pyct import parser 29from tensorflow.python.autograph.pyct import templates 30from tensorflow.python.autograph.pyct import transformer 31from tensorflow.python.autograph.utils import ag_logging as logging 32 33 34def _wrap_into_factory(nodes, entity_name, inner_factory_name, 35 outer_factory_name, closure_vars, factory_args, 36 future_features): 37 """Wraps an AST into the body of a factory with consistent lexical context. 38 39 The AST is expected to define some symbol with a name given by `entity_name`. 40 41 This mechanism ensures that the resulting transformed entity has lexical 42 scoping identical to that of the source entity, while allowing extra 43 parametrization. 44 45 Two nested factories achieve the following: 46 47 1. The inner factory dynamically creates the entity represented by `nodes`. 48 2. The inner factory is parametrized by a custom set of arguments. 49 3. The inner factory has a closure identical to that of the transformed 50 entity. 51 4. The inner factory has local variables named like `args`, which `nodes` may 52 use as additional parameters. 53 5. The inner factory returns the variables given by `entity_name`. 54 6. The outer factory is niladic. 55 7. The outer factory has no closure. 56 8. The outer factory creates the necessary lexical scope for the inner 57 factory, so that the loaded code has the given configuration for 58 closure/globals. 59 9. The outer factory returns the inner factory. 60 61 Roughly speaking, the following code is generated: 62 63 from __future__ import future_feature_1 64 from __future__ import future_feature_2 65 ... 66 67 def outer_factory(): 68 closure_var_1 = None 69 closure_var_2 = None 70 ... 71 72 def inner_factory(arg_1, arg_2, ...): 73 <<nodes>> 74 return entity 75 76 return inner_factory 77 78 The lexical scoping is created using dummy symbol declarations which create 79 local variables in the body of the outer factory, so that the Python parser 80 correctly marks them as free non-global variables upon load (that is, it 81 creates cell slots for each symbol. These symbols are initialized with None, 82 but their values are not expected to be used; instead, the caller is expected 83 to replace them with the cells of the source entity. For more details, see: 84 https://docs.python.org/3/reference/executionmodel.html#binding-of-names 85 86 Args: 87 nodes: Tuple[ast.AST], the source code to wrap. 88 entity_name: Union[Text, ast.AST], the name of the principal entity that 89 `nodes` define. 90 inner_factory_name: Text, the name of the inner factory. 91 outer_factory_name: Text, the name of the outer factory. 92 closure_vars: Iterable[Text], names of the closure variables for the inner 93 factory. 94 factory_args: Iterable[Text], names of additional arguments for the 95 inner factory. Useful to configure variables that the converted code can 96 use. Typically, these are modules. 97 future_features: Iterable[Text], names of future statements to associate the 98 code with. 99 100 Returns: 101 ast.AST 102 """ 103 dummy_closure_defs = [] 104 for var_name in closure_vars: 105 template = """ 106 var_name = None 107 """ 108 dummy_closure_defs.extend(templates.replace(template, var_name=var_name)) 109 110 if future_features: 111 future_imports = gast.ImportFrom( 112 module='__future__', 113 names=[gast.alias(name=name, asname=None) for name in future_features], 114 level=0) 115 else: 116 future_imports = [] 117 118 factory_args = [ 119 gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None) 120 for name in factory_args 121 ] 122 123 template = """ 124 future_imports 125 def outer_factory_name(): 126 dummy_closure_defs 127 def inner_factory_name(factory_args): 128 entity_defs 129 return entity_name 130 return inner_factory_name 131 """ 132 return templates.replace( 133 template, 134 dummy_closure_defs=dummy_closure_defs, 135 entity_defs=nodes, 136 entity_name=entity_name, 137 factory_args=factory_args, 138 future_imports=future_imports, 139 inner_factory_name=inner_factory_name, 140 outer_factory_name=outer_factory_name) 141 142 143class _PythonFnFactory(object): 144 """Helper object that wraps a Python function factory.""" 145 146 def __init__(self, name, freevars, extra_locals): 147 """Creates a new factory for a Python function. 148 149 Args: 150 name: The function name. 151 freevars: The list of non-global free variables for the function. 152 extra_locals: Dict[Text, Any], names and values for custom variables that 153 are accessible to the generated code as local variables. 154 """ 155 self._name = name 156 self._freevars = freevars 157 self._extra_locals = extra_locals 158 159 self._unbound_factory = None 160 self.module = None 161 self.source_map = None 162 163 def create(self, 164 nodes, 165 namer, 166 inner_factory_name='inner_factory', 167 outer_factory_name='outer_factory', 168 future_features=()): 169 """Initializes a function.""" 170 if self._unbound_factory is not None: 171 raise ValueError('double initialization; create a new object instead') 172 173 inner_factory_name = namer.new_symbol(inner_factory_name, ()) 174 outer_factory_name = namer.new_symbol(outer_factory_name, ()) 175 nodes = _wrap_into_factory(nodes, self._name, inner_factory_name, 176 outer_factory_name, self._freevars, 177 self._extra_locals.keys(), future_features) 178 179 module, _, source_map = loader.load_ast( 180 nodes, include_source_map=True) 181 outer_factory = getattr(module, outer_factory_name) 182 self._unbound_factory = outer_factory() 183 self.module = module 184 self.source_map = source_map 185 186 def instantiate(self, 187 globals_, 188 closure, 189 defaults=None, 190 kwdefaults=None): 191 """Creates a new function instance.""" 192 if self._unbound_factory is None: 193 raise ValueError('call create first') 194 195 factory_code = self._unbound_factory.__code__ 196 factory_freevars = factory_code.co_freevars 197 closure_map = dict(zip(self._freevars, closure)) 198 factory_closure = tuple( 199 closure_map[name] for name in factory_code.co_freevars) 200 if len(factory_closure) != len(closure): 201 raise ValueError( 202 'closure mismatch, requested {}, but source function had {}'.format( 203 self._freevars, factory_freevars)) 204 205 bound_factory = types.FunctionType( 206 code=factory_code, 207 globals=globals_, 208 name=self._name, 209 argdefs=(), 210 closure=factory_closure) 211 212 # The lint override is a false positive. 213 new_fn = bound_factory(**self._extra_locals) # pylint:disable=not-callable 214 215 if defaults: 216 new_fn.__defaults__ = defaults 217 if kwdefaults: 218 new_fn.__kwdefaults__ = kwdefaults 219 220 return new_fn 221 222 223class GenericTranspiler(object): 224 """A generic transpiler for Python functions. 225 226 Its interface is the `transform` API, which can process Python function 227 objects. Internally, it handles parsing. 228 229 Users typically subclass this, customizing the `transform_ast` method. The 230 output of transformed_ast is returned directly by `transform`. Existing 231 methods like `transform_function` may also be overloaded. 232 233 Example: 234 235 class MyTransformer(GenericTranspiler): 236 237 def transform_ast(self, node, ctx): 238 result = <<transform node>> 239 return result 240 241 transformer = MyTransfomer() 242 243 result = transformer.transform(f, ...) 244 # result is the output 245 """ 246 247 def get_transformed_name(self, node): 248 """Returns a name for the output function. Subclasses may override this.""" 249 if isinstance(node, gast.Lambda): 250 return 'lam' 251 elif isinstance(node, gast.FunctionDef): 252 return node.name 253 raise ValueError('Unknown node type {}'.format(node)) 254 255 def transform_ast(self, node, ctx): 256 """Performs an actual transformation of a function's AST. 257 258 Subclasses must implement this method, and do not usually call it. 259 260 Args: 261 node: One or more ast.AST nodes representing the AST to be transformed. 262 ctx: transformer.Context. 263 """ 264 raise NotImplementedError('subclasses must override this') 265 266 def transform(self, obj, user_context): 267 """Transforms a Python object. 268 269 Users typically call this method. 270 271 Args: 272 obj: A Python object, function, type, etc. 273 user_context: An opaque object (may be None) that is forwarded to 274 transform_ast, through the ctx.user_context argument. 275 Returns: 276 The result of calling transform_function. 277 278 Raises: 279 NotImplementedError: if the type of obj is not handled. 280 """ 281 if inspect.isfunction(obj) or inspect.ismethod(obj): 282 return self.transform_function(obj, user_context) 283 284 raise NotImplementedError('Non-function: {}'.format(type(obj))) 285 286 def _erase_arg_defaults(self, node): 287 """Erase arg default expressions, which would otherwise be unbound.""" 288 args = node.args 289 for i in range(len(args.defaults)): 290 args.defaults[i] = parser.parse_expression('None') 291 for i, d in enumerate(args.kw_defaults): 292 if d is not None: 293 args.kw_defaults[i] = parser.parse_expression('None') 294 return node 295 296 def transform_module(self, mod, user_context): 297 """Transforms a module. 298 299 Subclasses may override this method. The return value is opaque. 300 301 The method receives the original AST. The result is passed as-is to the 302 output of `transform`. 303 304 Args: 305 mod: A Python module. 306 user_context: An opaque object (may be None) that is forwarded to 307 transform_ast, through the ctx.user_context argument. 308 Returns: 309 List[Tuple[Any, Any]]. By default it returns the output of transform_ast, 310 evaluated on each supported member, other than modules, together with a 311 `transformer.Context` containing information about the transformation 312 process. 313 """ 314 result = [] 315 for member in mod.__dict__.values(): 316 if inspect.ismodule(member): 317 continue # Not transforming modules recursively. 318 try: 319 result.append(self.transform(member, user_context)) 320 except NotImplementedError: 321 pass # Skip unsupported elements. 322 return result 323 324 def transform_function(self, fn, user_context): 325 """Transforms a function. 326 327 Subclasses may override this method. The return value is opaque. 328 329 The method receives the original AST. The result is passed as-is to the 330 output of `transform`. 331 332 Args: 333 fn: A function or lambda. 334 user_context: An opaque object (may be None) that is forwarded to 335 transform_ast, through the ctx.user_context argument. 336 Returns: 337 Tuple[Any, Any]. By default it returns the output of transform_ast, 338 together with a `transformer.Context` containing information about the 339 transformation process. 340 """ 341 future_features = inspect_utils.getfutureimports(fn) 342 node, source = parser.parse_entity(fn, future_features=future_features) 343 logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) 344 345 origin_info.resolve_entity(node, source, fn) 346 347 namespace = inspect_utils.getnamespace(fn) 348 namer = naming.Namer(namespace) 349 new_name = namer.new_symbol(self.get_transformed_name(node), ()) 350 entity_info = transformer.EntityInfo( 351 name=new_name, 352 source_code=source, 353 source_file='<fragment>', 354 future_features=future_features, 355 namespace=namespace) 356 context = transformer.Context(entity_info, namer, user_context) 357 358 node = self._erase_arg_defaults(node) 359 result = self.transform_ast(node, context) 360 361 return result, context 362 363 364class PyToPy(GenericTranspiler): 365 """A generic Python-to-Python transpiler. 366 367 Its `transform` method offers a function-in, function-out interface. 368 Internally, it takes care of parsing, caching and loading of the translated 369 code. 370 371 Users typically subclass this, overriding `transform_ast`. 372 373 Usually, instances of this class are singletons, since each instance manages 374 its own cache. The caching can be controlled by overriding `get_caching_key`. 375 376 Example: 377 378 class MyTransformer(PyToPy): 379 380 def transform_ast(self, node, ctx): 381 node = <<transform node, usually using ast.NodeTransformer classes>> 382 return node 383 384 transformer = MyTransfomer() 385 386 new_f, module, source_map = transformer.transform_function(f, ...) 387 # new_f is a function with signature identical to f 388 389 The transformed function has access to the same namespace as the original 390 function. To allow access to internal APIs, users may inject additional 391 symbols by overriding `get_extra_locals`. 392 """ 393 394 def __init__(self): 395 self._cache_lock = threading.RLock() 396 self._cache = cache.CodeObjectCache() 397 398 def get_extra_locals(self): 399 """Returns extra static local variables to be made to transformed code. 400 401 Subclasses must override this. 402 403 Returns: 404 extra_locals: A Dict[Text, Any] containing additional variables to make 405 available to the transformed code. 406 """ 407 raise NotImplementedError('subclasses must override this') 408 409 def get_caching_key(self, user_context): 410 """Returns a unique key to use for caching. 411 412 Subclasses must override this. 413 414 Calls made to `transform_function` with functions that have the same code 415 object and caching key will return a cached instance on subsequent 416 invocations. 417 418 Args: 419 user_context: The context object which was passed to `transform`. 420 421 Returns: 422 extra_locals: A hashable. 423 """ 424 raise NotImplementedError('subclasses must override this') 425 426 def _cached_factory(self, fn, cache_subkey): 427 cached_factory = self._cache[fn][cache_subkey] 428 logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey, 429 cached_factory) 430 return cached_factory 431 432 def transform_function(self, fn, user_context): 433 """Transforms a function. See GenericTranspiler.trasnform_function. 434 435 This overload wraps the parent's `transform_function`, adding caching and 436 facilities to instantiate the output as a Python object. It also 437 adds facilities to make new symbols available to the generated Python code, 438 visible as local variables - see `get_extra_locals`. 439 440 Args: 441 fn: A function or lambda. 442 user_context: An opaque object (may be None) that is forwarded to 443 transform_ast, through the ctx.user_context argument. 444 Returns: 445 A tuple: 446 * A function or lambda with the same signature and closure as `fn` 447 * The temporary module into which the transformed function was loaded 448 * The source map as a 449 Dict[origin_info.LineLocation, origin_info.OriginInfo] 450 """ 451 cache_subkey = self.get_caching_key(user_context) 452 453 if self._cache.has(fn, cache_subkey): 454 # Fast path: use a lock-free check. 455 factory = self._cached_factory(fn, cache_subkey) 456 457 else: 458 with self._cache_lock: 459 # Check again under lock. 460 if self._cache.has(fn, cache_subkey): 461 factory = self._cached_factory(fn, cache_subkey) 462 463 else: 464 logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey) 465 # TODO(mdan): Confusing overloading pattern. Fix. 466 nodes, ctx = super(PyToPy, self).transform_function(fn, user_context) 467 468 if isinstance(nodes, gast.Lambda): 469 nodes = gast.Assign( 470 targets=[ 471 gast.Name( 472 ctx.info.name, 473 ctx=gast.Store(), 474 annotation=None, 475 type_comment=None) 476 ], 477 value=nodes) 478 else: 479 nodes.name = ctx.info.name 480 481 if logging.has_verbosity(2): 482 logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes)) 483 484 factory = _PythonFnFactory( 485 ctx.info.name, fn.__code__.co_freevars, self.get_extra_locals()) 486 factory.create( 487 nodes, ctx.namer, future_features=ctx.info.future_features) 488 self._cache[fn][cache_subkey] = factory 489 490 transformed_fn = factory.instantiate( 491 globals_=fn.__globals__, 492 closure=fn.__closure__ or (), 493 defaults=fn.__defaults__, 494 kwdefaults=getattr(fn, '__kwdefaults__', None)) 495 return transformed_fn, factory.module, factory.source_map 496