xref: /aosp_15_r20/external/pytorch/torchgen/api/types/types.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2Where should I add a new type? `types_base.py` vs `types.py`
3
4This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
5
6`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
7
8The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
9contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
10if we want to generate code for another C++ library.
11
12Add new types to `types.py` if these types are ATen/c10 related.
13Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
14"""
15
16from __future__ import annotations
17
18from dataclasses import dataclass
19
20from torchgen.api.types.types_base import (
21    BaseCppType,
22    BaseCType,
23    boolT,
24    byteT,
25    charT,
26    CType,
27    doubleT,
28    floatT,
29    int32T,
30    longT,
31    shortT,
32)
33from torchgen.model import BaseTy, ScalarType
34
35
36TENSOR_LIST_LIKE_CTYPES = [
37    "at::TensorList",
38    "const c10::List<::std::optional<at::Tensor>> &",
39    "const at::ITensorListRef &",
40]
41
42
43halfT = BaseCppType("at", "Half")
44complexHalfT = BaseCppType(
45    "c10", "complex<c10::Half>"
46)  # stuffing template param here is an abuse
47complexFloatT = BaseCppType("c10", "complex<float>")
48complexDoubleT = BaseCppType("c10", "complex<double>")
49bfloat16T = BaseCppType("at", "BFloat16")
50float8_e5m2T = BaseCppType("at", "Float8_e5m2")
51float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz")
52float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn")
53float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz")
54stringT = BaseCppType("c10", "string_view")
55generatorT = BaseCppType("at", "Generator")
56scalarTypeT = BaseCppType("at", "ScalarType")
57tensorT = BaseCppType("at", "Tensor")
58optionalTensorRefT = BaseCppType("at", "OptionalTensorRef")
59tensorListT = BaseCppType("at", "TensorList")
60iTensorListRefT = BaseCppType("at", "ITensorListRef")
61iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef")
62dimnameT = BaseCppType("at", "Dimname")
63dimnameListT = BaseCppType("at", "DimnameList")
64dimVectorT = BaseCppType("at", "DimVector")
65layoutT = BaseCppType("at", "Layout")
66deviceT = BaseCppType("at", "Device")
67deviceIndexT = BaseCppType("at", "DeviceIndex")
68scalarT = BaseCppType("at", "Scalar")
69optionalScalarRefT = BaseCppType("at", "OptionalScalarRef")
70memoryFormatT = BaseCppType("at", "MemoryFormat")
71qschemeT = BaseCppType("at", "QScheme")
72storageT = BaseCppType("at", "Storage")
73streamT = BaseCppType("at", "Stream")
74intArrayRefT = BaseCppType("at", "IntArrayRef")
75optionalIntArrayRefT = BaseCppType("at", "OptionalIntArrayRef")
76optionalSymIntArrayRefT = BaseCppType("at", "OptionalSymIntArrayRef")
77tensorOptionsT = BaseCppType("at", "TensorOptions")
78typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize")
79tensorGeometryT = BaseCppType("at", "TensorGeometry")
80SymIntT = BaseCppType("c10", "SymInt")
81symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
82
83# Types representing template parameters.  Technically, we probably shouldn't
84# represent them this way in codegen, but it was pretty convenient.
85scalar_t = BaseCppType("", "scalar_t")
86opmath_t = BaseCppType("", "opmath_t")
87
88ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = {
89    ScalarType.Byte: byteT,
90    ScalarType.Char: charT,
91    ScalarType.Short: shortT,
92    ScalarType.Int: int32T,
93    ScalarType.Long: longT,
94    ScalarType.Half: halfT,
95    ScalarType.Float: floatT,
96    ScalarType.Double: doubleT,
97    ScalarType.ComplexHalf: complexHalfT,
98    ScalarType.ComplexFloat: complexFloatT,
99    ScalarType.ComplexDouble: complexDoubleT,
100    ScalarType.Bool: boolT,
101    ScalarType.Float8_e5m2: float8_e5m2T,
102    ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT,
103    ScalarType.Float8_e4m3fn: float8_e4m3fnT,
104    ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT,
105}
106
107BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
108    BaseTy.int: longT,
109    BaseTy.float: doubleT,
110    BaseTy.bool: boolT,
111    BaseTy.str: stringT,
112    BaseTy.Generator: generatorT,
113    BaseTy.ScalarType: scalarTypeT,
114    BaseTy.Tensor: tensorT,
115    BaseTy.Dimname: dimnameT,
116    BaseTy.DimVector: dimVectorT,
117    BaseTy.Layout: layoutT,
118    BaseTy.Device: deviceT,
119    BaseTy.DeviceIndex: deviceIndexT,
120    BaseTy.Scalar: scalarT,
121    BaseTy.MemoryFormat: memoryFormatT,
122    BaseTy.QScheme: qschemeT,
123    BaseTy.Storage: storageT,
124    BaseTy.Stream: streamT,
125    BaseTy.SymInt: SymIntT,
126}
127
128# CTypes encode C++ type structure as needed for translation.
129
130
131@dataclass(frozen=True)
132class OptionalCType(CType):
133    elem: CType
134
135    def cpp_type(self, *, strip_ref: bool = False) -> str:
136        # Do not pass `strip_ref` recursively.
137        return f"::std::optional<{self.elem.cpp_type()}>"
138
139    def cpp_type_registration_declarations(self) -> str:
140        return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>"
141
142    def remove_const_ref(self) -> CType:
143        return OptionalCType(self.elem.remove_const_ref())
144
145
146@dataclass(frozen=True)
147class ListCType(CType):
148    elem: CType
149
150    def cpp_type(self, *, strip_ref: bool = False) -> str:
151        # Do not pass `strip_ref` recursively.
152        return f"c10::List<{self.elem.cpp_type()}>"
153
154    def cpp_type_registration_declarations(self) -> str:
155        return f"c10::List<{self.elem.cpp_type_registration_declarations()}>"
156
157    def remove_const_ref(self) -> CType:
158        return ListCType(self.elem.remove_const_ref())
159
160
161@dataclass(frozen=True)
162class ArrayRefCType(CType):
163    elem: CType
164
165    def cpp_type(self, *, strip_ref: bool = False) -> str:
166        # Do not pass `strip_ref` recursively.
167        return f"at::ArrayRef<{self.elem.cpp_type()}>"
168
169    def cpp_type_registration_declarations(self) -> str:
170        return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
171
172    def remove_const_ref(self) -> CType:
173        return ArrayRefCType(self.elem.remove_const_ref())
174
175
176@dataclass(frozen=True)
177class VectorizedCType(CType):
178    # This template is explicitly specialized, so the only valid
179    # elems are those we have specializations for (e.g., float, double, ...)
180    # scalar_t is also a common argument here (when we are codegen in
181    # a templated context)
182    elem: BaseCType
183
184    def cpp_type(self, *, strip_ref: bool = False) -> str:
185        return f"at::vec::Vectorized<{self.elem.cpp_type()}>"
186
187    def cpp_type_registration_declarations(self) -> str:
188        raise NotImplementedError
189
190    def remove_const_ref(self) -> CType:
191        return self
192