1 #pragma once
2
3 #include <cstdint>
4 #include <iosfwd>
5
6 #include <c10/core/ScalarType.h>
7 #include <c10/util/Logging.h>
8 #include <torch/csrc/Export.h>
9
10 #include <torch/csrc/jit/tensorexpr/exceptions.h>
11
12 namespace torch::jit::tensorexpr {
13
14 using int32 = std::int32_t;
15
16 class Dtype;
17 TORCH_API std::ostream& operator<<(std::ostream& stream, const Dtype& dtype);
18
19 using ScalarType = c10::ScalarType;
20
21 enum ElementType {
22 kAllTypes = 0,
23 kIntegralTypes = 1 << 0,
24 kFloatingPointTypes = 1 << 1,
25 kBoolType = 1 << 2,
26 kComplexTypes = 1 << 3,
27 kQintTypes = 1 << 4,
28 kNonComplexOrQintTypes = kIntegralTypes | kBoolType | kFloatingPointTypes,
29 };
30
31 // Data types for scalar and vector elements.
32 class TORCH_API Dtype {
33 public:
Dtype(int8_t type)34 explicit Dtype(int8_t type)
35 : scalar_type_(static_cast<ScalarType>(type)), lanes_(1) {}
Dtype(ScalarType type)36 explicit Dtype(ScalarType type) : scalar_type_(type), lanes_(1) {}
Dtype(int8_t type,int64_t lanes)37 Dtype(int8_t type, int64_t lanes)
38 : scalar_type_(static_cast<ScalarType>(type)), lanes_(lanes) {}
Dtype(ScalarType type,int64_t lanes)39 Dtype(ScalarType type, int64_t lanes) : scalar_type_(type), lanes_(lanes) {}
Dtype(Dtype type,int64_t lanes)40 Dtype(Dtype type, int64_t lanes)
41 : scalar_type_(type.scalar_type_), lanes_(lanes) {
42 if (type.lanes() != 1) {
43 throw malformed_input("dtype lanes dont match");
44 }
45 }
lanes()46 int64_t lanes() const {
47 return lanes_;
48 }
scalar_type()49 ScalarType scalar_type() const {
50 return scalar_type_;
51 }
52 Dtype scalar_dtype() const;
53 bool operator==(const Dtype& other) const {
54 return scalar_type_ == other.scalar_type_ && lanes_ == other.lanes_;
55 }
56 bool operator!=(const Dtype& other) const {
57 return !(*this == other);
58 }
59 int byte_size() const;
60 std::string ToCppString() const;
61
is_integral()62 bool is_integral() const {
63 return c10::isIntegralType(scalar_type_, true);
64 }
is_floating_point()65 bool is_floating_point() const {
66 return c10::isFloatingType(scalar_type_);
67 }
is_signed()68 bool is_signed() const {
69 return c10::isSignedType(scalar_type_);
70 }
71
cloneWithScalarType(ScalarType nt)72 Dtype cloneWithScalarType(ScalarType nt) const {
73 return Dtype(nt, lanes_);
74 }
75
76 private:
77 friend TORCH_API std::ostream& operator<<(
78 std::ostream& stream,
79 const Dtype& dtype);
80 ScalarType scalar_type_;
81 int64_t lanes_; // the width of the element for a vector time
82 };
83
84 extern TORCH_API Dtype kHandle;
85
86 #define NNC_DTYPE_DECLARATION(ctype, name) extern TORCH_API Dtype k##name;
87
88 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, NNC_DTYPE_DECLARATION)
89 NNC_DTYPE_DECLARATION(c10::quint8, QUInt8);
90 NNC_DTYPE_DECLARATION(c10::qint8, QInt8);
91 #undef NNC_DTYPE_DECLARATION
92
93 template <typename T>
94 TORCH_API Dtype ToDtype();
95
96 #define NNC_TODTYPE_DECLARATION(ctype, name) \
97 template <> \
98 inline Dtype ToDtype<ctype>() { \
99 return k##name; \
100 }
101 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, NNC_TODTYPE_DECLARATION)
102 NNC_TODTYPE_DECLARATION(c10::quint8, QUInt8);
103 NNC_TODTYPE_DECLARATION(c10::qint8, QInt8);
104 #undef NNC_TODTYPE_DECLARATION
105
106 TORCH_API Dtype ToDtype(ScalarType type);
107
promoteTypes(Dtype a,Dtype b)108 inline Dtype promoteTypes(Dtype a, Dtype b) {
109 if (a.lanes() != b.lanes()) {
110 throw malformed_input("promoting types with different lanes");
111 }
112 return Dtype(
113 static_cast<ScalarType>(c10::promoteTypes(
114 static_cast<c10::ScalarType>(a.scalar_type()),
115 static_cast<c10::ScalarType>(b.scalar_type()))),
116 a.lanes());
117 }
118
119 inline Dtype BinaryOpDtype(
120 Dtype op1_dtype,
121 Dtype op2_dtype,
122 ScalarType ret_type = ScalarType::Undefined) {
123 if (op1_dtype == op2_dtype) {
124 if (ret_type == ScalarType::Undefined) {
125 return op1_dtype;
126 }
127
128 return ToDtype(ret_type);
129 }
130
131 if (op1_dtype.lanes() != op2_dtype.lanes()) {
132 throw malformed_input("lanes dont match");
133 }
134 int64_t lanes = op1_dtype.lanes();
135
136 Dtype resultType = promoteTypes(op1_dtype, op2_dtype);
137 if (resultType.scalar_type() == ScalarType::Undefined) {
138 throw malformed_input("scalar type doesn't match");
139 }
140
141 if (lanes == 1) {
142 // Use the fixed scalar Dtypes.
143 return ToDtype(resultType.scalar_type());
144 }
145
146 return resultType;
147 }
148
149 } // namespace torch::jit::tensorexpr
150
151 namespace std {
152
153 using torch::jit::tensorexpr::Dtype;
154 std::string to_string(const Dtype& dtype);
155 using torch::jit::tensorexpr::ScalarType;
156 std::string to_string(const ScalarType& dtype);
157
158 } // namespace std
159