xref: /aosp_15_r20/external/pytorch/c10/util/BFloat16.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker // Defines the bloat16 type (brain floating-point). This representation uses
4*da0073e9SAndroid Build Coastguard Worker // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
7*da0073e9SAndroid Build Coastguard Worker #include <cmath>
8*da0073e9SAndroid Build Coastguard Worker #include <cstdint>
9*da0073e9SAndroid Build Coastguard Worker #include <cstring>
10*da0073e9SAndroid Build Coastguard Worker #include <iosfwd>
11*da0073e9SAndroid Build Coastguard Worker #include <ostream>
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) && !defined(USE_ROCM)
14*da0073e9SAndroid Build Coastguard Worker #include <cuda_bf16.h>
15*da0073e9SAndroid Build Coastguard Worker #endif
16*da0073e9SAndroid Build Coastguard Worker #if defined(__HIPCC__) && defined(USE_ROCM)
17*da0073e9SAndroid Build Coastguard Worker #include <hip/hip_bf16.h>
18*da0073e9SAndroid Build Coastguard Worker #endif
19*da0073e9SAndroid Build Coastguard Worker 
20*da0073e9SAndroid Build Coastguard Worker #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
21*da0073e9SAndroid Build Coastguard Worker #if defined(CL_SYCL_LANGUAGE_VERSION)
22*da0073e9SAndroid Build Coastguard Worker #include <CL/sycl.hpp> // for SYCL 1.2.1
23*da0073e9SAndroid Build Coastguard Worker #else
24*da0073e9SAndroid Build Coastguard Worker #include <sycl/sycl.hpp> // for SYCL 2020
25*da0073e9SAndroid Build Coastguard Worker #endif
26*da0073e9SAndroid Build Coastguard Worker #include <ext/oneapi/bfloat16.hpp>
27*da0073e9SAndroid Build Coastguard Worker #endif
28*da0073e9SAndroid Build Coastguard Worker 
29*da0073e9SAndroid Build Coastguard Worker namespace c10 {
30*da0073e9SAndroid Build Coastguard Worker 
31*da0073e9SAndroid Build Coastguard Worker namespace detail {
f32_from_bits(uint16_t src)32*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
33*da0073e9SAndroid Build Coastguard Worker   float res = 0;
34*da0073e9SAndroid Build Coastguard Worker   uint32_t tmp = src;
35*da0073e9SAndroid Build Coastguard Worker   tmp <<= 16;
36*da0073e9SAndroid Build Coastguard Worker 
37*da0073e9SAndroid Build Coastguard Worker #if defined(USE_ROCM)
38*da0073e9SAndroid Build Coastguard Worker   float* tempRes;
39*da0073e9SAndroid Build Coastguard Worker 
40*da0073e9SAndroid Build Coastguard Worker   // We should be using memcpy in order to respect the strict aliasing rule
41*da0073e9SAndroid Build Coastguard Worker   // but it fails in the HIP environment.
42*da0073e9SAndroid Build Coastguard Worker   tempRes = reinterpret_cast<float*>(&tmp);
43*da0073e9SAndroid Build Coastguard Worker   res = *tempRes;
44*da0073e9SAndroid Build Coastguard Worker #else
45*da0073e9SAndroid Build Coastguard Worker   std::memcpy(&res, &tmp, sizeof(tmp));
46*da0073e9SAndroid Build Coastguard Worker #endif
47*da0073e9SAndroid Build Coastguard Worker 
48*da0073e9SAndroid Build Coastguard Worker   return res;
49*da0073e9SAndroid Build Coastguard Worker }
50*da0073e9SAndroid Build Coastguard Worker 
bits_from_f32(float src)51*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
52*da0073e9SAndroid Build Coastguard Worker   uint32_t res = 0;
53*da0073e9SAndroid Build Coastguard Worker 
54*da0073e9SAndroid Build Coastguard Worker #if defined(USE_ROCM)
55*da0073e9SAndroid Build Coastguard Worker   // We should be using memcpy in order to respect the strict aliasing rule
56*da0073e9SAndroid Build Coastguard Worker   // but it fails in the HIP environment.
57*da0073e9SAndroid Build Coastguard Worker   uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
58*da0073e9SAndroid Build Coastguard Worker   res = *tempRes;
59*da0073e9SAndroid Build Coastguard Worker #else
60*da0073e9SAndroid Build Coastguard Worker   std::memcpy(&res, &src, sizeof(res));
61*da0073e9SAndroid Build Coastguard Worker #endif
62*da0073e9SAndroid Build Coastguard Worker 
63*da0073e9SAndroid Build Coastguard Worker   return res >> 16;
64*da0073e9SAndroid Build Coastguard Worker }
65*da0073e9SAndroid Build Coastguard Worker 
round_to_nearest_even(float src)66*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
67*da0073e9SAndroid Build Coastguard Worker #if defined(USE_ROCM)
68*da0073e9SAndroid Build Coastguard Worker   if (src != src) {
69*da0073e9SAndroid Build Coastguard Worker #elif defined(_MSC_VER)
70*da0073e9SAndroid Build Coastguard Worker   if (isnan(src)) {
71*da0073e9SAndroid Build Coastguard Worker #else
72*da0073e9SAndroid Build Coastguard Worker   if (std::isnan(src)) {
73*da0073e9SAndroid Build Coastguard Worker #endif
74*da0073e9SAndroid Build Coastguard Worker     return UINT16_C(0x7FC0);
75*da0073e9SAndroid Build Coastguard Worker   } else {
76*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
77*da0073e9SAndroid Build Coastguard Worker     union {
78*da0073e9SAndroid Build Coastguard Worker       uint32_t U32; // NOLINT(facebook-hte-BadMemberName)
79*da0073e9SAndroid Build Coastguard Worker       float F32; // NOLINT(facebook-hte-BadMemberName)
80*da0073e9SAndroid Build Coastguard Worker     };
81*da0073e9SAndroid Build Coastguard Worker 
82*da0073e9SAndroid Build Coastguard Worker     F32 = src;
83*da0073e9SAndroid Build Coastguard Worker     uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
84*da0073e9SAndroid Build Coastguard Worker     return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
85*da0073e9SAndroid Build Coastguard Worker   }
86*da0073e9SAndroid Build Coastguard Worker }
87*da0073e9SAndroid Build Coastguard Worker } // namespace detail
88*da0073e9SAndroid Build Coastguard Worker 
89*da0073e9SAndroid Build Coastguard Worker struct alignas(2) BFloat16 {
90*da0073e9SAndroid Build Coastguard Worker   uint16_t x;
91*da0073e9SAndroid Build Coastguard Worker 
92*da0073e9SAndroid Build Coastguard Worker   // HIP wants __host__ __device__ tag, CUDA does not
93*da0073e9SAndroid Build Coastguard Worker #if defined(USE_ROCM)
94*da0073e9SAndroid Build Coastguard Worker   C10_HOST_DEVICE BFloat16() = default;
95*da0073e9SAndroid Build Coastguard Worker #else
96*da0073e9SAndroid Build Coastguard Worker   BFloat16() = default;
97*da0073e9SAndroid Build Coastguard Worker #endif
98*da0073e9SAndroid Build Coastguard Worker 
99*da0073e9SAndroid Build Coastguard Worker   struct from_bits_t {};
100*da0073e9SAndroid Build Coastguard Worker   static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
101*da0073e9SAndroid Build Coastguard Worker     return from_bits_t();
102*da0073e9SAndroid Build Coastguard Worker   }
103*da0073e9SAndroid Build Coastguard Worker 
104*da0073e9SAndroid Build Coastguard Worker   constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
105*da0073e9SAndroid Build Coastguard Worker       : x(bits) {}
106*da0073e9SAndroid Build Coastguard Worker   /* implicit */ inline C10_HOST_DEVICE BFloat16(float value);
107*da0073e9SAndroid Build Coastguard Worker   inline C10_HOST_DEVICE operator float() const;
108*da0073e9SAndroid Build Coastguard Worker 
109*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) && !defined(USE_ROCM)
110*da0073e9SAndroid Build Coastguard Worker   inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
111*da0073e9SAndroid Build Coastguard Worker   explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
112*da0073e9SAndroid Build Coastguard Worker #endif
113*da0073e9SAndroid Build Coastguard Worker #if defined(__HIPCC__) && defined(USE_ROCM)
114*da0073e9SAndroid Build Coastguard Worker   inline C10_HOST_DEVICE BFloat16(const __hip_bfloat16& value);
115*da0073e9SAndroid Build Coastguard Worker   explicit inline C10_HOST_DEVICE operator __hip_bfloat16() const;
116*da0073e9SAndroid Build Coastguard Worker #endif
117*da0073e9SAndroid Build Coastguard Worker 
118*da0073e9SAndroid Build Coastguard Worker #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
119*da0073e9SAndroid Build Coastguard Worker   inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);
120*da0073e9SAndroid Build Coastguard Worker   explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
121*da0073e9SAndroid Build Coastguard Worker #endif
122*da0073e9SAndroid Build Coastguard Worker };
123*da0073e9SAndroid Build Coastguard Worker 
124*da0073e9SAndroid Build Coastguard Worker C10_API inline std::ostream& operator<<(
125*da0073e9SAndroid Build Coastguard Worker     std::ostream& out,
126*da0073e9SAndroid Build Coastguard Worker     const BFloat16& value) {
127*da0073e9SAndroid Build Coastguard Worker   out << (float)value;
128*da0073e9SAndroid Build Coastguard Worker   return out;
129*da0073e9SAndroid Build Coastguard Worker }
130*da0073e9SAndroid Build Coastguard Worker 
131*da0073e9SAndroid Build Coastguard Worker } // namespace c10
132*da0073e9SAndroid Build Coastguard Worker 
133*da0073e9SAndroid Build Coastguard Worker #include <c10/util/BFloat16-inl.h> // IWYU pragma: keep
134