xref: /aosp_15_r20/external/executorch/exir/tracer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport copy
10*523fa7a6SAndroid Build Coastguard Workerimport json
11*523fa7a6SAndroid Build Coastguard Workerimport traceback
12*523fa7a6SAndroid Build Coastguard Workerfrom contextlib import contextmanager
13*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import asdict, dataclass
14*523fa7a6SAndroid Build Coastguard Workerfrom typing import (
15*523fa7a6SAndroid Build Coastguard Worker    Any,
16*523fa7a6SAndroid Build Coastguard Worker    Callable,
17*523fa7a6SAndroid Build Coastguard Worker    Dict,
18*523fa7a6SAndroid Build Coastguard Worker    Generator,
19*523fa7a6SAndroid Build Coastguard Worker    Iterable,
20*523fa7a6SAndroid Build Coastguard Worker    List,
21*523fa7a6SAndroid Build Coastguard Worker    Optional,
22*523fa7a6SAndroid Build Coastguard Worker    Set,
23*523fa7a6SAndroid Build Coastguard Worker    Tuple,
24*523fa7a6SAndroid Build Coastguard Worker    Union,
25*523fa7a6SAndroid Build Coastguard Worker)
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Workerimport executorch.extension.pytree as ex_pytree
28*523fa7a6SAndroid Build Coastguard Workerimport torch
29*523fa7a6SAndroid Build Coastguard Workerimport torch._dynamo as torchdynamo
30*523fa7a6SAndroid Build Coastguard Workerimport torch.fx as fx
31*523fa7a6SAndroid Build Coastguard Worker
32*523fa7a6SAndroid Build Coastguard Workerimport torch.fx._pytree as fx_pytree
33*523fa7a6SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.common import (
36*523fa7a6SAndroid Build Coastguard Worker    extract_out_arguments,
37*523fa7a6SAndroid Build Coastguard Worker    format_schema_name,
38*523fa7a6SAndroid Build Coastguard Worker    no_dispatch,
39*523fa7a6SAndroid Build Coastguard Worker    setting_python_recursive_limit,
40*523fa7a6SAndroid Build Coastguard Worker)
41*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType, InternalError
42*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.graph_module import LeafValue
43*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.operator.convert import is_out_variant
44*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.types import ValueSpec
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Workerfrom torch._C import _EnableTorchFunction, DisableTorchFunctionSubclass  # @manual
47*523fa7a6SAndroid Build Coastguard Workerfrom torch._decomp import get_decompositions
48*523fa7a6SAndroid Build Coastguard Workerfrom torch._dynamo.guards import Guard
49*523fa7a6SAndroid Build Coastguard Workerfrom torch._functorch.eager_transforms import _maybe_unwrap_functional_tensor
50*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import default_decompositions
51*523fa7a6SAndroid Build Coastguard Workerfrom torch.func import functionalize
52*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.operator_schemas import normalize_function
53*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils._pytree import TreeSpec
54*523fa7a6SAndroid Build Coastguard Worker
55*523fa7a6SAndroid Build Coastguard Workerfrom typing_extensions import TypeAlias
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard WorkerValue: TypeAlias = Union[
59*523fa7a6SAndroid Build Coastguard Worker    LeafValue,
60*523fa7a6SAndroid Build Coastguard Worker    Tuple["Value", ...],
61*523fa7a6SAndroid Build Coastguard Worker    List["Value"],
62*523fa7a6SAndroid Build Coastguard Worker    Dict[str, "Value"],
63*523fa7a6SAndroid Build Coastguard Worker]
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Workertorchdynamo_enabled = False
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard Workerdef get_stacktrace() -> List[Dict[str, str]]:
69*523fa7a6SAndroid Build Coastguard Worker    """
70*523fa7a6SAndroid Build Coastguard Worker    Get the current stacktrace (between trace() and __torch_dispatch__())
71*523fa7a6SAndroid Build Coastguard Worker    Include the filename, function name, line number, and source code from the
72*523fa7a6SAndroid Build Coastguard Worker    start of the function to the given instruction.
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Worker    Return:
75*523fa7a6SAndroid Build Coastguard Worker        A list of stacktraces for each instruction along with the source code
76*523fa7a6SAndroid Build Coastguard Worker        context surrounding each instruction
77*523fa7a6SAndroid Build Coastguard Worker    """
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Worker    stacktrace = traceback.extract_stack()
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker    # The stacktrace typically looks like this:
82*523fa7a6SAndroid Build Coastguard Worker    #
83*523fa7a6SAndroid Build Coastguard Worker    #   1. I stack frames from the top level runner (e.g., the
84*523fa7a6SAndroid Build Coastguard Worker    #      test suite runner)
85*523fa7a6SAndroid Build Coastguard Worker    #   2. J frames in executorch/exir/tracer.py setting up tracing
86*523fa7a6SAndroid Build Coastguard Worker    #      (call this INIT_EXIR)
87*523fa7a6SAndroid Build Coastguard Worker    #   3. K frames in user model code (this is what we want to save!)
88*523fa7a6SAndroid Build Coastguard Worker    #   4. 1 frame in executorch/exir/tracer.py __torch_function__
89*523fa7a6SAndroid Build Coastguard Worker    #      returning to tracer (call this TRACE_EXIR)
90*523fa7a6SAndroid Build Coastguard Worker    #   5. H frames in executorch/exir/tracer.py AND torch/_tensor.py
91*523fa7a6SAndroid Build Coastguard Worker    #      doing all of the internal tracer handling
92*523fa7a6SAndroid Build Coastguard Worker    #
93*523fa7a6SAndroid Build Coastguard Worker    # The PyE tests assert that executorch/exir/tracer.py never shows
94*523fa7a6SAndroid Build Coastguard Worker    # up in the user provided stack traces, so we must oblige them.
95*523fa7a6SAndroid Build Coastguard Worker    #
96*523fa7a6SAndroid Build Coastguard Worker    # Assumptions:
97*523fa7a6SAndroid Build Coastguard Worker    #   - Reentrant tracing is not a thing.  Thus, the first time
98*523fa7a6SAndroid Build Coastguard Worker    #     executorch/exir/tracer.py shows up in the trace, we know
99*523fa7a6SAndroid Build Coastguard Worker    #     THAT is the point at which we start tracing.  (An alternative
100*523fa7a6SAndroid Build Coastguard Worker    #     is that the tracer entry point could record the stack trace
101*523fa7a6SAndroid Build Coastguard Worker    #     at this time, but I didn't do this.)
102*523fa7a6SAndroid Build Coastguard Worker    #
103*523fa7a6SAndroid Build Coastguard Worker    # Our plan is to do a miniature stack machine traversing these
104*523fa7a6SAndroid Build Coastguard Worker    # stack machines.
105*523fa7a6SAndroid Build Coastguard Worker
106*523fa7a6SAndroid Build Coastguard Worker    # Remove parts before the trace function and parts after entering
107*523fa7a6SAndroid Build Coastguard Worker    # __torch_dispatch__.  Defaults to returning the entire stack trace.
108*523fa7a6SAndroid Build Coastguard Worker    init_exir_end = 0
109*523fa7a6SAndroid Build Coastguard Worker    trace_exir_start = None
110*523fa7a6SAndroid Build Coastguard Worker    # A miniature state machine, referring to the frame segments described
111*523fa7a6SAndroid Build Coastguard Worker    # above.  The locations are closed-open interval.
112*523fa7a6SAndroid Build Coastguard Worker    FIND_INIT_EXIR_START, FIND_INIT_EXIR_END, FIND_TRACE_EXIR_START = range(3)
113*523fa7a6SAndroid Build Coastguard Worker    state = FIND_INIT_EXIR_START
114*523fa7a6SAndroid Build Coastguard Worker    for i, frame in enumerate(stacktrace):
115*523fa7a6SAndroid Build Coastguard Worker        if state == FIND_INIT_EXIR_START:
116*523fa7a6SAndroid Build Coastguard Worker            if "executorch/exir/tracer.py" in frame.filename:
117*523fa7a6SAndroid Build Coastguard Worker                state = FIND_INIT_EXIR_END
118*523fa7a6SAndroid Build Coastguard Worker        elif state == FIND_INIT_EXIR_END:
119*523fa7a6SAndroid Build Coastguard Worker            if "executorch/exir/tracer.py" not in frame.filename:
120*523fa7a6SAndroid Build Coastguard Worker                init_exir_end = i
121*523fa7a6SAndroid Build Coastguard Worker                state = FIND_TRACE_EXIR_START
122*523fa7a6SAndroid Build Coastguard Worker        elif state == FIND_TRACE_EXIR_START:
123*523fa7a6SAndroid Build Coastguard Worker            if "executorch/exir/tracer.py" in frame.filename:
124*523fa7a6SAndroid Build Coastguard Worker                trace_exir_start = i
125*523fa7a6SAndroid Build Coastguard Worker                break
126*523fa7a6SAndroid Build Coastguard Worker
127*523fa7a6SAndroid Build Coastguard Worker    stacktrace = stacktrace[init_exir_end:trace_exir_start]
128*523fa7a6SAndroid Build Coastguard Worker
129*523fa7a6SAndroid Build Coastguard Worker    # Get the source code from the errored line to it
130*523fa7a6SAndroid Build Coastguard Worker    contexts: List[str] = []
131*523fa7a6SAndroid Build Coastguard Worker    for s in stacktrace:
132*523fa7a6SAndroid Build Coastguard Worker        try:
133*523fa7a6SAndroid Build Coastguard Worker            with open(s.filename) as file:
134*523fa7a6SAndroid Build Coastguard Worker                # pyre-fixme[6]: For 1st param expected `Union[SupportsTrunc, bytes,
135*523fa7a6SAndroid Build Coastguard Worker                #  str, SupportsInt, SupportsIndex]` but got `Optional[int]`.
136*523fa7a6SAndroid Build Coastguard Worker                lineno = int(s.lineno)
137*523fa7a6SAndroid Build Coastguard Worker                # Get the source code 5 lines above/below the current instruction
138*523fa7a6SAndroid Build Coastguard Worker                file_contents = [
139*523fa7a6SAndroid Build Coastguard Worker                    str(index + 1) + line for index, line in enumerate(file.readlines())
140*523fa7a6SAndroid Build Coastguard Worker                ]
141*523fa7a6SAndroid Build Coastguard Worker                file_contents_above = "".join(
142*523fa7a6SAndroid Build Coastguard Worker                    file_contents[max(lineno - 5, 0) : lineno]
143*523fa7a6SAndroid Build Coastguard Worker                )
144*523fa7a6SAndroid Build Coastguard Worker                file_contents_below = "".join(
145*523fa7a6SAndroid Build Coastguard Worker                    file_contents[lineno : min(lineno + 5, len(file_contents))]
146*523fa7a6SAndroid Build Coastguard Worker                )
147*523fa7a6SAndroid Build Coastguard Worker                context = (
148*523fa7a6SAndroid Build Coastguard Worker                    file_contents_above
149*523fa7a6SAndroid Build Coastguard Worker                    + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
150*523fa7a6SAndroid Build Coastguard Worker                    + file_contents_below
151*523fa7a6SAndroid Build Coastguard Worker                )
152*523fa7a6SAndroid Build Coastguard Worker                contexts.append(context)
153*523fa7a6SAndroid Build Coastguard Worker        except FileNotFoundError:
154*523fa7a6SAndroid Build Coastguard Worker            contexts.append("<unknown file: unknown line>")
155*523fa7a6SAndroid Build Coastguard Worker
156*523fa7a6SAndroid Build Coastguard Worker    # torch.fx stack preservation logic expects strings to
157*523fa7a6SAndroid Build Coastguard Worker    # be passed around. Working with dictionary is lot easier
158*523fa7a6SAndroid Build Coastguard Worker    # to convert to string and vice versa.
159*523fa7a6SAndroid Build Coastguard Worker    frames: List[Dict[str, str]] = []
160*523fa7a6SAndroid Build Coastguard Worker    for i, frame in enumerate(stacktrace):
161*523fa7a6SAndroid Build Coastguard Worker        frames.append(
162*523fa7a6SAndroid Build Coastguard Worker            {
163*523fa7a6SAndroid Build Coastguard Worker                "filename": str(frame.filename),
164*523fa7a6SAndroid Build Coastguard Worker                "lineno": str(frame.lineno),
165*523fa7a6SAndroid Build Coastguard Worker                "name": str(frame.name),
166*523fa7a6SAndroid Build Coastguard Worker                "line": str(frame.line),
167*523fa7a6SAndroid Build Coastguard Worker                "context": contexts[i],
168*523fa7a6SAndroid Build Coastguard Worker            }
169*523fa7a6SAndroid Build Coastguard Worker        )
170*523fa7a6SAndroid Build Coastguard Worker
171*523fa7a6SAndroid Build Coastguard Worker    return frames
172*523fa7a6SAndroid Build Coastguard Worker
173*523fa7a6SAndroid Build Coastguard Worker
174*523fa7a6SAndroid Build Coastguard Workerdef unwrap_functional(t: torch.Tensor) -> torch.Tensor:
175*523fa7a6SAndroid Build Coastguard Worker    assert isinstance(t, torch.Tensor)
176*523fa7a6SAndroid Build Coastguard Worker    return _maybe_unwrap_functional_tensor(t, reapply_views=False)
177*523fa7a6SAndroid Build Coastguard Worker
178*523fa7a6SAndroid Build Coastguard Worker
179*523fa7a6SAndroid Build Coastguard Workerdef unwrap_proxy(t: LeafValue) -> Union[LeafValue, torch.fx.Proxy]:
180*523fa7a6SAndroid Build Coastguard Worker    if not isinstance(t, torch.Tensor):
181*523fa7a6SAndroid Build Coastguard Worker        return t
182*523fa7a6SAndroid Build Coastguard Worker    t = unwrap_functional(t)
183*523fa7a6SAndroid Build Coastguard Worker    return t.proxy if isinstance(t, PythonTensor) else t
184*523fa7a6SAndroid Build Coastguard Worker
185*523fa7a6SAndroid Build Coastguard Worker
186*523fa7a6SAndroid Build Coastguard Workerdef single_return(
187*523fa7a6SAndroid Build Coastguard Worker    output: LeafValue,
188*523fa7a6SAndroid Build Coastguard Worker    proxy: torch.fx.Proxy,
189*523fa7a6SAndroid Build Coastguard Worker    wrapper: Callable[..., LeafValue],
190*523fa7a6SAndroid Build Coastguard Worker) -> LeafValue:
191*523fa7a6SAndroid Build Coastguard Worker    if isinstance(output, torch.Tensor):
192*523fa7a6SAndroid Build Coastguard Worker        return wrapper(output, proxy)
193*523fa7a6SAndroid Build Coastguard Worker
194*523fa7a6SAndroid Build Coastguard Worker    return output
195*523fa7a6SAndroid Build Coastguard Worker
196*523fa7a6SAndroid Build Coastguard Worker
197*523fa7a6SAndroid Build Coastguard Workerdef tree_return(
198*523fa7a6SAndroid Build Coastguard Worker    outputs: Value,
199*523fa7a6SAndroid Build Coastguard Worker    proxy: torch.fx.Proxy,
200*523fa7a6SAndroid Build Coastguard Worker    wrapper: Callable[..., LeafValue],
201*523fa7a6SAndroid Build Coastguard Worker    meta_type: Callable[..., Iterable[ValueSpec]] = tuple,
202*523fa7a6SAndroid Build Coastguard Worker) -> Value:
203*523fa7a6SAndroid Build Coastguard Worker    i: int = 0
204*523fa7a6SAndroid Build Coastguard Worker
205*523fa7a6SAndroid Build Coastguard Worker    def wrap(o: LeafValue) -> LeafValue:
206*523fa7a6SAndroid Build Coastguard Worker        nonlocal i
207*523fa7a6SAndroid Build Coastguard Worker        ret = single_return(o, proxy[i], wrapper)
208*523fa7a6SAndroid Build Coastguard Worker        i += 1
209*523fa7a6SAndroid Build Coastguard Worker        return ret
210*523fa7a6SAndroid Build Coastguard Worker
211*523fa7a6SAndroid Build Coastguard Worker    return pytree.tree_map(wrap, outputs)
212*523fa7a6SAndroid Build Coastguard Worker
213*523fa7a6SAndroid Build Coastguard Worker
214*523fa7a6SAndroid Build Coastguard Workerclass DummyProxy:
215*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
216*523fa7a6SAndroid Build Coastguard Worker        class DummyNode:
217*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
218*523fa7a6SAndroid Build Coastguard Worker                self.meta = {}
219*523fa7a6SAndroid Build Coastguard Worker
220*523fa7a6SAndroid Build Coastguard Worker        self.node = DummyNode()
221*523fa7a6SAndroid Build Coastguard Worker
222*523fa7a6SAndroid Build Coastguard Worker    def __getitem__(self, key: str) -> "DummyProxy":
223*523fa7a6SAndroid Build Coastguard Worker        return DummyProxy()
224*523fa7a6SAndroid Build Coastguard Worker
225*523fa7a6SAndroid Build Coastguard Worker
226*523fa7a6SAndroid Build Coastguard Workerclass PythonTensor(torch.Tensor):
227*523fa7a6SAndroid Build Coastguard Worker    """
228*523fa7a6SAndroid Build Coastguard Worker    A wrapper tensor subclass used in the DispatchTracer to keep track of
229*523fa7a6SAndroid Build Coastguard Worker    proxies to construct the FX graph.
230*523fa7a6SAndroid Build Coastguard Worker
231*523fa7a6SAndroid Build Coastguard Worker    Wrapping something in PythonTensor implicitly detaches gradients.  If
232*523fa7a6SAndroid Build Coastguard Worker    something required grad, we will collect it as if it were a leaf.  A
233*523fa7a6SAndroid Build Coastguard Worker    consequence of detaching in this way is you need to maintain a parameter
234*523fa7a6SAndroid Build Coastguard Worker    cache when translating tensors into PythonTensor, so you don't create
235*523fa7a6SAndroid Build Coastguard Worker    multiple copies of a gradient (they are aliased, but they would count as
236*523fa7a6SAndroid Build Coastguard Worker    independent leaves).  An alternate strategy would be to avoid implicitly
237*523fa7a6SAndroid Build Coastguard Worker    detaching and instead "catch" gradients as they exit the PythonTensor
238*523fa7a6SAndroid Build Coastguard Worker    boundary.
239*523fa7a6SAndroid Build Coastguard Worker    """
240*523fa7a6SAndroid Build Coastguard Worker
241*523fa7a6SAndroid Build Coastguard Worker    __slots__ = ["proxy", "is_immutable"]
242*523fa7a6SAndroid Build Coastguard Worker
243*523fa7a6SAndroid Build Coastguard Worker    @staticmethod
244*523fa7a6SAndroid Build Coastguard Worker    def __new__(
245*523fa7a6SAndroid Build Coastguard Worker        cls, elem: torch.Tensor, proxy: torch.fx.Proxy, is_immutable: bool = False
246*523fa7a6SAndroid Build Coastguard Worker    ) -> torch.Tensor:
247*523fa7a6SAndroid Build Coastguard Worker        # assert not elem.requires_grad or not torch.is_grad_enabled()
248*523fa7a6SAndroid Build Coastguard Worker
249*523fa7a6SAndroid Build Coastguard Worker        r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
250*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(r, PythonTensor)
251*523fa7a6SAndroid Build Coastguard Worker        r.is_immutable: bool = is_immutable
252*523fa7a6SAndroid Build Coastguard Worker        r.update_proxy(proxy)
253*523fa7a6SAndroid Build Coastguard Worker        return r
254*523fa7a6SAndroid Build Coastguard Worker
255*523fa7a6SAndroid Build Coastguard Worker    def update_proxy(self, proxy: torch.fx.Proxy) -> None:
256*523fa7a6SAndroid Build Coastguard Worker        self.proxy = proxy
257*523fa7a6SAndroid Build Coastguard Worker
258*523fa7a6SAndroid Build Coastguard Worker    def __repr__(self, *, tensor_contents: None = None) -> str:
259*523fa7a6SAndroid Build Coastguard Worker        with no_dispatch():
260*523fa7a6SAndroid Build Coastguard Worker            return f"PythonTensor({self.as_subclass(torch.Tensor)})"
261*523fa7a6SAndroid Build Coastguard Worker
262*523fa7a6SAndroid Build Coastguard Worker    @classmethod
263*523fa7a6SAndroid Build Coastguard Worker    def __torch_function__(
264*523fa7a6SAndroid Build Coastguard Worker        cls,
265*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore: Missing parameter annotation [2]
266*523fa7a6SAndroid Build Coastguard Worker        func,
267*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore: Missing parameter annotation [2]
268*523fa7a6SAndroid Build Coastguard Worker        types,
269*523fa7a6SAndroid Build Coastguard Worker        args: Tuple[Value, ...] = (),
270*523fa7a6SAndroid Build Coastguard Worker        kwargs: Optional[Dict[str, Value]] = None,
271*523fa7a6SAndroid Build Coastguard Worker    ) -> Value:
272*523fa7a6SAndroid Build Coastguard Worker        if kwargs is None:
273*523fa7a6SAndroid Build Coastguard Worker            kwargs = {}
274*523fa7a6SAndroid Build Coastguard Worker        if torch.is_inference_mode_enabled():
275*523fa7a6SAndroid Build Coastguard Worker            if func is torch.nn.functional.layer_norm:
276*523fa7a6SAndroid Build Coastguard Worker                args, kwargs = normalize_function(func, args, kwargs)  # pyre-fixme[23]
277*523fa7a6SAndroid Build Coastguard Worker                input, normalized_shape = args
278*523fa7a6SAndroid Build Coastguard Worker                normalized_shape = list(normalized_shape)
279*523fa7a6SAndroid Build Coastguard Worker                return cls.__torch_dispatch__(
280*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.aten.layer_norm.default,
281*523fa7a6SAndroid Build Coastguard Worker                    types,
282*523fa7a6SAndroid Build Coastguard Worker                    (input, normalized_shape),
283*523fa7a6SAndroid Build Coastguard Worker                    kwargs,
284*523fa7a6SAndroid Build Coastguard Worker                )
285*523fa7a6SAndroid Build Coastguard Worker            elif func is torch.nn.functional.linear:
286*523fa7a6SAndroid Build Coastguard Worker                return cls.__torch_dispatch__(
287*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.aten.linear.default, types, args, kwargs
288*523fa7a6SAndroid Build Coastguard Worker                )
289*523fa7a6SAndroid Build Coastguard Worker        with DisableTorchFunctionSubclass():
290*523fa7a6SAndroid Build Coastguard Worker            return func(*args, **kwargs)
291*523fa7a6SAndroid Build Coastguard Worker
292*523fa7a6SAndroid Build Coastguard Worker    @classmethod
293*523fa7a6SAndroid Build Coastguard Worker    def __torch_dispatch__(  # noqa: C901
294*523fa7a6SAndroid Build Coastguard Worker        cls,
295*523fa7a6SAndroid Build Coastguard Worker        func_overload: torch._ops.OpOverload,
296*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore: Missing parameter annotation [2]
297*523fa7a6SAndroid Build Coastguard Worker        types,
298*523fa7a6SAndroid Build Coastguard Worker        args: Tuple[Value, ...] = (),
299*523fa7a6SAndroid Build Coastguard Worker        kwargs: Optional[Dict[str, Value]] = None,
300*523fa7a6SAndroid Build Coastguard Worker    ) -> Value:
301*523fa7a6SAndroid Build Coastguard Worker        """
302*523fa7a6SAndroid Build Coastguard Worker        This function is invoked every time an aten operation is called.
303*523fa7a6SAndroid Build Coastguard Worker
304*523fa7a6SAndroid Build Coastguard Worker        Args:
305*523fa7a6SAndroid Build Coastguard Worker            func_overload: The function that was called that invoked this
306*523fa7a6SAndroid Build Coastguard Worker                torch_dispatch call
307*523fa7a6SAndroid Build Coastguard Worker            types:
308*523fa7a6SAndroid Build Coastguard Worker            args: Arguments that were passed into the function. Each argument
309*523fa7a6SAndroid Build Coastguard Worker                has type PythonTensor.
310*523fa7a6SAndroid Build Coastguard Worker            kwargs: Keyword arguments that were passed into the function. Each
311*523fa7a6SAndroid Build Coastguard Worker                argument has type PythonTensor.
312*523fa7a6SAndroid Build Coastguard Worker        """
313*523fa7a6SAndroid Build Coastguard Worker        func = func_overload.overloadpacket
314*523fa7a6SAndroid Build Coastguard Worker
315*523fa7a6SAndroid Build Coastguard Worker        kwargs = kwargs or {}
316*523fa7a6SAndroid Build Coastguard Worker        if is_out_variant(func._qualified_op_name, func_overload._overloadname):
317*523fa7a6SAndroid Build Coastguard Worker            out_args = extract_out_arguments(func_overload._schema, kwargs)
318*523fa7a6SAndroid Build Coastguard Worker            out_args_iter = [out_args] if not isinstance(out_args, list) else out_args
319*523fa7a6SAndroid Build Coastguard Worker            for out_arg_name, out_arg_val in out_args_iter:
320*523fa7a6SAndroid Build Coastguard Worker                if isinstance(out_arg_val, PythonTensor) and out_arg_val.is_immutable:
321*523fa7a6SAndroid Build Coastguard Worker                    raise RuntimeError(
322*523fa7a6SAndroid Build Coastguard Worker                        "Immutable tensor `{}` is potentially getting modified by {}".format(
323*523fa7a6SAndroid Build Coastguard Worker                            out_arg_name, format_schema_name(func_overload._schema)
324*523fa7a6SAndroid Build Coastguard Worker                        )
325*523fa7a6SAndroid Build Coastguard Worker                    )
326*523fa7a6SAndroid Build Coastguard Worker
327*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[16]: Module `pytree` has no attribute `tree_map`.
328*523fa7a6SAndroid Build Coastguard Worker        proxy_args = ex_pytree.tree_map(unwrap_proxy, args)
329*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[16]: Module `pytree` has no attribute `tree_map`.
330*523fa7a6SAndroid Build Coastguard Worker        proxy_kwargs = ex_pytree.tree_map(unwrap_proxy, kwargs)
331*523fa7a6SAndroid Build Coastguard Worker
332*523fa7a6SAndroid Build Coastguard Worker        # Get the output of the function
333*523fa7a6SAndroid Build Coastguard Worker        g = _EnableTorchFunction()
334*523fa7a6SAndroid Build Coastguard Worker        try:
335*523fa7a6SAndroid Build Coastguard Worker            proxy_out = (
336*523fa7a6SAndroid Build Coastguard Worker                func_overload(*proxy_args, **proxy_kwargs)
337*523fa7a6SAndroid Build Coastguard Worker                if DispatchTracer.get() or torchdynamo_enabled
338*523fa7a6SAndroid Build Coastguard Worker                # Disable node creation when no tracer is active.
339*523fa7a6SAndroid Build Coastguard Worker                else DummyProxy()
340*523fa7a6SAndroid Build Coastguard Worker            )
341*523fa7a6SAndroid Build Coastguard Worker        finally:
342*523fa7a6SAndroid Build Coastguard Worker            del g
343*523fa7a6SAndroid Build Coastguard Worker
344*523fa7a6SAndroid Build Coastguard Worker        with no_dispatch():
345*523fa7a6SAndroid Build Coastguard Worker            real_out = func_overload(*args, **kwargs)
346*523fa7a6SAndroid Build Coastguard Worker
347*523fa7a6SAndroid Build Coastguard Worker        # Kind of a hacky way to test if an op is in-place or not
348*523fa7a6SAndroid Build Coastguard Worker        if func.__name__[-1] == "_" and func.__name__[0] != "_":
349*523fa7a6SAndroid Build Coastguard Worker            if isinstance(args[0], PythonTensor):
350*523fa7a6SAndroid Build Coastguard Worker                args[0].proxy = proxy_out
351*523fa7a6SAndroid Build Coastguard Worker
352*523fa7a6SAndroid Build Coastguard Worker        if not torch.fx.traceback.has_preserved_node_meta():
353*523fa7a6SAndroid Build Coastguard Worker            proxy_out.node.meta["stack_trace"] = json.dumps(get_stacktrace())
354*523fa7a6SAndroid Build Coastguard Worker
355*523fa7a6SAndroid Build Coastguard Worker        # Wrap the output tensors with the PythonTensor subclass to propagate to
356*523fa7a6SAndroid Build Coastguard Worker        # future tracing
357*523fa7a6SAndroid Build Coastguard Worker        def wrap_with_proxy(e: LeafValue, proxy: torch.fx.Proxy) -> LeafValue:
358*523fa7a6SAndroid Build Coastguard Worker            # Some ops (like native_batch_norm_backward) return undefined tensors that get
359*523fa7a6SAndroid Build Coastguard Worker            # converted into None in python.
360*523fa7a6SAndroid Build Coastguard Worker            # As the function signature expects tensors, if we directly return these None
361*523fa7a6SAndroid Build Coastguard Worker            # tensors back to C++, we'll error.
362*523fa7a6SAndroid Build Coastguard Worker            if e is None:
363*523fa7a6SAndroid Build Coastguard Worker                e = torch.empty(())
364*523fa7a6SAndroid Build Coastguard Worker
365*523fa7a6SAndroid Build Coastguard Worker            if isinstance(e, torch.Tensor):
366*523fa7a6SAndroid Build Coastguard Worker                return PythonTensor(e, proxy)
367*523fa7a6SAndroid Build Coastguard Worker
368*523fa7a6SAndroid Build Coastguard Worker            # Inplace and out-variant ops may return one of their arguments, which is already
369*523fa7a6SAndroid Build Coastguard Worker            # a PythonTensor. In this case, we need to update the PythonTensor's associated
370*523fa7a6SAndroid Build Coastguard Worker            # proxy to the newly created proxy.
371*523fa7a6SAndroid Build Coastguard Worker            if isinstance(e, PythonTensor):
372*523fa7a6SAndroid Build Coastguard Worker                e.update_proxy(proxy)
373*523fa7a6SAndroid Build Coastguard Worker                return e
374*523fa7a6SAndroid Build Coastguard Worker
375*523fa7a6SAndroid Build Coastguard Worker            return e
376*523fa7a6SAndroid Build Coastguard Worker
377*523fa7a6SAndroid Build Coastguard Worker        retval = None
378*523fa7a6SAndroid Build Coastguard Worker        if not isinstance(real_out, (list, tuple)):
379*523fa7a6SAndroid Build Coastguard Worker            retval = single_return(real_out, proxy_out, wrap_with_proxy)
380*523fa7a6SAndroid Build Coastguard Worker        else:
381*523fa7a6SAndroid Build Coastguard Worker            retval = tree_return(real_out, proxy_out, wrap_with_proxy, type(real_out))
382*523fa7a6SAndroid Build Coastguard Worker        return retval
383*523fa7a6SAndroid Build Coastguard Worker
384*523fa7a6SAndroid Build Coastguard Worker
385*523fa7a6SAndroid Build Coastguard Worker@contextmanager
386*523fa7a6SAndroid Build Coastguard Workerdef using_tracer(tracer: Optional["DispatchTracer"]) -> Generator[None, None, None]:
387*523fa7a6SAndroid Build Coastguard Worker    """
388*523fa7a6SAndroid Build Coastguard Worker    Set the "current" global tracer within the scope of using_tracer
389*523fa7a6SAndroid Build Coastguard Worker    context manager.
390*523fa7a6SAndroid Build Coastguard Worker
391*523fa7a6SAndroid Build Coastguard Worker    Since various things we want to capture today with torch_dispatch
392*523fa7a6SAndroid Build Coastguard Worker    does not "trap" into dispatcher really (for example, cond() and
393*523fa7a6SAndroid Build Coastguard Worker    shape()), we need a separate singleton tracer exposed to user space
394*523fa7a6SAndroid Build Coastguard Worker    in addition to Dispatcher to trigger graph capturing.
395*523fa7a6SAndroid Build Coastguard Worker    """
396*523fa7a6SAndroid Build Coastguard Worker    global TRACER
397*523fa7a6SAndroid Build Coastguard Worker    TRACER, prev = tracer, TRACER
398*523fa7a6SAndroid Build Coastguard Worker    try:
399*523fa7a6SAndroid Build Coastguard Worker        yield
400*523fa7a6SAndroid Build Coastguard Worker    finally:
401*523fa7a6SAndroid Build Coastguard Worker        TRACER = prev
402*523fa7a6SAndroid Build Coastguard Worker
403*523fa7a6SAndroid Build Coastguard Worker
404*523fa7a6SAndroid Build Coastguard Workerclass DispatchTracer(fx.Tracer):
405*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
406*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
407*523fa7a6SAndroid Build Coastguard Worker        self.root: torch.nn.Module = torch.nn.Module()
408*523fa7a6SAndroid Build Coastguard Worker        self.tensor_attrs: Dict[torch.Tensor, str] = {}
409*523fa7a6SAndroid Build Coastguard Worker        self.submodules: Dict[fx.GraphModule, str] = {}
410*523fa7a6SAndroid Build Coastguard Worker
411*523fa7a6SAndroid Build Coastguard Worker    def call_module(
412*523fa7a6SAndroid Build Coastguard Worker        self,
413*523fa7a6SAndroid Build Coastguard Worker        m: torch.nn.Module,
414*523fa7a6SAndroid Build Coastguard Worker        forward: Callable[..., Value],
415*523fa7a6SAndroid Build Coastguard Worker        args: Tuple[Value, ...],
416*523fa7a6SAndroid Build Coastguard Worker        kwargs: Dict[str, Value],
417*523fa7a6SAndroid Build Coastguard Worker    ) -> Value:
418*523fa7a6SAndroid Build Coastguard Worker        return forward(*args, **kwargs)
419*523fa7a6SAndroid Build Coastguard Worker
420*523fa7a6SAndroid Build Coastguard Worker    def _module_getattr(
421*523fa7a6SAndroid Build Coastguard Worker        self, attr: str, attr_val: Value, parameter_proxy_cache: Dict[str, torch.Tensor]
422*523fa7a6SAndroid Build Coastguard Worker    ) -> Value:
423*523fa7a6SAndroid Build Coastguard Worker        if isinstance(attr_val, torch.nn.Parameter):
424*523fa7a6SAndroid Build Coastguard Worker            for n, p in self.root.named_parameters():
425*523fa7a6SAndroid Build Coastguard Worker                if attr_val is p:
426*523fa7a6SAndroid Build Coastguard Worker                    if n not in parameter_proxy_cache:
427*523fa7a6SAndroid Build Coastguard Worker                        proxy = self.create_proxy("get_attr", n, (), {})
428*523fa7a6SAndroid Build Coastguard Worker                        parameter_proxy_cache[n] = PythonTensor(attr_val, proxy)
429*523fa7a6SAndroid Build Coastguard Worker                    return parameter_proxy_cache[n]
430*523fa7a6SAndroid Build Coastguard Worker            return attr_val
431*523fa7a6SAndroid Build Coastguard Worker        return attr_val
432*523fa7a6SAndroid Build Coastguard Worker
433*523fa7a6SAndroid Build Coastguard Worker    def create_arg(self, a: Value) -> torch.fx.Node:  # noqa: C901
434*523fa7a6SAndroid Build Coastguard Worker        if isinstance(a, torch.nn.Parameter):
435*523fa7a6SAndroid Build Coastguard Worker            for n, p in self.root.named_parameters():
436*523fa7a6SAndroid Build Coastguard Worker                if a is p:
437*523fa7a6SAndroid Build Coastguard Worker                    return self.create_node("get_attr", n, (), {})
438*523fa7a6SAndroid Build Coastguard Worker            qualname: Optional[str] = None
439*523fa7a6SAndroid Build Coastguard Worker
440*523fa7a6SAndroid Build Coastguard Worker            if not qualname:
441*523fa7a6SAndroid Build Coastguard Worker                i = 0
442*523fa7a6SAndroid Build Coastguard Worker                while True:
443*523fa7a6SAndroid Build Coastguard Worker                    qualname = f"_param_constant{i}"
444*523fa7a6SAndroid Build Coastguard Worker                    if not hasattr(self.root, qualname):
445*523fa7a6SAndroid Build Coastguard Worker                        break
446*523fa7a6SAndroid Build Coastguard Worker                    i += 1
447*523fa7a6SAndroid Build Coastguard Worker                setattr(self.root, qualname, a)
448*523fa7a6SAndroid Build Coastguard Worker
449*523fa7a6SAndroid Build Coastguard Worker            return self.create_node("get_attr", qualname, (), {})
450*523fa7a6SAndroid Build Coastguard Worker
451*523fa7a6SAndroid Build Coastguard Worker        if isinstance(a, torch.Tensor):
452*523fa7a6SAndroid Build Coastguard Worker            qualname: Optional[str] = self.tensor_attrs.get(a)
453*523fa7a6SAndroid Build Coastguard Worker
454*523fa7a6SAndroid Build Coastguard Worker            if not qualname:
455*523fa7a6SAndroid Build Coastguard Worker                i = 0
456*523fa7a6SAndroid Build Coastguard Worker                while True:
457*523fa7a6SAndroid Build Coastguard Worker                    qualname = f"_tensor_constant{i}"
458*523fa7a6SAndroid Build Coastguard Worker                    if not hasattr(self.root, qualname):
459*523fa7a6SAndroid Build Coastguard Worker                        break
460*523fa7a6SAndroid Build Coastguard Worker                    i += 1
461*523fa7a6SAndroid Build Coastguard Worker                self.tensor_attrs[a] = qualname
462*523fa7a6SAndroid Build Coastguard Worker                self.root.register_buffer(qualname, a)
463*523fa7a6SAndroid Build Coastguard Worker
464*523fa7a6SAndroid Build Coastguard Worker            return self.create_node("get_attr", qualname, (), {})
465*523fa7a6SAndroid Build Coastguard Worker
466*523fa7a6SAndroid Build Coastguard Worker        # higher-order operator
467*523fa7a6SAndroid Build Coastguard Worker        if isinstance(a, fx.GraphModule):
468*523fa7a6SAndroid Build Coastguard Worker            if a not in self.submodules:
469*523fa7a6SAndroid Build Coastguard Worker                name_submodule = f"submodule_{len(self.submodules)}"
470*523fa7a6SAndroid Build Coastguard Worker                self.root.add_module(name_submodule, a)
471*523fa7a6SAndroid Build Coastguard Worker                self.submodules[a] = name_submodule
472*523fa7a6SAndroid Build Coastguard Worker            return self.create_node("get_attr", self.submodules[a], (), {})
473*523fa7a6SAndroid Build Coastguard Worker
474*523fa7a6SAndroid Build Coastguard Worker        return super().create_arg(a)  # pyre-fixme[7]
475*523fa7a6SAndroid Build Coastguard Worker
476*523fa7a6SAndroid Build Coastguard Worker    @staticmethod
477*523fa7a6SAndroid Build Coastguard Worker    def get() -> "DispatchTracer":
478*523fa7a6SAndroid Build Coastguard Worker        return TRACER
479*523fa7a6SAndroid Build Coastguard Worker
480*523fa7a6SAndroid Build Coastguard Worker    def trace(  # pyre-fixme[14,15]
481*523fa7a6SAndroid Build Coastguard Worker        self,
482*523fa7a6SAndroid Build Coastguard Worker        root: Callable[..., Value],
483*523fa7a6SAndroid Build Coastguard Worker        concrete_args: Tuple[Value, ...] = (),
484*523fa7a6SAndroid Build Coastguard Worker        in_spec: Optional[TreeSpec] = None,
485*523fa7a6SAndroid Build Coastguard Worker    ) -> Value:
486*523fa7a6SAndroid Build Coastguard Worker        """
487*523fa7a6SAndroid Build Coastguard Worker        Traces the given graph module.
488*523fa7a6SAndroid Build Coastguard Worker        """
489*523fa7a6SAndroid Build Coastguard Worker        with using_tracer(self):
490*523fa7a6SAndroid Build Coastguard Worker            return self._trace(root, concrete_args=concrete_args, in_spec=in_spec)
491*523fa7a6SAndroid Build Coastguard Worker
492*523fa7a6SAndroid Build Coastguard Worker    def _trace(
493*523fa7a6SAndroid Build Coastguard Worker        self,
494*523fa7a6SAndroid Build Coastguard Worker        root: Callable[..., Value],
495*523fa7a6SAndroid Build Coastguard Worker        concrete_args: Tuple[Value, ...],
496*523fa7a6SAndroid Build Coastguard Worker        in_spec: Optional[TreeSpec],
497*523fa7a6SAndroid Build Coastguard Worker    ) -> Value:
498*523fa7a6SAndroid Build Coastguard Worker        self.root = torch.nn.Module()
499*523fa7a6SAndroid Build Coastguard Worker        root_fn = root
500*523fa7a6SAndroid Build Coastguard Worker
501*523fa7a6SAndroid Build Coastguard Worker        tracer_cls = getattr(self, "__class__", None)
502*523fa7a6SAndroid Build Coastguard Worker        self.graph = fx.Graph(tracer_cls=tracer_cls)
503*523fa7a6SAndroid Build Coastguard Worker        # Don't support module, so tensor_attrs is always empty
504*523fa7a6SAndroid Build Coastguard Worker        self.tensor_attrs = {}
505*523fa7a6SAndroid Build Coastguard Worker
506*523fa7a6SAndroid Build Coastguard Worker        # Wrap all inputs as a PythonTensor subclass and insert them into the FX
507*523fa7a6SAndroid Build Coastguard Worker        # graph as placeholder nodes
508*523fa7a6SAndroid Build Coastguard Worker        def wrap(arg: Value, i: int) -> Value:
509*523fa7a6SAndroid Build Coastguard Worker            placeholder = self.create_proxy("placeholder", f"ph_{i}", (), {})
510*523fa7a6SAndroid Build Coastguard Worker            if isinstance(arg, torch.Tensor):
511*523fa7a6SAndroid Build Coastguard Worker                return PythonTensor(arg, placeholder, is_immutable=True)
512*523fa7a6SAndroid Build Coastguard Worker            else:
513*523fa7a6SAndroid Build Coastguard Worker                # torch._assert(
514*523fa7a6SAndroid Build Coastguard Worker                #     placeholder == arg,
515*523fa7a6SAndroid Build Coastguard Worker                #     f"ph_{i} has been specialized to have value {arg}",
516*523fa7a6SAndroid Build Coastguard Worker                # )
517*523fa7a6SAndroid Build Coastguard Worker                return arg
518*523fa7a6SAndroid Build Coastguard Worker
519*523fa7a6SAndroid Build Coastguard Worker        tree_args = [wrap(arg, i) for i, arg in enumerate(concrete_args)]
520*523fa7a6SAndroid Build Coastguard Worker        if in_spec:
521*523fa7a6SAndroid Build Coastguard Worker            tree_args = pytree.tree_unflatten(tree_args, in_spec)
522*523fa7a6SAndroid Build Coastguard Worker
523*523fa7a6SAndroid Build Coastguard Worker        tree_out = root_fn(*tree_args)
524*523fa7a6SAndroid Build Coastguard Worker
525*523fa7a6SAndroid Build Coastguard Worker        out_args, _ = pytree.tree_flatten(tree_out)
526*523fa7a6SAndroid Build Coastguard Worker
527*523fa7a6SAndroid Build Coastguard Worker        def unwrap(out: LeafValue) -> Union[LeafValue, torch.fx.Proxy]:
528*523fa7a6SAndroid Build Coastguard Worker            # it's legit for a model to return a list of items some of which
529*523fa7a6SAndroid Build Coastguard Worker            # are None
530*523fa7a6SAndroid Build Coastguard Worker            if out is None:
531*523fa7a6SAndroid Build Coastguard Worker                return None
532*523fa7a6SAndroid Build Coastguard Worker            if not isinstance(out, torch.Tensor):
533*523fa7a6SAndroid Build Coastguard Worker                raise TypeError(
534*523fa7a6SAndroid Build Coastguard Worker                    f"Expect model to return torch.Tensor, got type: '{type(out)}' (value: {out})."
535*523fa7a6SAndroid Build Coastguard Worker                )
536*523fa7a6SAndroid Build Coastguard Worker            return unwrap_proxy(out)
537*523fa7a6SAndroid Build Coastguard Worker
538*523fa7a6SAndroid Build Coastguard Worker        returns = [unwrap(out) for out in out_args]
539*523fa7a6SAndroid Build Coastguard Worker
540*523fa7a6SAndroid Build Coastguard Worker        return_annotation = None
541*523fa7a6SAndroid Build Coastguard Worker        # some ops like torch.sub doesn't have annotations
542*523fa7a6SAndroid Build Coastguard Worker        if hasattr(root_fn, "__annotations__"):
543*523fa7a6SAndroid Build Coastguard Worker            return_annotation = root_fn.__annotations__.get("return", None)
544*523fa7a6SAndroid Build Coastguard Worker
545*523fa7a6SAndroid Build Coastguard Worker        self.create_proxy(
546*523fa7a6SAndroid Build Coastguard Worker            "output",
547*523fa7a6SAndroid Build Coastguard Worker            "output",
548*523fa7a6SAndroid Build Coastguard Worker            (returns,),
549*523fa7a6SAndroid Build Coastguard Worker            {},
550*523fa7a6SAndroid Build Coastguard Worker            type_expr=return_annotation,
551*523fa7a6SAndroid Build Coastguard Worker        )
552*523fa7a6SAndroid Build Coastguard Worker
553*523fa7a6SAndroid Build Coastguard Worker        self.submodule_paths = None
554*523fa7a6SAndroid Build Coastguard Worker
555*523fa7a6SAndroid Build Coastguard Worker        return tree_out
556*523fa7a6SAndroid Build Coastguard Worker
557*523fa7a6SAndroid Build Coastguard Worker
558*523fa7a6SAndroid Build Coastguard WorkerTRACER: Optional[DispatchTracer] = None
559*523fa7a6SAndroid Build Coastguard WorkerTORCHDYNAMO_ENABLED: bool = False
560*523fa7a6SAndroid Build Coastguard Worker
561*523fa7a6SAndroid Build Coastguard Worker
562*523fa7a6SAndroid Build Coastguard Worker@contextmanager
563*523fa7a6SAndroid Build Coastguard Workerdef using_dynamo(val: bool) -> Generator[None, None, None]:
564*523fa7a6SAndroid Build Coastguard Worker    global TORCHDYNAMO_ENABLED
565*523fa7a6SAndroid Build Coastguard Worker    TORCHDYNAMO_ENABLED, prev = val, TORCHDYNAMO_ENABLED
566*523fa7a6SAndroid Build Coastguard Worker    try:
567*523fa7a6SAndroid Build Coastguard Worker        yield
568*523fa7a6SAndroid Build Coastguard Worker    finally:
569*523fa7a6SAndroid Build Coastguard Worker        TORCHDYNAMO_ENABLED = prev
570*523fa7a6SAndroid Build Coastguard Worker
571*523fa7a6SAndroid Build Coastguard Worker
572*523fa7a6SAndroid Build Coastguard Workerdef flattened_dispatch_trace(
573*523fa7a6SAndroid Build Coastguard Worker    f: Callable[..., Value],
574*523fa7a6SAndroid Build Coastguard Worker    args: Tuple[LeafValue, ...],
575*523fa7a6SAndroid Build Coastguard Worker    guards: Set[Guard],
576*523fa7a6SAndroid Build Coastguard Worker    in_spec: Optional[TreeSpec] = None,
577*523fa7a6SAndroid Build Coastguard Worker    enable_functionalization: bool = True,
578*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[torch.fx.GraphModule, Value]:
579*523fa7a6SAndroid Build Coastguard Worker    if not isinstance(args, tuple):
580*523fa7a6SAndroid Build Coastguard Worker        raise TypeError(f"Expecting 'args' to be a tuple, got: {type(args)}")
581*523fa7a6SAndroid Build Coastguard Worker
582*523fa7a6SAndroid Build Coastguard Worker    tracer = DispatchTracer()
583*523fa7a6SAndroid Build Coastguard Worker
584*523fa7a6SAndroid Build Coastguard Worker    if enable_functionalization:
585*523fa7a6SAndroid Build Coastguard Worker        f = functionalize(f, remove="mutations_and_views")
586*523fa7a6SAndroid Build Coastguard Worker    tree_out = tracer.trace(f, concrete_args=args, in_spec=in_spec)
587*523fa7a6SAndroid Build Coastguard Worker
588*523fa7a6SAndroid Build Coastguard Worker    name = type(f).__name__ if isinstance(f, torch.nn.Module) else f.__name__
589*523fa7a6SAndroid Build Coastguard Worker    gm = torch.fx.GraphModule(tracer.root, tracer.graph, name)
590*523fa7a6SAndroid Build Coastguard Worker
591*523fa7a6SAndroid Build Coastguard Worker    return (gm, tree_out)
592*523fa7a6SAndroid Build Coastguard Worker
593*523fa7a6SAndroid Build Coastguard Worker
594*523fa7a6SAndroid Build Coastguard Worker@dataclass
595*523fa7a6SAndroid Build Coastguard Workerclass ExirDynamoConfig:
596*523fa7a6SAndroid Build Coastguard Worker    """
597*523fa7a6SAndroid Build Coastguard Worker    Manage Exir-specific configurations of Dynamo.
598*523fa7a6SAndroid Build Coastguard Worker    """
599*523fa7a6SAndroid Build Coastguard Worker
600*523fa7a6SAndroid Build Coastguard Worker    allow_rnn: bool = True
601*523fa7a6SAndroid Build Coastguard Worker    verbose: bool = True
602*523fa7a6SAndroid Build Coastguard Worker    assume_static_by_default: bool = False
603*523fa7a6SAndroid Build Coastguard Worker
604*523fa7a6SAndroid Build Coastguard Worker
605*523fa7a6SAndroid Build Coastguard Workerdef flatten_output(gm: torch.fx.GraphModule) -> None:
606*523fa7a6SAndroid Build Coastguard Worker    """
607*523fa7a6SAndroid Build Coastguard Worker    Modifies the output nodes in the submodules to return the result
608*523fa7a6SAndroid Build Coastguard Worker    as a flattened list. This keeps it consistent with the result of
609*523fa7a6SAndroid Build Coastguard Worker    EXIR's tracer
610*523fa7a6SAndroid Build Coastguard Worker    """
611*523fa7a6SAndroid Build Coastguard Worker    for node in reversed(gm.graph.nodes):
612*523fa7a6SAndroid Build Coastguard Worker        if node.op == "output":
613*523fa7a6SAndroid Build Coastguard Worker            assert len(node.args) == 1
614*523fa7a6SAndroid Build Coastguard Worker            outputs = node.args[0]
615*523fa7a6SAndroid Build Coastguard Worker            returns, _ = pytree.tree_flatten(outputs)
616*523fa7a6SAndroid Build Coastguard Worker            node.args = (returns,)
617*523fa7a6SAndroid Build Coastguard Worker            return
618*523fa7a6SAndroid Build Coastguard Worker    raise RuntimeError(f"Could not find an output node in {gm.graph}")
619*523fa7a6SAndroid Build Coastguard Worker
620*523fa7a6SAndroid Build Coastguard Worker
621*523fa7a6SAndroid Build Coastguard Workerdef _default_decomposition_table(
622*523fa7a6SAndroid Build Coastguard Worker    _use_old_decomp_table=False,
623*523fa7a6SAndroid Build Coastguard Worker) -> Dict[torch._ops.OpOverload, Callable[..., Value]]:
624*523fa7a6SAndroid Build Coastguard Worker    if _use_old_decomp_table:
625*523fa7a6SAndroid Build Coastguard Worker        decomp_opset = [
626*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.log_sigmoid_forward,
627*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.ones,
628*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.arange.default,
629*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.arange.start,
630*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.transpose,
631*523fa7a6SAndroid Build Coastguard Worker        ]
632*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e...
633*523fa7a6SAndroid Build Coastguard Worker        return get_decompositions(decomp_opset)
634*523fa7a6SAndroid Build Coastguard Worker    # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
635*523fa7a6SAndroid Build Coastguard Worker    return default_decompositions()
636*523fa7a6SAndroid Build Coastguard Worker
637*523fa7a6SAndroid Build Coastguard Worker
638*523fa7a6SAndroid Build Coastguard Workerdef dynamo_trace(
639*523fa7a6SAndroid Build Coastguard Worker    f: Callable[..., Value],
640*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
641*523fa7a6SAndroid Build Coastguard Worker    args: Tuple[Any, ...],
642*523fa7a6SAndroid Build Coastguard Worker    aten_graph: bool,
643*523fa7a6SAndroid Build Coastguard Worker    tracing_mode: str = "real",
644*523fa7a6SAndroid Build Coastguard Worker    dynamo_config: Optional[ExirDynamoConfig] = None,
645*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
646*523fa7a6SAndroid Build Coastguard Worker    dynamic_shapes: Optional[List[Any]] = None,
647*523fa7a6SAndroid Build Coastguard Worker    _use_old_decomp_table: bool = False,
648*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[torch.fx.GraphModule, Set[Guard]]:
649*523fa7a6SAndroid Build Coastguard Worker    """
650*523fa7a6SAndroid Build Coastguard Worker    TODO: Once we fully migrate to torchdynamo frontend, we will remove
651*523fa7a6SAndroid Build Coastguard Worker    this config option alltogether.  For now, it helps with quick
652*523fa7a6SAndroid Build Coastguard Worker    experiments with playing around with TorchDynamo
653*523fa7a6SAndroid Build Coastguard Worker    """
654*523fa7a6SAndroid Build Coastguard Worker    if dynamo_config is None:
655*523fa7a6SAndroid Build Coastguard Worker        dynamo_config = ExirDynamoConfig()
656*523fa7a6SAndroid Build Coastguard Worker
657*523fa7a6SAndroid Build Coastguard Worker    with torchdynamo.config.patch(
658*523fa7a6SAndroid Build Coastguard Worker        asdict(dynamo_config)
659*523fa7a6SAndroid Build Coastguard Worker    ), setting_python_recursive_limit(2000):
660*523fa7a6SAndroid Build Coastguard Worker        torchdynamo.reset()
661*523fa7a6SAndroid Build Coastguard Worker        try:
662*523fa7a6SAndroid Build Coastguard Worker            # TODO merge executorch functionalization with official
663*523fa7a6SAndroid Build Coastguard Worker            # functionalization
664*523fa7a6SAndroid Build Coastguard Worker            # pyre-fixme[7]: Expected `Tuple[GraphModule, Set[Guard]]` but got
665*523fa7a6SAndroid Build Coastguard Worker            #  `ExportResult`.
666*523fa7a6SAndroid Build Coastguard Worker            return torchdynamo.export(
667*523fa7a6SAndroid Build Coastguard Worker                f,
668*523fa7a6SAndroid Build Coastguard Worker                aten_graph=aten_graph,
669*523fa7a6SAndroid Build Coastguard Worker                tracing_mode=tracing_mode,
670*523fa7a6SAndroid Build Coastguard Worker                assume_static_by_default=dynamo_config.assume_static_by_default,
671*523fa7a6SAndroid Build Coastguard Worker                decomposition_table=(
672*523fa7a6SAndroid Build Coastguard Worker                    _default_decomposition_table(_use_old_decomp_table)
673*523fa7a6SAndroid Build Coastguard Worker                    if aten_graph
674*523fa7a6SAndroid Build Coastguard Worker                    else None
675*523fa7a6SAndroid Build Coastguard Worker                ),
676*523fa7a6SAndroid Build Coastguard Worker                dynamic_shapes=dynamic_shapes,
677*523fa7a6SAndroid Build Coastguard Worker            )(
678*523fa7a6SAndroid Build Coastguard Worker                *copy.deepcopy(args),
679*523fa7a6SAndroid Build Coastguard Worker            )
680*523fa7a6SAndroid Build Coastguard Worker        except torchdynamo.exc.Unsupported as exc:
681*523fa7a6SAndroid Build Coastguard Worker            raise ExportError(
682*523fa7a6SAndroid Build Coastguard Worker                ExportErrorType.NOT_SUPPORTED,
683*523fa7a6SAndroid Build Coastguard Worker                "The user code is using a feature we don't support. "
684*523fa7a6SAndroid Build Coastguard Worker                "Please try torchdynamo.explain() to get possible the reasons",
685*523fa7a6SAndroid Build Coastguard Worker            ) from exc
686*523fa7a6SAndroid Build Coastguard Worker        except Exception as exc:
687*523fa7a6SAndroid Build Coastguard Worker            raise InternalError(
688*523fa7a6SAndroid Build Coastguard Worker                "torchdynamo internal error occured. Please see above stacktrace"
689*523fa7a6SAndroid Build Coastguard Worker            ) from exc
690*523fa7a6SAndroid Build Coastguard Worker
691*523fa7a6SAndroid Build Coastguard Worker
692*523fa7a6SAndroid Build Coastguard Workerdef dispatch_trace(
693*523fa7a6SAndroid Build Coastguard Worker    f: Callable[..., Value],
694*523fa7a6SAndroid Build Coastguard Worker    args: Tuple[Value, ...],
695*523fa7a6SAndroid Build Coastguard Worker) -> torch.fx.GraphModule:
696*523fa7a6SAndroid Build Coastguard Worker    """
697*523fa7a6SAndroid Build Coastguard Worker    Executes a given callable `f` with a given tuple of arguments. During
698*523fa7a6SAndroid Build Coastguard Worker    execution, Tensor operations are recorded in a fx.GraphModule, which is then
699*523fa7a6SAndroid Build Coastguard Worker    returned.
700*523fa7a6SAndroid Build Coastguard Worker
701*523fa7a6SAndroid Build Coastguard Worker    Args:
702*523fa7a6SAndroid Build Coastguard Worker        f: A `nn.Module` or a Python function that implements an ML program.
703*523fa7a6SAndroid Build Coastguard Worker        args: A tuple of arguments of any type to be used as inputs for the tracing run.
704*523fa7a6SAndroid Build Coastguard Worker
705*523fa7a6SAndroid Build Coastguard Worker    Returns:
706*523fa7a6SAndroid Build Coastguard Worker        EXIR contained in a fx.GraphModule
707*523fa7a6SAndroid Build Coastguard Worker    """
708*523fa7a6SAndroid Build Coastguard Worker    trace_func = f
709*523fa7a6SAndroid Build Coastguard Worker    guards = set()
710*523fa7a6SAndroid Build Coastguard Worker    if TORCHDYNAMO_ENABLED:
711*523fa7a6SAndroid Build Coastguard Worker        # Copying args is safer in case downstream implementations of trace_func mutate them
712*523fa7a6SAndroid Build Coastguard Worker        trace_func, guards = dynamo_trace(trace_func, args, False)
713*523fa7a6SAndroid Build Coastguard Worker
714*523fa7a6SAndroid Build Coastguard Worker    # Copying args is safer in case downstream implementations of trace_func mutate them
715*523fa7a6SAndroid Build Coastguard Worker    trace_args, in_spec = pytree.tree_flatten(args)
716*523fa7a6SAndroid Build Coastguard Worker
717*523fa7a6SAndroid Build Coastguard Worker    in_args = copy.deepcopy(tuple(trace_args))
718*523fa7a6SAndroid Build Coastguard Worker    gm, tree_out = flattened_dispatch_trace(
719*523fa7a6SAndroid Build Coastguard Worker        trace_func,
720*523fa7a6SAndroid Build Coastguard Worker        in_args,
721*523fa7a6SAndroid Build Coastguard Worker        guards,
722*523fa7a6SAndroid Build Coastguard Worker        in_spec,
723*523fa7a6SAndroid Build Coastguard Worker        enable_functionalization=False,
724*523fa7a6SAndroid Build Coastguard Worker    )
725*523fa7a6SAndroid Build Coastguard Worker
726*523fa7a6SAndroid Build Coastguard Worker    _, out_spec = pytree.tree_flatten(tree_out)
727*523fa7a6SAndroid Build Coastguard Worker
728*523fa7a6SAndroid Build Coastguard Worker    # pyre-fixme[16]: `GraphModule` has no attribute `in_spec`.
729*523fa7a6SAndroid Build Coastguard Worker    gm.in_spec = in_spec
730*523fa7a6SAndroid Build Coastguard Worker    # pyre-fixme[16]: `GraphModule` has no attribute `out_spec`.
731*523fa7a6SAndroid Build Coastguard Worker    gm.out_spec = out_spec
732*523fa7a6SAndroid Build Coastguard Worker
733*523fa7a6SAndroid Build Coastguard Worker    # TODO (tmanlaibaatar) This is bit clowny, but our
734*523fa7a6SAndroid Build Coastguard Worker    # dispatch_trace sometimes creates unused node that
735*523fa7a6SAndroid Build Coastguard Worker    # breaks functionalization. it seems too much trouble
736*523fa7a6SAndroid Build Coastguard Worker    # to fix it properly since dispatch_trace will be deprecated soon.
737*523fa7a6SAndroid Build Coastguard Worker    # Basically dispatch_trace struggles on:
738*523fa7a6SAndroid Build Coastguard Worker    # def f(x: torch.Tensor) -> torch.Tensor:
739*523fa7a6SAndroid Build Coastguard Worker    #    return torch.ones(6, dtype=x.dtype)
740*523fa7a6SAndroid Build Coastguard Worker    changed = gm.graph.eliminate_dead_code()
741*523fa7a6SAndroid Build Coastguard Worker    if changed:
742*523fa7a6SAndroid Build Coastguard Worker        gm.recompile()
743*523fa7a6SAndroid Build Coastguard Worker
744*523fa7a6SAndroid Build Coastguard Worker    in_args = copy.deepcopy(tuple(trace_args))
745*523fa7a6SAndroid Build Coastguard Worker    assert callable(gm)
746*523fa7a6SAndroid Build Coastguard Worker
747*523fa7a6SAndroid Build Coastguard Worker    # This wrapper is used for preserving the stacktrace
748*523fa7a6SAndroid Build Coastguard Worker    # during second round of tracing.
749*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
750*523fa7a6SAndroid Build Coastguard Worker    def graph_with_interpreter(*args):
751*523fa7a6SAndroid Build Coastguard Worker        try:
752*523fa7a6SAndroid Build Coastguard Worker            args = fx_pytree.tree_flatten_spec(args, gm.in_spec)  # type: ignore[assignment]
753*523fa7a6SAndroid Build Coastguard Worker        except Exception:
754*523fa7a6SAndroid Build Coastguard Worker            _, received_spec = pytree.tree_flatten(args)
755*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(
756*523fa7a6SAndroid Build Coastguard Worker                "Trying to flatten user inputs with exported input tree spec: \n"
757*523fa7a6SAndroid Build Coastguard Worker                f"{gm.in_spec}\n"
758*523fa7a6SAndroid Build Coastguard Worker                "but actually got inputs with tree spec of: \n"
759*523fa7a6SAndroid Build Coastguard Worker                f"{received_spec}"
760*523fa7a6SAndroid Build Coastguard Worker            )
761*523fa7a6SAndroid Build Coastguard Worker        with torch.fx.traceback.preserve_node_meta():
762*523fa7a6SAndroid Build Coastguard Worker            res = gm(*args)
763*523fa7a6SAndroid Build Coastguard Worker
764*523fa7a6SAndroid Build Coastguard Worker        if gm.out_spec is not None:
765*523fa7a6SAndroid Build Coastguard Worker            try:
766*523fa7a6SAndroid Build Coastguard Worker                res = pytree.tree_unflatten(res, gm.out_spec)
767*523fa7a6SAndroid Build Coastguard Worker            except Exception:
768*523fa7a6SAndroid Build Coastguard Worker                _, received_spec = pytree.tree_flatten(res)
769*523fa7a6SAndroid Build Coastguard Worker                raise RuntimeError(
770*523fa7a6SAndroid Build Coastguard Worker                    "Trying to flatten user outputs with exported output tree spec: \n"
771*523fa7a6SAndroid Build Coastguard Worker                    f"{gm.out_spec}\n"
772*523fa7a6SAndroid Build Coastguard Worker                    "but actually got outputs with tree spec of: \n"
773*523fa7a6SAndroid Build Coastguard Worker                    f"{received_spec}"
774*523fa7a6SAndroid Build Coastguard Worker                )
775*523fa7a6SAndroid Build Coastguard Worker        return res
776*523fa7a6SAndroid Build Coastguard Worker
777*523fa7a6SAndroid Build Coastguard Worker    gm, tree_out = flattened_dispatch_trace(
778*523fa7a6SAndroid Build Coastguard Worker        graph_with_interpreter,
779*523fa7a6SAndroid Build Coastguard Worker        in_args,
780*523fa7a6SAndroid Build Coastguard Worker        guards,
781*523fa7a6SAndroid Build Coastguard Worker        in_spec,
782*523fa7a6SAndroid Build Coastguard Worker        enable_functionalization=True,
783*523fa7a6SAndroid Build Coastguard Worker    )
784*523fa7a6SAndroid Build Coastguard Worker
785*523fa7a6SAndroid Build Coastguard Worker    gm.in_spec = in_spec
786*523fa7a6SAndroid Build Coastguard Worker    gm.out_spec = out_spec
787*523fa7a6SAndroid Build Coastguard Worker
788*523fa7a6SAndroid Build Coastguard Worker    return gm
789