xref: /aosp_15_r20/external/pytorch/torch/onnx/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4
5__all__ = [
6    # Modules
7    "symbolic_helper",
8    "utils",
9    "errors",
10    # All opsets
11    "symbolic_caffe2",
12    "symbolic_opset7",
13    "symbolic_opset8",
14    "symbolic_opset9",
15    "symbolic_opset10",
16    "symbolic_opset11",
17    "symbolic_opset12",
18    "symbolic_opset13",
19    "symbolic_opset14",
20    "symbolic_opset15",
21    "symbolic_opset16",
22    "symbolic_opset17",
23    "symbolic_opset18",
24    "symbolic_opset19",
25    "symbolic_opset20",
26    # Enums
27    "ExportTypes",
28    "OperatorExportTypes",
29    "TrainingMode",
30    "TensorProtoDataType",
31    "JitScalarType",
32    # Public functions
33    "export",
34    "export_to_pretty_string",
35    "is_in_onnx_export",
36    "select_model_mode_for_export",
37    "register_custom_op_symbolic",
38    "unregister_custom_op_symbolic",
39    "disable_log",
40    "enable_log",
41    # Base error
42    "OnnxExporterError",
43    # Dynamo Exporter
44    "DiagnosticOptions",
45    "ExportOptions",
46    "ONNXProgram",
47    "ONNXRuntimeOptions",
48    "OnnxRegistry",
49    "dynamo_export",
50    "enable_fake_mode",
51    # DORT / torch.compile
52    "is_onnxrt_backend_supported",
53]
54
55from typing import Any, Callable, Collection, Mapping, Sequence, TYPE_CHECKING
56
57import torch
58from torch import _C
59from torch._C import _onnx as _C_onnx
60from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
61
62from ._exporter_states import ExportTypes
63from ._internal.onnxruntime import (
64    is_onnxrt_backend_supported,
65    OrtBackend as _OrtBackend,
66    OrtBackendOptions as _OrtBackendOptions,
67    OrtExecutionProvider as _OrtExecutionProvider,
68)
69from ._type_utils import JitScalarType
70from .errors import OnnxExporterError
71from .utils import (
72    _optimize_graph,
73    _run_symbolic_function,
74    _run_symbolic_method,
75    export_to_pretty_string,
76    is_in_onnx_export,
77    register_custom_op_symbolic,
78    select_model_mode_for_export,
79    unregister_custom_op_symbolic,
80)
81
82
83from . import (  # usort: skip. Keep the order instead of sorting lexicographically
84    errors,
85    symbolic_caffe2,
86    symbolic_helper,
87    symbolic_opset7,
88    symbolic_opset8,
89    symbolic_opset9,
90    symbolic_opset10,
91    symbolic_opset11,
92    symbolic_opset12,
93    symbolic_opset13,
94    symbolic_opset14,
95    symbolic_opset15,
96    symbolic_opset16,
97    symbolic_opset17,
98    symbolic_opset18,
99    symbolic_opset19,
100    symbolic_opset20,
101    utils,
102)
103
104
105from ._internal._exporter_legacy import (  # usort: skip. needs to be last to avoid circular import
106    DiagnosticOptions,
107    ExportOptions,
108    ONNXProgram,
109    ONNXRuntimeOptions,
110    OnnxRegistry,
111    enable_fake_mode,
112)
113
114
115if TYPE_CHECKING:
116    import os
117
118# Set namespace for exposed private names
119DiagnosticOptions.__module__ = "torch.onnx"
120ExportOptions.__module__ = "torch.onnx"
121ExportTypes.__module__ = "torch.onnx"
122JitScalarType.__module__ = "torch.onnx"
123ONNXProgram.__module__ = "torch.onnx"
124ONNXRuntimeOptions.__module__ = "torch.onnx"
125OnnxExporterError.__module__ = "torch.onnx"
126OnnxRegistry.__module__ = "torch.onnx"
127_OrtBackend.__module__ = "torch.onnx"
128_OrtBackendOptions.__module__ = "torch.onnx"
129_OrtExecutionProvider.__module__ = "torch.onnx"
130enable_fake_mode.__module__ = "torch.onnx"
131is_onnxrt_backend_supported.__module__ = "torch.onnx"
132
133producer_name = "pytorch"
134producer_version = _C_onnx.PRODUCER_VERSION
135
136
137def export(
138    model: torch.nn.Module
139    | torch.export.ExportedProgram
140    | torch.jit.ScriptModule
141    | torch.jit.ScriptFunction,
142    args: tuple[Any, ...] = (),
143    f: str | os.PathLike | None = None,
144    *,
145    kwargs: dict[str, Any] | None = None,
146    export_params: bool = True,
147    verbose: bool | None = None,
148    input_names: Sequence[str] | None = None,
149    output_names: Sequence[str] | None = None,
150    opset_version: int | None = None,
151    dynamic_axes: Mapping[str, Mapping[int, str]]
152    | Mapping[str, Sequence[int]]
153    | None = None,
154    keep_initializers_as_inputs: bool = False,
155    dynamo: bool = False,
156    # Dynamo only options
157    external_data: bool = True,
158    dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
159    report: bool = False,
160    verify: bool = False,
161    profile: bool = False,
162    dump_exported_program: bool = False,
163    artifacts_dir: str | os.PathLike = ".",
164    fallback: bool = False,
165    # Deprecated options
166    training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
167    operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,
168    do_constant_folding: bool = True,
169    custom_opsets: Mapping[str, int] | None = None,
170    export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False,
171    autograd_inlining: bool = True,
172    **_: Any,  # ignored options
173) -> Any | None:
174    r"""Exports a model into ONNX format.
175
176    Args:
177        model: The model to be exported.
178        args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the
179            exported model; any Tensor arguments will become inputs of the exported model,
180            in the order they occur in the tuple.
181        f: Path to the output ONNX model file. E.g. "model.onnx".
182        kwargs: Optional example keyword inputs.
183        export_params: If false, parameters (weights) will not be exported.
184        verbose: Whether to enable verbose logging.
185        input_names: names to assign to the input nodes of the graph, in order.
186        output_names: names to assign to the output nodes of the graph, in order.
187        opset_version: The version of the
188            `default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
189            to target. Must be >= 7.
190        dynamic_axes:
191
192            By default the exported model will have the shapes of all input and output tensors
193            set to exactly match those given in ``args``. To specify axes of tensors as
194            dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema:
195
196            * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or
197                ``output_names``.
198            * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a
199                list, each element is an axis index.
200
201            For example::
202
203                class SumModule(torch.nn.Module):
204                    def forward(self, x):
205                        return torch.sum(x, dim=1)
206
207
208                torch.onnx.export(
209                    SumModule(),
210                    (torch.ones(2, 2),),
211                    "onnx.pb",
212                    input_names=["x"],
213                    output_names=["sum"],
214                )
215
216            Produces::
217
218                input {
219                  name: "x"
220                  ...
221                      shape {
222                        dim {
223                          dim_value: 2  # axis 0
224                        }
225                        dim {
226                          dim_value: 2  # axis 1
227                ...
228                output {
229                  name: "sum"
230                  ...
231                      shape {
232                        dim {
233                          dim_value: 2  # axis 0
234                ...
235
236            While::
237
238                torch.onnx.export(
239                    SumModule(),
240                    (torch.ones(2, 2),),
241                    "onnx.pb",
242                    input_names=["x"],
243                    output_names=["sum"],
244                    dynamic_axes={
245                        # dict value: manually named axes
246                        "x": {0: "my_custom_axis_name"},
247                        # list value: automatic names
248                        "sum": [0],
249                    },
250                )
251
252            Produces::
253
254                input {
255                  name: "x"
256                  ...
257                      shape {
258                        dim {
259                          dim_param: "my_custom_axis_name"  # axis 0
260                        }
261                        dim {
262                          dim_value: 2  # axis 1
263                ...
264                output {
265                  name: "sum"
266                  ...
267                      shape {
268                        dim {
269                          dim_param: "sum_dynamic_axes_1"  # axis 0
270                ...
271
272        keep_initializers_as_inputs: If True, all the
273            initializers (typically corresponding to model weights) in the
274            exported graph will also be added as inputs to the graph. If False,
275            then initializers are not added as inputs to the graph, and only
276            the user inputs are added as inputs.
277
278            Set this to True if you intend to supply model weights at runtime.
279            Set it to False if the weights are static to allow for better optimizations
280            (e.g. constant folding) by backends/runtimes.
281
282        dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript.
283        external_data: Whether to save the model weights as an external data file.
284            This is required for models with large weights that exceed the ONNX file size limit (2GB).
285            When False, the weights are saved in the ONNX file with the model architecture.
286        dynamic_shapes: A dictionary of dynamic shapes for the model inputs. Refer to
287            :func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True.
288            Only one parameter `dynamic_axes` or `dynamic_shapes` should be set
289            at the same time.
290        report: Whether to generate a markdown report for the export process.
291        verify: Whether to verify the exported model using ONNX Runtime.
292        profile: Whether to profile the export process.
293        dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file.
294            This is useful for debugging the exporter.
295        artifacts_dir: The directory to save the debugging artifacts like the report and the serialized
296            exported program.
297        fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails.
298
299        training: Deprecated option. Instead, set the training mode of the model before exporting.
300        operator_export_type: Deprecated option. Only ONNX is supported.
301        do_constant_folding: Deprecated option. The exported graph is always optimized.
302        custom_opsets: Deprecated.
303            A dictionary:
304
305            * KEY (str): opset domain name
306            * VALUE (int): opset version
307
308            If a custom opset is referenced by ``model`` but not mentioned in this dictionary,
309            the opset version is set to 1. Only custom opset domain name and version should be
310            indicated through this argument.
311        export_modules_as_functions: Deprecated option.
312
313            Flag to enable
314            exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the
315            particular types of modules to export as local functions in ONNX.
316            This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because
317            ``opset_version`` < 15 implies IR version < 8, which means no local function support.
318            Module variables will be exported as function attributes. There are two categories of function
319            attributes.
320
321            1. Annotated attributes: class variables that have type annotations via
322            `PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_
323            will be exported as attributes.
324            Annotated attributes are not used inside the subgraph of ONNX local function because
325            they are not created by PyTorch JIT tracing, but they may be used by consumers
326            to determine whether or not to replace the function with a particular fused kernel.
327
328            2. Inferred attributes: variables that are used by operators inside the module. Attribute names
329            will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from
330            python module annotations. Inferred attributes are used inside the subgraph of ONNX local function.
331
332            * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes.
333            * ``True``: export all ``nn.Module`` forward calls as local function nodes.
334            * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes,
335                only if the type of the ``nn.Module`` is found in the set.
336        autograd_inlining: Deprecated.
337            Flag used to control whether to inline autograd functions.
338            Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
339    """
340    if dynamo is True or isinstance(model, torch.export.ExportedProgram):
341        from torch.onnx._internal import exporter
342
343        if isinstance(args, torch.Tensor):
344            args = (args,)
345        return exporter.export_compat(
346            model,
347            args,
348            f,
349            kwargs=kwargs,
350            export_params=export_params,
351            verbose=verbose,
352            input_names=input_names,
353            output_names=output_names,
354            opset_version=opset_version,
355            dynamic_axes=dynamic_axes,
356            keep_initializers_as_inputs=keep_initializers_as_inputs,
357            external_data=external_data,
358            dynamic_shapes=dynamic_shapes,
359            report=report,
360            verify=verify,
361            profile=profile,
362            dump_exported_program=dump_exported_program,
363            artifacts_dir=artifacts_dir,
364            fallback=fallback,
365        )
366    else:
367        from torch.onnx.utils import export
368
369        if dynamic_shapes:
370            raise ValueError(
371                "The exporter only supports dynamic shapes "
372                "through parameter dynamic_axes when dynamo=False."
373            )
374
375        export(
376            model,
377            args,
378            f,  # type: ignore[arg-type]
379            kwargs=kwargs,
380            export_params=export_params,
381            verbose=verbose is True,
382            input_names=input_names,
383            output_names=output_names,
384            opset_version=opset_version,
385            dynamic_axes=dynamic_axes,
386            keep_initializers_as_inputs=keep_initializers_as_inputs,
387            training=training,
388            operator_export_type=operator_export_type,
389            do_constant_folding=do_constant_folding,
390            custom_opsets=custom_opsets,
391            export_modules_as_functions=export_modules_as_functions,
392            autograd_inlining=autograd_inlining,
393        )
394        return None
395
396
397def dynamo_export(
398    model: torch.nn.Module | Callable | torch.export.ExportedProgram,  # type: ignore[name-defined]
399    /,
400    *model_args,
401    export_options: ExportOptions | None = None,
402    **model_kwargs,
403) -> ONNXProgram | Any:
404    """Export a torch.nn.Module to an ONNX graph.
405
406    Args:
407        model: The PyTorch model to be exported to ONNX.
408        model_args: Positional inputs to ``model``.
409        model_kwargs: Keyword inputs to ``model``.
410        export_options: Options to influence the export to ONNX.
411
412    Returns:
413        An in-memory representation of the exported ONNX model.
414
415    **Example 1 - Simplest export**
416    ::
417
418        class MyModel(torch.nn.Module):
419            def __init__(self) -> None:
420                super().__init__()
421                self.linear = torch.nn.Linear(2, 2)
422
423            def forward(self, x, bias=None):
424                out = self.linear(x)
425                out = out + bias
426                return out
427
428
429        model = MyModel()
430        kwargs = {"bias": 3.0}
431        args = (torch.randn(2, 2, 2),)
432        onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save(
433            "my_simple_model.onnx"
434        )
435
436    **Example 2 - Exporting with dynamic shapes**
437    ::
438
439        # The previous model can be exported with dynamic shapes
440        export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
441        onnx_program = torch.onnx.dynamo_export(
442            model, *args, **kwargs, export_options=export_options
443        )
444        onnx_program.save("my_dynamic_model.onnx")
445    """
446
447    # NOTE: The new exporter is experimental and is not enabled by default.
448    import warnings
449
450    from torch.onnx import _flags
451    from torch.onnx._internal import exporter
452    from torch.utils import _pytree
453
454    if isinstance(model, torch.export.ExportedProgram):
455        return exporter.export_compat(
456            model,  # type: ignore[arg-type]
457            model_args,
458            f=None,
459            kwargs=model_kwargs,
460            opset_version=18,
461            external_data=True,
462            export_params=True,
463            fallback=True,
464        )
465    elif _flags.USE_EXPERIMENTAL_LOGIC:
466        if export_options is not None:
467            warnings.warn(
468                "You are using an experimental ONNX export logic, which currently only supports dynamic shapes. "
469                "For a more comprehensive set of export options, including advanced features, please consider using "
470                "`torch.onnx.export(..., dynamo=True)`. ",
471                category=FutureWarning,
472            )
473
474        if export_options is not None and export_options.dynamic_shapes:
475            # Make all shapes dynamic
476            def _to_dynamic_shapes_mapper():
477                arg_order = 0
478
479                def _to_dynamic_shape(x):
480                    nonlocal arg_order
481                    if isinstance(x, torch.Tensor):
482                        rank = len(x.shape)
483                        dynamic_shape = {}
484                        for i in range(rank):
485                            dynamic_shape[i] = torch.export.Dim(
486                                f"arg_{arg_order}_dim_{i}"
487                            )
488                        arg_order += 1
489                        return dynamic_shape
490                    else:
491                        return None
492
493                return _to_dynamic_shape
494
495            # model_args could be nested
496            dynamic_shapes = _pytree.tree_map(
497                _to_dynamic_shapes_mapper(),
498                model_args,
499            )
500        else:
501            dynamic_shapes = None
502
503        return exporter.export_compat(
504            model,  # type: ignore[arg-type]
505            model_args,
506            f=None,
507            kwargs=model_kwargs,
508            dynamic_shapes=dynamic_shapes,
509            opset_version=18,
510            external_data=True,
511            export_params=True,
512            fallback=True,
513        )
514    else:
515        from torch.onnx._internal._exporter_legacy import dynamo_export
516
517        return dynamo_export(
518            model, *model_args, export_options=export_options, **model_kwargs
519        )
520
521
522# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.
523
524# Returns True iff ONNX logging is turned on.
525is_onnx_log_enabled = _C._jit_is_onnx_log_enabled
526
527
528def enable_log() -> None:
529    r"""Enables ONNX logging."""
530    _C._jit_set_onnx_log_enabled(True)
531
532
533def disable_log() -> None:
534    r"""Disables ONNX logging."""
535    _C._jit_set_onnx_log_enabled(False)
536
537
538"""Sets output stream for ONNX logging.
539
540Args:
541    stream_name (str, default "stdout"): Only 'stdout' and 'stderr' are supported
542        as ``stream_name``.
543"""
544set_log_stream = _C._jit_set_onnx_log_output_stream
545
546
547"""A simple logging facility for ONNX exporter.
548
549Args:
550    args: Arguments are converted to string, concatenated together with a newline
551        character appended to the end, and flushed to output stream.
552"""
553log = _C._jit_onnx_log
554