xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import functools
5from typing import Any, Callable, Mapping, Sequence
6
7import torch
8import torch.fx
9import torch.onnx
10import torch.onnx._internal.fx.passes as passes
11from torch.onnx._internal import _exporter_legacy, io_adapter
12
13
14# Functions directly wrapped to produce torch.fx.Proxy so that symbolic
15# data can flow through those functions. Python functions (e.g., `torch.arange`)
16# not defined by pybind11 in C++ do not go though Python dispatcher, so
17# they are not automatically patched by FX's Python dispatcher.
18# The list below means `torch.arange`, `torch.tensor`, and so on will be
19# patched.
20_TORCH_METHODS_TO_PATCH: tuple[str, ...] = (
21    "arange",
22    "tensor",
23    "finfo",
24    "full",
25    "empty",
26)
27
28
29class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
30    """Tracer to create ONNX-exporting friendly FX graph.
31
32    This tracer traces models into operators. That is,
33    the traced graph mostly contains call_function nodes and
34    has no call_module nodes. The call_module nodes
35    are problematic to the use of make_fx(...) in ONNX
36    exporter.
37    """
38
39    def is_leaf_module(
40        self, module: torch.nn.Module, module_qualified_name: str
41    ) -> bool:
42        # This returns False so that all sub-modules are considered as not leaves
43        # and therefore expanded into operators in
44        # torch.fx._symbolic_trace.Tracer.call_module.
45        return False
46
47    def to_bool(self, obj: torch.fx.Proxy) -> bool:
48        # FIXME: This is a hack to tracing through if-else Python blocks.
49        # It may generate incorrect ONNX graphs if the if-else block
50        return False
51
52
53def _wrap_for_symbolic_trace(target: Callable) -> tuple[Callable, Callable]:
54    """This function wraps ```target`` for symbolic tracing.
55
56    This function wraps ```target``` so that its wrapper produces
57    torch.fx.Proxy in symbolic computation. The returned values are
58    the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`,
59    this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs.
60    """
61
62    @functools.wraps(target)
63    def wrapper(*args, **kwargs):
64        proxy = None
65
66        def check_has_proxy(v):
67            if isinstance(v, torch.fx.Proxy):
68                nonlocal proxy
69                proxy = v
70
71        torch.fx.node.map_aggregate(args, check_has_proxy)
72        torch.fx.node.map_aggregate(kwargs, check_has_proxy)
73
74        if proxy is not None:
75            return proxy.tracer.create_proxy("call_function", target, args, kwargs)
76        else:
77            return target(*args, **kwargs)
78
79    return wrapper, target
80
81
82def _module_expansion_symbolic_trace(
83    root: torch.nn.Module | Callable[..., Any],
84    concrete_args: dict[str, Any] | None = None,
85) -> torch.fx.GraphModule:
86    """Trace a callable into FX graph.
87
88    When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be
89    expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph
90    structure.
91    """
92    # For functions doesn't support symbolic tracing, create wrappers
93    # which produce symbolic results during tracing.
94    patched_torch_methods = {
95        target_name: _wrap_for_symbolic_trace(getattr(torch, target_name))
96        for target_name in _TORCH_METHODS_TO_PATCH
97    }
98
99    # Set the symbolic-tracing friendly functions so that `tracer.trace` below
100    # can work.
101    for name, (wrapper, _) in patched_torch_methods.items():
102        setattr(torch, name, wrapper)
103
104    try:
105        # Set up a tracer.
106        tracer = ModuleExpansionTracer()
107        # Trace the model.
108        graph = tracer.trace(root, concrete_args)
109        name = (
110            root.__class__.__name__
111            if isinstance(root, torch.nn.Module)
112            else root.__name__
113        )
114        return torch.fx.GraphModule(tracer.root, graph, name)
115    finally:
116        # Revert the patches for symbolic tracing.
117        for name, (_, wrapped) in patched_torch_methods.items():
118            # wrapped is the original version of `torch.name`.
119            setattr(torch, name, wrapped)
120
121
122# TODO: Migrate to `DynamoExporter` after fake model tracing is supported.
123# Proposal at https://github.com/pytorch/pytorch/issues/95900.
124class FXSymbolicTracer(_exporter_legacy.FXGraphExtractor):
125    """Generates a FX GraphModule using torch.fx.symbolic_trace API
126    Args:
127        concrete_args: Inputs to be partially specialized
128            It can be used to remove control flow or data structures.
129            For example::
130                def f(a, b):
131                    if b == True:
132                        return a
133                    else:
134                        return a*2
135            FX can typically not trace through this due to the presence of control
136            flow. However, we can use `concrete_args` to specialize on the value of
137            `b` to trace through this::
138                f = fx.symbolic_trace(f, concrete_args={'b': False})
139                assert f(3, False)  == 6
140            Note that although you can still pass in different values of `b`, they will be ignored.
141            It can also be used to eliminate data-structure handling from
142            our function. This will use pytrees to flatten your input. To avoid
143            overspecializing, pass in `fx.PH` for values that shouldn't be
144            specialized. For example::
145                def f(x):
146                    out = 0
147                    for v in x.values():
148                        out += v
149                    return out
150
151
152                f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
153                assert f({"a": 1, "b": 2, "c": 4}) == 7
154    """
155
156    def __init__(self, concrete_args: dict[str, Any] | None = None):
157        super().__init__()
158        # TODO: plumb ``concrete_args`` to symbolic_trace call at ``generate_fx``
159        self.concrete_args = concrete_args
160
161    def _trace_into_fx_graph_via_fx_symbolic_trace(
162        self, model, model_args, model_kwargs
163    ) -> torch.fx.GraphModule:
164        # Bind model args and kwargs with model signature to retrieve default values
165        # of unprovided arguments. These are then used to construct ``concrete_args``.
166        bind_input_step = io_adapter.BindInputStep(
167            torch.onnx.utils.model_signature(model)
168        )
169        self.input_adapter.append_step(bind_input_step)
170        _, named_args = bind_input_step.apply(model_args, model_kwargs, model=model)
171
172        # Create inputs to call symbolic trace (torch.fx.symbolic_trace)
173        # Example content of concrete_args:
174        #  concrete_args["x"] = torch.fx._symbolic_trace.PH
175        #  concrete_args["b"] = 1
176        # where "x" and "b" are argument names in "signature".
177        concrete_args = {}
178        for param_name, param_value in named_args.items():
179            if isinstance(param_value, torch.Tensor):
180                # param_value can be, e.g., a real tensor or a fake tensor.
181                # param_value is treated as substitutable tensor symbol (aka placeholder).
182                concrete_args[param_name] = torch.fx._symbolic_trace.PH
183            else:
184                concrete_args[param_name] = param_value
185
186        # Merge kwargs back into args since that is the format FX graph expects.
187        merge_kwargs_step = io_adapter.MergeKwargsIntoArgsInputStep()
188        self.input_adapter.append_step(merge_kwargs_step)
189        return _module_expansion_symbolic_trace(model, concrete_args=concrete_args)
190
191    def generate_fx(
192        self,
193        options: _exporter_legacy.ResolvedExportOptions,
194        model: torch.nn.Module | Callable,
195        model_args: Sequence[Any],
196        model_kwargs: Mapping[str, Any],
197    ) -> torch.fx.GraphModule:
198        diagnostic_context = options.diagnostic_context
199        graph_module = self._trace_into_fx_graph_via_fx_symbolic_trace(
200            model, model_args, model_kwargs
201        )
202
203        # Make sure all placeholder nodes are executed before get_attr nodes.
204        # Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input.
205        # Basically, we want
206        #  ModeoProto.graph.input =
207        #   [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m]
208        # and we don't want
209        #  ModeoProto.graph.input =
210        #   [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m]
211        graph_module = passes.MovePlaceholderToFront(
212            diagnostic_context, graph_module
213        ).run()
214        # To save memory, move get_attr to input so that the generated model doesn't
215        # have weigh tensors. "replaced_attrs" are a tuple of replaced weight tensors.
216        replace_get_attr_with_placeholder_pass = passes.ReplaceGetAttrWithPlaceholder(
217            diagnostic_context, graph_module
218        )
219        graph_module = replace_get_attr_with_placeholder_pass.run()
220        replaced_attrs = replace_get_attr_with_placeholder_pass.replaced_attrs
221        append_extra_input_step = io_adapter.LiftParametersAndBuffersIntoArgsInputStep(
222            replaced_attrs
223        )
224        self.input_adapter.append_step(append_extra_input_step)
225        # Move all newly created placeholder nodes to the front of the graph.
226        graph_module = passes.MovePlaceholderToFront(
227            diagnostic_context, graph_module
228        ).run()
229        # Finalize the graph editing.
230        graph_module.recompile()
231
232        updated_model_args = self.input_adapter.apply(
233            *model_args, model=model, **model_kwargs
234        )
235
236        return self.pre_export_passes(options, model, graph_module, updated_model_args)  # type: ignore[return-value]
237
238    def pre_export_passes(
239        self,
240        options: _exporter_legacy.ResolvedExportOptions,
241        original_model: torch.nn.Module | Callable,
242        fx_module: torch.fx.GraphModule,
243        fx_module_args: Sequence[Any],
244    ):
245        return _exporter_legacy.common_pre_export_passes(
246            options, original_model, fx_module, fx_module_args
247        )
248