xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/rewriter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import ast
4import inspect
5import textwrap
6import copy
7import functools
8from types import FunctionType
9from typing import cast, Union, Callable, Dict, Optional, Any
10from torch.fx._symbolic_trace import Tracer
11from torch.fx.graph import Graph
12from torch._sources import normalize_source_lines
13import torch
14
15class AST_Rewriter(ast.NodeTransformer):
16    """
17    Take a FunctionType object representing a `forward` method, then
18    perform an AST rewrite to swap out nodes that are not symbolically
19    traceable with a callsite to the FX alternative.
20
21    To support swapping out an AST node, define a new `visit` method on
22    that node. For more details, see:
23    https://docs.python.org/3/library/ast.html#ast.NodeTransformer
24    """
25
26    # This function checks for new keys added in the globals dict. TorchDynamo
27    # can insert new keys in the global dict and upset the check. Therefore, put
28    # a disable here. This function is an optimization pass and not really
29    # suitable for dynamo tracing anyways.
30    @torch._dynamo.disable
31    def rewrite(self, fn: FunctionType):
32
33        # Normalize the source lines
34        sourcelines, _ = inspect.getsourcelines(fn)
35        sourcelines = normalize_source_lines(sourcelines)
36        source = ''.join(sourcelines)
37        normalized_str = textwrap.dedent(source)
38
39        # Rewrite the original AST
40        source_ast = ast.parse(normalized_str)
41        dest_ast = ast.fix_missing_locations(self.visit(source_ast))
42
43        # Pull out the compiled function from the newly-created Module
44        code = compile(dest_ast, "", "exec")
45        globals_dict = copy.copy(fn.__globals__)
46        keys_before = set(globals_dict.keys())
47        exec(code, globals_dict)
48        new_keys = list(set(globals_dict.keys()) - keys_before)
49        assert len(new_keys) == 1
50        fn_compiled = globals_dict[new_keys[0]]
51
52        # return the compiled function with the original globals
53        def change_func_globals(f, globals):
54            """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
55            # __globals__ is a private member of the function class
56            # so we have to copy the function, f, all of its member, except f.__globals__
57            g = FunctionType(
58                f.__code__,
59                globals,
60                name=f.__name__,
61                argdefs=f.__defaults__,
62                closure=f.__closure__,
63            )
64            g = functools.update_wrapper(g, f)
65            g.__kwdefaults__ = copy.copy(f.__kwdefaults__)  # type:ignore[attr-defined]
66            return g
67        # Return the correct FunctionType object
68        return change_func_globals(fn_compiled, globals=fn.__globals__)
69
70    def visit_Assert(self, node):
71        """
72        Swap out the Assert node (Python's `assert`) with a callsite to the
73        symbolically-traceable torch._assert function
74        """
75        # Create the Call node
76        n = ast.parse('torch._assert()', mode='eval')
77        assert isinstance(n, ast.Expression)
78        call_node = n.body
79        assert isinstance(call_node, ast.Call)
80        msg = node.msg if node.msg else ast.Constant(value="", kind=None)
81        call_node.args = [node.test, msg]
82
83        # Ensure that the new node conforms to the Python AST grammar
84        expr_wrapper = ast.Expr(value=call_node)
85
86        # Return the new Call node to signify that we want to use it as
87        # a replacement for the original _assert node
88        return ast.copy_location(expr_wrapper, node)
89
90    def visit_AnnAssign(self, node):
91        """
92        Swap out Python's AnnAssign with an Assign node where the annotation function is called.
93        Example:
94             Original:
95             y: Tensor_Type(1,2,3, Dyn) = f2(x)
96            Output:
97             y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
98        """
99        return ast.Assign(targets=[node.target], value=ast.Call(
100            func=ast.Name(id='annotate', ctx=ast.Load()),
101            args=[node.value, node.annotation], keywords=[]))
102
103
104class RewritingTracer(Tracer):
105    def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
106        return super().trace(_rewrite(root), concrete_args)
107
108
109def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
110    if isinstance(fn, torch.nn.Module):
111        # Rewrite this module's `forward` as well as the `forward`s of
112        # all of this module's recursive descendents. Return the new,
113        # rewritten module hierarchy.
114        def rewrite_module(m : torch.nn.Module):
115            class RewrittenModule(torch.nn.Module):
116                def __init__(self, orig):
117                    super().__init__()
118                    for k, v in orig.__dict__.items():
119                        if isinstance(v, torch.nn.Module):
120                            self.__dict__[k] = copy.copy(rewrite_module(v))
121                        else:
122                            self.__dict__[k] = copy.copy(v)
123            RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward))
124            return RewrittenModule(m)
125        return rewrite_module(fn)
126    else:
127        # Rewrite this single free function
128        return AST_Rewriter().rewrite(cast(FunctionType, fn))
129