1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 * SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, this 9 * list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 * 30 **************************************************************************************************/ 31 /*! 32 \file 33 \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register 34 */ 35 36 #pragma once 37 38 #include <cutlass/arch/arch.h> 39 #include <cutlass/array.h> 40 #include <cutlass/half.h> 41 #include <cutlass/numeric_types.h> 42 43 namespace cutlass { 44 45 // This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low 46 // bits and the odd elements are in the high bits of the register. In addition, it assumes elements were originally 47 // signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. 48 // This converter will uninterleave the data and subtract the bias while converting to the result type. 49 template<typename T, typename S, int N> 50 struct FastInterleavedAndBiasedNumericArrayConverter { 51 }; 52 53 template<> 54 struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4> { 55 using result_type = Array<half_t, 4>; 56 using source_type = Array<uint8_t, 4>; 57 58 CUTLASS_DEVICE 59 static result_type convert(source_type const& source) 60 { 61 result_type result; 62 63 uint32_t* h = reinterpret_cast<uint32_t*>(&result); 64 uint32_t const i8s = reinterpret_cast<uint32_t const&>(source); 65 66 static constexpr uint32_t mask_for_elt_01 = 0x5250; 67 static constexpr uint32_t mask_for_elt_23 = 0x5351; 68 static constexpr uint32_t start_byte_for_fp16 = 0x64646464; 69 asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); 70 asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); 71 72 // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. 73 static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; 74 asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); 75 asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); 76 77 return result; 78 } 79 80 CUTLASS_DEVICE 81 result_type operator()(source_type const& s) 82 { 83 return convert(s); 84 } 85 }; 86 87 template<int N> 88 struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N> { 89 static constexpr int VEC_WIDTH = 4; 90 static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); 91 92 using result_type = Array<half_t, N>; 93 using source_type = Array<uint8_t, N>; 94 95 CUTLASS_DEVICE 96 static result_type convert(source_type const& source) 97 { 98 using scalar_result_type = typename result_type::Element; 99 using scalar_source_type = typename source_type::Element; 100 FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH> 101 convert_vector_; 102 103 result_type result; 104 using vec_result = Array<scalar_result_type, VEC_WIDTH>; 105 using vec_source = Array<scalar_source_type, VEC_WIDTH>; 106 107 vec_result* result_ptr = reinterpret_cast<vec_result*>(&result); 108 vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source); 109 110 CUTLASS_PRAGMA_UNROLL 111 for (int i = 0; i < N / VEC_WIDTH; ++i) { 112 result_ptr[i] = convert_vector_(source_ptr[i]); 113 } 114 115 return result; 116 } 117 118 CUTLASS_DEVICE 119 result_type operator()(source_type const& s) 120 { 121 return convert(s); 122 } 123 }; 124 125 template<> 126 struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, 4> { 127 using result_type = Array<bfloat16_t, 4>; 128 using source_type = Array<uint8_t, 4>; 129 130 CUTLASS_DEVICE 131 static result_type convert(source_type const& source) 132 { 133 result_type result; 134 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) 135 136 uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&result); 137 uint32_t const i8s = reinterpret_cast<uint32_t const&>(source); 138 139 static constexpr uint32_t fp32_base = 0x4B000000; 140 float fp32_intermediates[4]; 141 142 // Construct FP32s, bfloat does not have enough mantissa for IADD trick 143 uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates); 144 fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); 145 fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); 146 fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); 147 fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); 148 149 // Subtract out fp32_base + 128 to make the unsigned integer signed. 150 CUTLASS_PRAGMA_UNROLL 151 for (int ii = 0; ii < 4; ++ii) { 152 fp32_intermediates[ii] -= 8388736.f; 153 } 154 155 // Truncate the fp32 representation and pack up as bfloat16s. 156 CUTLASS_PRAGMA_UNROLL 157 for (int ii = 0; ii < 2; ++ii) { 158 bf16_result_ptr[ii] = 159 __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); 160 } 161 #else 162 // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use 163 // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. 164 result.clear(); // Suppress compiler warning 165 arch::device_breakpoint(); 166 #endif 167 return result; 168 } 169 170 CUTLASS_DEVICE 171 result_type operator()(source_type const& s) 172 { 173 return convert(s); 174 } 175 }; 176 177 template<int N> 178 struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, N> { 179 static constexpr int VEC_WIDTH = 4; 180 static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); 181 182 using result_type = Array<bfloat16_t, N>; 183 using source_type = Array<uint8_t, N>; 184 185 CUTLASS_DEVICE 186 static result_type convert(source_type const& source) 187 { 188 using scalar_result_type = typename result_type::Element; 189 using scalar_source_type = typename source_type::Element; 190 FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH> 191 convert_vector_; 192 193 result_type result; 194 using vec_result = Array<scalar_result_type, VEC_WIDTH>; 195 using vec_source = Array<scalar_source_type, VEC_WIDTH>; 196 197 vec_result* result_ptr = reinterpret_cast<vec_result*>(&result); 198 vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source); 199 200 CUTLASS_PRAGMA_UNROLL 201 for (int i = 0; i < N / VEC_WIDTH; ++i) { 202 result_ptr[i] = convert_vector_(source_ptr[i]); 203 } 204 205 return result; 206 } 207 208 CUTLASS_DEVICE 209 result_type operator()(source_type const& s) 210 { 211 return convert(s); 212 } 213 }; 214 215 template<> 216 struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, 8> { 217 using result_type = Array<half_t, 8>; 218 using source_type = Array<uint4b_t, 8>; 219 220 CUTLASS_DEVICE 221 static result_type convert(source_type const& source) 222 { 223 result_type result; 224 225 uint32_t* h = reinterpret_cast<uint32_t*>(&result); 226 uint32_t const i4s = reinterpret_cast<uint32_t const&>(source); 227 228 // First, we extract the i4s and construct an intermediate fp16 number. 229 static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; 230 static constexpr uint32_t BOTTOM_MASK = 0x000f000f; 231 static constexpr uint32_t TOP_MASK = 0x00f000f0; 232 static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; 233 234 // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing 235 // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. 236 // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and 237 // elt_67 to fp16 without having to shift them to the bottom bits before hand. 238 239 // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue 240 // immediately before required. 241 const uint32_t top_i4s = i4s >> 8; 242 // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 243 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" 244 : "=r"(h[0]) 245 : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); 246 // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 247 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" 248 : "=r"(h[1]) 249 : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); 250 // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 251 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" 252 : "=r"(h[2]) 253 : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); 254 // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 255 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" 256 : "=r"(h[3]) 257 : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); 258 259 // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the 260 // half2 ctor. In this case, I chose performance reliability over code readability. 261 262 // This is the half2 {1032, 1032} represented as an integer. 263 static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; 264 // This is the half2 {1 / 16, 1 / 16} represented as an integer. 265 static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; 266 // This is the half2 {-72, -72} represented as an integer. 267 static constexpr uint32_t NEG_72 = 0xd480d480; 268 269 // Finally, we construct the output numbers. 270 // Convert elt_01 271 asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); 272 // Convert elt_23 273 asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); 274 // Convert elt_45 275 asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); 276 // Convert elt_67 277 asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); 278 279 return result; 280 } 281 282 CUTLASS_DEVICE 283 result_type operator()(source_type const& s) 284 { 285 return convert(s); 286 } 287 }; 288 289 template<int N> 290 struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, N> { 291 static constexpr int VEC_WIDTH = 8; 292 static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); 293 294 using result_type = Array<half_t, N>; 295 using source_type = Array<uint4b_t, N>; 296 297 CUTLASS_DEVICE 298 static result_type convert(source_type const& source) 299 { 300 using scalar_result_type = typename result_type::Element; 301 using scalar_source_type = typename source_type::Element; 302 FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH> 303 convert_vector_; 304 305 result_type result; 306 using vec_result = Array<scalar_result_type, VEC_WIDTH>; 307 using vec_source = Array<scalar_source_type, VEC_WIDTH>; 308 309 vec_result* result_ptr = reinterpret_cast<vec_result*>(&result); 310 vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source); 311 312 CUTLASS_PRAGMA_UNROLL 313 for (int i = 0; i < N / VEC_WIDTH; ++i) { 314 result_ptr[i] = convert_vector_(source_ptr[i]); 315 } 316 317 return result; 318 } 319 320 CUTLASS_DEVICE 321 result_type operator()(source_type const& s) 322 { 323 return convert(s); 324 } 325 }; 326 327 template<> 328 struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, 8> { 329 using result_type = Array<bfloat16_t, 8>; 330 using source_type = Array<uint4b_t, 8>; 331 332 CUTLASS_DEVICE 333 static result_type convert(source_type const& source) 334 { 335 result_type result; 336 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) 337 338 uint32_t* h = reinterpret_cast<uint32_t*>(&result); 339 uint32_t const source_i4s = reinterpret_cast<uint32_t const&>(source); 340 341 // First, we extract the i4s and construct an intermediate fp16 number. 342 static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; 343 static constexpr uint32_t MASK = 0x000f000f; 344 static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; 345 346 // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. 347 // No shift needed for first item. 348 uint32_t i4s = source_i4s; 349 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" 350 : "=r"(h[0]) 351 : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); 352 CUTLASS_PRAGMA_UNROLL 353 for (int ii = 1; ii < result_type::kElements / 2; ++ii) { 354 i4s >>= sizeof_bits<typename source_type::Element>::value; 355 // (i4s & 0x000f000f) | 0x43004300 356 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" 357 : "=r"(h[ii]) 358 : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); 359 } 360 361 // This is the BF16 {-136, -136} represented as an integer. 362 static constexpr uint32_t BF16_BIAS = 0xC308C308; 363 static constexpr uint32_t BF16_ONE = 0x3F803F80; 364 365 // Finally, we construct the output numbers. 366 CUTLASS_PRAGMA_UNROLL 367 for (int ii = 0; ii < result_type::kElements / 2; ++ii) { 368 // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction 369 asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); 370 } 371 #else 372 // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use 373 // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. 374 arch::device_breakpoint(); 375 result.clear(); // Suppress compiler warning. 376 #endif 377 return result; 378 } 379 380 CUTLASS_DEVICE 381 result_type operator()(source_type const& s) 382 { 383 return convert(s); 384 } 385 }; 386 387 template<int N> 388 struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N> { 389 static constexpr int VEC_WIDTH = 8; 390 static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); 391 392 using result_type = Array<bfloat16_t, N>; 393 using source_type = Array<uint4b_t, N>; 394 395 CUTLASS_DEVICE 396 static result_type convert(source_type const& source) 397 { 398 using scalar_result_type = typename result_type::Element; 399 using scalar_source_type = typename source_type::Element; 400 FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH> 401 convert_vector_; 402 403 result_type result; 404 using vec_result = Array<scalar_result_type, VEC_WIDTH>; 405 using vec_source = Array<scalar_source_type, VEC_WIDTH>; 406 407 vec_result* result_ptr = reinterpret_cast<vec_result*>(&result); 408 vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source); 409 410 CUTLASS_PRAGMA_UNROLL 411 for (int i = 0; i < N / VEC_WIDTH; ++i) { 412 result_ptr[i] = convert_vector_(source_ptr[i]); 413 } 414 415 return result; 416 } 417 418 CUTLASS_DEVICE 419 result_type operator()(source_type const& s) 420 { 421 return convert(s); 422 } 423 }; 424 425 ///////////////////////////////////////////////////////////////////////////////////////////////// 426 427 } // namespace cutlass 428 429 ///////////////////////////////////////////////////////////////////////////////////////////////// 430