1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import ast 4import inspect 5import textwrap 6import copy 7import functools 8from types import FunctionType 9from typing import cast, Union, Callable, Dict, Optional, Any 10from torch.fx._symbolic_trace import Tracer 11from torch.fx.graph import Graph 12from torch._sources import normalize_source_lines 13import torch 14 15class AST_Rewriter(ast.NodeTransformer): 16 """ 17 Take a FunctionType object representing a `forward` method, then 18 perform an AST rewrite to swap out nodes that are not symbolically 19 traceable with a callsite to the FX alternative. 20 21 To support swapping out an AST node, define a new `visit` method on 22 that node. For more details, see: 23 https://docs.python.org/3/library/ast.html#ast.NodeTransformer 24 """ 25 26 # This function checks for new keys added in the globals dict. TorchDynamo 27 # can insert new keys in the global dict and upset the check. Therefore, put 28 # a disable here. This function is an optimization pass and not really 29 # suitable for dynamo tracing anyways. 30 @torch._dynamo.disable 31 def rewrite(self, fn: FunctionType): 32 33 # Normalize the source lines 34 sourcelines, _ = inspect.getsourcelines(fn) 35 sourcelines = normalize_source_lines(sourcelines) 36 source = ''.join(sourcelines) 37 normalized_str = textwrap.dedent(source) 38 39 # Rewrite the original AST 40 source_ast = ast.parse(normalized_str) 41 dest_ast = ast.fix_missing_locations(self.visit(source_ast)) 42 43 # Pull out the compiled function from the newly-created Module 44 code = compile(dest_ast, "", "exec") 45 globals_dict = copy.copy(fn.__globals__) 46 keys_before = set(globals_dict.keys()) 47 exec(code, globals_dict) 48 new_keys = list(set(globals_dict.keys()) - keys_before) 49 assert len(new_keys) == 1 50 fn_compiled = globals_dict[new_keys[0]] 51 52 # return the compiled function with the original globals 53 def change_func_globals(f, globals): 54 """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" 55 # __globals__ is a private member of the function class 56 # so we have to copy the function, f, all of its member, except f.__globals__ 57 g = FunctionType( 58 f.__code__, 59 globals, 60 name=f.__name__, 61 argdefs=f.__defaults__, 62 closure=f.__closure__, 63 ) 64 g = functools.update_wrapper(g, f) 65 g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined] 66 return g 67 # Return the correct FunctionType object 68 return change_func_globals(fn_compiled, globals=fn.__globals__) 69 70 def visit_Assert(self, node): 71 """ 72 Swap out the Assert node (Python's `assert`) with a callsite to the 73 symbolically-traceable torch._assert function 74 """ 75 # Create the Call node 76 n = ast.parse('torch._assert()', mode='eval') 77 assert isinstance(n, ast.Expression) 78 call_node = n.body 79 assert isinstance(call_node, ast.Call) 80 msg = node.msg if node.msg else ast.Constant(value="", kind=None) 81 call_node.args = [node.test, msg] 82 83 # Ensure that the new node conforms to the Python AST grammar 84 expr_wrapper = ast.Expr(value=call_node) 85 86 # Return the new Call node to signify that we want to use it as 87 # a replacement for the original _assert node 88 return ast.copy_location(expr_wrapper, node) 89 90 def visit_AnnAssign(self, node): 91 """ 92 Swap out Python's AnnAssign with an Assign node where the annotation function is called. 93 Example: 94 Original: 95 y: Tensor_Type(1,2,3, Dyn) = f2(x) 96 Output: 97 y = annotate(f2(x),Tensor_Type((1,2,3,Dyn))) 98 """ 99 return ast.Assign(targets=[node.target], value=ast.Call( 100 func=ast.Name(id='annotate', ctx=ast.Load()), 101 args=[node.value, node.annotation], keywords=[])) 102 103 104class RewritingTracer(Tracer): 105 def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: 106 return super().trace(_rewrite(root), concrete_args) 107 108 109def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]: 110 if isinstance(fn, torch.nn.Module): 111 # Rewrite this module's `forward` as well as the `forward`s of 112 # all of this module's recursive descendents. Return the new, 113 # rewritten module hierarchy. 114 def rewrite_module(m : torch.nn.Module): 115 class RewrittenModule(torch.nn.Module): 116 def __init__(self, orig): 117 super().__init__() 118 for k, v in orig.__dict__.items(): 119 if isinstance(v, torch.nn.Module): 120 self.__dict__[k] = copy.copy(rewrite_module(v)) 121 else: 122 self.__dict__[k] = copy.copy(v) 123 RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward)) 124 return RewrittenModule(m) 125 return rewrite_module(fn) 126 else: 127 # Rewrite this single free function 128 return AST_Rewriter().rewrite(cast(FunctionType, fn)) 129