1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import contextlib 5from typing import Callable, Mapping, TYPE_CHECKING 6 7import torch 8import torch._ops 9from torch._dispatch import python as python_dispatch 10from torch._subclasses import fake_tensor 11from torch.fx.experimental import proxy_tensor 12from torch.onnx._internal.fx import _pass, diagnostics 13from torch.onnx._internal.fx.passes import _utils 14 15 16if TYPE_CHECKING: 17 import torch.fx 18 19 20class Decompose(_pass.Transform): 21 def __init__( 22 self, 23 diagnostic_context: diagnostics.DiagnosticContext, 24 module: torch.fx.GraphModule, 25 decomposition_table: Mapping[torch._ops.OpOverload, Callable], 26 enable_dynamic_axes: bool, 27 allow_fake_constant: bool | None = False, 28 ): 29 super().__init__(diagnostic_context, module) 30 self.decomposition_table = decomposition_table 31 self.enable_dynamic_axes = enable_dynamic_axes 32 self.allow_fake_constant = allow_fake_constant 33 34 def _run(self, *args, **kwargs) -> torch.fx.GraphModule: 35 assert not kwargs, "kwargs is not supported in Decompose." 36 37 # To preserve stack trace info after `make_fx`. 38 module = _utils.wrap_graph_module_for_node_meta_preservation(self.module) 39 40 # fake mode use static size to trace the size of tensors. while symbolic 41 # mode generates aten::sym_size to dynamically trace the size of tensors. 42 43 # e.g. fake mode: 44 # view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [3, 5, 20]) 45 46 # e.g. symbolic mode: 47 # sym_size = torch.ops.aten.sym_size(x, 0) 48 # sym_size_1 = torch.ops.aten.sym_size(x, 1) 49 # sym_size_2 = torch.ops.aten.sym_size(x, 2) 50 # sym_size_3 = torch.ops.aten.sym_size(x, 3) 51 # mul = sym_size_2 * sym_size_3; sym_size_2 = sym_size_3 = None 52 # view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [sym_size, sym_size_1, mul]) 53 54 # Mimic `torch._dynamo.export(aten_graph=True)` behavior in invoking `make_fx`. 55 # TODO: May need revisit for user fake mode export + dynamic shape scenario. 56 fake_mode: fake_tensor.FakeTensorMode | None = self.fake_mode 57 maybe_fake_args = self._maybe_fakefy_args(fake_mode, *args) 58 if fake_mode is not None: 59 # Using existing fake mode as context, signal `make_fx` that it does not need 60 # to create a new fake mode by passing tracing_mode as "real". 61 tracing_mode = "real" 62 else: 63 # Existing fake mode not found, signal `make_fx` to create one. 64 fake_mode = contextlib.nullcontext() # type: ignore[assignment] 65 tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake" 66 67 # Apply decomposition table to the input graph. 68 assert fake_mode is not None # for mypy 69 with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), fake_mode: 70 decomposed_module = proxy_tensor.make_fx( 71 module, 72 decomposition_table=self.decomposition_table, 73 tracing_mode=tracing_mode, 74 _allow_non_fake_inputs=True, 75 _allow_fake_constant=bool(self.allow_fake_constant), 76 )(*maybe_fake_args) 77 78 # Rename placeholder targets to match the original module's signature since 79 # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2). 80 _utils.replace_placeholder_name_and_target(decomposed_module, self.module) 81 82 return decomposed_module 83