1# mypy: allow-untyped-defs 2"""Utilities for converting and operating on ONNX, JIT and torch types.""" 3 4from __future__ import annotations 5 6from typing import ( 7 Any, 8 Dict, 9 List, 10 Optional, 11 Protocol, 12 runtime_checkable, 13 Tuple, 14 TYPE_CHECKING, 15 Union, 16) 17 18import numpy 19 20import onnx 21 22import torch 23from torch._subclasses import fake_tensor 24 25 26if TYPE_CHECKING: 27 import onnx.defs.OpSchema.AttrType # type: ignore[import] # noqa: TCH004 28 29 30# Enable both TorchScriptTensor and torch.Tensor to be tested 31# for dtype in OpSchemaWrapper. 32@runtime_checkable 33class TensorLike(Protocol): 34 @property 35 def dtype(self) -> torch.dtype | None: ... 36 37 38def is_torch_complex_dtype(tensor_dtype: torch.dtype) -> bool: 39 # NOTE: This is needed as TorchScriptTensor is nor supported by torch.is_complex() 40 return tensor_dtype in _COMPLEX_TO_FLOAT 41 42 43def from_complex_to_float(dtype: torch.dtype) -> torch.dtype: 44 return _COMPLEX_TO_FLOAT[dtype] 45 46 47def from_sym_value_to_torch_dtype(sym_value: SYM_VALUE_TYPE) -> torch.dtype: 48 return _SYM_TYPE_TO_TORCH_DTYPE[type(sym_value)] 49 50 51def is_optional_onnx_dtype_str(onnx_type_str: str) -> bool: 52 return onnx_type_str in _OPTIONAL_ONNX_DTYPE_STR 53 54 55def from_torch_dtype_to_onnx_dtype_str(dtype: torch.dtype | type) -> set[str]: 56 return _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS[dtype] 57 58 59def from_python_type_to_onnx_attribute_type( 60 dtype: type, is_sequence: bool = False 61) -> onnx.defs.OpSchema.AttrType | None: 62 import onnx.defs # type: ignore[import] 63 64 _PYTHON_TYPE_TO_ONNX_ATTRIBUTE_TYPE = { 65 float: onnx.defs.OpSchema.AttrType.FLOAT, 66 int: onnx.defs.OpSchema.AttrType.INT, 67 str: onnx.defs.OpSchema.AttrType.STRING, 68 bool: onnx.defs.OpSchema.AttrType.INT, 69 } 70 71 _SEQUENCE_TYPE_TO_ONNX_ATTRIBUTE_TYPE = { 72 float: onnx.defs.OpSchema.AttrType.FLOATS, 73 int: onnx.defs.OpSchema.AttrType.INTS, 74 str: onnx.defs.OpSchema.AttrType.STRINGS, 75 bool: onnx.defs.OpSchema.AttrType.INTS, 76 } 77 78 if is_sequence: 79 return _SEQUENCE_TYPE_TO_ONNX_ATTRIBUTE_TYPE.get(dtype) 80 return _PYTHON_TYPE_TO_ONNX_ATTRIBUTE_TYPE.get(dtype) 81 82 83def from_python_type_to_onnx_tensor_element_type(type: type): 84 """ 85 Converts a Python type to the corresponding ONNX tensor element type. 86 For example, `from_python_type_to_onnx_tensor_element_type(float)` returns 87 `onnx.TensorProto.FLOAT`. 88 89 Args: 90 type (type): The Python type to convert. 91 92 Returns: 93 int: The corresponding ONNX tensor element type. 94 95 """ 96 _PYTHON_TYPE_TO_ONNX_TENSOR_ELEMENT_TYPE = { 97 float: onnx.TensorProto.FLOAT, # type: ignore[attr-defined] 98 int: onnx.TensorProto.INT64, # type: ignore[attr-defined] 99 bool: onnx.TensorProto.BOOL, # type: ignore[attr-defined] 100 } 101 return _PYTHON_TYPE_TO_ONNX_TENSOR_ELEMENT_TYPE.get(type) 102 103 104def is_torch_symbolic_type(value: Any) -> bool: 105 return isinstance(value, (torch.SymBool, torch.SymInt, torch.SymFloat)) 106 107 108def from_torch_dtype_to_abbr(dtype: torch.dtype | None) -> str: 109 if dtype is None: 110 return "" 111 return _TORCH_DTYPE_TO_ABBREVIATION.get(dtype, "") 112 113 114def from_scalar_type_to_torch_dtype(scalar_type: type) -> torch.dtype | None: 115 return _SCALAR_TYPE_TO_TORCH_DTYPE.get(scalar_type) 116 117 118# NOTE: this is a mapping from torch dtype to a set of compatible onnx types 119# It's used in dispatcher to find the best match overload for the input dtypes 120_TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS: dict[torch.dtype | type, set[str]] = { 121 torch.bfloat16: {"tensor(bfloat16)"}, 122 torch.bool: {"tensor(bool)"}, 123 torch.float64: {"tensor(double)"}, 124 torch.float32: {"tensor(float)"}, 125 torch.float16: {"tensor(float16)"}, 126 torch.float8_e4m3fn: {"tensor(float8_e4m3fn)"}, 127 torch.float8_e4m3fnuz: {"tensor(float8_e4m3fnuz)"}, 128 torch.float8_e5m2: {"tensor(float8_e5m2)"}, 129 torch.float8_e5m2fnuz: {"tensor(float8_e5m2fnuz)"}, 130 torch.int16: {"tensor(int16)"}, 131 torch.int32: {"tensor(int32)"}, 132 torch.int64: {"tensor(int64)"}, 133 torch.int8: {"tensor(int8)"}, 134 torch.uint8: {"tensor(uint8)"}, 135 str: {"tensor(string)"}, 136 int: {"tensor(int16)", "tensor(int32)", "tensor(int64)"}, 137 float: {"tensor(float16)", "tensor(float)", "tensor(double)"}, 138 bool: {"tensor(int32)", "tensor(int64)", "tensor(bool)"}, 139 complex: {"tensor(float)", "tensor(double)"}, 140 torch.complex32: {"tensor(float16)"}, 141 torch.complex64: {"tensor(float)"}, 142 torch.complex128: {"tensor(double)"}, 143} 144 145_OPTIONAL_ONNX_DTYPE_STR: set[str] = { 146 f"optional({value})" 147 for value_set in _TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS.values() 148 for value in value_set 149} 150 151_PYTHON_TYPE_TO_TORCH_DTYPE = { 152 bool: torch.bool, 153 int: torch.int64, 154 float: torch.float32, 155 complex: torch.complex64, 156} 157 158_COMPLEX_TO_FLOAT: dict[torch.dtype, torch.dtype] = { 159 torch.complex32: torch.float16, 160 torch.complex64: torch.float32, 161 torch.complex128: torch.float64, # NOTE: ORT doesn't support torch.float64 162} 163 164_SYM_TYPE_TO_TORCH_DTYPE = { 165 torch.SymInt: torch.int64, 166 torch.SymFloat: torch.float32, 167 torch.SymBool: torch.bool, 168} 169 170_SCALAR_TYPE_TO_TORCH_DTYPE: dict[type, torch.dtype] = { 171 **_PYTHON_TYPE_TO_TORCH_DTYPE, 172 **_SYM_TYPE_TO_TORCH_DTYPE, # type: ignore[dict-item] 173} 174 175_TORCH_DTYPE_TO_ABBREVIATION = { 176 torch.bfloat16: "bf16", 177 torch.float64: "f64", 178 torch.float32: "f32", 179 torch.float16: "f16", 180 torch.float8_e4m3fn: "e4m3fn", 181 torch.float8_e4m3fnuz: "e4m3fnuz", 182 torch.float8_e5m2: "f8e5m2", 183 torch.float8_e5m2fnuz: "e5m2fnuz", 184 torch.complex32: "c32", 185 torch.complex64: "c64", 186 torch.complex128: "c128", 187 torch.int8: "i8", 188 torch.int16: "i16", 189 torch.int32: "i32", 190 torch.int64: "i64", 191 torch.bool: "b8", 192 torch.uint8: "u8", 193} 194 195_TORCH_DTYPE_TO_NUMPY_DTYPE = { 196 torch.float16: numpy.float16, 197 torch.float32: numpy.float32, 198 torch.float64: numpy.float64, 199 torch.uint8: numpy.uint8, 200 torch.int8: numpy.int8, 201 torch.int16: numpy.int16, 202 torch.int32: numpy.int32, 203 torch.int64: numpy.longlong, 204 torch.bool: numpy.bool_, 205} 206 207_ONNX_TENSOR_ELEMENT_TYPE_TO_TORCH_DTYPE = { 208 onnx.TensorProto.FLOAT: torch.float32, # type: ignore[attr-defined] 209 onnx.TensorProto.FLOAT16: torch.float16, # type: ignore[attr-defined] 210 onnx.TensorProto.FLOAT8E5M2: torch.float8_e5m2, # type: ignore[attr-defined] 211 onnx.TensorProto.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, # type: ignore[attr-defined] 212 onnx.TensorProto.FLOAT8E4M3FN: torch.float8_e4m3fn, # type: ignore[attr-defined] 213 onnx.TensorProto.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, # type: ignore[attr-defined] 214 onnx.TensorProto.DOUBLE: torch.float64, # type: ignore[attr-defined] 215 onnx.TensorProto.BOOL: torch.bool, # type: ignore[attr-defined] 216 onnx.TensorProto.UINT8: torch.uint8, # type: ignore[attr-defined] 217 onnx.TensorProto.INT8: torch.int8, # type: ignore[attr-defined] 218 onnx.TensorProto.INT16: torch.int16, # type: ignore[attr-defined] 219 onnx.TensorProto.INT32: torch.int32, # type: ignore[attr-defined] 220 onnx.TensorProto.INT64: torch.int64, # type: ignore[attr-defined] 221} 222 223_TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE = { 224 value: key for key, value in _ONNX_TENSOR_ELEMENT_TYPE_TO_TORCH_DTYPE.items() 225} 226 227SYM_VALUE_TYPE = Union[torch.SymInt, torch.SymFloat, torch.SymBool] 228META_VALUE_TYPE = Union[fake_tensor.FakeTensor, SYM_VALUE_TYPE, int, float, bool] 229# NOTE: Belows are from torch/fx/node.py 230BaseArgumentTypes = Union[ 231 str, 232 int, 233 float, 234 bool, 235 complex, 236 torch.dtype, 237 torch.Tensor, 238 torch.device, 239 torch.memory_format, 240 torch.layout, 241 torch._ops.OpOverload, 242 torch.SymInt, 243 torch.SymFloat, 244 torch.SymBool, 245] 246Argument = Optional[ 247 Union[ 248 Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types 249 List[Any], # actually Argument 250 Dict[str, Any], # actually Argument 251 slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing 252 range, 253 "torch.fx.Node", 254 BaseArgumentTypes, 255 ] 256] 257