xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/type_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Utilities for converting and operating on ONNX, JIT and torch types."""
3
4from __future__ import annotations
5
6from typing import (
7    Any,
8    Dict,
9    List,
10    Optional,
11    Protocol,
12    runtime_checkable,
13    Tuple,
14    TYPE_CHECKING,
15    Union,
16)
17
18import numpy
19
20import onnx
21
22import torch
23from torch._subclasses import fake_tensor
24
25
26if TYPE_CHECKING:
27    import onnx.defs.OpSchema.AttrType  # type: ignore[import]  # noqa: TCH004
28
29
30# Enable both TorchScriptTensor and torch.Tensor to be tested
31# for dtype in OpSchemaWrapper.
32@runtime_checkable
33class TensorLike(Protocol):
34    @property
35    def dtype(self) -> torch.dtype | None: ...
36
37
38def is_torch_complex_dtype(tensor_dtype: torch.dtype) -> bool:
39    # NOTE: This is needed as TorchScriptTensor is nor supported by torch.is_complex()
40    return tensor_dtype in _COMPLEX_TO_FLOAT
41
42
43def from_complex_to_float(dtype: torch.dtype) -> torch.dtype:
44    return _COMPLEX_TO_FLOAT[dtype]
45
46
47def from_sym_value_to_torch_dtype(sym_value: SYM_VALUE_TYPE) -> torch.dtype:
48    return _SYM_TYPE_TO_TORCH_DTYPE[type(sym_value)]
49
50
51def is_optional_onnx_dtype_str(onnx_type_str: str) -> bool:
52    return onnx_type_str in _OPTIONAL_ONNX_DTYPE_STR
53
54
55def from_torch_dtype_to_onnx_dtype_str(dtype: torch.dtype | type) -> set[str]:
56    return _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS[dtype]
57
58
59def from_python_type_to_onnx_attribute_type(
60    dtype: type, is_sequence: bool = False
61) -> onnx.defs.OpSchema.AttrType | None:
62    import onnx.defs  # type: ignore[import]
63
64    _PYTHON_TYPE_TO_ONNX_ATTRIBUTE_TYPE = {
65        float: onnx.defs.OpSchema.AttrType.FLOAT,
66        int: onnx.defs.OpSchema.AttrType.INT,
67        str: onnx.defs.OpSchema.AttrType.STRING,
68        bool: onnx.defs.OpSchema.AttrType.INT,
69    }
70
71    _SEQUENCE_TYPE_TO_ONNX_ATTRIBUTE_TYPE = {
72        float: onnx.defs.OpSchema.AttrType.FLOATS,
73        int: onnx.defs.OpSchema.AttrType.INTS,
74        str: onnx.defs.OpSchema.AttrType.STRINGS,
75        bool: onnx.defs.OpSchema.AttrType.INTS,
76    }
77
78    if is_sequence:
79        return _SEQUENCE_TYPE_TO_ONNX_ATTRIBUTE_TYPE.get(dtype)
80    return _PYTHON_TYPE_TO_ONNX_ATTRIBUTE_TYPE.get(dtype)
81
82
83def from_python_type_to_onnx_tensor_element_type(type: type):
84    """
85    Converts a Python type to the corresponding ONNX tensor element type.
86    For example, `from_python_type_to_onnx_tensor_element_type(float)` returns
87    `onnx.TensorProto.FLOAT`.
88
89    Args:
90      type (type): The Python type to convert.
91
92    Returns:
93      int: The corresponding ONNX tensor element type.
94
95    """
96    _PYTHON_TYPE_TO_ONNX_TENSOR_ELEMENT_TYPE = {
97        float: onnx.TensorProto.FLOAT,  # type: ignore[attr-defined]
98        int: onnx.TensorProto.INT64,  # type: ignore[attr-defined]
99        bool: onnx.TensorProto.BOOL,  # type: ignore[attr-defined]
100    }
101    return _PYTHON_TYPE_TO_ONNX_TENSOR_ELEMENT_TYPE.get(type)
102
103
104def is_torch_symbolic_type(value: Any) -> bool:
105    return isinstance(value, (torch.SymBool, torch.SymInt, torch.SymFloat))
106
107
108def from_torch_dtype_to_abbr(dtype: torch.dtype | None) -> str:
109    if dtype is None:
110        return ""
111    return _TORCH_DTYPE_TO_ABBREVIATION.get(dtype, "")
112
113
114def from_scalar_type_to_torch_dtype(scalar_type: type) -> torch.dtype | None:
115    return _SCALAR_TYPE_TO_TORCH_DTYPE.get(scalar_type)
116
117
118# NOTE: this is a mapping from torch dtype to a set of compatible onnx types
119# It's used in dispatcher to find the best match overload for the input dtypes
120_TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS: dict[torch.dtype | type, set[str]] = {
121    torch.bfloat16: {"tensor(bfloat16)"},
122    torch.bool: {"tensor(bool)"},
123    torch.float64: {"tensor(double)"},
124    torch.float32: {"tensor(float)"},
125    torch.float16: {"tensor(float16)"},
126    torch.float8_e4m3fn: {"tensor(float8_e4m3fn)"},
127    torch.float8_e4m3fnuz: {"tensor(float8_e4m3fnuz)"},
128    torch.float8_e5m2: {"tensor(float8_e5m2)"},
129    torch.float8_e5m2fnuz: {"tensor(float8_e5m2fnuz)"},
130    torch.int16: {"tensor(int16)"},
131    torch.int32: {"tensor(int32)"},
132    torch.int64: {"tensor(int64)"},
133    torch.int8: {"tensor(int8)"},
134    torch.uint8: {"tensor(uint8)"},
135    str: {"tensor(string)"},
136    int: {"tensor(int16)", "tensor(int32)", "tensor(int64)"},
137    float: {"tensor(float16)", "tensor(float)", "tensor(double)"},
138    bool: {"tensor(int32)", "tensor(int64)", "tensor(bool)"},
139    complex: {"tensor(float)", "tensor(double)"},
140    torch.complex32: {"tensor(float16)"},
141    torch.complex64: {"tensor(float)"},
142    torch.complex128: {"tensor(double)"},
143}
144
145_OPTIONAL_ONNX_DTYPE_STR: set[str] = {
146    f"optional({value})"
147    for value_set in _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS.values()
148    for value in value_set
149}
150
151_PYTHON_TYPE_TO_TORCH_DTYPE = {
152    bool: torch.bool,
153    int: torch.int64,
154    float: torch.float32,
155    complex: torch.complex64,
156}
157
158_COMPLEX_TO_FLOAT: dict[torch.dtype, torch.dtype] = {
159    torch.complex32: torch.float16,
160    torch.complex64: torch.float32,
161    torch.complex128: torch.float64,  # NOTE: ORT doesn't support torch.float64
162}
163
164_SYM_TYPE_TO_TORCH_DTYPE = {
165    torch.SymInt: torch.int64,
166    torch.SymFloat: torch.float32,
167    torch.SymBool: torch.bool,
168}
169
170_SCALAR_TYPE_TO_TORCH_DTYPE: dict[type, torch.dtype] = {
171    **_PYTHON_TYPE_TO_TORCH_DTYPE,
172    **_SYM_TYPE_TO_TORCH_DTYPE,  # type: ignore[dict-item]
173}
174
175_TORCH_DTYPE_TO_ABBREVIATION = {
176    torch.bfloat16: "bf16",
177    torch.float64: "f64",
178    torch.float32: "f32",
179    torch.float16: "f16",
180    torch.float8_e4m3fn: "e4m3fn",
181    torch.float8_e4m3fnuz: "e4m3fnuz",
182    torch.float8_e5m2: "f8e5m2",
183    torch.float8_e5m2fnuz: "e5m2fnuz",
184    torch.complex32: "c32",
185    torch.complex64: "c64",
186    torch.complex128: "c128",
187    torch.int8: "i8",
188    torch.int16: "i16",
189    torch.int32: "i32",
190    torch.int64: "i64",
191    torch.bool: "b8",
192    torch.uint8: "u8",
193}
194
195_TORCH_DTYPE_TO_NUMPY_DTYPE = {
196    torch.float16: numpy.float16,
197    torch.float32: numpy.float32,
198    torch.float64: numpy.float64,
199    torch.uint8: numpy.uint8,
200    torch.int8: numpy.int8,
201    torch.int16: numpy.int16,
202    torch.int32: numpy.int32,
203    torch.int64: numpy.longlong,
204    torch.bool: numpy.bool_,
205}
206
207_ONNX_TENSOR_ELEMENT_TYPE_TO_TORCH_DTYPE = {
208    onnx.TensorProto.FLOAT: torch.float32,  # type: ignore[attr-defined]
209    onnx.TensorProto.FLOAT16: torch.float16,  # type: ignore[attr-defined]
210    onnx.TensorProto.FLOAT8E5M2: torch.float8_e5m2,  # type: ignore[attr-defined]
211    onnx.TensorProto.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,  # type: ignore[attr-defined]
212    onnx.TensorProto.FLOAT8E4M3FN: torch.float8_e4m3fn,  # type: ignore[attr-defined]
213    onnx.TensorProto.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,  # type: ignore[attr-defined]
214    onnx.TensorProto.DOUBLE: torch.float64,  # type: ignore[attr-defined]
215    onnx.TensorProto.BOOL: torch.bool,  # type: ignore[attr-defined]
216    onnx.TensorProto.UINT8: torch.uint8,  # type: ignore[attr-defined]
217    onnx.TensorProto.INT8: torch.int8,  # type: ignore[attr-defined]
218    onnx.TensorProto.INT16: torch.int16,  # type: ignore[attr-defined]
219    onnx.TensorProto.INT32: torch.int32,  # type: ignore[attr-defined]
220    onnx.TensorProto.INT64: torch.int64,  # type: ignore[attr-defined]
221}
222
223_TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE = {
224    value: key for key, value in _ONNX_TENSOR_ELEMENT_TYPE_TO_TORCH_DTYPE.items()
225}
226
227SYM_VALUE_TYPE = Union[torch.SymInt, torch.SymFloat, torch.SymBool]
228META_VALUE_TYPE = Union[fake_tensor.FakeTensor, SYM_VALUE_TYPE, int, float, bool]
229# NOTE: Belows are from torch/fx/node.py
230BaseArgumentTypes = Union[
231    str,
232    int,
233    float,
234    bool,
235    complex,
236    torch.dtype,
237    torch.Tensor,
238    torch.device,
239    torch.memory_format,
240    torch.layout,
241    torch._ops.OpOverload,
242    torch.SymInt,
243    torch.SymFloat,
244    torch.SymBool,
245]
246Argument = Optional[
247    Union[
248        Tuple[Any, ...],  # actually Argument, but mypy can't represent recursive types
249        List[Any],  # actually Argument
250        Dict[str, Any],  # actually Argument
251        slice,  # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
252        range,
253        "torch.fx.Node",
254        BaseArgumentTypes,
255    ]
256]
257