xref: /aosp_15_r20/external/pytorch/torch/onnx/_experimental.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Experimental classes and functions used by ONNX export."""
2
3import dataclasses
4from typing import Mapping, Optional, Sequence, Set, Type, Union
5
6import torch
7import torch._C._onnx as _C_onnx
8
9
10@dataclasses.dataclass
11class ExportOptions:
12    """Arguments used by :func:`torch.onnx.export`."""
13
14    # TODO(justinchuby): Deprecate and remove this class.
15
16    export_params: bool = True
17    verbose: bool = False
18    training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL
19    input_names: Optional[Sequence[str]] = None
20    output_names: Optional[Sequence[str]] = None
21    operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX
22    opset_version: Optional[int] = None
23    do_constant_folding: bool = True
24    dynamic_axes: Optional[Mapping[str, Union[Mapping[int, str], Sequence[int]]]] = None
25    keep_initializers_as_inputs: Optional[bool] = None
26    custom_opsets: Optional[Mapping[str, int]] = None
27    export_modules_as_functions: Union[bool, Set[Type[torch.nn.Module]]] = False
28