1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3  * SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice, this
9  * list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *
30  **************************************************************************************************/
31 /*!
32     \file
33     \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register
34 */
35 
36 #pragma once
37 
38 #include <cutlass/arch/arch.h>
39 #include <cutlass/array.h>
40 #include <cutlass/half.h>
41 #include <cutlass/numeric_types.h>
42 
43 namespace cutlass {
44 
45 // This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
46 // bits and the odd elements are in the high bits of the register. In addition, it assumes elements were originally
47 // signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
48 // This converter will uninterleave the data and subtract the bias while converting to the result type.
49 template<typename T, typename S, int N>
50 struct FastInterleavedAndBiasedNumericArrayConverter {
51 };
52 
53 template<>
54 struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4> {
55     using result_type = Array<half_t, 4>;
56     using source_type = Array<uint8_t, 4>;
57 
58     CUTLASS_DEVICE
59     static result_type convert(source_type const& source)
60     {
61         result_type result;
62 
63         uint32_t*      h   = reinterpret_cast<uint32_t*>(&result);
64         uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
65 
66         static constexpr uint32_t mask_for_elt_01     = 0x5250;
67         static constexpr uint32_t mask_for_elt_23     = 0x5351;
68         static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
69         asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
70         asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
71 
72         // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
73         static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
74         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
75         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
76 
77         return result;
78     }
79 
80     CUTLASS_DEVICE
81     result_type operator()(source_type const& s)
82     {
83         return convert(s);
84     }
85 };
86 
87 template<int N>
88 struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N> {
89     static constexpr int VEC_WIDTH = 4;
90     static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
91 
92     using result_type = Array<half_t, N>;
93     using source_type = Array<uint8_t, N>;
94 
95     CUTLASS_DEVICE
96     static result_type convert(source_type const& source)
97     {
98         using scalar_result_type = typename result_type::Element;
99         using scalar_source_type = typename source_type::Element;
100         FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
101             convert_vector_;
102 
103         result_type result;
104         using vec_result = Array<scalar_result_type, VEC_WIDTH>;
105         using vec_source = Array<scalar_source_type, VEC_WIDTH>;
106 
107         vec_result*       result_ptr = reinterpret_cast<vec_result*>(&result);
108         vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
109 
110         CUTLASS_PRAGMA_UNROLL
111         for (int i = 0; i < N / VEC_WIDTH; ++i) {
112             result_ptr[i] = convert_vector_(source_ptr[i]);
113         }
114 
115         return result;
116     }
117 
118     CUTLASS_DEVICE
119     result_type operator()(source_type const& s)
120     {
121         return convert(s);
122     }
123 };
124 
125 template<>
126 struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, 4> {
127     using result_type = Array<bfloat16_t, 4>;
128     using source_type = Array<uint8_t, 4>;
129 
130     CUTLASS_DEVICE
131     static result_type convert(source_type const& source)
132     {
133         result_type result;
134 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
135 
136         uint32_t*      bf16_result_ptr = reinterpret_cast<uint32_t*>(&result);
137         uint32_t const i8s             = reinterpret_cast<uint32_t const&>(source);
138 
139         static constexpr uint32_t fp32_base = 0x4B000000;
140         float                     fp32_intermediates[4];
141 
142         // Construct FP32s, bfloat does not have enough mantissa for IADD trick
143         uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
144         fp32_intermediates_casted[0]        = __byte_perm(i8s, fp32_base, 0x7650);
145         fp32_intermediates_casted[1]        = __byte_perm(i8s, fp32_base, 0x7652);
146         fp32_intermediates_casted[2]        = __byte_perm(i8s, fp32_base, 0x7651);
147         fp32_intermediates_casted[3]        = __byte_perm(i8s, fp32_base, 0x7653);
148 
149         // Subtract out fp32_base + 128 to make the unsigned integer signed.
150         CUTLASS_PRAGMA_UNROLL
151         for (int ii = 0; ii < 4; ++ii) {
152             fp32_intermediates[ii] -= 8388736.f;
153         }
154 
155         // Truncate the fp32 representation and pack up as bfloat16s.
156         CUTLASS_PRAGMA_UNROLL
157         for (int ii = 0; ii < 2; ++ii) {
158             bf16_result_ptr[ii] =
159                 __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
160         }
161 #else
162         // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
163         // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
164         result.clear();  // Suppress compiler warning
165         arch::device_breakpoint();
166 #endif
167         return result;
168     }
169 
170     CUTLASS_DEVICE
171     result_type operator()(source_type const& s)
172     {
173         return convert(s);
174     }
175 };
176 
177 template<int N>
178 struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, N> {
179     static constexpr int VEC_WIDTH = 4;
180     static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
181 
182     using result_type = Array<bfloat16_t, N>;
183     using source_type = Array<uint8_t, N>;
184 
185     CUTLASS_DEVICE
186     static result_type convert(source_type const& source)
187     {
188         using scalar_result_type = typename result_type::Element;
189         using scalar_source_type = typename source_type::Element;
190         FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
191             convert_vector_;
192 
193         result_type result;
194         using vec_result = Array<scalar_result_type, VEC_WIDTH>;
195         using vec_source = Array<scalar_source_type, VEC_WIDTH>;
196 
197         vec_result*       result_ptr = reinterpret_cast<vec_result*>(&result);
198         vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
199 
200         CUTLASS_PRAGMA_UNROLL
201         for (int i = 0; i < N / VEC_WIDTH; ++i) {
202             result_ptr[i] = convert_vector_(source_ptr[i]);
203         }
204 
205         return result;
206     }
207 
208     CUTLASS_DEVICE
209     result_type operator()(source_type const& s)
210     {
211         return convert(s);
212     }
213 };
214 
215 template<>
216 struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, 8> {
217     using result_type = Array<half_t, 8>;
218     using source_type = Array<uint4b_t, 8>;
219 
220     CUTLASS_DEVICE
221     static result_type convert(source_type const& source)
222     {
223         result_type result;
224 
225         uint32_t*      h   = reinterpret_cast<uint32_t*>(&result);
226         uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
227 
228         // First, we extract the i4s and construct an intermediate fp16 number.
229         static constexpr uint32_t immLut                = (0xf0 & 0xcc) | 0xaa;
230         static constexpr uint32_t BOTTOM_MASK           = 0x000f000f;
231         static constexpr uint32_t TOP_MASK              = 0x00f000f0;
232         static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
233 
234         // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
235         // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
236         // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
237         // elt_67 to fp16 without having to shift them to the bottom bits before hand.
238 
239         // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
240         // immediately before required.
241         const uint32_t top_i4s = i4s >> 8;
242         // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
243         asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
244                      : "=r"(h[0])
245                      : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
246         // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
247         asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
248                      : "=r"(h[1])
249                      : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
250         // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
251         asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
252                      : "=r"(h[2])
253                      : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
254         // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
255         asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
256                      : "=r"(h[3])
257                      : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
258 
259         // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
260         // half2 ctor. In this case, I chose performance reliability over code readability.
261 
262         // This is the half2 {1032, 1032} represented as an integer.
263         static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
264         // This is the half2 {1 / 16, 1 / 16} represented as an integer.
265         static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
266         // This is the half2 {-72, -72} represented as an integer.
267         static constexpr uint32_t NEG_72 = 0xd480d480;
268 
269         // Finally, we construct the output numbers.
270         // Convert elt_01
271         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
272         // Convert elt_23
273         asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
274         // Convert elt_45
275         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
276         // Convert elt_67
277         asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
278 
279         return result;
280     }
281 
282     CUTLASS_DEVICE
283     result_type operator()(source_type const& s)
284     {
285         return convert(s);
286     }
287 };
288 
289 template<int N>
290 struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, N> {
291     static constexpr int VEC_WIDTH = 8;
292     static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
293 
294     using result_type = Array<half_t, N>;
295     using source_type = Array<uint4b_t, N>;
296 
297     CUTLASS_DEVICE
298     static result_type convert(source_type const& source)
299     {
300         using scalar_result_type = typename result_type::Element;
301         using scalar_source_type = typename source_type::Element;
302         FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
303             convert_vector_;
304 
305         result_type result;
306         using vec_result = Array<scalar_result_type, VEC_WIDTH>;
307         using vec_source = Array<scalar_source_type, VEC_WIDTH>;
308 
309         vec_result*       result_ptr = reinterpret_cast<vec_result*>(&result);
310         vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
311 
312         CUTLASS_PRAGMA_UNROLL
313         for (int i = 0; i < N / VEC_WIDTH; ++i) {
314             result_ptr[i] = convert_vector_(source_ptr[i]);
315         }
316 
317         return result;
318     }
319 
320     CUTLASS_DEVICE
321     result_type operator()(source_type const& s)
322     {
323         return convert(s);
324     }
325 };
326 
327 template<>
328 struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, 8> {
329     using result_type = Array<bfloat16_t, 8>;
330     using source_type = Array<uint4b_t, 8>;
331 
332     CUTLASS_DEVICE
333     static result_type convert(source_type const& source)
334     {
335         result_type result;
336 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
337 
338         uint32_t*      h          = reinterpret_cast<uint32_t*>(&result);
339         uint32_t const source_i4s = reinterpret_cast<uint32_t const&>(source);
340 
341         // First, we extract the i4s and construct an intermediate fp16 number.
342         static constexpr uint32_t immLut                 = (0xf0 & 0xcc) | 0xaa;
343         static constexpr uint32_t MASK                   = 0x000f000f;
344         static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
345 
346         // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop.
347         // No shift needed for first item.
348         uint32_t i4s = source_i4s;
349         asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
350                      : "=r"(h[0])
351                      : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
352         CUTLASS_PRAGMA_UNROLL
353         for (int ii = 1; ii < result_type::kElements / 2; ++ii) {
354             i4s >>= sizeof_bits<typename source_type::Element>::value;
355             // (i4s & 0x000f000f) | 0x43004300
356             asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
357                          : "=r"(h[ii])
358                          : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
359         }
360 
361         // This is the BF16 {-136, -136} represented as an integer.
362         static constexpr uint32_t BF16_BIAS = 0xC308C308;
363         static constexpr uint32_t BF16_ONE  = 0x3F803F80;
364 
365         // Finally, we construct the output numbers.
366         CUTLASS_PRAGMA_UNROLL
367         for (int ii = 0; ii < result_type::kElements / 2; ++ii) {
368             // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction
369             asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
370         }
371 #else
372         // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
373         // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
374         arch::device_breakpoint();
375         result.clear();  // Suppress compiler warning.
376 #endif
377         return result;
378     }
379 
380     CUTLASS_DEVICE
381     result_type operator()(source_type const& s)
382     {
383         return convert(s);
384     }
385 };
386 
387 template<int N>
388 struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N> {
389     static constexpr int VEC_WIDTH = 8;
390     static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
391 
392     using result_type = Array<bfloat16_t, N>;
393     using source_type = Array<uint4b_t, N>;
394 
395     CUTLASS_DEVICE
396     static result_type convert(source_type const& source)
397     {
398         using scalar_result_type = typename result_type::Element;
399         using scalar_source_type = typename source_type::Element;
400         FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
401             convert_vector_;
402 
403         result_type result;
404         using vec_result = Array<scalar_result_type, VEC_WIDTH>;
405         using vec_source = Array<scalar_source_type, VEC_WIDTH>;
406 
407         vec_result*       result_ptr = reinterpret_cast<vec_result*>(&result);
408         vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
409 
410         CUTLASS_PRAGMA_UNROLL
411         for (int i = 0; i < N / VEC_WIDTH; ++i) {
412             result_ptr[i] = convert_vector_(source_ptr[i]);
413         }
414 
415         return result;
416     }
417 
418     CUTLASS_DEVICE
419     result_type operator()(source_type const& s)
420     {
421         return convert(s);
422     }
423 };
424 
425 /////////////////////////////////////////////////////////////////////////////////////////////////
426 
427 }  // namespace cutlass
428 
429 /////////////////////////////////////////////////////////////////////////////////////////////////
430