xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/passes/decomp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import contextlib
5from typing import Callable, Mapping, TYPE_CHECKING
6
7import torch
8import torch._ops
9from torch._dispatch import python as python_dispatch
10from torch._subclasses import fake_tensor
11from torch.fx.experimental import proxy_tensor
12from torch.onnx._internal.fx import _pass, diagnostics
13from torch.onnx._internal.fx.passes import _utils
14
15
16if TYPE_CHECKING:
17    import torch.fx
18
19
20class Decompose(_pass.Transform):
21    def __init__(
22        self,
23        diagnostic_context: diagnostics.DiagnosticContext,
24        module: torch.fx.GraphModule,
25        decomposition_table: Mapping[torch._ops.OpOverload, Callable],
26        enable_dynamic_axes: bool,
27        allow_fake_constant: bool | None = False,
28    ):
29        super().__init__(diagnostic_context, module)
30        self.decomposition_table = decomposition_table
31        self.enable_dynamic_axes = enable_dynamic_axes
32        self.allow_fake_constant = allow_fake_constant
33
34    def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
35        assert not kwargs, "kwargs is not supported in Decompose."
36
37        # To preserve stack trace info after `make_fx`.
38        module = _utils.wrap_graph_module_for_node_meta_preservation(self.module)
39
40        # fake mode use static size to trace the size of tensors. while symbolic
41        # mode generates aten::sym_size to dynamically trace the size of tensors.
42
43        # e.g. fake mode:
44        #  view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [3, 5, 20])
45
46        # e.g. symbolic mode:
47        #  sym_size = torch.ops.aten.sym_size(x, 0)
48        #  sym_size_1 = torch.ops.aten.sym_size(x, 1)
49        #  sym_size_2 = torch.ops.aten.sym_size(x, 2)
50        #  sym_size_3 = torch.ops.aten.sym_size(x, 3)
51        #  mul = sym_size_2 * sym_size_3;  sym_size_2 = sym_size_3 = None
52        #  view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [sym_size, sym_size_1, mul])
53
54        # Mimic `torch._dynamo.export(aten_graph=True)` behavior in invoking `make_fx`.
55        # TODO: May need revisit for user fake mode export + dynamic shape scenario.
56        fake_mode: fake_tensor.FakeTensorMode | None = self.fake_mode
57        maybe_fake_args = self._maybe_fakefy_args(fake_mode, *args)
58        if fake_mode is not None:
59            # Using existing fake mode as context, signal `make_fx` that it does not need
60            # to create a new fake mode by passing tracing_mode as "real".
61            tracing_mode = "real"
62        else:
63            # Existing fake mode not found, signal `make_fx` to create one.
64            fake_mode = contextlib.nullcontext()  # type: ignore[assignment]
65            tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake"
66
67        # Apply decomposition table to the input graph.
68        assert fake_mode is not None  # for mypy
69        with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), fake_mode:
70            decomposed_module = proxy_tensor.make_fx(
71                module,
72                decomposition_table=self.decomposition_table,
73                tracing_mode=tracing_mode,
74                _allow_non_fake_inputs=True,
75                _allow_fake_constant=bool(self.allow_fake_constant),
76            )(*maybe_fake_args)
77
78        # Rename placeholder targets to match the original module's signature since
79        # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2).
80        _utils.replace_placeholder_name_and_target(decomposed_module, self.module)
81
82        return decomposed_module
83