1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker // Defines the bloat16 type (brain floating-point). This representation uses 4*da0073e9SAndroid Build Coastguard Worker // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h> 7*da0073e9SAndroid Build Coastguard Worker #include <cmath> 8*da0073e9SAndroid Build Coastguard Worker #include <cstdint> 9*da0073e9SAndroid Build Coastguard Worker #include <cstring> 10*da0073e9SAndroid Build Coastguard Worker #include <iosfwd> 11*da0073e9SAndroid Build Coastguard Worker #include <ostream> 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) && !defined(USE_ROCM) 14*da0073e9SAndroid Build Coastguard Worker #include <cuda_bf16.h> 15*da0073e9SAndroid Build Coastguard Worker #endif 16*da0073e9SAndroid Build Coastguard Worker #if defined(__HIPCC__) && defined(USE_ROCM) 17*da0073e9SAndroid Build Coastguard Worker #include <hip/hip_bf16.h> 18*da0073e9SAndroid Build Coastguard Worker #endif 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) 21*da0073e9SAndroid Build Coastguard Worker #if defined(CL_SYCL_LANGUAGE_VERSION) 22*da0073e9SAndroid Build Coastguard Worker #include <CL/sycl.hpp> // for SYCL 1.2.1 23*da0073e9SAndroid Build Coastguard Worker #else 24*da0073e9SAndroid Build Coastguard Worker #include <sycl/sycl.hpp> // for SYCL 2020 25*da0073e9SAndroid Build Coastguard Worker #endif 26*da0073e9SAndroid Build Coastguard Worker #include <ext/oneapi/bfloat16.hpp> 27*da0073e9SAndroid Build Coastguard Worker #endif 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker namespace c10 { 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker namespace detail { f32_from_bits(uint16_t src)32*da0073e9SAndroid Build Coastguard Workerinline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { 33*da0073e9SAndroid Build Coastguard Worker float res = 0; 34*da0073e9SAndroid Build Coastguard Worker uint32_t tmp = src; 35*da0073e9SAndroid Build Coastguard Worker tmp <<= 16; 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker #if defined(USE_ROCM) 38*da0073e9SAndroid Build Coastguard Worker float* tempRes; 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker // We should be using memcpy in order to respect the strict aliasing rule 41*da0073e9SAndroid Build Coastguard Worker // but it fails in the HIP environment. 42*da0073e9SAndroid Build Coastguard Worker tempRes = reinterpret_cast<float*>(&tmp); 43*da0073e9SAndroid Build Coastguard Worker res = *tempRes; 44*da0073e9SAndroid Build Coastguard Worker #else 45*da0073e9SAndroid Build Coastguard Worker std::memcpy(&res, &tmp, sizeof(tmp)); 46*da0073e9SAndroid Build Coastguard Worker #endif 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker return res; 49*da0073e9SAndroid Build Coastguard Worker } 50*da0073e9SAndroid Build Coastguard Worker bits_from_f32(float src)51*da0073e9SAndroid Build Coastguard Workerinline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { 52*da0073e9SAndroid Build Coastguard Worker uint32_t res = 0; 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker #if defined(USE_ROCM) 55*da0073e9SAndroid Build Coastguard Worker // We should be using memcpy in order to respect the strict aliasing rule 56*da0073e9SAndroid Build Coastguard Worker // but it fails in the HIP environment. 57*da0073e9SAndroid Build Coastguard Worker uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src); 58*da0073e9SAndroid Build Coastguard Worker res = *tempRes; 59*da0073e9SAndroid Build Coastguard Worker #else 60*da0073e9SAndroid Build Coastguard Worker std::memcpy(&res, &src, sizeof(res)); 61*da0073e9SAndroid Build Coastguard Worker #endif 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker return res >> 16; 64*da0073e9SAndroid Build Coastguard Worker } 65*da0073e9SAndroid Build Coastguard Worker round_to_nearest_even(float src)66*da0073e9SAndroid Build Coastguard Workerinline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { 67*da0073e9SAndroid Build Coastguard Worker #if defined(USE_ROCM) 68*da0073e9SAndroid Build Coastguard Worker if (src != src) { 69*da0073e9SAndroid Build Coastguard Worker #elif defined(_MSC_VER) 70*da0073e9SAndroid Build Coastguard Worker if (isnan(src)) { 71*da0073e9SAndroid Build Coastguard Worker #else 72*da0073e9SAndroid Build Coastguard Worker if (std::isnan(src)) { 73*da0073e9SAndroid Build Coastguard Worker #endif 74*da0073e9SAndroid Build Coastguard Worker return UINT16_C(0x7FC0); 75*da0073e9SAndroid Build Coastguard Worker } else { 76*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) 77*da0073e9SAndroid Build Coastguard Worker union { 78*da0073e9SAndroid Build Coastguard Worker uint32_t U32; // NOLINT(facebook-hte-BadMemberName) 79*da0073e9SAndroid Build Coastguard Worker float F32; // NOLINT(facebook-hte-BadMemberName) 80*da0073e9SAndroid Build Coastguard Worker }; 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker F32 = src; 83*da0073e9SAndroid Build Coastguard Worker uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); 84*da0073e9SAndroid Build Coastguard Worker return static_cast<uint16_t>((U32 + rounding_bias) >> 16); 85*da0073e9SAndroid Build Coastguard Worker } 86*da0073e9SAndroid Build Coastguard Worker } 87*da0073e9SAndroid Build Coastguard Worker } // namespace detail 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker struct alignas(2) BFloat16 { 90*da0073e9SAndroid Build Coastguard Worker uint16_t x; 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker // HIP wants __host__ __device__ tag, CUDA does not 93*da0073e9SAndroid Build Coastguard Worker #if defined(USE_ROCM) 94*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE BFloat16() = default; 95*da0073e9SAndroid Build Coastguard Worker #else 96*da0073e9SAndroid Build Coastguard Worker BFloat16() = default; 97*da0073e9SAndroid Build Coastguard Worker #endif 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker struct from_bits_t {}; 100*da0073e9SAndroid Build Coastguard Worker static constexpr C10_HOST_DEVICE from_bits_t from_bits() { 101*da0073e9SAndroid Build Coastguard Worker return from_bits_t(); 102*da0073e9SAndroid Build Coastguard Worker } 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) 105*da0073e9SAndroid Build Coastguard Worker : x(bits) {} 106*da0073e9SAndroid Build Coastguard Worker /* implicit */ inline C10_HOST_DEVICE BFloat16(float value); 107*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE operator float() const; 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) && !defined(USE_ROCM) 110*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); 111*da0073e9SAndroid Build Coastguard Worker explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; 112*da0073e9SAndroid Build Coastguard Worker #endif 113*da0073e9SAndroid Build Coastguard Worker #if defined(__HIPCC__) && defined(USE_ROCM) 114*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16(const __hip_bfloat16& value); 115*da0073e9SAndroid Build Coastguard Worker explicit inline C10_HOST_DEVICE operator __hip_bfloat16() const; 116*da0073e9SAndroid Build Coastguard Worker #endif 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) 119*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value); 120*da0073e9SAndroid Build Coastguard Worker explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const; 121*da0073e9SAndroid Build Coastguard Worker #endif 122*da0073e9SAndroid Build Coastguard Worker }; 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker C10_API inline std::ostream& operator<<( 125*da0073e9SAndroid Build Coastguard Worker std::ostream& out, 126*da0073e9SAndroid Build Coastguard Worker const BFloat16& value) { 127*da0073e9SAndroid Build Coastguard Worker out << (float)value; 128*da0073e9SAndroid Build Coastguard Worker return out; 129*da0073e9SAndroid Build Coastguard Worker } 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker } // namespace c10 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker #include <c10/util/BFloat16-inl.h> // IWYU pragma: keep 134