xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/exporter/_compat.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Compatibility functions for the torch.onnx.export API."""
2
3# mypy: allow-untyped-defs
4# mypy: disable-error-code=attr-defined
5from __future__ import annotations
6
7import inspect
8import logging
9from typing import Any, Mapping, Sequence, TYPE_CHECKING
10
11import torch
12from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
13from torch.onnx._internal.exporter import _core, _onnx_program
14
15
16if TYPE_CHECKING:
17    import os
18
19logger = logging.getLogger(__name__)
20
21
22def _signature(model) -> inspect.Signature:
23    should_be_callable = getattr(model, "forward", model)
24    if callable(should_be_callable):
25        return inspect.signature(should_be_callable)
26    raise ValueError("model has no forward method and is not callable")
27
28
29def _from_dynamic_axes_to_dynamic_shapes(
30    model,
31    *,
32    dynamic_axes=None,
33    output_names: set[str],
34    input_names: Sequence[str] | None = None,
35) -> dict[str, Any] | None:
36    """
37
38    dynamic_axes examples:
39    (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}}
40    (2) dynamic_axes = {"x": [0], "y": [1]}
41
42    these will be converted to dynamic_shapes respectively:
43    (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}}
44    (2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}}  # auto-generated dim names
45
46    """
47    # https://github.com/pytorch/pytorch/pull/128371
48    # 1. The function does not need to provide dynamic_shapes to torch.export.export
49    if dynamic_axes is None:
50        return None
51
52    if input_names is None:
53        input_names = []
54
55    sig = _signature(model)
56    if len(input_names) > len(sig.parameters):
57        raise ValueError(
58            f"Number of input names ({len(input_names)}) should not be greater than "
59            f"the number of model inputs ({len(sig.parameters)})"
60        )
61    input_names_to_model_inputs = {}
62    for idx, param_name in enumerate(sig.parameters):
63        if idx < len(input_names):
64            input_names_to_model_inputs[input_names[idx]] = param_name
65        else:
66            input_names_to_model_inputs[param_name] = param_name
67
68    # NOTE: torch.export.export does not support input names assignment,
69    # so we need to map input names to model inputs to create dynamic_shapes
70    # for the exported program
71    dynamic_shapes_to_exported_program = {}
72    for input_name, axes in dynamic_axes.items():
73        if input_name in output_names:
74            # User specified an output name as a dynamic axis, so we skip it
75            continue
76        # input_name can be either from input_names or from the model inputs
77        if input_name not in input_names_to_model_inputs:
78            raise ValueError(
79                f"dynamic axis: {input_name} is not found in the input names: {input_names}"
80            )
81        model_input_name = input_names_to_model_inputs[input_name]
82        if isinstance(axes, dict):
83            dynamic_shapes_to_exported_program[model_input_name] = {
84                k: torch.export.Dim(v) for k, v in axes.items()
85            }
86        elif isinstance(axes, list):
87            dynamic_shapes_to_exported_program[model_input_name] = {
88                k: torch.export.Dim(f"{model_input_name}_dim_{k}") for k in axes
89            }
90        else:
91            raise TypeError(
92                f"dynamic_axes value must be either a dict or a list, but got {type(axes)}"
93            )
94    # torch.export.export needs static dim to present in dynamic_shapes
95    # for all input tensors, so we need to add them with None
96    for input_name in sig.parameters:
97        if input_name not in dynamic_shapes_to_exported_program:
98            dynamic_shapes_to_exported_program[input_name] = None  # type: ignore[assignment]
99
100    return dynamic_shapes_to_exported_program
101
102
103def _get_torch_export_args(
104    args: tuple[Any, ...],
105    kwargs: dict[str, Any] | None,
106) -> tuple[tuple[Any, ...], dict[str, Any] | None]:
107    """Obtain the arguments for torch.onnx.export from the model and the input arguments."""
108    if not kwargs and args and isinstance(args[-1], dict):
109        kwargs = args[-1]
110        args = args[:-1]
111    return args, kwargs
112
113
114def export_compat(
115    model: torch.nn.Module
116    | torch.export.ExportedProgram
117    | torch.jit.ScriptModule
118    | torch.jit.ScriptFunction,
119    args: tuple[Any, ...],
120    f: str | os.PathLike | None = None,
121    *,
122    kwargs: dict[str, Any] | None = None,
123    export_params: bool = True,
124    verbose: bool | None = None,
125    input_names: Sequence[str] | None = None,
126    output_names: Sequence[str] | None = None,
127    opset_version: int | None = None,
128    dynamic_axes: Mapping[str, Mapping[int, str]]
129    | Mapping[str, Sequence[int]]
130    | None = None,
131    dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
132    keep_initializers_as_inputs: bool = False,
133    external_data: bool = True,
134    report: bool = False,
135    verify: bool = False,
136    profile: bool = False,
137    dump_exported_program: bool = False,
138    artifacts_dir: str | os.PathLike = ".",
139    fallback: bool = False,
140    **_,
141) -> _onnx_program.ONNXProgram:
142    if opset_version is None:
143        # TODO(justinchuby): Change the hardcoded opset version for it to be flexible
144        opset_version = 18
145
146    if isinstance(model, torch.export.ExportedProgram):
147        # We know the model is already exported program, so the args, kwargs, and dynamic_shapes
148        # are not used
149        dynamic_shapes = dynamic_shapes or {}
150    else:
151        args, kwargs = _get_torch_export_args(args, kwargs)
152        if dynamic_shapes is None and dynamic_axes is not None:
153            dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes(
154                model,
155                dynamic_axes=dynamic_axes,
156                input_names=input_names,
157                output_names=set(output_names or ()),
158            )
159
160    try:
161        onnx_program = _core.export(
162            model,
163            args,
164            kwargs,
165            registry=None,
166            dynamic_shapes=dynamic_shapes,
167            input_names=input_names,
168            output_names=output_names,
169            profile=profile,
170            report=report,
171            verify=verify,
172            dump_exported_program=dump_exported_program,
173            artifacts_dir=artifacts_dir,
174            verbose=verbose,
175        )
176
177    except Exception as e:
178        if fallback:
179            if verbose is not False:
180                print(
181                    "[torch.onnx] Falling back to legacy torch.onnx.export due "
182                    f"to the following error: {e}",
183                )
184            if f is None:
185                raise TypeError("f must be provided when fallback is enabled") from e
186            torch.onnx.utils.export(
187                model,  # type: ignore[arg-type]
188                args,
189                f,  # type: ignore[arg-type]
190                kwargs=kwargs,
191                export_params=export_params,
192                input_names=input_names,
193                output_names=output_names,
194                opset_version=17,  # TODO(justinchuby): Hard coded to 17 for now
195                dynamic_axes=dynamic_axes,
196                keep_initializers_as_inputs=keep_initializers_as_inputs,
197            )
198            onnx_program = _onnx_program.ONNXProgram(ir.load(f), None)
199        else:
200            raise
201
202    # Converter opset version and optimize
203    onnx_program.model = onnxscript_apis.convert_version(
204        onnx_program.model, opset_version
205    )
206    onnx_program.model = onnxscript_apis.optimize(onnx_program.model)
207
208    if f is not None:
209        onnx_program.save(
210            f,
211            include_initializers=export_params,
212            keep_initializers_as_inputs=keep_initializers_as_inputs,
213            external_data=external_data,
214        )
215
216    return onnx_program
217