#pragma once /// Defines the Float8_e4m3fnuz type (8-bit floating-point) including /// conversions to standard C types and basic arithmetic operations. Note that /// arithmetic operations are implemented by converting to floating point and /// performing the operation in float32. /// Binary configuration remains the same as Float8_e4m3fn: /// s eeee mmm /// 1 sign bit /// 4 exponent bits /// 3 mantissa bits /// The key differences versus Float8_e4m3fn are: /// bias = 8 /// no infinities or negative zero /// NaN only when sign bit is 1, rest all 0s /// /// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and /// the existing Float8_e4m3fn implementation. #include #include #include #include #if defined(__cplusplus) #include #elif !defined(__OPENCL_VERSION__) #include #include #endif #include #include namespace c10 { namespace detail { /* * Convert a 32-bit floating-point number in IEEE single-precision format to a * 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation. */ inline C10_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) { /* * Binary representation of 256.0f, which is the first value not representable * (i.e. the first value which would overflow in to the sign bit, resulting in * a NaN) in fp8e4m3fnuz range: * 1 0000 000 - fp8e4m3fnuz * 0 10000111 00000000000000000000000 - fp32 */ constexpr uint32_t fnuz_max = UINT32_C(0x87) << 23; /* * A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range * into denorm representation * magic number: ((127 - 8) + (23 - 3) + 1) */ constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23; uint32_t f_bits = fp32_to_bits(f); uint32_t result = 0u; /* * Extract the sign of the input number into the high bit of the 32-bit word: * * +---+----------------------------------+ * | S |0000000 00000000 00000000 00000000| * +---+----------------------------------+ * Bits 31 0-31 */ const uint32_t sign = f_bits & UINT32_C(0x80000000); /* * Set sign bit to 0 */ f_bits ^= sign; if (f_bits >= fnuz_max) { // NaN -- sign bit set to 1, rest 0s. return 0x80; } if (f_bits < (UINT32_C(0x78) << 23) /* 2^-7 in float32 */) { // Input exponent is less than -7, the smallest e4m3fnuz exponent, so the // number will become subnormal. f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); result = static_cast(f_bits - denorm_mask); if (result == 0) { // fnuz types don't have negative zero. return 0; } } else { // resulting mantissa is odd uint8_t mant_odd = (f_bits >> 20) & 1; // update exponent, rounding bias part 1 f_bits += ((uint32_t)(8 - 127) << 23) + 0x7FFFF; // rounding bias part 2 f_bits += mant_odd; // take the bits! result = static_cast(f_bits >> 20); } result |= sign >> 24; return result; } } // namespace detail struct alignas(1) Float8_e4m3fnuz { uint8_t x; struct from_bits_t {}; C10_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); } Float8_e4m3fnuz() = default; constexpr C10_HOST_DEVICE Float8_e4m3fnuz(uint8_t bits, from_bits_t) : x(bits) {} inline C10_HOST_DEVICE Float8_e4m3fnuz(float value); inline C10_HOST_DEVICE operator float() const; inline C10_HOST_DEVICE bool isnan() const; }; C10_API inline std::ostream& operator<<( std::ostream& out, const Float8_e4m3fnuz& value) { out << (float)value; return out; } } // namespace c10 #include // IWYU pragma: keep