xref: /aosp_15_r20/external/executorch/exir/capture/_capture.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8import warnings
9from collections import namedtuple
10from contextlib import contextmanager
11from types import MethodType
12from typing import Any, Callable, cast, List, Optional, Tuple
13
14import torch
15from executorch.exir.capture._config import CaptureConfig
16from executorch.exir.error import ExportError, ExportErrorType, InternalError
17from executorch.exir.program import ExirExportedProgram
18from executorch.exir.program._program import _transform, HackedUpExportedProgramDONOTUSE
19from executorch.exir.tracer import (
20    _default_decomposition_table,
21    dispatch_trace,
22    dynamo_trace,
23    flatten_output,
24    Value,
25)
26from executorch.exir.verification.verifier import EXIRATenDialectVerifierBase
27from torch import _guards
28from torch._dispatch.python import enable_python_dispatcher
29from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
30from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
31from torch.export import export
32from torch.export.exported_program import (
33    ExportedProgram,
34    ExportGraphSignature,
35    InputKind,
36    InputSpec,
37    ModuleCallEntry,
38    ModuleCallSignature,
39    OutputKind,
40    OutputSpec,
41    TensorArgument,
42)
43from torch.func import functionalize
44from torch.fx._compatibility import compatibility
45from torch.fx.experimental.proxy_tensor import make_fx
46from torch.fx.experimental.symbolic_shapes import ShapeEnv
47from torch.utils import _pytree as pytree
48
49
50Val = Any
51
52
53CompileSpec = namedtuple(
54    "CompileSpec", ["method_name", "callable", "args", "dynamic_shapes"]
55)
56
57
58CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"])
59
60
61@compatibility(is_backward_compatible=False)
62def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram:
63    """
64    This is a legacy API that should be avoided. Prefer to use capture() instead.
65    """
66    warnings.warn(
67        "This function is now deprecated, please use `torch.export and exir.to_edge` instead. "
68        "See https://github.com/pytorch/functorch for more details.",
69        DeprecationWarning,
70        stacklevel=1,
71    )
72
73    graph_module = dispatch_trace(f, args)
74    flat_args = tuple(pytree.tree_flatten(args)[0])
75    in_spec, out_spec = graph_module.in_spec, graph_module.out_spec
76
77    _instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args)
78    graph_module._apply(torch.Tensor.contiguous)
79
80    user_inputs = [
81        node.name for node in graph_module.graph.nodes if node.op == "placeholder"
82    ]
83    output_node = list(graph_module.graph.nodes)[-1]
84    assert output_node.op == "output"
85    user_outputs = [arg.name for arg in output_node.args[0]]
86
87    for n in graph_module.graph.nodes:
88        if n.op == "call_function" and "val" not in n.meta:
89            try:
90                args, kwargs = pytree.tree_map_only(
91                    torch.fx.Node, lambda x: x.meta["val"], (n.args, n.kwargs)
92                )
93                n.meta["val"] = n.target(*args, **kwargs)
94            except Exception:
95                n.meta["val"] = None
96
97    ep = HackedUpExportedProgramDONOTUSE(
98        root=graph_module,
99        graph=graph_module.graph,
100        graph_signature=ExportGraphSignature(
101            input_specs=[
102                InputSpec(
103                    kind=InputKind.USER_INPUT, arg=TensorArgument(name=i), target=None
104                )
105                for i in user_inputs
106            ],
107            output_specs=[
108                OutputSpec(
109                    kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=o), target=None
110                )
111                for o in user_outputs
112            ],
113        ),
114        call_spec=CallSpec(in_spec, out_spec),
115        state_dict={},
116        range_constraints={},
117        module_call_graph=[
118            ModuleCallEntry(
119                fqn="",
120                signature=ModuleCallSignature(
121                    inputs=[],
122                    outputs=[],
123                    # pyre-fixme[6]: For 3rd argument expected `TreeSpec` but got
124                    #  `Union[Tensor, Module]`.
125                    in_spec=in_spec,
126                    # pyre-fixme[6]: For 4th argument expected `TreeSpec` but got
127                    #  `Union[Tensor, Module]`.
128                    out_spec=out_spec,
129                ),
130            )
131        ],
132        example_inputs=None,
133        verifier=EXIRATenDialectVerifierBase,
134    )
135    return ExirExportedProgram(ep, False)
136
137
138@contextmanager
139def patch_forward(obj: torch.nn.Module, new_method):
140    """Helper method to make it easier to cleanly torch.export() a method on a
141    module that is not `forward`.
142
143    TODO(suo): upstream this to torch.export.wrapper.
144    """
145    # Save the original method
146    original_method = obj.forward
147
148    # Patch the method
149    obj.forward = new_method.__get__(obj, obj.__class__)
150
151    try:
152        yield
153    finally:
154        # Restore the original method
155        obj.forward = original_method
156
157
158class WrapperModule(torch.nn.Module):
159    def __init__(self, f):
160        super().__init__()
161        self.forward = f
162
163
164@compatibility(is_backward_compatible=False)
165def capture(  # noqa: C901
166    f: Callable[..., Any],
167    args: Tuple[Value, ...],
168    config: Optional[CaptureConfig] = None,
169    dynamic_shapes: Optional[List[Any]] = None,
170) -> ExirExportedProgram:
171    warnings.warn(
172        "This function is now deprecated, please use `torch.export and exir.to_edge` instead. ",
173        DeprecationWarning,
174        stacklevel=1,
175    )
176    if not isinstance(args, tuple):
177        raise ExportError(
178            ExportErrorType.INVALID_INPUT_TYPE,
179            f"Expect `args` to be a tuple, got type: {type(args)}.",
180        )
181
182    config = config or CaptureConfig()
183    out_spec = None
184    # TODO (zhxchen17) Always functionalize in a second pass no matter which path is taken.
185    flat_args = tuple(pytree.tree_flatten(args)[0])
186    if not config.enable_aot:
187        if config._unlift:
188            raise ExportError(
189                ExportErrorType.NOT_SUPPORTED,
190                "_unlift config doesn't do anything without enable_aot enabled. Please do not set it",
191            )
192    if config.pt2_mode:
193        if config.enable_aot:
194            if config.enable_dynamic_shape:
195                raise ExportError(
196                    ExportErrorType.NOT_SUPPORTED,
197                    "Under enable_aot, enable_dynamic_shapes flag doesn't do anything. Please do not set it",
198                )
199            if not config.enable_functionalization:
200                raise ExportError(
201                    ExportErrorType.NOT_SUPPORTED,
202                    "Functionalization is required for enable_aot.",
203                )
204
205            # If trying to capture a method and the bound class instance is a
206            # Module, then export the module while patching in that method.
207            if isinstance(f, MethodType) and isinstance(f.__self__, torch.nn.Module):
208                with patch_forward(f.__self__, f):
209                    ep = export(
210                        cast(torch.nn.Module, f.__self__),
211                        args,
212                        dynamic_shapes=dynamic_shapes,
213                    )
214            else:
215                mod = f if isinstance(f, torch.nn.Module) else WrapperModule(f)
216                ep = export(mod, args, dynamic_shapes=dynamic_shapes)
217
218            ep = ep.run_decompositions(_default_decomposition_table())
219            ep = _transform(ep, ReplaceViewOpsWithViewCopyOpsPass())
220            if not config._unlift:
221                return ExirExportedProgram(ep, False)
222            graph_module = cast(torch.fx.GraphModule, ep.module())
223
224        elif config.enable_dynamic_shape:
225            graph_module, _ = dynamo_trace(
226                f,
227                args,
228                aten_graph=True,
229                tracing_mode="symbolic",
230                dynamo_config=config._dynamo_config,
231                dynamic_shapes=dynamic_shapes,
232                _use_old_decomp_table=config._use_old_decomp_table,
233            )
234
235        else:
236            graph_module, _ = dynamo_trace(
237                f,
238                args,
239                aten_graph=True,
240                tracing_mode="fake",
241                dynamo_config=config._dynamo_config,
242                dynamic_shapes=None,
243                _use_old_decomp_table=config._use_old_decomp_table,
244            )
245
246        if out_spec is None:
247            if isinstance(graph_module.graph._codegen, torch.fx.graph._PyTreeCodeGen):
248                out_spec = graph_module.graph._codegen.pytree_info.out_spec
249            elif hasattr(graph_module, "_out_spec"):
250                out_spec = graph_module._out_spec
251            else:
252                out_spec = pytree.tree_flatten(f(*args))[1]
253
254        # NOTE (tmanlaibaatar)
255        # torchdynamo.export adds extra kwarg into the graph module
256        # which is then lost while we are calling make_fx. This is because
257        # make_fx doesn't handle kwargs. Originally we used to use torchdynamo
258        # input spec, but due to some limitations in pytree implementation, it doesn't
259        # recognize the make_fx graph with torchdynamo input spec. We workaround it
260        # by getting the input spec directly from user argument.
261        in_spec = pytree.tree_flatten((args, {}))[1]
262
263        if config.enable_functionalization and not config.enable_aot:
264            args = copy.deepcopy(args)
265
266            def graph_with_interpreter(*args):
267                with torch.fx.traceback.preserve_node_meta():
268                    return torch.fx.Interpreter(graph_module).run(*args)
269
270            functionalized_callable = functionalize(
271                graph_with_interpreter,
272                remove="mutations_and_views",
273            )
274            assert isinstance(functionalized_callable, Callable)
275
276            if config.enable_dynamic_shape:
277                fake_tensor_mode = FakeTensorMode(
278                    allow_fallback_kernels=False,
279                    allow_non_fake_inputs=True,
280                    shape_env=ShapeEnv(),
281                )
282
283                inps: List[torch.Tensor] = []
284                for node in graph_module.graph.nodes:
285                    if node.op == "placeholder" and "val" in node.meta:
286                        example_fake_tensor = node.meta["val"]
287                        assert isinstance(example_fake_tensor, FakeTensor)
288                        inps.append(example_fake_tensor)
289
290                if detected_fake_mode := _guards.detect_fake_mode(inps):
291                    fake_tensor_mode = detected_fake_mode
292
293                count = 0
294
295                def convert_to_fake(x):
296                    nonlocal count
297                    val = inps[count]
298                    count += 1
299                    return val
300
301                fake_args = pytree.tree_map_only(torch.Tensor, convert_to_fake, args)
302
303                with enable_python_dispatcher(), fake_tensor_mode:
304                    graph_module = make_fx(
305                        functionalized_callable,
306                        tracing_mode="real",
307                        _allow_non_fake_inputs=True,
308                    )(*fake_args)
309            else:
310                # To avoid breaking folks, use the deprecated "real" tracing
311                # mode if we're not using pt2.
312                tracing_mode = "fake" if config.pt2_mode else "real"
313                graph_module = make_fx(
314                    functionalized_callable,
315                    tracing_mode=tracing_mode,
316                    _allow_non_fake_inputs=True,
317                )(*args)
318
319        flatten_output(graph_module)
320
321    else:
322        raise InternalError("pt2=False path is officially deprecated")
323
324    _instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args)
325    graph_module._apply(torch.Tensor.contiguous)
326
327    user_inputs = [
328        InputSpec(
329            kind=InputKind.USER_INPUT, arg=TensorArgument(name=node.name), target=None
330        )
331        for node in graph_module.graph.nodes
332        if node.op == "placeholder"
333    ]
334    output_node = list(graph_module.graph.nodes)[-1]
335    assert output_node.op == "output"
336    user_outputs = [
337        OutputSpec(
338            kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=arg.name), target=None
339        )
340        for arg in output_node.args[0]
341    ]
342
343    graph_module.graph.eliminate_dead_code()
344    ep = ExportedProgram(
345        root=graph_module,
346        graph=graph_module.graph,
347        graph_signature=ExportGraphSignature(user_inputs, user_outputs),
348        state_dict={},
349        range_constraints={},
350        module_call_graph=[
351            ModuleCallEntry(
352                fqn="",
353                signature=ModuleCallSignature(
354                    inputs=[],
355                    outputs=[],
356                    in_spec=in_spec,
357                    # pyre-fixme[6]: For 4th argument expected `TreeSpec` but got
358                    #  `Union[None, TreeSpec, Tensor, Module]`.
359                    out_spec=out_spec,
360                ),
361            )
362        ],
363        example_inputs=None,
364        verifiers=[EXIRATenDialectVerifierBase],
365    )
366    return ExirExportedProgram(ep, False)
367
368
369# This is to bootstrap the missing meta["val"] when 1. ph consists of scalar
370# 2. meta["val"] is not properly set in dispatch_trace.
371def _instantiate_missing_placeholder_val_with_real_inputs(gm, args):
372    phs = [node for node in gm.graph.nodes if node.op == "placeholder"]
373    if len(phs) != len(args):
374        raise ExportError(
375            ExportErrorType.NOT_SUPPORTED,
376            "Expect number of placeholders to be the same as user inputs.",
377        )
378    for node, arg in zip(phs, args):
379        if "val" not in node.meta or node.meta["val"] is None:
380            node.meta["val"] = arg
381