xref: /aosp_15_r20/external/executorch/exir/capture/_capture.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport copy
8*523fa7a6SAndroid Build Coastguard Workerimport warnings
9*523fa7a6SAndroid Build Coastguard Workerfrom collections import namedtuple
10*523fa7a6SAndroid Build Coastguard Workerfrom contextlib import contextmanager
11*523fa7a6SAndroid Build Coastguard Workerfrom types import MethodType
12*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Callable, cast, List, Optional, Tuple
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport torch
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.capture._config import CaptureConfig
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType, InternalError
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.program import ExirExportedProgram
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.program._program import _transform, HackedUpExportedProgramDONOTUSE
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tracer import (
20*523fa7a6SAndroid Build Coastguard Worker    _default_decomposition_table,
21*523fa7a6SAndroid Build Coastguard Worker    dispatch_trace,
22*523fa7a6SAndroid Build Coastguard Worker    dynamo_trace,
23*523fa7a6SAndroid Build Coastguard Worker    flatten_output,
24*523fa7a6SAndroid Build Coastguard Worker    Value,
25*523fa7a6SAndroid Build Coastguard Worker)
26*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.verification.verifier import EXIRATenDialectVerifierBase
27*523fa7a6SAndroid Build Coastguard Workerfrom torch import _guards
28*523fa7a6SAndroid Build Coastguard Workerfrom torch._dispatch.python import enable_python_dispatcher
29*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
30*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
31*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export
32*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import (
33*523fa7a6SAndroid Build Coastguard Worker    ExportedProgram,
34*523fa7a6SAndroid Build Coastguard Worker    ExportGraphSignature,
35*523fa7a6SAndroid Build Coastguard Worker    InputKind,
36*523fa7a6SAndroid Build Coastguard Worker    InputSpec,
37*523fa7a6SAndroid Build Coastguard Worker    ModuleCallEntry,
38*523fa7a6SAndroid Build Coastguard Worker    ModuleCallSignature,
39*523fa7a6SAndroid Build Coastguard Worker    OutputKind,
40*523fa7a6SAndroid Build Coastguard Worker    OutputSpec,
41*523fa7a6SAndroid Build Coastguard Worker    TensorArgument,
42*523fa7a6SAndroid Build Coastguard Worker)
43*523fa7a6SAndroid Build Coastguard Workerfrom torch.func import functionalize
44*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx._compatibility import compatibility
45*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.experimental.proxy_tensor import make_fx
46*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.experimental.symbolic_shapes import ShapeEnv
47*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Worker
50*523fa7a6SAndroid Build Coastguard WorkerVal = Any
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard WorkerCompileSpec = namedtuple(
54*523fa7a6SAndroid Build Coastguard Worker    "CompileSpec", ["method_name", "callable", "args", "dynamic_shapes"]
55*523fa7a6SAndroid Build Coastguard Worker)
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard WorkerCallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"])
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
62*523fa7a6SAndroid Build Coastguard Workerdef _capture_legacy_do_not_use(f, args) -> ExirExportedProgram:
63*523fa7a6SAndroid Build Coastguard Worker    """
64*523fa7a6SAndroid Build Coastguard Worker    This is a legacy API that should be avoided. Prefer to use capture() instead.
65*523fa7a6SAndroid Build Coastguard Worker    """
66*523fa7a6SAndroid Build Coastguard Worker    warnings.warn(
67*523fa7a6SAndroid Build Coastguard Worker        "This function is now deprecated, please use `torch.export and exir.to_edge` instead. "
68*523fa7a6SAndroid Build Coastguard Worker        "See https://github.com/pytorch/functorch for more details.",
69*523fa7a6SAndroid Build Coastguard Worker        DeprecationWarning,
70*523fa7a6SAndroid Build Coastguard Worker        stacklevel=1,
71*523fa7a6SAndroid Build Coastguard Worker    )
72*523fa7a6SAndroid Build Coastguard Worker
73*523fa7a6SAndroid Build Coastguard Worker    graph_module = dispatch_trace(f, args)
74*523fa7a6SAndroid Build Coastguard Worker    flat_args = tuple(pytree.tree_flatten(args)[0])
75*523fa7a6SAndroid Build Coastguard Worker    in_spec, out_spec = graph_module.in_spec, graph_module.out_spec
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Worker    _instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args)
78*523fa7a6SAndroid Build Coastguard Worker    graph_module._apply(torch.Tensor.contiguous)
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker    user_inputs = [
81*523fa7a6SAndroid Build Coastguard Worker        node.name for node in graph_module.graph.nodes if node.op == "placeholder"
82*523fa7a6SAndroid Build Coastguard Worker    ]
83*523fa7a6SAndroid Build Coastguard Worker    output_node = list(graph_module.graph.nodes)[-1]
84*523fa7a6SAndroid Build Coastguard Worker    assert output_node.op == "output"
85*523fa7a6SAndroid Build Coastguard Worker    user_outputs = [arg.name for arg in output_node.args[0]]
86*523fa7a6SAndroid Build Coastguard Worker
87*523fa7a6SAndroid Build Coastguard Worker    for n in graph_module.graph.nodes:
88*523fa7a6SAndroid Build Coastguard Worker        if n.op == "call_function" and "val" not in n.meta:
89*523fa7a6SAndroid Build Coastguard Worker            try:
90*523fa7a6SAndroid Build Coastguard Worker                args, kwargs = pytree.tree_map_only(
91*523fa7a6SAndroid Build Coastguard Worker                    torch.fx.Node, lambda x: x.meta["val"], (n.args, n.kwargs)
92*523fa7a6SAndroid Build Coastguard Worker                )
93*523fa7a6SAndroid Build Coastguard Worker                n.meta["val"] = n.target(*args, **kwargs)
94*523fa7a6SAndroid Build Coastguard Worker            except Exception:
95*523fa7a6SAndroid Build Coastguard Worker                n.meta["val"] = None
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Worker    ep = HackedUpExportedProgramDONOTUSE(
98*523fa7a6SAndroid Build Coastguard Worker        root=graph_module,
99*523fa7a6SAndroid Build Coastguard Worker        graph=graph_module.graph,
100*523fa7a6SAndroid Build Coastguard Worker        graph_signature=ExportGraphSignature(
101*523fa7a6SAndroid Build Coastguard Worker            input_specs=[
102*523fa7a6SAndroid Build Coastguard Worker                InputSpec(
103*523fa7a6SAndroid Build Coastguard Worker                    kind=InputKind.USER_INPUT, arg=TensorArgument(name=i), target=None
104*523fa7a6SAndroid Build Coastguard Worker                )
105*523fa7a6SAndroid Build Coastguard Worker                for i in user_inputs
106*523fa7a6SAndroid Build Coastguard Worker            ],
107*523fa7a6SAndroid Build Coastguard Worker            output_specs=[
108*523fa7a6SAndroid Build Coastguard Worker                OutputSpec(
109*523fa7a6SAndroid Build Coastguard Worker                    kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=o), target=None
110*523fa7a6SAndroid Build Coastguard Worker                )
111*523fa7a6SAndroid Build Coastguard Worker                for o in user_outputs
112*523fa7a6SAndroid Build Coastguard Worker            ],
113*523fa7a6SAndroid Build Coastguard Worker        ),
114*523fa7a6SAndroid Build Coastguard Worker        call_spec=CallSpec(in_spec, out_spec),
115*523fa7a6SAndroid Build Coastguard Worker        state_dict={},
116*523fa7a6SAndroid Build Coastguard Worker        range_constraints={},
117*523fa7a6SAndroid Build Coastguard Worker        module_call_graph=[
118*523fa7a6SAndroid Build Coastguard Worker            ModuleCallEntry(
119*523fa7a6SAndroid Build Coastguard Worker                fqn="",
120*523fa7a6SAndroid Build Coastguard Worker                signature=ModuleCallSignature(
121*523fa7a6SAndroid Build Coastguard Worker                    inputs=[],
122*523fa7a6SAndroid Build Coastguard Worker                    outputs=[],
123*523fa7a6SAndroid Build Coastguard Worker                    # pyre-fixme[6]: For 3rd argument expected `TreeSpec` but got
124*523fa7a6SAndroid Build Coastguard Worker                    #  `Union[Tensor, Module]`.
125*523fa7a6SAndroid Build Coastguard Worker                    in_spec=in_spec,
126*523fa7a6SAndroid Build Coastguard Worker                    # pyre-fixme[6]: For 4th argument expected `TreeSpec` but got
127*523fa7a6SAndroid Build Coastguard Worker                    #  `Union[Tensor, Module]`.
128*523fa7a6SAndroid Build Coastguard Worker                    out_spec=out_spec,
129*523fa7a6SAndroid Build Coastguard Worker                ),
130*523fa7a6SAndroid Build Coastguard Worker            )
131*523fa7a6SAndroid Build Coastguard Worker        ],
132*523fa7a6SAndroid Build Coastguard Worker        example_inputs=None,
133*523fa7a6SAndroid Build Coastguard Worker        verifier=EXIRATenDialectVerifierBase,
134*523fa7a6SAndroid Build Coastguard Worker    )
135*523fa7a6SAndroid Build Coastguard Worker    return ExirExportedProgram(ep, False)
136*523fa7a6SAndroid Build Coastguard Worker
137*523fa7a6SAndroid Build Coastguard Worker
138*523fa7a6SAndroid Build Coastguard Worker@contextmanager
139*523fa7a6SAndroid Build Coastguard Workerdef patch_forward(obj: torch.nn.Module, new_method):
140*523fa7a6SAndroid Build Coastguard Worker    """Helper method to make it easier to cleanly torch.export() a method on a
141*523fa7a6SAndroid Build Coastguard Worker    module that is not `forward`.
142*523fa7a6SAndroid Build Coastguard Worker
143*523fa7a6SAndroid Build Coastguard Worker    TODO(suo): upstream this to torch.export.wrapper.
144*523fa7a6SAndroid Build Coastguard Worker    """
145*523fa7a6SAndroid Build Coastguard Worker    # Save the original method
146*523fa7a6SAndroid Build Coastguard Worker    original_method = obj.forward
147*523fa7a6SAndroid Build Coastguard Worker
148*523fa7a6SAndroid Build Coastguard Worker    # Patch the method
149*523fa7a6SAndroid Build Coastguard Worker    obj.forward = new_method.__get__(obj, obj.__class__)
150*523fa7a6SAndroid Build Coastguard Worker
151*523fa7a6SAndroid Build Coastguard Worker    try:
152*523fa7a6SAndroid Build Coastguard Worker        yield
153*523fa7a6SAndroid Build Coastguard Worker    finally:
154*523fa7a6SAndroid Build Coastguard Worker        # Restore the original method
155*523fa7a6SAndroid Build Coastguard Worker        obj.forward = original_method
156*523fa7a6SAndroid Build Coastguard Worker
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Workerclass WrapperModule(torch.nn.Module):
159*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, f):
160*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
161*523fa7a6SAndroid Build Coastguard Worker        self.forward = f
162*523fa7a6SAndroid Build Coastguard Worker
163*523fa7a6SAndroid Build Coastguard Worker
164*523fa7a6SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
165*523fa7a6SAndroid Build Coastguard Workerdef capture(  # noqa: C901
166*523fa7a6SAndroid Build Coastguard Worker    f: Callable[..., Any],
167*523fa7a6SAndroid Build Coastguard Worker    args: Tuple[Value, ...],
168*523fa7a6SAndroid Build Coastguard Worker    config: Optional[CaptureConfig] = None,
169*523fa7a6SAndroid Build Coastguard Worker    dynamic_shapes: Optional[List[Any]] = None,
170*523fa7a6SAndroid Build Coastguard Worker) -> ExirExportedProgram:
171*523fa7a6SAndroid Build Coastguard Worker    warnings.warn(
172*523fa7a6SAndroid Build Coastguard Worker        "This function is now deprecated, please use `torch.export and exir.to_edge` instead. ",
173*523fa7a6SAndroid Build Coastguard Worker        DeprecationWarning,
174*523fa7a6SAndroid Build Coastguard Worker        stacklevel=1,
175*523fa7a6SAndroid Build Coastguard Worker    )
176*523fa7a6SAndroid Build Coastguard Worker    if not isinstance(args, tuple):
177*523fa7a6SAndroid Build Coastguard Worker        raise ExportError(
178*523fa7a6SAndroid Build Coastguard Worker            ExportErrorType.INVALID_INPUT_TYPE,
179*523fa7a6SAndroid Build Coastguard Worker            f"Expect `args` to be a tuple, got type: {type(args)}.",
180*523fa7a6SAndroid Build Coastguard Worker        )
181*523fa7a6SAndroid Build Coastguard Worker
182*523fa7a6SAndroid Build Coastguard Worker    config = config or CaptureConfig()
183*523fa7a6SAndroid Build Coastguard Worker    out_spec = None
184*523fa7a6SAndroid Build Coastguard Worker    # TODO (zhxchen17) Always functionalize in a second pass no matter which path is taken.
185*523fa7a6SAndroid Build Coastguard Worker    flat_args = tuple(pytree.tree_flatten(args)[0])
186*523fa7a6SAndroid Build Coastguard Worker    if not config.enable_aot:
187*523fa7a6SAndroid Build Coastguard Worker        if config._unlift:
188*523fa7a6SAndroid Build Coastguard Worker            raise ExportError(
189*523fa7a6SAndroid Build Coastguard Worker                ExportErrorType.NOT_SUPPORTED,
190*523fa7a6SAndroid Build Coastguard Worker                "_unlift config doesn't do anything without enable_aot enabled. Please do not set it",
191*523fa7a6SAndroid Build Coastguard Worker            )
192*523fa7a6SAndroid Build Coastguard Worker    if config.pt2_mode:
193*523fa7a6SAndroid Build Coastguard Worker        if config.enable_aot:
194*523fa7a6SAndroid Build Coastguard Worker            if config.enable_dynamic_shape:
195*523fa7a6SAndroid Build Coastguard Worker                raise ExportError(
196*523fa7a6SAndroid Build Coastguard Worker                    ExportErrorType.NOT_SUPPORTED,
197*523fa7a6SAndroid Build Coastguard Worker                    "Under enable_aot, enable_dynamic_shapes flag doesn't do anything. Please do not set it",
198*523fa7a6SAndroid Build Coastguard Worker                )
199*523fa7a6SAndroid Build Coastguard Worker            if not config.enable_functionalization:
200*523fa7a6SAndroid Build Coastguard Worker                raise ExportError(
201*523fa7a6SAndroid Build Coastguard Worker                    ExportErrorType.NOT_SUPPORTED,
202*523fa7a6SAndroid Build Coastguard Worker                    "Functionalization is required for enable_aot.",
203*523fa7a6SAndroid Build Coastguard Worker                )
204*523fa7a6SAndroid Build Coastguard Worker
205*523fa7a6SAndroid Build Coastguard Worker            # If trying to capture a method and the bound class instance is a
206*523fa7a6SAndroid Build Coastguard Worker            # Module, then export the module while patching in that method.
207*523fa7a6SAndroid Build Coastguard Worker            if isinstance(f, MethodType) and isinstance(f.__self__, torch.nn.Module):
208*523fa7a6SAndroid Build Coastguard Worker                with patch_forward(f.__self__, f):
209*523fa7a6SAndroid Build Coastguard Worker                    ep = export(
210*523fa7a6SAndroid Build Coastguard Worker                        cast(torch.nn.Module, f.__self__),
211*523fa7a6SAndroid Build Coastguard Worker                        args,
212*523fa7a6SAndroid Build Coastguard Worker                        dynamic_shapes=dynamic_shapes,
213*523fa7a6SAndroid Build Coastguard Worker                    )
214*523fa7a6SAndroid Build Coastguard Worker            else:
215*523fa7a6SAndroid Build Coastguard Worker                mod = f if isinstance(f, torch.nn.Module) else WrapperModule(f)
216*523fa7a6SAndroid Build Coastguard Worker                ep = export(mod, args, dynamic_shapes=dynamic_shapes)
217*523fa7a6SAndroid Build Coastguard Worker
218*523fa7a6SAndroid Build Coastguard Worker            ep = ep.run_decompositions(_default_decomposition_table())
219*523fa7a6SAndroid Build Coastguard Worker            ep = _transform(ep, ReplaceViewOpsWithViewCopyOpsPass())
220*523fa7a6SAndroid Build Coastguard Worker            if not config._unlift:
221*523fa7a6SAndroid Build Coastguard Worker                return ExirExportedProgram(ep, False)
222*523fa7a6SAndroid Build Coastguard Worker            graph_module = cast(torch.fx.GraphModule, ep.module())
223*523fa7a6SAndroid Build Coastguard Worker
224*523fa7a6SAndroid Build Coastguard Worker        elif config.enable_dynamic_shape:
225*523fa7a6SAndroid Build Coastguard Worker            graph_module, _ = dynamo_trace(
226*523fa7a6SAndroid Build Coastguard Worker                f,
227*523fa7a6SAndroid Build Coastguard Worker                args,
228*523fa7a6SAndroid Build Coastguard Worker                aten_graph=True,
229*523fa7a6SAndroid Build Coastguard Worker                tracing_mode="symbolic",
230*523fa7a6SAndroid Build Coastguard Worker                dynamo_config=config._dynamo_config,
231*523fa7a6SAndroid Build Coastguard Worker                dynamic_shapes=dynamic_shapes,
232*523fa7a6SAndroid Build Coastguard Worker                _use_old_decomp_table=config._use_old_decomp_table,
233*523fa7a6SAndroid Build Coastguard Worker            )
234*523fa7a6SAndroid Build Coastguard Worker
235*523fa7a6SAndroid Build Coastguard Worker        else:
236*523fa7a6SAndroid Build Coastguard Worker            graph_module, _ = dynamo_trace(
237*523fa7a6SAndroid Build Coastguard Worker                f,
238*523fa7a6SAndroid Build Coastguard Worker                args,
239*523fa7a6SAndroid Build Coastguard Worker                aten_graph=True,
240*523fa7a6SAndroid Build Coastguard Worker                tracing_mode="fake",
241*523fa7a6SAndroid Build Coastguard Worker                dynamo_config=config._dynamo_config,
242*523fa7a6SAndroid Build Coastguard Worker                dynamic_shapes=None,
243*523fa7a6SAndroid Build Coastguard Worker                _use_old_decomp_table=config._use_old_decomp_table,
244*523fa7a6SAndroid Build Coastguard Worker            )
245*523fa7a6SAndroid Build Coastguard Worker
246*523fa7a6SAndroid Build Coastguard Worker        if out_spec is None:
247*523fa7a6SAndroid Build Coastguard Worker            if isinstance(graph_module.graph._codegen, torch.fx.graph._PyTreeCodeGen):
248*523fa7a6SAndroid Build Coastguard Worker                out_spec = graph_module.graph._codegen.pytree_info.out_spec
249*523fa7a6SAndroid Build Coastguard Worker            elif hasattr(graph_module, "_out_spec"):
250*523fa7a6SAndroid Build Coastguard Worker                out_spec = graph_module._out_spec
251*523fa7a6SAndroid Build Coastguard Worker            else:
252*523fa7a6SAndroid Build Coastguard Worker                out_spec = pytree.tree_flatten(f(*args))[1]
253*523fa7a6SAndroid Build Coastguard Worker
254*523fa7a6SAndroid Build Coastguard Worker        # NOTE (tmanlaibaatar)
255*523fa7a6SAndroid Build Coastguard Worker        # torchdynamo.export adds extra kwarg into the graph module
256*523fa7a6SAndroid Build Coastguard Worker        # which is then lost while we are calling make_fx. This is because
257*523fa7a6SAndroid Build Coastguard Worker        # make_fx doesn't handle kwargs. Originally we used to use torchdynamo
258*523fa7a6SAndroid Build Coastguard Worker        # input spec, but due to some limitations in pytree implementation, it doesn't
259*523fa7a6SAndroid Build Coastguard Worker        # recognize the make_fx graph with torchdynamo input spec. We workaround it
260*523fa7a6SAndroid Build Coastguard Worker        # by getting the input spec directly from user argument.
261*523fa7a6SAndroid Build Coastguard Worker        in_spec = pytree.tree_flatten((args, {}))[1]
262*523fa7a6SAndroid Build Coastguard Worker
263*523fa7a6SAndroid Build Coastguard Worker        if config.enable_functionalization and not config.enable_aot:
264*523fa7a6SAndroid Build Coastguard Worker            args = copy.deepcopy(args)
265*523fa7a6SAndroid Build Coastguard Worker
266*523fa7a6SAndroid Build Coastguard Worker            def graph_with_interpreter(*args):
267*523fa7a6SAndroid Build Coastguard Worker                with torch.fx.traceback.preserve_node_meta():
268*523fa7a6SAndroid Build Coastguard Worker                    return torch.fx.Interpreter(graph_module).run(*args)
269*523fa7a6SAndroid Build Coastguard Worker
270*523fa7a6SAndroid Build Coastguard Worker            functionalized_callable = functionalize(
271*523fa7a6SAndroid Build Coastguard Worker                graph_with_interpreter,
272*523fa7a6SAndroid Build Coastguard Worker                remove="mutations_and_views",
273*523fa7a6SAndroid Build Coastguard Worker            )
274*523fa7a6SAndroid Build Coastguard Worker            assert isinstance(functionalized_callable, Callable)
275*523fa7a6SAndroid Build Coastguard Worker
276*523fa7a6SAndroid Build Coastguard Worker            if config.enable_dynamic_shape:
277*523fa7a6SAndroid Build Coastguard Worker                fake_tensor_mode = FakeTensorMode(
278*523fa7a6SAndroid Build Coastguard Worker                    allow_fallback_kernels=False,
279*523fa7a6SAndroid Build Coastguard Worker                    allow_non_fake_inputs=True,
280*523fa7a6SAndroid Build Coastguard Worker                    shape_env=ShapeEnv(),
281*523fa7a6SAndroid Build Coastguard Worker                )
282*523fa7a6SAndroid Build Coastguard Worker
283*523fa7a6SAndroid Build Coastguard Worker                inps: List[torch.Tensor] = []
284*523fa7a6SAndroid Build Coastguard Worker                for node in graph_module.graph.nodes:
285*523fa7a6SAndroid Build Coastguard Worker                    if node.op == "placeholder" and "val" in node.meta:
286*523fa7a6SAndroid Build Coastguard Worker                        example_fake_tensor = node.meta["val"]
287*523fa7a6SAndroid Build Coastguard Worker                        assert isinstance(example_fake_tensor, FakeTensor)
288*523fa7a6SAndroid Build Coastguard Worker                        inps.append(example_fake_tensor)
289*523fa7a6SAndroid Build Coastguard Worker
290*523fa7a6SAndroid Build Coastguard Worker                if detected_fake_mode := _guards.detect_fake_mode(inps):
291*523fa7a6SAndroid Build Coastguard Worker                    fake_tensor_mode = detected_fake_mode
292*523fa7a6SAndroid Build Coastguard Worker
293*523fa7a6SAndroid Build Coastguard Worker                count = 0
294*523fa7a6SAndroid Build Coastguard Worker
295*523fa7a6SAndroid Build Coastguard Worker                def convert_to_fake(x):
296*523fa7a6SAndroid Build Coastguard Worker                    nonlocal count
297*523fa7a6SAndroid Build Coastguard Worker                    val = inps[count]
298*523fa7a6SAndroid Build Coastguard Worker                    count += 1
299*523fa7a6SAndroid Build Coastguard Worker                    return val
300*523fa7a6SAndroid Build Coastguard Worker
301*523fa7a6SAndroid Build Coastguard Worker                fake_args = pytree.tree_map_only(torch.Tensor, convert_to_fake, args)
302*523fa7a6SAndroid Build Coastguard Worker
303*523fa7a6SAndroid Build Coastguard Worker                with enable_python_dispatcher(), fake_tensor_mode:
304*523fa7a6SAndroid Build Coastguard Worker                    graph_module = make_fx(
305*523fa7a6SAndroid Build Coastguard Worker                        functionalized_callable,
306*523fa7a6SAndroid Build Coastguard Worker                        tracing_mode="real",
307*523fa7a6SAndroid Build Coastguard Worker                        _allow_non_fake_inputs=True,
308*523fa7a6SAndroid Build Coastguard Worker                    )(*fake_args)
309*523fa7a6SAndroid Build Coastguard Worker            else:
310*523fa7a6SAndroid Build Coastguard Worker                # To avoid breaking folks, use the deprecated "real" tracing
311*523fa7a6SAndroid Build Coastguard Worker                # mode if we're not using pt2.
312*523fa7a6SAndroid Build Coastguard Worker                tracing_mode = "fake" if config.pt2_mode else "real"
313*523fa7a6SAndroid Build Coastguard Worker                graph_module = make_fx(
314*523fa7a6SAndroid Build Coastguard Worker                    functionalized_callable,
315*523fa7a6SAndroid Build Coastguard Worker                    tracing_mode=tracing_mode,
316*523fa7a6SAndroid Build Coastguard Worker                    _allow_non_fake_inputs=True,
317*523fa7a6SAndroid Build Coastguard Worker                )(*args)
318*523fa7a6SAndroid Build Coastguard Worker
319*523fa7a6SAndroid Build Coastguard Worker        flatten_output(graph_module)
320*523fa7a6SAndroid Build Coastguard Worker
321*523fa7a6SAndroid Build Coastguard Worker    else:
322*523fa7a6SAndroid Build Coastguard Worker        raise InternalError("pt2=False path is officially deprecated")
323*523fa7a6SAndroid Build Coastguard Worker
324*523fa7a6SAndroid Build Coastguard Worker    _instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args)
325*523fa7a6SAndroid Build Coastguard Worker    graph_module._apply(torch.Tensor.contiguous)
326*523fa7a6SAndroid Build Coastguard Worker
327*523fa7a6SAndroid Build Coastguard Worker    user_inputs = [
328*523fa7a6SAndroid Build Coastguard Worker        InputSpec(
329*523fa7a6SAndroid Build Coastguard Worker            kind=InputKind.USER_INPUT, arg=TensorArgument(name=node.name), target=None
330*523fa7a6SAndroid Build Coastguard Worker        )
331*523fa7a6SAndroid Build Coastguard Worker        for node in graph_module.graph.nodes
332*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder"
333*523fa7a6SAndroid Build Coastguard Worker    ]
334*523fa7a6SAndroid Build Coastguard Worker    output_node = list(graph_module.graph.nodes)[-1]
335*523fa7a6SAndroid Build Coastguard Worker    assert output_node.op == "output"
336*523fa7a6SAndroid Build Coastguard Worker    user_outputs = [
337*523fa7a6SAndroid Build Coastguard Worker        OutputSpec(
338*523fa7a6SAndroid Build Coastguard Worker            kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=arg.name), target=None
339*523fa7a6SAndroid Build Coastguard Worker        )
340*523fa7a6SAndroid Build Coastguard Worker        for arg in output_node.args[0]
341*523fa7a6SAndroid Build Coastguard Worker    ]
342*523fa7a6SAndroid Build Coastguard Worker
343*523fa7a6SAndroid Build Coastguard Worker    graph_module.graph.eliminate_dead_code()
344*523fa7a6SAndroid Build Coastguard Worker    ep = ExportedProgram(
345*523fa7a6SAndroid Build Coastguard Worker        root=graph_module,
346*523fa7a6SAndroid Build Coastguard Worker        graph=graph_module.graph,
347*523fa7a6SAndroid Build Coastguard Worker        graph_signature=ExportGraphSignature(user_inputs, user_outputs),
348*523fa7a6SAndroid Build Coastguard Worker        state_dict={},
349*523fa7a6SAndroid Build Coastguard Worker        range_constraints={},
350*523fa7a6SAndroid Build Coastguard Worker        module_call_graph=[
351*523fa7a6SAndroid Build Coastguard Worker            ModuleCallEntry(
352*523fa7a6SAndroid Build Coastguard Worker                fqn="",
353*523fa7a6SAndroid Build Coastguard Worker                signature=ModuleCallSignature(
354*523fa7a6SAndroid Build Coastguard Worker                    inputs=[],
355*523fa7a6SAndroid Build Coastguard Worker                    outputs=[],
356*523fa7a6SAndroid Build Coastguard Worker                    in_spec=in_spec,
357*523fa7a6SAndroid Build Coastguard Worker                    # pyre-fixme[6]: For 4th argument expected `TreeSpec` but got
358*523fa7a6SAndroid Build Coastguard Worker                    #  `Union[None, TreeSpec, Tensor, Module]`.
359*523fa7a6SAndroid Build Coastguard Worker                    out_spec=out_spec,
360*523fa7a6SAndroid Build Coastguard Worker                ),
361*523fa7a6SAndroid Build Coastguard Worker            )
362*523fa7a6SAndroid Build Coastguard Worker        ],
363*523fa7a6SAndroid Build Coastguard Worker        example_inputs=None,
364*523fa7a6SAndroid Build Coastguard Worker        verifiers=[EXIRATenDialectVerifierBase],
365*523fa7a6SAndroid Build Coastguard Worker    )
366*523fa7a6SAndroid Build Coastguard Worker    return ExirExportedProgram(ep, False)
367*523fa7a6SAndroid Build Coastguard Worker
368*523fa7a6SAndroid Build Coastguard Worker
369*523fa7a6SAndroid Build Coastguard Worker# This is to bootstrap the missing meta["val"] when 1. ph consists of scalar
370*523fa7a6SAndroid Build Coastguard Worker# 2. meta["val"] is not properly set in dispatch_trace.
371*523fa7a6SAndroid Build Coastguard Workerdef _instantiate_missing_placeholder_val_with_real_inputs(gm, args):
372*523fa7a6SAndroid Build Coastguard Worker    phs = [node for node in gm.graph.nodes if node.op == "placeholder"]
373*523fa7a6SAndroid Build Coastguard Worker    if len(phs) != len(args):
374*523fa7a6SAndroid Build Coastguard Worker        raise ExportError(
375*523fa7a6SAndroid Build Coastguard Worker            ExportErrorType.NOT_SUPPORTED,
376*523fa7a6SAndroid Build Coastguard Worker            "Expect number of placeholders to be the same as user inputs.",
377*523fa7a6SAndroid Build Coastguard Worker        )
378*523fa7a6SAndroid Build Coastguard Worker    for node, arg in zip(phs, args):
379*523fa7a6SAndroid Build Coastguard Worker        if "val" not in node.meta or node.meta["val"] is None:
380*523fa7a6SAndroid Build Coastguard Worker            node.meta["val"] = arg
381