1""" 2Where should I add a new type? `types_base.py` vs `types.py` 3 4This file defines data model classes for torchgen typing system, as well as some base types such as int32_t. 5 6`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types. 7 8The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't 9contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused 10if we want to generate code for another C++ library. 11 12Add new types to `types.py` if these types are ATen/c10 related. 13Add new types to `types_base.py` if they are basic and not attached to ATen/c10. 14""" 15 16from __future__ import annotations 17 18from dataclasses import dataclass 19 20from torchgen.api.types.types_base import ( 21 BaseCppType, 22 BaseCType, 23 boolT, 24 byteT, 25 charT, 26 CType, 27 doubleT, 28 floatT, 29 int32T, 30 longT, 31 shortT, 32) 33from torchgen.model import BaseTy, ScalarType 34 35 36TENSOR_LIST_LIKE_CTYPES = [ 37 "at::TensorList", 38 "const c10::List<::std::optional<at::Tensor>> &", 39 "const at::ITensorListRef &", 40] 41 42 43halfT = BaseCppType("at", "Half") 44complexHalfT = BaseCppType( 45 "c10", "complex<c10::Half>" 46) # stuffing template param here is an abuse 47complexFloatT = BaseCppType("c10", "complex<float>") 48complexDoubleT = BaseCppType("c10", "complex<double>") 49bfloat16T = BaseCppType("at", "BFloat16") 50float8_e5m2T = BaseCppType("at", "Float8_e5m2") 51float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz") 52float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn") 53float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz") 54stringT = BaseCppType("c10", "string_view") 55generatorT = BaseCppType("at", "Generator") 56scalarTypeT = BaseCppType("at", "ScalarType") 57tensorT = BaseCppType("at", "Tensor") 58optionalTensorRefT = BaseCppType("at", "OptionalTensorRef") 59tensorListT = BaseCppType("at", "TensorList") 60iTensorListRefT = BaseCppType("at", "ITensorListRef") 61iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef") 62dimnameT = BaseCppType("at", "Dimname") 63dimnameListT = BaseCppType("at", "DimnameList") 64dimVectorT = BaseCppType("at", "DimVector") 65layoutT = BaseCppType("at", "Layout") 66deviceT = BaseCppType("at", "Device") 67deviceIndexT = BaseCppType("at", "DeviceIndex") 68scalarT = BaseCppType("at", "Scalar") 69optionalScalarRefT = BaseCppType("at", "OptionalScalarRef") 70memoryFormatT = BaseCppType("at", "MemoryFormat") 71qschemeT = BaseCppType("at", "QScheme") 72storageT = BaseCppType("at", "Storage") 73streamT = BaseCppType("at", "Stream") 74intArrayRefT = BaseCppType("at", "IntArrayRef") 75optionalIntArrayRefT = BaseCppType("at", "OptionalIntArrayRef") 76optionalSymIntArrayRefT = BaseCppType("at", "OptionalSymIntArrayRef") 77tensorOptionsT = BaseCppType("at", "TensorOptions") 78typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize") 79tensorGeometryT = BaseCppType("at", "TensorGeometry") 80SymIntT = BaseCppType("c10", "SymInt") 81symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef") 82 83# Types representing template parameters. Technically, we probably shouldn't 84# represent them this way in codegen, but it was pretty convenient. 85scalar_t = BaseCppType("", "scalar_t") 86opmath_t = BaseCppType("", "opmath_t") 87 88ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = { 89 ScalarType.Byte: byteT, 90 ScalarType.Char: charT, 91 ScalarType.Short: shortT, 92 ScalarType.Int: int32T, 93 ScalarType.Long: longT, 94 ScalarType.Half: halfT, 95 ScalarType.Float: floatT, 96 ScalarType.Double: doubleT, 97 ScalarType.ComplexHalf: complexHalfT, 98 ScalarType.ComplexFloat: complexFloatT, 99 ScalarType.ComplexDouble: complexDoubleT, 100 ScalarType.Bool: boolT, 101 ScalarType.Float8_e5m2: float8_e5m2T, 102 ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT, 103 ScalarType.Float8_e4m3fn: float8_e4m3fnT, 104 ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT, 105} 106 107BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = { 108 BaseTy.int: longT, 109 BaseTy.float: doubleT, 110 BaseTy.bool: boolT, 111 BaseTy.str: stringT, 112 BaseTy.Generator: generatorT, 113 BaseTy.ScalarType: scalarTypeT, 114 BaseTy.Tensor: tensorT, 115 BaseTy.Dimname: dimnameT, 116 BaseTy.DimVector: dimVectorT, 117 BaseTy.Layout: layoutT, 118 BaseTy.Device: deviceT, 119 BaseTy.DeviceIndex: deviceIndexT, 120 BaseTy.Scalar: scalarT, 121 BaseTy.MemoryFormat: memoryFormatT, 122 BaseTy.QScheme: qschemeT, 123 BaseTy.Storage: storageT, 124 BaseTy.Stream: streamT, 125 BaseTy.SymInt: SymIntT, 126} 127 128# CTypes encode C++ type structure as needed for translation. 129 130 131@dataclass(frozen=True) 132class OptionalCType(CType): 133 elem: CType 134 135 def cpp_type(self, *, strip_ref: bool = False) -> str: 136 # Do not pass `strip_ref` recursively. 137 return f"::std::optional<{self.elem.cpp_type()}>" 138 139 def cpp_type_registration_declarations(self) -> str: 140 return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>" 141 142 def remove_const_ref(self) -> CType: 143 return OptionalCType(self.elem.remove_const_ref()) 144 145 146@dataclass(frozen=True) 147class ListCType(CType): 148 elem: CType 149 150 def cpp_type(self, *, strip_ref: bool = False) -> str: 151 # Do not pass `strip_ref` recursively. 152 return f"c10::List<{self.elem.cpp_type()}>" 153 154 def cpp_type_registration_declarations(self) -> str: 155 return f"c10::List<{self.elem.cpp_type_registration_declarations()}>" 156 157 def remove_const_ref(self) -> CType: 158 return ListCType(self.elem.remove_const_ref()) 159 160 161@dataclass(frozen=True) 162class ArrayRefCType(CType): 163 elem: CType 164 165 def cpp_type(self, *, strip_ref: bool = False) -> str: 166 # Do not pass `strip_ref` recursively. 167 return f"at::ArrayRef<{self.elem.cpp_type()}>" 168 169 def cpp_type_registration_declarations(self) -> str: 170 return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>" 171 172 def remove_const_ref(self) -> CType: 173 return ArrayRefCType(self.elem.remove_const_ref()) 174 175 176@dataclass(frozen=True) 177class VectorizedCType(CType): 178 # This template is explicitly specialized, so the only valid 179 # elems are those we have specializations for (e.g., float, double, ...) 180 # scalar_t is also a common argument here (when we are codegen in 181 # a templated context) 182 elem: BaseCType 183 184 def cpp_type(self, *, strip_ref: bool = False) -> str: 185 return f"at::vec::Vectorized<{self.elem.cpp_type()}>" 186 187 def cpp_type_registration_declarations(self) -> str: 188 raise NotImplementedError 189 190 def remove_const_ref(self) -> CType: 191 return self 192