xref: /aosp_15_r20/external/executorch/exir/pass_base.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport operator
10*523fa7a6SAndroid Build Coastguard Workerimport traceback
11*523fa7a6SAndroid Build Coastguard Workerfrom contextlib import nullcontext
12*523fa7a6SAndroid Build Coastguard Workerfrom typing import (
13*523fa7a6SAndroid Build Coastguard Worker    Any,
14*523fa7a6SAndroid Build Coastguard Worker    Callable,
15*523fa7a6SAndroid Build Coastguard Worker    Dict,
16*523fa7a6SAndroid Build Coastguard Worker    List,
17*523fa7a6SAndroid Build Coastguard Worker    MutableMapping,
18*523fa7a6SAndroid Build Coastguard Worker    Optional,
19*523fa7a6SAndroid Build Coastguard Worker    Protocol,
20*523fa7a6SAndroid Build Coastguard Worker    runtime_checkable,
21*523fa7a6SAndroid Build Coastguard Worker    Set,
22*523fa7a6SAndroid Build Coastguard Worker    Tuple,
23*523fa7a6SAndroid Build Coastguard Worker    TypeVar,
24*523fa7a6SAndroid Build Coastguard Worker    Union,
25*523fa7a6SAndroid Build Coastguard Worker)
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Workerimport torch
28*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import memory
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.delegate import executorch_call_delegate, is_lowered_module
31*523fa7a6SAndroid Build Coastguard Worker
32*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload
33*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType
34*523fa7a6SAndroid Build Coastguard Workerfrom torch import fx
35*523fa7a6SAndroid Build Coastguard Workerfrom torch._dispatch.python import enable_python_dispatcher
36*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException
37*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensor
38*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
39*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import traceback as fx_traceback
40*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.experimental.proxy_tensor import PythonKeyTracer
41*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.graph import CodeGen
42*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.infra.pass_base import PassBase, PassResult
43*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
44*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
45*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils._pytree import PyTree
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard WorkerFn = Callable[..., Any]  # pyre-ignore
48*523fa7a6SAndroid Build Coastguard WorkerArgument = Any  # pyre-ignore
49*523fa7a6SAndroid Build Coastguard WorkerValue = Any  # pyre-ignore
50*523fa7a6SAndroid Build Coastguard WorkerNodeMetadataValue = Any  # pyre-ignore
51*523fa7a6SAndroid Build Coastguard WorkerK = TypeVar("K")
52*523fa7a6SAndroid Build Coastguard WorkerPassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker
55*523fa7a6SAndroid Build Coastguard Worker_TORCH_SYM_OPS: Set[Any] = {  # pyre-ignore
56*523fa7a6SAndroid Build Coastguard Worker    torch.sym_int,
57*523fa7a6SAndroid Build Coastguard Worker    torch.sym_float,
58*523fa7a6SAndroid Build Coastguard Worker    torch.sym_ite,
59*523fa7a6SAndroid Build Coastguard Worker    torch.sym_max,
60*523fa7a6SAndroid Build Coastguard Worker    torch.sym_min,
61*523fa7a6SAndroid Build Coastguard Worker    torch.sym_not,
62*523fa7a6SAndroid Build Coastguard Worker    torch.sym_sqrt,
63*523fa7a6SAndroid Build Coastguard Worker}
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard WorkerPROTECTED_KEYS: Set[str] = {
67*523fa7a6SAndroid Build Coastguard Worker    "val",
68*523fa7a6SAndroid Build Coastguard Worker    "stack_trace",
69*523fa7a6SAndroid Build Coastguard Worker    "nn_module_stack",
70*523fa7a6SAndroid Build Coastguard Worker    "debug_handle",
71*523fa7a6SAndroid Build Coastguard Worker    "tensor_meta",
72*523fa7a6SAndroid Build Coastguard Worker}
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Worker
75*523fa7a6SAndroid Build Coastguard Workerdef _unstack_pytree(xs) -> List[PyTree]:  # pyre-ignore
76*523fa7a6SAndroid Build Coastguard Worker    flat_xs, inspec = pytree.tree_flatten(xs)
77*523fa7a6SAndroid Build Coastguard Worker    if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
78*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker    if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
81*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(
82*523fa7a6SAndroid Build Coastguard Worker            f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
83*523fa7a6SAndroid Build Coastguard Worker        )
84*523fa7a6SAndroid Build Coastguard Worker
85*523fa7a6SAndroid Build Coastguard Worker    ctx = (
86*523fa7a6SAndroid Build Coastguard Worker        FunctionalTensorMode
87*523fa7a6SAndroid Build Coastguard Worker        if any(isinstance(x, FunctionalTensor) for x in flat_xs)
88*523fa7a6SAndroid Build Coastguard Worker        else nullcontext
89*523fa7a6SAndroid Build Coastguard Worker    )
90*523fa7a6SAndroid Build Coastguard Worker    with ctx():
91*523fa7a6SAndroid Build Coastguard Worker        a = zip(*flat_xs)
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Worker    pytrees = []
94*523fa7a6SAndroid Build Coastguard Worker    for tuple in a:
95*523fa7a6SAndroid Build Coastguard Worker        pytrees.append(pytree.tree_unflatten(tuple, inspec))
96*523fa7a6SAndroid Build Coastguard Worker    return pytrees
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Workerclass NodeMetadata:
100*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, data: Dict[str, Any]) -> None:
101*523fa7a6SAndroid Build Coastguard Worker        self.data: Dict[str, Any] = data.copy()
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Worker    def __getitem__(self, key: str) -> NodeMetadataValue:
104*523fa7a6SAndroid Build Coastguard Worker        return self.data[key]
105*523fa7a6SAndroid Build Coastguard Worker
106*523fa7a6SAndroid Build Coastguard Worker    def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue:
107*523fa7a6SAndroid Build Coastguard Worker        if key in PROTECTED_KEYS:
108*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(f"Could not override node key: {key}")
109*523fa7a6SAndroid Build Coastguard Worker        self.data[key] = value
110*523fa7a6SAndroid Build Coastguard Worker
111*523fa7a6SAndroid Build Coastguard Worker    def __contains__(self, key: str) -> bool:
112*523fa7a6SAndroid Build Coastguard Worker        return key in self.data
113*523fa7a6SAndroid Build Coastguard Worker
114*523fa7a6SAndroid Build Coastguard Worker    def copy(self) -> "NodeMetadata":
115*523fa7a6SAndroid Build Coastguard Worker        return NodeMetadata(self.data.copy())
116*523fa7a6SAndroid Build Coastguard Worker
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard Workerclass ProxyValue:
119*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
120*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, data, proxy: Union[torch.fx.Proxy, torch.fx.Node]):
121*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore
122*523fa7a6SAndroid Build Coastguard Worker        self.data = data
123*523fa7a6SAndroid Build Coastguard Worker        self.proxy_or_node = proxy
124*523fa7a6SAndroid Build Coastguard Worker
125*523fa7a6SAndroid Build Coastguard Worker    @property
126*523fa7a6SAndroid Build Coastguard Worker    def node(self) -> torch.fx.Node:
127*523fa7a6SAndroid Build Coastguard Worker        if isinstance(self.proxy_or_node, torch.fx.Node):
128*523fa7a6SAndroid Build Coastguard Worker            return self.proxy_or_node
129*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(self.proxy_or_node, torch.fx.Proxy)
130*523fa7a6SAndroid Build Coastguard Worker        return self.proxy_or_node.node
131*523fa7a6SAndroid Build Coastguard Worker
132*523fa7a6SAndroid Build Coastguard Worker    @property
133*523fa7a6SAndroid Build Coastguard Worker    def proxy(self) -> torch.fx.Proxy:
134*523fa7a6SAndroid Build Coastguard Worker        if not isinstance(self.proxy_or_node, torch.fx.Proxy):
135*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(
136*523fa7a6SAndroid Build Coastguard Worker                f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}"
137*523fa7a6SAndroid Build Coastguard Worker            )
138*523fa7a6SAndroid Build Coastguard Worker        return self.proxy_or_node
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker    def to_tensor(self) -> torch.Tensor:
141*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(self.data, torch.Tensor)
142*523fa7a6SAndroid Build Coastguard Worker        return self.data
143*523fa7a6SAndroid Build Coastguard Worker
144*523fa7a6SAndroid Build Coastguard Worker    def is_tensor(self) -> bool:
145*523fa7a6SAndroid Build Coastguard Worker        return isinstance(self.data, torch.Tensor)
146*523fa7a6SAndroid Build Coastguard Worker
147*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
148*523fa7a6SAndroid Build Coastguard Worker    def __iter__(self):
149*523fa7a6SAndroid Build Coastguard Worker        yield from self.data
150*523fa7a6SAndroid Build Coastguard Worker
151*523fa7a6SAndroid Build Coastguard Worker    def __bool__(self) -> bool:
152*523fa7a6SAndroid Build Coastguard Worker        return bool(self.data)
153*523fa7a6SAndroid Build Coastguard Worker
154*523fa7a6SAndroid Build Coastguard Worker
155*523fa7a6SAndroid Build Coastguard Workerclass ExportPassBaseError(RuntimeError):
156*523fa7a6SAndroid Build Coastguard Worker    pass
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Workerclass _ExportPassBase(PassBase):
160*523fa7a6SAndroid Build Coastguard Worker    """
161*523fa7a6SAndroid Build Coastguard Worker    Interpreter-based pass class to help users maintain the IR spec while writing
162*523fa7a6SAndroid Build Coastguard Worker    transformations.
163*523fa7a6SAndroid Build Coastguard Worker    """
164*523fa7a6SAndroid Build Coastguard Worker
165*523fa7a6SAndroid Build Coastguard Worker    @staticmethod
166*523fa7a6SAndroid Build Coastguard Worker    def _create_dummy_node_metadata() -> NodeMetadata:
167*523fa7a6SAndroid Build Coastguard Worker        return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
168*523fa7a6SAndroid Build Coastguard Worker
169*523fa7a6SAndroid Build Coastguard Worker    class ExportTracer(PythonKeyTracer):
170*523fa7a6SAndroid Build Coastguard Worker        def __init__(self, callback: "_ExportPassBase", codegen: CodeGen) -> None:
171*523fa7a6SAndroid Build Coastguard Worker            super().__init__()
172*523fa7a6SAndroid Build Coastguard Worker            self.callback = callback
173*523fa7a6SAndroid Build Coastguard Worker            self.root = torch.nn.Module()
174*523fa7a6SAndroid Build Coastguard Worker            self.graph = torch.fx.Graph()
175*523fa7a6SAndroid Build Coastguard Worker            self.graph.set_codegen(codegen)
176*523fa7a6SAndroid Build Coastguard Worker            self.tensor_attrs: Dict[str, torch.Tensor] = {}  # type: ignore[assignment]
177*523fa7a6SAndroid Build Coastguard Worker            self.fake_tensor_mode: Optional[FakeTensorMode] = None
178*523fa7a6SAndroid Build Coastguard Worker            self.submodules: Dict[torch.nn.Module, str] = {}
179*523fa7a6SAndroid Build Coastguard Worker
180*523fa7a6SAndroid Build Coastguard Worker        def trace(self) -> None:  # pyre-fixme[14,15]
181*523fa7a6SAndroid Build Coastguard Worker            raise ExportPassBaseError("ExportTracer doesn't support trace().")
182*523fa7a6SAndroid Build Coastguard Worker
183*523fa7a6SAndroid Build Coastguard Worker        def create_arg(self, a: Argument) -> torch.fx.Node:
184*523fa7a6SAndroid Build Coastguard Worker            if isinstance(a, torch.nn.Module):
185*523fa7a6SAndroid Build Coastguard Worker                if a not in self.submodules:
186*523fa7a6SAndroid Build Coastguard Worker                    name_submodule = f"submodule_{len(self.submodules)}"
187*523fa7a6SAndroid Build Coastguard Worker                    self.root.add_module(name_submodule, a)
188*523fa7a6SAndroid Build Coastguard Worker                    self.submodules[a] = name_submodule
189*523fa7a6SAndroid Build Coastguard Worker            elif isinstance(a, FakeTensor):
190*523fa7a6SAndroid Build Coastguard Worker                if not hasattr(a, "constant") or a.constant is None:
191*523fa7a6SAndroid Build Coastguard Worker                    raise ExportPassBaseError(f"Cannot add {a} to graph.")
192*523fa7a6SAndroid Build Coastguard Worker                a = a.constant
193*523fa7a6SAndroid Build Coastguard Worker            node = super().create_arg(a)
194*523fa7a6SAndroid Build Coastguard Worker            if (
195*523fa7a6SAndroid Build Coastguard Worker                isinstance(a, torch.Tensor)
196*523fa7a6SAndroid Build Coastguard Worker                and isinstance(node, torch.fx.Node)
197*523fa7a6SAndroid Build Coastguard Worker                and node.op == "get_attr"
198*523fa7a6SAndroid Build Coastguard Worker            ):
199*523fa7a6SAndroid Build Coastguard Worker                self.set_metadata(node, a)
200*523fa7a6SAndroid Build Coastguard Worker                self.callback.on_attr(ProxyValue(a, node))
201*523fa7a6SAndroid Build Coastguard Worker            return node
202*523fa7a6SAndroid Build Coastguard Worker
203*523fa7a6SAndroid Build Coastguard Worker        def set_metadata(  # noqa: C901
204*523fa7a6SAndroid Build Coastguard Worker            self,
205*523fa7a6SAndroid Build Coastguard Worker            node: torch.fx.Node,
206*523fa7a6SAndroid Build Coastguard Worker            value: Argument,
207*523fa7a6SAndroid Build Coastguard Worker        ) -> None:
208*523fa7a6SAndroid Build Coastguard Worker            # propagate the fake tensor or sym nodes
209*523fa7a6SAndroid Build Coastguard Worker            def make_val(
210*523fa7a6SAndroid Build Coastguard Worker                x: Argument,
211*523fa7a6SAndroid Build Coastguard Worker            ) -> Union[
212*523fa7a6SAndroid Build Coastguard Worker                FakeTensor,
213*523fa7a6SAndroid Build Coastguard Worker                torch.SymInt,
214*523fa7a6SAndroid Build Coastguard Worker                torch.SymFloat,
215*523fa7a6SAndroid Build Coastguard Worker                torch.SymBool,
216*523fa7a6SAndroid Build Coastguard Worker                int,
217*523fa7a6SAndroid Build Coastguard Worker                float,
218*523fa7a6SAndroid Build Coastguard Worker                bool,
219*523fa7a6SAndroid Build Coastguard Worker                str,
220*523fa7a6SAndroid Build Coastguard Worker                None,
221*523fa7a6SAndroid Build Coastguard Worker            ]:
222*523fa7a6SAndroid Build Coastguard Worker                if isinstance(x, FakeTensor):
223*523fa7a6SAndroid Build Coastguard Worker                    return x
224*523fa7a6SAndroid Build Coastguard Worker                elif isinstance(x, torch.Tensor):
225*523fa7a6SAndroid Build Coastguard Worker                    if x.is_quantized:
226*523fa7a6SAndroid Build Coastguard Worker                        # TODO (tmanlaibaatar) properly support Quantized FakeTensor
227*523fa7a6SAndroid Build Coastguard Worker                        x = torch.dequantize(x)
228*523fa7a6SAndroid Build Coastguard Worker
229*523fa7a6SAndroid Build Coastguard Worker                    try:
230*523fa7a6SAndroid Build Coastguard Worker                        assert self.fake_tensor_mode is not None
231*523fa7a6SAndroid Build Coastguard Worker                        # TODO we should allocate static shapes
232*523fa7a6SAndroid Build Coastguard Worker                        # for param/buffer values
233*523fa7a6SAndroid Build Coastguard Worker                        if isinstance(x, torch.nn.Parameter):
234*523fa7a6SAndroid Build Coastguard Worker                            fake_tensor = self.fake_tensor_mode.from_tensor(
235*523fa7a6SAndroid Build Coastguard Worker                                x, static_shapes=True
236*523fa7a6SAndroid Build Coastguard Worker                            )
237*523fa7a6SAndroid Build Coastguard Worker                        else:
238*523fa7a6SAndroid Build Coastguard Worker                            fake_tensor = self.fake_tensor_mode.from_tensor(x)
239*523fa7a6SAndroid Build Coastguard Worker                    except UnsupportedFakeTensorException:
240*523fa7a6SAndroid Build Coastguard Worker                        # TODO: This is just a workaround to get over the
241*523fa7a6SAndroid Build Coastguard Worker                        # x.as_subclass error
242*523fa7a6SAndroid Build Coastguard Worker                        print(
243*523fa7a6SAndroid Build Coastguard Worker                            "Fakeifying a Tensor subclass is not supported \
244*523fa7a6SAndroid Build Coastguard Worker                            right now. Instead a TensorMetadata is used."
245*523fa7a6SAndroid Build Coastguard Worker                        )
246*523fa7a6SAndroid Build Coastguard Worker                        fake_tensor = None
247*523fa7a6SAndroid Build Coastguard Worker                    return fake_tensor
248*523fa7a6SAndroid Build Coastguard Worker                elif isinstance(
249*523fa7a6SAndroid Build Coastguard Worker                    x,
250*523fa7a6SAndroid Build Coastguard Worker                    (
251*523fa7a6SAndroid Build Coastguard Worker                        torch.SymInt,
252*523fa7a6SAndroid Build Coastguard Worker                        torch.SymFloat,
253*523fa7a6SAndroid Build Coastguard Worker                        torch.SymBool,
254*523fa7a6SAndroid Build Coastguard Worker                        int,
255*523fa7a6SAndroid Build Coastguard Worker                        float,
256*523fa7a6SAndroid Build Coastguard Worker                        bool,
257*523fa7a6SAndroid Build Coastguard Worker                        str,
258*523fa7a6SAndroid Build Coastguard Worker                    ),
259*523fa7a6SAndroid Build Coastguard Worker                ):
260*523fa7a6SAndroid Build Coastguard Worker                    return x
261*523fa7a6SAndroid Build Coastguard Worker                else:
262*523fa7a6SAndroid Build Coastguard Worker                    return None
263*523fa7a6SAndroid Build Coastguard Worker
264*523fa7a6SAndroid Build Coastguard Worker            node.meta["val"] = pytree.tree_map(make_val, value)
265*523fa7a6SAndroid Build Coastguard Worker
266*523fa7a6SAndroid Build Coastguard Worker            # Set the tensor_metadata for values that do not have a corresponding FakeTensor
267*523fa7a6SAndroid Build Coastguard Worker            def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]:
268*523fa7a6SAndroid Build Coastguard Worker                if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor):
269*523fa7a6SAndroid Build Coastguard Worker                    if x.is_quantized:
270*523fa7a6SAndroid Build Coastguard Worker                        # TODO (tmanlaibaatar) properly support Quantized FakeTensor
271*523fa7a6SAndroid Build Coastguard Worker                        x = torch.dequantize(x)
272*523fa7a6SAndroid Build Coastguard Worker
273*523fa7a6SAndroid Build Coastguard Worker                    try:
274*523fa7a6SAndroid Build Coastguard Worker                        assert self.fake_tensor_mode is not None
275*523fa7a6SAndroid Build Coastguard Worker                        _ = self.fake_tensor_mode.from_tensor(x)
276*523fa7a6SAndroid Build Coastguard Worker                        tensor_meta = None
277*523fa7a6SAndroid Build Coastguard Worker                    except UnsupportedFakeTensorException:
278*523fa7a6SAndroid Build Coastguard Worker                        # TODO: This is just a workaround to get over the
279*523fa7a6SAndroid Build Coastguard Worker                        # x.as_subclass error
280*523fa7a6SAndroid Build Coastguard Worker                        tensor_meta = _extract_tensor_metadata(x)
281*523fa7a6SAndroid Build Coastguard Worker                    return tensor_meta
282*523fa7a6SAndroid Build Coastguard Worker                else:
283*523fa7a6SAndroid Build Coastguard Worker                    return None
284*523fa7a6SAndroid Build Coastguard Worker
285*523fa7a6SAndroid Build Coastguard Worker            node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
286*523fa7a6SAndroid Build Coastguard Worker
287*523fa7a6SAndroid Build Coastguard Worker    class ExportInterpreter(fx.Interpreter):
288*523fa7a6SAndroid Build Coastguard Worker        def __init__(self, callback: "_ExportPassBase", gm: fx.GraphModule) -> None:
289*523fa7a6SAndroid Build Coastguard Worker            super().__init__(gm)
290*523fa7a6SAndroid Build Coastguard Worker            self.callback = callback
291*523fa7a6SAndroid Build Coastguard Worker            self.node: torch.fx.Node = next(iter(gm.graph.nodes))
292*523fa7a6SAndroid Build Coastguard Worker
293*523fa7a6SAndroid Build Coastguard Worker        def placeholder(  # pyre-fixme[14]
294*523fa7a6SAndroid Build Coastguard Worker            self,
295*523fa7a6SAndroid Build Coastguard Worker            target: str,
296*523fa7a6SAndroid Build Coastguard Worker            args: Tuple[Argument, ...],
297*523fa7a6SAndroid Build Coastguard Worker            kwargs: Dict[str, Argument],
298*523fa7a6SAndroid Build Coastguard Worker        ) -> ProxyValue:
299*523fa7a6SAndroid Build Coastguard Worker            arg = super().placeholder(target, args, kwargs)
300*523fa7a6SAndroid Build Coastguard Worker            return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
301*523fa7a6SAndroid Build Coastguard Worker
302*523fa7a6SAndroid Build Coastguard Worker        def output(
303*523fa7a6SAndroid Build Coastguard Worker            self,
304*523fa7a6SAndroid Build Coastguard Worker            target: torch.fx.node.Target,
305*523fa7a6SAndroid Build Coastguard Worker            args: Tuple[Argument, ...],
306*523fa7a6SAndroid Build Coastguard Worker            kwargs: Dict[str, Argument],
307*523fa7a6SAndroid Build Coastguard Worker        ) -> ProxyValue:
308*523fa7a6SAndroid Build Coastguard Worker            return self.callback.output(args[0], NodeMetadata(self.node.meta)).data
309*523fa7a6SAndroid Build Coastguard Worker
310*523fa7a6SAndroid Build Coastguard Worker        def call_function(
311*523fa7a6SAndroid Build Coastguard Worker            self,
312*523fa7a6SAndroid Build Coastguard Worker            target: torch.fx.node.Target,
313*523fa7a6SAndroid Build Coastguard Worker            args: Tuple[Argument, ...],
314*523fa7a6SAndroid Build Coastguard Worker            kwargs: Dict[str, Argument],
315*523fa7a6SAndroid Build Coastguard Worker        ) -> ProxyValue:
316*523fa7a6SAndroid Build Coastguard Worker            meta = NodeMetadata(self.node.meta)
317*523fa7a6SAndroid Build Coastguard Worker
318*523fa7a6SAndroid Build Coastguard Worker            if target == operator.getitem:
319*523fa7a6SAndroid Build Coastguard Worker                value, key = args
320*523fa7a6SAndroid Build Coastguard Worker                return self.callback.call_getitem(value, key, meta)
321*523fa7a6SAndroid Build Coastguard Worker            elif getattr(target, "__module__", None) in {
322*523fa7a6SAndroid Build Coastguard Worker                "_operator",
323*523fa7a6SAndroid Build Coastguard Worker                "builtins",
324*523fa7a6SAndroid Build Coastguard Worker                "math",
325*523fa7a6SAndroid Build Coastguard Worker            }:
326*523fa7a6SAndroid Build Coastguard Worker                assert callable(target)
327*523fa7a6SAndroid Build Coastguard Worker                return self.callback.call_sym(target, args, meta)
328*523fa7a6SAndroid Build Coastguard Worker            elif target in _TORCH_SYM_OPS:
329*523fa7a6SAndroid Build Coastguard Worker                assert callable(target)
330*523fa7a6SAndroid Build Coastguard Worker                return self.callback.call_sym(target, args, meta)
331*523fa7a6SAndroid Build Coastguard Worker            elif isinstance(
332*523fa7a6SAndroid Build Coastguard Worker                target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)
333*523fa7a6SAndroid Build Coastguard Worker            ):
334*523fa7a6SAndroid Build Coastguard Worker                return self.callback.call_operator(
335*523fa7a6SAndroid Build Coastguard Worker                    target,
336*523fa7a6SAndroid Build Coastguard Worker                    args,
337*523fa7a6SAndroid Build Coastguard Worker                    kwargs,
338*523fa7a6SAndroid Build Coastguard Worker                    meta,
339*523fa7a6SAndroid Build Coastguard Worker                )
340*523fa7a6SAndroid Build Coastguard Worker            elif target == torch.ops.higher_order.cond:
341*523fa7a6SAndroid Build Coastguard Worker                pred, true_fn, false_fn, inputs = args
342*523fa7a6SAndroid Build Coastguard Worker                return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
343*523fa7a6SAndroid Build Coastguard Worker            elif target == torch.ops.higher_order.map_impl:
344*523fa7a6SAndroid Build Coastguard Worker                f, mapped_args, operands = args  # type: ignore[assignment]
345*523fa7a6SAndroid Build Coastguard Worker                return self.callback.call_map(f, mapped_args, operands, meta)
346*523fa7a6SAndroid Build Coastguard Worker            # For other unregistered HigherOrderOps, just interpret them blindly
347*523fa7a6SAndroid Build Coastguard Worker            elif isinstance(target, torch._ops.HigherOrderOperator):
348*523fa7a6SAndroid Build Coastguard Worker                return self.callback._fx(
349*523fa7a6SAndroid Build Coastguard Worker                    "call_function",
350*523fa7a6SAndroid Build Coastguard Worker                    target,
351*523fa7a6SAndroid Build Coastguard Worker                    args,
352*523fa7a6SAndroid Build Coastguard Worker                    kwargs,
353*523fa7a6SAndroid Build Coastguard Worker                    meta,
354*523fa7a6SAndroid Build Coastguard Worker                )
355*523fa7a6SAndroid Build Coastguard Worker            else:
356*523fa7a6SAndroid Build Coastguard Worker                raise ExportPassBaseError(f"Unsupported target type: {target}")
357*523fa7a6SAndroid Build Coastguard Worker
358*523fa7a6SAndroid Build Coastguard Worker        def get_attr(  # pyre-fixme[14]
359*523fa7a6SAndroid Build Coastguard Worker            self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
360*523fa7a6SAndroid Build Coastguard Worker        ) -> Argument:
361*523fa7a6SAndroid Build Coastguard Worker            return super().get_attr(target, args, kwargs)
362*523fa7a6SAndroid Build Coastguard Worker
363*523fa7a6SAndroid Build Coastguard Worker        def call_module(
364*523fa7a6SAndroid Build Coastguard Worker            self,
365*523fa7a6SAndroid Build Coastguard Worker            target: torch.fx.node.Target,
366*523fa7a6SAndroid Build Coastguard Worker            args: Tuple[Argument, ...],
367*523fa7a6SAndroid Build Coastguard Worker            kwargs: Dict[str, Argument],
368*523fa7a6SAndroid Build Coastguard Worker        ) -> None:
369*523fa7a6SAndroid Build Coastguard Worker            raise ExportPassBaseError("call_module is not supported.")
370*523fa7a6SAndroid Build Coastguard Worker
371*523fa7a6SAndroid Build Coastguard Worker        def call_method(  # pyre-fixme[14]
372*523fa7a6SAndroid Build Coastguard Worker            self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
373*523fa7a6SAndroid Build Coastguard Worker        ) -> None:
374*523fa7a6SAndroid Build Coastguard Worker            raise ExportPassBaseError("call_method is not supported.")
375*523fa7a6SAndroid Build Coastguard Worker
376*523fa7a6SAndroid Build Coastguard Worker        def run_node(self, n: torch.fx.Node) -> Argument:
377*523fa7a6SAndroid Build Coastguard Worker            self.node = n
378*523fa7a6SAndroid Build Coastguard Worker            self.callback.node_debug_str = n.format_node()
379*523fa7a6SAndroid Build Coastguard Worker            return super().run_node(n)
380*523fa7a6SAndroid Build Coastguard Worker
381*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
382*523fa7a6SAndroid Build Coastguard Worker        self.interpreter = torch.fx.Interpreter(
383*523fa7a6SAndroid Build Coastguard Worker            torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
384*523fa7a6SAndroid Build Coastguard Worker        )
385*523fa7a6SAndroid Build Coastguard Worker        self.tracer = self.ExportTracer(self, CodeGen())  # pyre-ignore
386*523fa7a6SAndroid Build Coastguard Worker        self.fake_tensor_mode: Optional[FakeTensorMode] = None
387*523fa7a6SAndroid Build Coastguard Worker        self._initialized = True
388*523fa7a6SAndroid Build Coastguard Worker        self.node_debug_str: Optional[str] = None
389*523fa7a6SAndroid Build Coastguard Worker
390*523fa7a6SAndroid Build Coastguard Worker    def _fx(
391*523fa7a6SAndroid Build Coastguard Worker        self,
392*523fa7a6SAndroid Build Coastguard Worker        kind: str,
393*523fa7a6SAndroid Build Coastguard Worker        target: torch.fx.node.Target,
394*523fa7a6SAndroid Build Coastguard Worker        args: Tuple[Argument, ...],
395*523fa7a6SAndroid Build Coastguard Worker        kwargs: Dict[str, Argument],
396*523fa7a6SAndroid Build Coastguard Worker        meta: NodeMetadata,
397*523fa7a6SAndroid Build Coastguard Worker    ) -> ProxyValue:
398*523fa7a6SAndroid Build Coastguard Worker        args_data, kwargs_data = pytree.tree_map_only(
399*523fa7a6SAndroid Build Coastguard Worker            ProxyValue, lambda x: x.data, (args, kwargs)
400*523fa7a6SAndroid Build Coastguard Worker        )
401*523fa7a6SAndroid Build Coastguard Worker        res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data)
402*523fa7a6SAndroid Build Coastguard Worker        args_proxy, kwargs_proxy = pytree.tree_map_only(
403*523fa7a6SAndroid Build Coastguard Worker            ProxyValue, lambda x: x.proxy, (args, kwargs)
404*523fa7a6SAndroid Build Coastguard Worker        )
405*523fa7a6SAndroid Build Coastguard Worker
406*523fa7a6SAndroid Build Coastguard Worker        name = None
407*523fa7a6SAndroid Build Coastguard Worker        if isinstance(target, torch._ops.OpOverload):
408*523fa7a6SAndroid Build Coastguard Worker            name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
409*523fa7a6SAndroid Build Coastguard Worker
410*523fa7a6SAndroid Build Coastguard Worker        res_proxy = self.tracer.create_proxy(
411*523fa7a6SAndroid Build Coastguard Worker            kind, target, args_proxy, kwargs_proxy, name=name
412*523fa7a6SAndroid Build Coastguard Worker        )
413*523fa7a6SAndroid Build Coastguard Worker        res_proxy.node.meta.update(meta.data)
414*523fa7a6SAndroid Build Coastguard Worker        self.tracer.set_metadata(res_proxy.node, res_data)
415*523fa7a6SAndroid Build Coastguard Worker        return ProxyValue(res_data, res_proxy)
416*523fa7a6SAndroid Build Coastguard Worker
417*523fa7a6SAndroid Build Coastguard Worker    def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]:
418*523fa7a6SAndroid Build Coastguard Worker        # TODO(angelayi): Update this with what we decide to do for metadata in
419*523fa7a6SAndroid Build Coastguard Worker        # the exported graph module
420*523fa7a6SAndroid Build Coastguard Worker        if (args := graph_module.meta.get("args", None)) is not None:
421*523fa7a6SAndroid Build Coastguard Worker            return list(args)
422*523fa7a6SAndroid Build Coastguard Worker
423*523fa7a6SAndroid Build Coastguard Worker        def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
424*523fa7a6SAndroid Build Coastguard Worker            if "val" in node.meta:
425*523fa7a6SAndroid Build Coastguard Worker                fake = node.meta["val"]
426*523fa7a6SAndroid Build Coastguard Worker                if hasattr(fake, "constant") and fake.constant is not None:
427*523fa7a6SAndroid Build Coastguard Worker                    return fake.constant
428*523fa7a6SAndroid Build Coastguard Worker                return fake
429*523fa7a6SAndroid Build Coastguard Worker            elif tensor_meta := node.meta.get("tensor_meta"):
430*523fa7a6SAndroid Build Coastguard Worker                assert self.fake_tensor_mode is not None
431*523fa7a6SAndroid Build Coastguard Worker                return FakeTensor(
432*523fa7a6SAndroid Build Coastguard Worker                    self.fake_tensor_mode,
433*523fa7a6SAndroid Build Coastguard Worker                    torch.empty(
434*523fa7a6SAndroid Build Coastguard Worker                        tensor_meta.shape,
435*523fa7a6SAndroid Build Coastguard Worker                        dtype=tensor_meta.dtype,
436*523fa7a6SAndroid Build Coastguard Worker                        device="meta",
437*523fa7a6SAndroid Build Coastguard Worker                        requires_grad=tensor_meta.requires_grad,
438*523fa7a6SAndroid Build Coastguard Worker                        memory_format=tensor_meta.memory_format,
439*523fa7a6SAndroid Build Coastguard Worker                    ),
440*523fa7a6SAndroid Build Coastguard Worker                    torch.device("cpu"),
441*523fa7a6SAndroid Build Coastguard Worker                )
442*523fa7a6SAndroid Build Coastguard Worker            elif len(node.users) == 0:
443*523fa7a6SAndroid Build Coastguard Worker                return None
444*523fa7a6SAndroid Build Coastguard Worker            raise ExportPassBaseError(
445*523fa7a6SAndroid Build Coastguard Worker                f"Cannot construct an input for graph module: {graph_module}.",
446*523fa7a6SAndroid Build Coastguard Worker            )
447*523fa7a6SAndroid Build Coastguard Worker
448*523fa7a6SAndroid Build Coastguard Worker        return [
449*523fa7a6SAndroid Build Coastguard Worker            extract_input(node)
450*523fa7a6SAndroid Build Coastguard Worker            for node in graph_module.graph.nodes
451*523fa7a6SAndroid Build Coastguard Worker            if node.op == "placeholder"
452*523fa7a6SAndroid Build Coastguard Worker        ]
453*523fa7a6SAndroid Build Coastguard Worker
454*523fa7a6SAndroid Build Coastguard Worker    def on_attr(self, attr: ProxyValue) -> None:
455*523fa7a6SAndroid Build Coastguard Worker        pass
456*523fa7a6SAndroid Build Coastguard Worker
457*523fa7a6SAndroid Build Coastguard Worker    def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue:
458*523fa7a6SAndroid Build Coastguard Worker        arg_proxy = self.tracer.create_proxy("placeholder", name, (), {})
459*523fa7a6SAndroid Build Coastguard Worker        arg_proxy.node.meta = meta.data
460*523fa7a6SAndroid Build Coastguard Worker        arg_proxy.node.meta["val"] = arg
461*523fa7a6SAndroid Build Coastguard Worker        return ProxyValue(arg, arg_proxy)
462*523fa7a6SAndroid Build Coastguard Worker
463*523fa7a6SAndroid Build Coastguard Worker    def call_operator(
464*523fa7a6SAndroid Build Coastguard Worker        self,
465*523fa7a6SAndroid Build Coastguard Worker        op,  # pyre-ignore
466*523fa7a6SAndroid Build Coastguard Worker        args: Tuple[Argument, ...],
467*523fa7a6SAndroid Build Coastguard Worker        kwargs: Dict[str, Argument],
468*523fa7a6SAndroid Build Coastguard Worker        meta: NodeMetadata,
469*523fa7a6SAndroid Build Coastguard Worker    ) -> ProxyValue:
470*523fa7a6SAndroid Build Coastguard Worker        return self._fx("call_function", op, args, kwargs, meta)
471*523fa7a6SAndroid Build Coastguard Worker
472*523fa7a6SAndroid Build Coastguard Worker    def call_sym(
473*523fa7a6SAndroid Build Coastguard Worker        self,
474*523fa7a6SAndroid Build Coastguard Worker        target: Fn,
475*523fa7a6SAndroid Build Coastguard Worker        args: Tuple[Argument, ...],
476*523fa7a6SAndroid Build Coastguard Worker        meta: NodeMetadata,
477*523fa7a6SAndroid Build Coastguard Worker    ) -> ProxyValue:
478*523fa7a6SAndroid Build Coastguard Worker        return self._fx("call_function", target, args, {}, meta)
479*523fa7a6SAndroid Build Coastguard Worker
480*523fa7a6SAndroid Build Coastguard Worker    def call_cond(
481*523fa7a6SAndroid Build Coastguard Worker        self,
482*523fa7a6SAndroid Build Coastguard Worker        pred: ProxyValue,
483*523fa7a6SAndroid Build Coastguard Worker        true_fn: torch.fx.GraphModule,
484*523fa7a6SAndroid Build Coastguard Worker        false_fn: torch.fx.GraphModule,
485*523fa7a6SAndroid Build Coastguard Worker        inputs: List[Argument],
486*523fa7a6SAndroid Build Coastguard Worker        meta: NodeMetadata,
487*523fa7a6SAndroid Build Coastguard Worker    ) -> ProxyValue:
488*523fa7a6SAndroid Build Coastguard Worker        true_branch = self.call_submodule(true_fn, tuple(inputs))
489*523fa7a6SAndroid Build Coastguard Worker        false_branch = self.call_submodule(false_fn, tuple(inputs))
490*523fa7a6SAndroid Build Coastguard Worker        assert true_branch is not None
491*523fa7a6SAndroid Build Coastguard Worker        assert false_branch is not None
492*523fa7a6SAndroid Build Coastguard Worker        return self._fx(
493*523fa7a6SAndroid Build Coastguard Worker            "call_function",
494*523fa7a6SAndroid Build Coastguard Worker            torch.ops.higher_order.cond,
495*523fa7a6SAndroid Build Coastguard Worker            (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)),
496*523fa7a6SAndroid Build Coastguard Worker            {},
497*523fa7a6SAndroid Build Coastguard Worker            meta,
498*523fa7a6SAndroid Build Coastguard Worker        )
499*523fa7a6SAndroid Build Coastguard Worker
500*523fa7a6SAndroid Build Coastguard Worker    def call_map(
501*523fa7a6SAndroid Build Coastguard Worker        self,
502*523fa7a6SAndroid Build Coastguard Worker        f: torch.fx.GraphModule,
503*523fa7a6SAndroid Build Coastguard Worker        mapped_args: List[ProxyValue],
504*523fa7a6SAndroid Build Coastguard Worker        operands: List[ProxyValue],
505*523fa7a6SAndroid Build Coastguard Worker        meta: NodeMetadata,
506*523fa7a6SAndroid Build Coastguard Worker    ) -> ProxyValue:
507*523fa7a6SAndroid Build Coastguard Worker        xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
508*523fa7a6SAndroid Build Coastguard Worker        f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
509*523fa7a6SAndroid Build Coastguard Worker        assert f_branch is not None
510*523fa7a6SAndroid Build Coastguard Worker        return self._fx(
511*523fa7a6SAndroid Build Coastguard Worker            "call_function",
512*523fa7a6SAndroid Build Coastguard Worker            torch.ops.higher_order.map_impl,
513*523fa7a6SAndroid Build Coastguard Worker            (f_branch.graph_module, mapped_args, operands),
514*523fa7a6SAndroid Build Coastguard Worker            {},
515*523fa7a6SAndroid Build Coastguard Worker            meta,
516*523fa7a6SAndroid Build Coastguard Worker        )
517*523fa7a6SAndroid Build Coastguard Worker
518*523fa7a6SAndroid Build Coastguard Worker    def call_getitem(
519*523fa7a6SAndroid Build Coastguard Worker        self, value: ProxyValue, key: int, meta: NodeMetadata
520*523fa7a6SAndroid Build Coastguard Worker    ) -> ProxyValue:
521*523fa7a6SAndroid Build Coastguard Worker        return self._fx("call_function", operator.getitem, (value, key), {}, meta)
522*523fa7a6SAndroid Build Coastguard Worker
523*523fa7a6SAndroid Build Coastguard Worker    def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
524*523fa7a6SAndroid Build Coastguard Worker        return self._fx("output", "output", (results,), {}, meta)
525*523fa7a6SAndroid Build Coastguard Worker
526*523fa7a6SAndroid Build Coastguard Worker    def call_submodule(
527*523fa7a6SAndroid Build Coastguard Worker        self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
528*523fa7a6SAndroid Build Coastguard Worker    ) -> PassResult:
529*523fa7a6SAndroid Build Coastguard Worker        prev_tracer, self.tracer = self.tracer, self.ExportTracer(
530*523fa7a6SAndroid Build Coastguard Worker            self, graph_module.graph._codegen
531*523fa7a6SAndroid Build Coastguard Worker        )
532*523fa7a6SAndroid Build Coastguard Worker        self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
533*523fa7a6SAndroid Build Coastguard Worker        interpreter = self.ExportInterpreter(self, graph_module)
534*523fa7a6SAndroid Build Coastguard Worker        prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter(
535*523fa7a6SAndroid Build Coastguard Worker            torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
536*523fa7a6SAndroid Build Coastguard Worker        )
537*523fa7a6SAndroid Build Coastguard Worker        inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
538*523fa7a6SAndroid Build Coastguard Worker        with fx_traceback.preserve_node_meta():
539*523fa7a6SAndroid Build Coastguard Worker            interpreter.run(*inputs_data)
540*523fa7a6SAndroid Build Coastguard Worker
541*523fa7a6SAndroid Build Coastguard Worker        new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
542*523fa7a6SAndroid Build Coastguard Worker
543*523fa7a6SAndroid Build Coastguard Worker        self.tracer = prev_tracer
544*523fa7a6SAndroid Build Coastguard Worker        self.interpreter = prev_interpreter
545*523fa7a6SAndroid Build Coastguard Worker        return PassResult(
546*523fa7a6SAndroid Build Coastguard Worker            new_graph_module,
547*523fa7a6SAndroid Build Coastguard Worker            True,
548*523fa7a6SAndroid Build Coastguard Worker        )
549*523fa7a6SAndroid Build Coastguard Worker
550*523fa7a6SAndroid Build Coastguard Worker    def call(self, graph_module: fx.GraphModule) -> PassResult:
551*523fa7a6SAndroid Build Coastguard Worker        if not getattr(self, "_initialized", False):
552*523fa7a6SAndroid Build Coastguard Worker            raise ExportPassBaseError(
553*523fa7a6SAndroid Build Coastguard Worker                "ExportPass is not initialized with __init__().",
554*523fa7a6SAndroid Build Coastguard Worker            )
555*523fa7a6SAndroid Build Coastguard Worker
556*523fa7a6SAndroid Build Coastguard Worker        inputs = self.inputs(graph_module)
557*523fa7a6SAndroid Build Coastguard Worker
558*523fa7a6SAndroid Build Coastguard Worker        fake_tensor_mode = None
559*523fa7a6SAndroid Build Coastguard Worker        for i in inputs:
560*523fa7a6SAndroid Build Coastguard Worker            if isinstance(i, FakeTensor):
561*523fa7a6SAndroid Build Coastguard Worker                assert (
562*523fa7a6SAndroid Build Coastguard Worker                    fake_tensor_mode is None or fake_tensor_mode is i.fake_mode
563*523fa7a6SAndroid Build Coastguard Worker                ), "Multiple fake tensor mode detected."
564*523fa7a6SAndroid Build Coastguard Worker                fake_tensor_mode = i.fake_mode
565*523fa7a6SAndroid Build Coastguard Worker        if fake_tensor_mode is None:
566*523fa7a6SAndroid Build Coastguard Worker            self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
567*523fa7a6SAndroid Build Coastguard Worker            fake_tensor_mode = nullcontext()  # type: ignore[assignment]
568*523fa7a6SAndroid Build Coastguard Worker            dispatcher_mode = nullcontext()  # type: ignore[assignment]
569*523fa7a6SAndroid Build Coastguard Worker        else:
570*523fa7a6SAndroid Build Coastguard Worker            fake_tensor_mode.allow_non_fake_inputs = True
571*523fa7a6SAndroid Build Coastguard Worker            self.tracer.fake_tensor_mode = fake_tensor_mode
572*523fa7a6SAndroid Build Coastguard Worker            dispatcher_mode = enable_python_dispatcher()  # type: ignore[assignment]
573*523fa7a6SAndroid Build Coastguard Worker        self.fake_tensor_mode = self.tracer.fake_tensor_mode
574*523fa7a6SAndroid Build Coastguard Worker
575*523fa7a6SAndroid Build Coastguard Worker        with fake_tensor_mode, dispatcher_mode:  # type: ignore[assignment, union-attr]
576*523fa7a6SAndroid Build Coastguard Worker            result = self.call_submodule(graph_module, tuple(inputs))
577*523fa7a6SAndroid Build Coastguard Worker
578*523fa7a6SAndroid Build Coastguard Worker        return result
579*523fa7a6SAndroid Build Coastguard Worker
580*523fa7a6SAndroid Build Coastguard Worker
581*523fa7a6SAndroid Build Coastguard Workerclass ExportPass(_ExportPassBase):
582*523fa7a6SAndroid Build Coastguard Worker    class ExportTracer(_ExportPassBase.ExportTracer):
583*523fa7a6SAndroid Build Coastguard Worker        def create_arg(self, a: Argument) -> torch.fx.Node:
584*523fa7a6SAndroid Build Coastguard Worker            if isinstance(a, torch.nn.Module):
585*523fa7a6SAndroid Build Coastguard Worker                if a not in self.submodules:
586*523fa7a6SAndroid Build Coastguard Worker                    prefix = "lowered_module" if is_lowered_module(a) else "submodule"
587*523fa7a6SAndroid Build Coastguard Worker                    name_submodule = f"{prefix}_{len(self.submodules)}"
588*523fa7a6SAndroid Build Coastguard Worker                    self.root.add_module(name_submodule, a)
589*523fa7a6SAndroid Build Coastguard Worker                    self.submodules[a] = name_submodule
590*523fa7a6SAndroid Build Coastguard Worker            return super().create_arg(a)
591*523fa7a6SAndroid Build Coastguard Worker
592*523fa7a6SAndroid Build Coastguard Worker    class ExportInterpreter(_ExportPassBase.ExportInterpreter):
593*523fa7a6SAndroid Build Coastguard Worker        """
594*523fa7a6SAndroid Build Coastguard Worker        Interpreter to callback on any ExportPassBase functions
595*523fa7a6SAndroid Build Coastguard Worker        """
596*523fa7a6SAndroid Build Coastguard Worker
597*523fa7a6SAndroid Build Coastguard Worker        def __init__(self, callback: "ExportPass", gm: fx.GraphModule) -> None:
598*523fa7a6SAndroid Build Coastguard Worker            super().__init__(callback, gm)
599*523fa7a6SAndroid Build Coastguard Worker
600*523fa7a6SAndroid Build Coastguard Worker        def call_function(
601*523fa7a6SAndroid Build Coastguard Worker            self,
602*523fa7a6SAndroid Build Coastguard Worker            target: torch.fx.node.Target,
603*523fa7a6SAndroid Build Coastguard Worker            args: Tuple[Argument, ...],
604*523fa7a6SAndroid Build Coastguard Worker            kwargs: Dict[str, Argument],
605*523fa7a6SAndroid Build Coastguard Worker        ) -> ProxyValue:
606*523fa7a6SAndroid Build Coastguard Worker            meta = NodeMetadata(self.node.meta)
607*523fa7a6SAndroid Build Coastguard Worker            if target == operator.getitem:
608*523fa7a6SAndroid Build Coastguard Worker                value, key = args
609*523fa7a6SAndroid Build Coastguard Worker                return self.callback.call_getitem(value, key, meta)
610*523fa7a6SAndroid Build Coastguard Worker            elif isinstance(target, EdgeOpOverload):
611*523fa7a6SAndroid Build Coastguard Worker                return self.callback.call_operator(
612*523fa7a6SAndroid Build Coastguard Worker                    target,
613*523fa7a6SAndroid Build Coastguard Worker                    args,
614*523fa7a6SAndroid Build Coastguard Worker                    kwargs,
615*523fa7a6SAndroid Build Coastguard Worker                    meta,
616*523fa7a6SAndroid Build Coastguard Worker                )
617*523fa7a6SAndroid Build Coastguard Worker
618*523fa7a6SAndroid Build Coastguard Worker            # TODO according to zhengxu ExportPassBase should not be aware of
619*523fa7a6SAndroid Build Coastguard Worker            # memory.alloc. Check this comment:
620*523fa7a6SAndroid Build Coastguard Worker            # https://www.internalfb.com/diff/D42758019?dst_version_fbid=5906016402813292&transaction_fbid=1104713900200176
621*523fa7a6SAndroid Build Coastguard Worker            elif target == memory.alloc:
622*523fa7a6SAndroid Build Coastguard Worker                return self.callback._fx(
623*523fa7a6SAndroid Build Coastguard Worker                    "call_function",
624*523fa7a6SAndroid Build Coastguard Worker                    target,
625*523fa7a6SAndroid Build Coastguard Worker                    args,
626*523fa7a6SAndroid Build Coastguard Worker                    kwargs,
627*523fa7a6SAndroid Build Coastguard Worker                    meta,
628*523fa7a6SAndroid Build Coastguard Worker                )
629*523fa7a6SAndroid Build Coastguard Worker
630*523fa7a6SAndroid Build Coastguard Worker            elif target == executorch_call_delegate:
631*523fa7a6SAndroid Build Coastguard Worker                lowered_module = args[0]
632*523fa7a6SAndroid Build Coastguard Worker                args = args[1:]
633*523fa7a6SAndroid Build Coastguard Worker                return self.callback.call_delegate(  # pyre-ignore
634*523fa7a6SAndroid Build Coastguard Worker                    lowered_module,
635*523fa7a6SAndroid Build Coastguard Worker                    args,
636*523fa7a6SAndroid Build Coastguard Worker                    kwargs,
637*523fa7a6SAndroid Build Coastguard Worker                    NodeMetadata(self.node.meta),
638*523fa7a6SAndroid Build Coastguard Worker                )
639*523fa7a6SAndroid Build Coastguard Worker
640*523fa7a6SAndroid Build Coastguard Worker            return super().call_function(target, args, kwargs)
641*523fa7a6SAndroid Build Coastguard Worker
642*523fa7a6SAndroid Build Coastguard Worker    def call_delegate(
643*523fa7a6SAndroid Build Coastguard Worker        self,
644*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
645*523fa7a6SAndroid Build Coastguard Worker        lowered_module: "LoweredBackendModule",  # noqa
646*523fa7a6SAndroid Build Coastguard Worker        args: Tuple[ProxyValue, ...],
647*523fa7a6SAndroid Build Coastguard Worker        kwargs: Dict[str, Argument],
648*523fa7a6SAndroid Build Coastguard Worker        meta: NodeMetadata,
649*523fa7a6SAndroid Build Coastguard Worker    ) -> ProxyValue:
650*523fa7a6SAndroid Build Coastguard Worker        args = (lowered_module,) + args
651*523fa7a6SAndroid Build Coastguard Worker        return self._fx(
652*523fa7a6SAndroid Build Coastguard Worker            "call_function",
653*523fa7a6SAndroid Build Coastguard Worker            executorch_call_delegate,
654*523fa7a6SAndroid Build Coastguard Worker            args,
655*523fa7a6SAndroid Build Coastguard Worker            kwargs,
656*523fa7a6SAndroid Build Coastguard Worker            meta,
657*523fa7a6SAndroid Build Coastguard Worker        )
658*523fa7a6SAndroid Build Coastguard Worker
659*523fa7a6SAndroid Build Coastguard Worker    def call_submodule(
660*523fa7a6SAndroid Build Coastguard Worker        self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
661*523fa7a6SAndroid Build Coastguard Worker    ) -> PassResult:
662*523fa7a6SAndroid Build Coastguard Worker        res = super().call_submodule(graph_module, inputs)
663*523fa7a6SAndroid Build Coastguard Worker
664*523fa7a6SAndroid Build Coastguard Worker        def preserve_original_ph_meta_val(
665*523fa7a6SAndroid Build Coastguard Worker            gm: torch.fx.GraphModule, new_gm: torch.fx.GraphModule
666*523fa7a6SAndroid Build Coastguard Worker        ) -> None:
667*523fa7a6SAndroid Build Coastguard Worker            def get_phs(gm: torch.fx.GraphModule) -> List[torch.fx.Node]:
668*523fa7a6SAndroid Build Coastguard Worker                return [node for node in gm.graph.nodes if node.op == "placeholder"]
669*523fa7a6SAndroid Build Coastguard Worker
670*523fa7a6SAndroid Build Coastguard Worker            def migrate_meta_val(
671*523fa7a6SAndroid Build Coastguard Worker                orig_phs: List[torch.fx.Node], new_phs: List[torch.fx.Node]
672*523fa7a6SAndroid Build Coastguard Worker            ) -> None:
673*523fa7a6SAndroid Build Coastguard Worker                if len(orig_phs) != len(new_phs):
674*523fa7a6SAndroid Build Coastguard Worker                    raise ExportError(
675*523fa7a6SAndroid Build Coastguard Worker                        ExportErrorType.NOT_SUPPORTED,
676*523fa7a6SAndroid Build Coastguard Worker                        "ExportPassBase doesn't support changing the placeholders",
677*523fa7a6SAndroid Build Coastguard Worker                    )
678*523fa7a6SAndroid Build Coastguard Worker                for ph, new_ph in zip(orig_phs, new_phs):
679*523fa7a6SAndroid Build Coastguard Worker                    if isinstance(new_ph.meta["val"], torch.Tensor):
680*523fa7a6SAndroid Build Coastguard Worker                        if (
681*523fa7a6SAndroid Build Coastguard Worker                            not isinstance(ph.meta["val"], torch.Tensor)
682*523fa7a6SAndroid Build Coastguard Worker                            or new_ph.meta["val"].size() != ph.meta["val"].size()
683*523fa7a6SAndroid Build Coastguard Worker                        ):
684*523fa7a6SAndroid Build Coastguard Worker                            raise ExportError(
685*523fa7a6SAndroid Build Coastguard Worker                                ExportErrorType.NOT_SUPPORTED,
686*523fa7a6SAndroid Build Coastguard Worker                                "ExportPassBase doesn't support changing the placeholders",
687*523fa7a6SAndroid Build Coastguard Worker                            )
688*523fa7a6SAndroid Build Coastguard Worker                    new_ph.meta["val"] = ph.meta["val"]
689*523fa7a6SAndroid Build Coastguard Worker
690*523fa7a6SAndroid Build Coastguard Worker            migrate_meta_val(get_phs(gm), get_phs(new_gm))
691*523fa7a6SAndroid Build Coastguard Worker
692*523fa7a6SAndroid Build Coastguard Worker        # After one pass, new_graph_module's placeholders will always hold fake tensors in
693*523fa7a6SAndroid Build Coastguard Worker        # meta['val'] but sometimes we want to preserve the original meta['val'] of placeholders
694*523fa7a6SAndroid Build Coastguard Worker        #
695*523fa7a6SAndroid Build Coastguard Worker        # For example, custom flows and certain passes assume no fake_tensor_mode is activated
696*523fa7a6SAndroid Build Coastguard Worker        # and it doesn't quite work with fake_tensor_mode. but we don't bother to fix them.
697*523fa7a6SAndroid Build Coastguard Worker        # So we'll just reset the meta of placeholders to its original value. It's safe because that
698*523fa7a6SAndroid Build Coastguard Worker        # 1. For models captured with pt2_mode, the meta['val'] of placeholders are fake_tensors already, so
699*523fa7a6SAndroid Build Coastguard Worker        # preserving it to the new graph module won't hurt.
700*523fa7a6SAndroid Build Coastguard Worker        # 2. For models captured with dispatch_trace, the meta['val'] field
701*523fa7a6SAndroid Build Coastguard Worker        # Note that it's only safe when passes don't modify the inputs.
702*523fa7a6SAndroid Build Coastguard Worker        preserve_original_ph_meta_val(graph_module, res.graph_module)
703*523fa7a6SAndroid Build Coastguard Worker
704*523fa7a6SAndroid Build Coastguard Worker        return res
705*523fa7a6SAndroid Build Coastguard Worker
706*523fa7a6SAndroid Build Coastguard Worker
707*523fa7a6SAndroid Build Coastguard Worker@runtime_checkable
708*523fa7a6SAndroid Build Coastguard Workerclass ArgSchema(Protocol):
709*523fa7a6SAndroid Build Coastguard Worker    name: str
710*523fa7a6SAndroid Build Coastguard Worker    kwarg_only: bool
711*523fa7a6SAndroid Build Coastguard Worker    type: Any  # pyre-ignore
712*523fa7a6SAndroid Build Coastguard Worker
713*523fa7a6SAndroid Build Coastguard Worker
714*523fa7a6SAndroid Build Coastguard Workerdef map_args(
715*523fa7a6SAndroid Build Coastguard Worker    op: torch._ops.OpOverload,
716*523fa7a6SAndroid Build Coastguard Worker    fn: Fn,
717*523fa7a6SAndroid Build Coastguard Worker    args: Argument,
718*523fa7a6SAndroid Build Coastguard Worker    kwargs: Dict[str, Argument],
719*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[Argument, Dict[str, Argument]]:
720*523fa7a6SAndroid Build Coastguard Worker    assert isinstance(args, tuple)
721*523fa7a6SAndroid Build Coastguard Worker    assert isinstance(kwargs, dict)
722*523fa7a6SAndroid Build Coastguard Worker    args = list(args)
723*523fa7a6SAndroid Build Coastguard Worker    kwargs = kwargs.copy()
724*523fa7a6SAndroid Build Coastguard Worker
725*523fa7a6SAndroid Build Coastguard Worker    def update(key: K, args: MutableMapping[K, PyTree], schema: ArgSchema) -> None:
726*523fa7a6SAndroid Build Coastguard Worker        args[key] = fn(args[key], schema)
727*523fa7a6SAndroid Build Coastguard Worker
728*523fa7a6SAndroid Build Coastguard Worker    for i, schema in enumerate(op._schema.arguments):
729*523fa7a6SAndroid Build Coastguard Worker        if schema.name in kwargs:
730*523fa7a6SAndroid Build Coastguard Worker            update(schema.name, kwargs, schema)
731*523fa7a6SAndroid Build Coastguard Worker        elif not schema.kwarg_only and i < len(args):
732*523fa7a6SAndroid Build Coastguard Worker            update(i, args, schema)  # pyre-ignore
733*523fa7a6SAndroid Build Coastguard Worker
734*523fa7a6SAndroid Build Coastguard Worker    return tuple(args), kwargs
735