xref: /aosp_15_r20/external/pytorch/c10/util/Float8_e5m2fnuz-inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 #include <c10/util/Float8_fnuz_cvt.h>
5 #include <cstring>
6 #include <limits>
7 
8 C10_CLANG_DIAGNOSTIC_PUSH()
9 #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
10 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
11 #endif
12 
13 namespace c10 {
14 
15 /// Constructors
16 
Float8_e5m2fnuz(float value)17 inline C10_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value)
18     : x(detail::fp8e5m2fnuz_from_fp32_value(value)) {}
19 
20 /// Implicit conversions
21 
22 inline C10_HOST_DEVICE Float8_e5m2fnuz::operator float() const {
23   return detail::fp8_fnuz_to_fp32_value<5, 2>(x);
24 }
25 
26 /// Special values helpers
27 
isnan()28 inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const {
29   return x == 0b10000000;
30 }
31 
isinf()32 inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const {
33   return false;
34 }
35 
36 /// Arithmetic
37 
38 inline C10_HOST_DEVICE Float8_e5m2fnuz
39 operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
40   return static_cast<float>(a) + static_cast<float>(b);
41 }
42 
43 inline C10_HOST_DEVICE Float8_e5m2fnuz
44 operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
45   return static_cast<float>(a) - static_cast<float>(b);
46 }
47 
48 inline C10_HOST_DEVICE Float8_e5m2fnuz
49 operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
50   return static_cast<float>(a) * static_cast<float>(b);
51 }
52 
53 inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(
54     const Float8_e5m2fnuz& a,
55     const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
56   return static_cast<float>(a) / static_cast<float>(b);
57 }
58 
59 inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(const Float8_e5m2fnuz& a) {
60   return -static_cast<float>(a);
61 }
62 
63 inline C10_HOST_DEVICE Float8_e5m2fnuz& operator+=(
64     Float8_e5m2fnuz& a,
65     const Float8_e5m2fnuz& b) {
66   a = a + b;
67   return a;
68 }
69 
70 inline C10_HOST_DEVICE Float8_e5m2fnuz& operator-=(
71     Float8_e5m2fnuz& a,
72     const Float8_e5m2fnuz& b) {
73   a = a - b;
74   return a;
75 }
76 
77 inline C10_HOST_DEVICE Float8_e5m2fnuz& operator*=(
78     Float8_e5m2fnuz& a,
79     const Float8_e5m2fnuz& b) {
80   a = a * b;
81   return a;
82 }
83 
84 inline C10_HOST_DEVICE Float8_e5m2fnuz& operator/=(
85     Float8_e5m2fnuz& a,
86     const Float8_e5m2fnuz& b) {
87   a = a / b;
88   return a;
89 }
90 
91 /// Arithmetic with floats
92 
93 inline C10_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) {
94   return static_cast<float>(a) + b;
95 }
96 inline C10_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) {
97   return static_cast<float>(a) - b;
98 }
99 inline C10_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) {
100   return static_cast<float>(a) * b;
101 }
102 inline C10_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b)
103     __ubsan_ignore_float_divide_by_zero__ {
104   return static_cast<float>(a) / b;
105 }
106 
107 inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) {
108   return a + static_cast<float>(b);
109 }
110 inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) {
111   return a - static_cast<float>(b);
112 }
113 inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) {
114   return a * static_cast<float>(b);
115 }
116 inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b)
117     __ubsan_ignore_float_divide_by_zero__ {
118   return a / static_cast<float>(b);
119 }
120 
121 inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2fnuz& b) {
122   return a += static_cast<float>(b);
123 }
124 inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2fnuz& b) {
125   return a -= static_cast<float>(b);
126 }
127 inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2fnuz& b) {
128   return a *= static_cast<float>(b);
129 }
130 inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2fnuz& b) {
131   return a /= static_cast<float>(b);
132 }
133 
134 /// Arithmetic with doubles
135 
136 inline C10_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) {
137   return static_cast<double>(a) + b;
138 }
139 inline C10_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) {
140   return static_cast<double>(a) - b;
141 }
142 inline C10_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) {
143   return static_cast<double>(a) * b;
144 }
145 inline C10_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b)
146     __ubsan_ignore_float_divide_by_zero__ {
147   return static_cast<double>(a) / b;
148 }
149 
150 inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) {
151   return a + static_cast<double>(b);
152 }
153 inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) {
154   return a - static_cast<double>(b);
155 }
156 inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) {
157   return a * static_cast<double>(b);
158 }
159 inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b)
160     __ubsan_ignore_float_divide_by_zero__ {
161   return a / static_cast<double>(b);
162 }
163 
164 /// Arithmetic with ints
165 
166 inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int b) {
167   return a + static_cast<Float8_e5m2fnuz>(b);
168 }
169 inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int b) {
170   return a - static_cast<Float8_e5m2fnuz>(b);
171 }
172 inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int b) {
173   return a * static_cast<Float8_e5m2fnuz>(b);
174 }
175 inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int b) {
176   return a / static_cast<Float8_e5m2fnuz>(b);
177 }
178 
179 inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int a, Float8_e5m2fnuz b) {
180   return static_cast<Float8_e5m2fnuz>(a) + b;
181 }
182 inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int a, Float8_e5m2fnuz b) {
183   return static_cast<Float8_e5m2fnuz>(a) - b;
184 }
185 inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int a, Float8_e5m2fnuz b) {
186   return static_cast<Float8_e5m2fnuz>(a) * b;
187 }
188 inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int a, Float8_e5m2fnuz b) {
189   return static_cast<Float8_e5m2fnuz>(a) / b;
190 }
191 
192 //// Arithmetic with int64_t
193 
194 inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int64_t b) {
195   return a + static_cast<Float8_e5m2fnuz>(b);
196 }
197 inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int64_t b) {
198   return a - static_cast<Float8_e5m2fnuz>(b);
199 }
200 inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int64_t b) {
201   return a * static_cast<Float8_e5m2fnuz>(b);
202 }
203 inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int64_t b) {
204   return a / static_cast<Float8_e5m2fnuz>(b);
205 }
206 
207 inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int64_t a, Float8_e5m2fnuz b) {
208   return static_cast<Float8_e5m2fnuz>(a) + b;
209 }
210 inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int64_t a, Float8_e5m2fnuz b) {
211   return static_cast<Float8_e5m2fnuz>(a) - b;
212 }
213 inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int64_t a, Float8_e5m2fnuz b) {
214   return static_cast<Float8_e5m2fnuz>(a) * b;
215 }
216 inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int64_t a, Float8_e5m2fnuz b) {
217   return static_cast<Float8_e5m2fnuz>(a) / b;
218 }
219 
220 /// NOTE: we do not define comparisons directly and instead rely on the implicit
221 /// conversion from c10::Float8_e5m2fnuz to float.
222 
223 } // namespace c10
224 
225 namespace std {
226 
227 template <>
228 class numeric_limits<c10::Float8_e5m2fnuz> {
229  public:
230   static constexpr bool is_signed = true;
231   static constexpr bool is_integer = false;
232   static constexpr bool is_specialized = true;
233   static constexpr bool is_exact = false;
234   static constexpr bool has_infinity = false;
235   static constexpr bool has_quiet_NaN = true;
236   static constexpr bool has_signaling_NaN = false;
237   static constexpr auto has_denorm = true;
238   static constexpr auto has_denorm_loss = true;
239   static constexpr auto round_style = numeric_limits<float>::round_style;
240   static constexpr bool is_iec559 = false;
241   static constexpr bool is_bounded = true;
242   static constexpr bool is_modulo = false;
243   static constexpr int digits = 3;
244   static constexpr int digits10 = 0;
245   static constexpr int max_digits10 = 2;
246   static constexpr int radix = 2;
247   static constexpr int min_exponent = -14;
248   static constexpr int min_exponent10 = -4;
249   static constexpr int max_exponent = 16;
250   static constexpr int max_exponent10 = 4;
251   static constexpr auto traps = numeric_limits<float>::traps;
252   static constexpr auto tinyness_before =
253       numeric_limits<float>::tinyness_before;
254 
min()255   static constexpr c10::Float8_e5m2fnuz min() {
256     return c10::Float8_e5m2fnuz(0x04, c10::Float8_e5m2fnuz::from_bits());
257   }
max()258   static constexpr c10::Float8_e5m2fnuz max() {
259     return c10::Float8_e5m2fnuz(0x7F, c10::Float8_e5m2fnuz::from_bits());
260   }
lowest()261   static constexpr c10::Float8_e5m2fnuz lowest() {
262     return c10::Float8_e5m2fnuz(0xFF, c10::Float8_e5m2fnuz::from_bits());
263   }
epsilon()264   static constexpr c10::Float8_e5m2fnuz epsilon() {
265     return c10::Float8_e5m2fnuz(0x34, c10::Float8_e5m2fnuz::from_bits());
266   }
round_error()267   static constexpr c10::Float8_e5m2fnuz round_error() {
268     return c10::Float8_e5m2fnuz(0x38, c10::Float8_e5m2fnuz::from_bits());
269   }
infinity()270   static constexpr c10::Float8_e5m2fnuz infinity() {
271     return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
272   }
273   // TODO(future): we are mapping neg_zero to both inf and NaN, this is
274   // surprising and we should figure out what to do about it.
quiet_NaN()275   static constexpr c10::Float8_e5m2fnuz quiet_NaN() {
276     return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
277   }
denorm_min()278   static constexpr c10::Float8_e5m2fnuz denorm_min() {
279     return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits());
280   }
281 };
282 
283 } // namespace std
284 
285 C10_CLANG_DIAGNOSTIC_POP()
286