1 #pragma once
2
3 #include <c10/macros/Macros.h>
4 #include <c10/util/Float8_fnuz_cvt.h>
5 #include <cstring>
6 #include <limits>
7
8 C10_CLANG_DIAGNOSTIC_PUSH()
9 #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
10 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
11 #endif
12
13 namespace c10 {
14
15 /// Constructors
16
Float8_e5m2fnuz(float value)17 inline C10_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value)
18 : x(detail::fp8e5m2fnuz_from_fp32_value(value)) {}
19
20 /// Implicit conversions
21
22 inline C10_HOST_DEVICE Float8_e5m2fnuz::operator float() const {
23 return detail::fp8_fnuz_to_fp32_value<5, 2>(x);
24 }
25
26 /// Special values helpers
27
isnan()28 inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const {
29 return x == 0b10000000;
30 }
31
isinf()32 inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const {
33 return false;
34 }
35
36 /// Arithmetic
37
38 inline C10_HOST_DEVICE Float8_e5m2fnuz
39 operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
40 return static_cast<float>(a) + static_cast<float>(b);
41 }
42
43 inline C10_HOST_DEVICE Float8_e5m2fnuz
44 operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
45 return static_cast<float>(a) - static_cast<float>(b);
46 }
47
48 inline C10_HOST_DEVICE Float8_e5m2fnuz
49 operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
50 return static_cast<float>(a) * static_cast<float>(b);
51 }
52
53 inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(
54 const Float8_e5m2fnuz& a,
55 const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
56 return static_cast<float>(a) / static_cast<float>(b);
57 }
58
59 inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(const Float8_e5m2fnuz& a) {
60 return -static_cast<float>(a);
61 }
62
63 inline C10_HOST_DEVICE Float8_e5m2fnuz& operator+=(
64 Float8_e5m2fnuz& a,
65 const Float8_e5m2fnuz& b) {
66 a = a + b;
67 return a;
68 }
69
70 inline C10_HOST_DEVICE Float8_e5m2fnuz& operator-=(
71 Float8_e5m2fnuz& a,
72 const Float8_e5m2fnuz& b) {
73 a = a - b;
74 return a;
75 }
76
77 inline C10_HOST_DEVICE Float8_e5m2fnuz& operator*=(
78 Float8_e5m2fnuz& a,
79 const Float8_e5m2fnuz& b) {
80 a = a * b;
81 return a;
82 }
83
84 inline C10_HOST_DEVICE Float8_e5m2fnuz& operator/=(
85 Float8_e5m2fnuz& a,
86 const Float8_e5m2fnuz& b) {
87 a = a / b;
88 return a;
89 }
90
91 /// Arithmetic with floats
92
93 inline C10_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) {
94 return static_cast<float>(a) + b;
95 }
96 inline C10_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) {
97 return static_cast<float>(a) - b;
98 }
99 inline C10_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) {
100 return static_cast<float>(a) * b;
101 }
102 inline C10_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b)
103 __ubsan_ignore_float_divide_by_zero__ {
104 return static_cast<float>(a) / b;
105 }
106
107 inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) {
108 return a + static_cast<float>(b);
109 }
110 inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) {
111 return a - static_cast<float>(b);
112 }
113 inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) {
114 return a * static_cast<float>(b);
115 }
116 inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b)
117 __ubsan_ignore_float_divide_by_zero__ {
118 return a / static_cast<float>(b);
119 }
120
121 inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2fnuz& b) {
122 return a += static_cast<float>(b);
123 }
124 inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2fnuz& b) {
125 return a -= static_cast<float>(b);
126 }
127 inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2fnuz& b) {
128 return a *= static_cast<float>(b);
129 }
130 inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2fnuz& b) {
131 return a /= static_cast<float>(b);
132 }
133
134 /// Arithmetic with doubles
135
136 inline C10_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) {
137 return static_cast<double>(a) + b;
138 }
139 inline C10_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) {
140 return static_cast<double>(a) - b;
141 }
142 inline C10_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) {
143 return static_cast<double>(a) * b;
144 }
145 inline C10_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b)
146 __ubsan_ignore_float_divide_by_zero__ {
147 return static_cast<double>(a) / b;
148 }
149
150 inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) {
151 return a + static_cast<double>(b);
152 }
153 inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) {
154 return a - static_cast<double>(b);
155 }
156 inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) {
157 return a * static_cast<double>(b);
158 }
159 inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b)
160 __ubsan_ignore_float_divide_by_zero__ {
161 return a / static_cast<double>(b);
162 }
163
164 /// Arithmetic with ints
165
166 inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int b) {
167 return a + static_cast<Float8_e5m2fnuz>(b);
168 }
169 inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int b) {
170 return a - static_cast<Float8_e5m2fnuz>(b);
171 }
172 inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int b) {
173 return a * static_cast<Float8_e5m2fnuz>(b);
174 }
175 inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int b) {
176 return a / static_cast<Float8_e5m2fnuz>(b);
177 }
178
179 inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int a, Float8_e5m2fnuz b) {
180 return static_cast<Float8_e5m2fnuz>(a) + b;
181 }
182 inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int a, Float8_e5m2fnuz b) {
183 return static_cast<Float8_e5m2fnuz>(a) - b;
184 }
185 inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int a, Float8_e5m2fnuz b) {
186 return static_cast<Float8_e5m2fnuz>(a) * b;
187 }
188 inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int a, Float8_e5m2fnuz b) {
189 return static_cast<Float8_e5m2fnuz>(a) / b;
190 }
191
192 //// Arithmetic with int64_t
193
194 inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int64_t b) {
195 return a + static_cast<Float8_e5m2fnuz>(b);
196 }
197 inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int64_t b) {
198 return a - static_cast<Float8_e5m2fnuz>(b);
199 }
200 inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int64_t b) {
201 return a * static_cast<Float8_e5m2fnuz>(b);
202 }
203 inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int64_t b) {
204 return a / static_cast<Float8_e5m2fnuz>(b);
205 }
206
207 inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int64_t a, Float8_e5m2fnuz b) {
208 return static_cast<Float8_e5m2fnuz>(a) + b;
209 }
210 inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int64_t a, Float8_e5m2fnuz b) {
211 return static_cast<Float8_e5m2fnuz>(a) - b;
212 }
213 inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int64_t a, Float8_e5m2fnuz b) {
214 return static_cast<Float8_e5m2fnuz>(a) * b;
215 }
216 inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int64_t a, Float8_e5m2fnuz b) {
217 return static_cast<Float8_e5m2fnuz>(a) / b;
218 }
219
220 /// NOTE: we do not define comparisons directly and instead rely on the implicit
221 /// conversion from c10::Float8_e5m2fnuz to float.
222
223 } // namespace c10
224
225 namespace std {
226
227 template <>
228 class numeric_limits<c10::Float8_e5m2fnuz> {
229 public:
230 static constexpr bool is_signed = true;
231 static constexpr bool is_integer = false;
232 static constexpr bool is_specialized = true;
233 static constexpr bool is_exact = false;
234 static constexpr bool has_infinity = false;
235 static constexpr bool has_quiet_NaN = true;
236 static constexpr bool has_signaling_NaN = false;
237 static constexpr auto has_denorm = true;
238 static constexpr auto has_denorm_loss = true;
239 static constexpr auto round_style = numeric_limits<float>::round_style;
240 static constexpr bool is_iec559 = false;
241 static constexpr bool is_bounded = true;
242 static constexpr bool is_modulo = false;
243 static constexpr int digits = 3;
244 static constexpr int digits10 = 0;
245 static constexpr int max_digits10 = 2;
246 static constexpr int radix = 2;
247 static constexpr int min_exponent = -14;
248 static constexpr int min_exponent10 = -4;
249 static constexpr int max_exponent = 16;
250 static constexpr int max_exponent10 = 4;
251 static constexpr auto traps = numeric_limits<float>::traps;
252 static constexpr auto tinyness_before =
253 numeric_limits<float>::tinyness_before;
254
min()255 static constexpr c10::Float8_e5m2fnuz min() {
256 return c10::Float8_e5m2fnuz(0x04, c10::Float8_e5m2fnuz::from_bits());
257 }
max()258 static constexpr c10::Float8_e5m2fnuz max() {
259 return c10::Float8_e5m2fnuz(0x7F, c10::Float8_e5m2fnuz::from_bits());
260 }
lowest()261 static constexpr c10::Float8_e5m2fnuz lowest() {
262 return c10::Float8_e5m2fnuz(0xFF, c10::Float8_e5m2fnuz::from_bits());
263 }
epsilon()264 static constexpr c10::Float8_e5m2fnuz epsilon() {
265 return c10::Float8_e5m2fnuz(0x34, c10::Float8_e5m2fnuz::from_bits());
266 }
round_error()267 static constexpr c10::Float8_e5m2fnuz round_error() {
268 return c10::Float8_e5m2fnuz(0x38, c10::Float8_e5m2fnuz::from_bits());
269 }
infinity()270 static constexpr c10::Float8_e5m2fnuz infinity() {
271 return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
272 }
273 // TODO(future): we are mapping neg_zero to both inf and NaN, this is
274 // surprising and we should figure out what to do about it.
quiet_NaN()275 static constexpr c10::Float8_e5m2fnuz quiet_NaN() {
276 return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
277 }
denorm_min()278 static constexpr c10::Float8_e5m2fnuz denorm_min() {
279 return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits());
280 }
281 };
282
283 } // namespace std
284
285 C10_CLANG_DIAGNOSTIC_POP()
286