xref: /aosp_15_r20/external/pytorch/c10/util/BFloat16-math.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/BFloat16.h>
4 #include <c10/util/Half.h>
5 
6 C10_CLANG_DIAGNOSTIC_PUSH()
7 #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
8 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
9 #endif
10 
11 namespace std {
12 
13 template <typename T>
14 struct is_reduced_floating_point
15     : std::integral_constant<
16           bool,
17           std::is_same_v<T, c10::Half> || std::is_same_v<T, c10::BFloat16>> {};
18 
19 template <typename T>
20 constexpr bool is_reduced_floating_point_v =
21     is_reduced_floating_point<T>::value;
22 
23 template <
24     typename T,
25     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
acos(T a)26 inline T acos(T a) {
27   return std::acos(float(a));
28 }
29 template <
30     typename T,
31     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
asin(T a)32 inline T asin(T a) {
33   return std::asin(float(a));
34 }
35 template <
36     typename T,
37     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
atan(T a)38 inline T atan(T a) {
39   return std::atan(float(a));
40 }
41 template <
42     typename T,
43     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
atanh(T a)44 inline T atanh(T a) {
45   return std::atanh(float(a));
46 }
47 template <
48     typename T,
49     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
erf(T a)50 inline T erf(T a) {
51   return std::erf(float(a));
52 }
53 template <
54     typename T,
55     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
erfc(T a)56 inline T erfc(T a) {
57   return std::erfc(float(a));
58 }
59 template <
60     typename T,
61     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
exp(T a)62 inline T exp(T a) {
63   return std::exp(float(a));
64 }
65 template <
66     typename T,
67     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
expm1(T a)68 inline T expm1(T a) {
69   return std::expm1(float(a));
70 }
71 template <
72     typename T,
73     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
isfinite(T a)74 inline bool isfinite(T a) {
75   return std::isfinite(float(a));
76 }
77 template <
78     typename T,
79     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
log(T a)80 inline T log(T a) {
81   return std::log(float(a));
82 }
83 template <
84     typename T,
85     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
log10(T a)86 inline T log10(T a) {
87   return std::log10(float(a));
88 }
89 template <
90     typename T,
91     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
log1p(T a)92 inline T log1p(T a) {
93   return std::log1p(float(a));
94 }
95 template <
96     typename T,
97     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
log2(T a)98 inline T log2(T a) {
99   return std::log2(float(a));
100 }
101 template <
102     typename T,
103     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
ceil(T a)104 inline T ceil(T a) {
105   return std::ceil(float(a));
106 }
107 template <
108     typename T,
109     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
cos(T a)110 inline T cos(T a) {
111   return std::cos(float(a));
112 }
113 template <
114     typename T,
115     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
floor(T a)116 inline T floor(T a) {
117   return std::floor(float(a));
118 }
119 template <
120     typename T,
121     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
nearbyint(T a)122 inline T nearbyint(T a) {
123   return std::nearbyint(float(a));
124 }
125 template <
126     typename T,
127     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
sin(T a)128 inline T sin(T a) {
129   return std::sin(float(a));
130 }
131 template <
132     typename T,
133     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
tan(T a)134 inline T tan(T a) {
135   return std::tan(float(a));
136 }
137 template <
138     typename T,
139     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
sinh(T a)140 inline T sinh(T a) {
141   return std::sinh(float(a));
142 }
143 template <
144     typename T,
145     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
cosh(T a)146 inline T cosh(T a) {
147   return std::cosh(float(a));
148 }
149 template <
150     typename T,
151     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
tanh(T a)152 inline T tanh(T a) {
153   return std::tanh(float(a));
154 }
155 template <
156     typename T,
157     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
trunc(T a)158 inline T trunc(T a) {
159   return std::trunc(float(a));
160 }
161 template <
162     typename T,
163     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
lgamma(T a)164 inline T lgamma(T a) {
165   return std::lgamma(float(a));
166 }
167 template <
168     typename T,
169     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
sqrt(T a)170 inline T sqrt(T a) {
171   return std::sqrt(float(a));
172 }
173 template <
174     typename T,
175     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
rsqrt(T a)176 inline T rsqrt(T a) {
177   return 1.0 / std::sqrt(float(a));
178 }
179 template <
180     typename T,
181     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
abs(T a)182 inline T abs(T a) {
183   return std::abs(float(a));
184 }
185 #if defined(_MSC_VER) && defined(__CUDACC__)
186 template <
187     typename T,
188     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
pow(T a,double b)189 inline T pow(T a, double b) {
190   return std::pow(float(a), float(b));
191 }
192 #else
193 template <
194     typename T,
195     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
pow(T a,double b)196 inline T pow(T a, double b) {
197   return std::pow(float(a), b);
198 }
199 #endif
200 template <
201     typename T,
202     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
pow(T a,T b)203 inline T pow(T a, T b) {
204   return std::pow(float(a), float(b));
205 }
206 template <
207     typename T,
208     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
fmod(T a,T b)209 inline T fmod(T a, T b) {
210   return std::fmod(float(a), float(b));
211 }
212 
213 /*
214   The following function is inspired from the implementation in `musl`
215   Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
216   ----------------------------------------------------------------------
217   Copyright © 2005-2020 Rich Felker, et al.
218 
219   Permission is hereby granted, free of charge, to any person obtaining
220   a copy of this software and associated documentation files (the
221   "Software"), to deal in the Software without restriction, including
222   without limitation the rights to use, copy, modify, merge, publish,
223   distribute, sublicense, and/or sell copies of the Software, and to
224   permit persons to whom the Software is furnished to do so, subject to
225   the following conditions:
226 
227   The above copyright notice and this permission notice shall be
228   included in all copies or substantial portions of the Software.
229 
230   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
231   EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
232   MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
233   IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
234   CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
235   TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
236   SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
237   ----------------------------------------------------------------------
238  */
239 template <
240     typename T,
241     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
nextafter(T from,T to)242 C10_HOST_DEVICE inline T nextafter(T from, T to) {
243   // Reference:
244   // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c
245   using int_repr_t = uint16_t;
246   constexpr uint8_t bits = 16;
247   union {
248     T f;
249     int_repr_t i;
250   } ufrom = {from}, uto = {to};
251 
252   // get a mask to get the sign bit i.e. MSB
253   int_repr_t sign_mask = int_repr_t{1} << (bits - 1);
254 
255   // short-circuit: if either is NaN, return NaN
256   if (from != from || to != to) {
257     return from + to;
258   }
259 
260   // short-circuit: if they are exactly the same.
261   if (ufrom.i == uto.i) {
262     return from;
263   }
264 
265   // mask the sign-bit to zero i.e. positive
266   // equivalent to abs(x)
267   int_repr_t abs_from = ufrom.i & ~sign_mask;
268   int_repr_t abs_to = uto.i & ~sign_mask;
269   if (abs_from == 0) {
270     // if both are zero but with different sign,
271     // preserve the sign of `to`.
272     if (abs_to == 0) {
273       return to;
274     }
275     // smallest subnormal with sign of `to`.
276     ufrom.i = (uto.i & sign_mask) | int_repr_t{1};
277     return ufrom.f;
278   }
279 
280   // if abs(from) > abs(to) or sign(from) != sign(to)
281   if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) {
282     ufrom.i--;
283   } else {
284     ufrom.i++;
285   }
286 
287   return ufrom.f;
288 }
289 
290 } // namespace std
291 
292 C10_CLANG_DIAGNOSTIC_POP()
293