1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3  * SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice, this
9  * list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *
30  **************************************************************************************************/
31 /*! \file
32   \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
33 */
34 
35 #pragma once
36 
37 #include <cutlass/cutlass.h>
38 
39 #include <cutlass/array.h>
40 #include <cutlass/matrix_shape.h>
41 #include <cutlass/numeric_types.h>
42 #include <cutlass/tensor_ref.h>
43 
44 #include <cutlass/arch/arch.h>
45 #include <cutlass/arch/memory_sm75.h>
46 #include <cutlass/gemm/gemm.h>
47 
48 #include <cutlass/layout/matrix.h>
49 #include <cutlass/layout/pitch_linear.h>
50 #include <cutlass/layout/tensor.h>
51 
52 #include <cutlass/functional.h>
53 #include <cutlass/platform/platform.h>
54 
55 //#include <src/fastertransformer/utils/cuda_bf16_wrapper.h>
56 //#ifdef ENABLE_BF16
57 #include <cuda_bf16.h>
58 //#endif
59 
60 ////////////////////////////////////////////////////////////////////////////////
61 
62 namespace cutlass {
63 namespace gemm {
64 namespace warp {
65 
66 ////////////////////////////////////////////////////////////////////////////////
67 
68 template<
69     /// Matrix multiply operator
70     typename MmaOperator_,
71     /// Size of the matrix to load (concept: MatrixShape)
72     typename Shape_,
73     /// Operand identity
74     Operand Operand,
75     /// Data type of Scale elements
76     typename Element_,
77     /// Layout of operand
78     typename Layout_,
79     /// Number of threads participating in one matrix operation
80     int Threads,
81     ///
82     typename Enable = void>
83 class MmaTensorOpDequantizer;
84 
85 ////////////////////////////////////////////////////////////////////////////////
86 // Bfloat specialization for Ampere
87 template<
88     /// Underlying matrix multiply operator (concept: MmaTensorOp)
89     typename MmaOperator_,
90     /// Shape of the warp level matrix multiply (concept: GemmShape)
91     typename Shape_>
92 class MmaTensorOpDequantizer<
93     MmaOperator_,
94     Shape_,
95     Operand::kB,
96     bfloat16_t,
97     layout::RowMajor,
98     32,
99     typename platform::enable_if<
100         MmaOperator_::ArchTag::kMinComputeCapability >= 80
101         && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
102 
103 public:
104     /// Mma Operator
105     using MmaOperator = MmaOperator_;
106 
107     // The architecture specific mma ooperator being used
108     using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
109 
110     // Mma Instruction Shape
111     using InstructionShape = typename ArchMmaOperator::Shape;
112 
113     // This is the ratio of the load instruction vs the compute instruction.
114     static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
115 
116     /// Type of the scales
117     using ElementScale = bfloat16_t;
118 
119     /// Fragment to hold B data before Mma
120     using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
121 
122     // Fragment to hold scale data to apply to B before mma
123     // We need 1 fp16 per matrix iteration in the N dimension
124     static constexpr int kColsPerMmaPerThread = 1;
125     using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
126 
127     /// Warp mma shape
128     using Shape = Shape_;
129 
130     /// Layout of the scales in shared memory
131     using Layout = layout::RowMajor;
132 
133     /// TensorRef type for loading element from a tensor
134     using TensorRef = TensorRef<ElementScale, Layout>;
135 
136     CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales,const int warp_idx_n,const int lane_idx)137     MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
138     {
139         const int warp_offset   = warp_idx_n * Shape::kN;
140         const int quad          = lane_idx / 4;
141         const int thread_offset = warp_offset + quad;
142         pointer_                = smem_scales.data() + thread_offset;
143     }
144 
145     CUTLASS_DEVICE
load(FragmentScale & scale_frag)146     void load(FragmentScale& scale_frag)
147     {
148 
149         CUTLASS_PRAGMA_UNROLL
150         for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
151             scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
152         }
153     }
154 
155     CUTLASS_DEVICE
dequantize(FragmentDequantizedOperand & operand_frag,const FragmentScale & scale_frag)156     void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
157     {
158 //#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
159 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
160         using _MmaOperandB        = typename ArchMmaOperator::FragmentB;
161         using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
162         static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
163                           == FragmentDequantizedOperand::kElements,
164                       "");
165 
166         const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag);
167 
168         ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
169         CUTLASS_PRAGMA_UNROLL
170         for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
171             static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
172 
173             __nv_bfloat162  scalex2            = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
174             __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
175             CUTLASS_PRAGMA_UNROLL
176             for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) {
177                 operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
178             }
179         }
180 #else
181         // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
182         // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
183         // numerous conversion instructions in GEMM main loop.
184         arch::device_breakpoint();
185 #endif
186     }
187 
188 private:
189     ElementScale const* pointer_;
190 };
191 
192 ////////////////////////////////////////////////////////////////////////////////
193 
194 // Specialization for Turing & Ampere
195 template<
196     /// Underlying matrix multiply operator (concept: MmaTensorOp)
197     typename MmaOperator_,
198     /// Shape of the warp level matrix multiply (concept: GemmShape)
199     typename Shape_>
200 class MmaTensorOpDequantizer<
201     MmaOperator_,
202     Shape_,
203     Operand::kB,
204     half_t,
205     layout::RowMajor,
206     32,
207     typename platform::enable_if<
208         MmaOperator_::ArchTag::kMinComputeCapability >= 75
209         && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
210 
211 public:
212     /// Mma Operator
213     using MmaOperator = MmaOperator_;
214 
215     // The architecture specific mma ooperator being used
216     using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
217 
218     // Mma Instruction Shape
219     using InstructionShape = typename ArchMmaOperator::Shape;
220 
221     // This is the ratio of the load instruction vs the compute instruction.
222     static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
223 
224     /// Type of the scales
225     using ElementScale = half_t;
226 
227     /// Fragment to hold B data before Mma
228     using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
229 
230     // Fragment to hold scale data to apply to B before mma
231     // We need 1 fp16 per matrix iteration in the N dimension
232     static constexpr int kColsPerMmaPerThread = 1;
233     using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
234 
235     /// Warp mma shape
236     using Shape = Shape_;
237 
238     /// Layout of the scales in shared memory
239     using Layout = layout::RowMajor;
240 
241     /// TensorRef type for loading element from a tensor
242     using TensorRef = TensorRef<ElementScale, Layout>;
243 
244     CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales,const int warp_idx_n,const int lane_idx)245     MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
246     {
247         const int warp_offset   = warp_idx_n * Shape::kN;
248         const int quad          = lane_idx / 4;
249         const int thread_offset = warp_offset + quad;
250         pointer_                = smem_scales.data() + thread_offset;
251     }
252 
253     CUTLASS_DEVICE
load(FragmentScale & scale_frag)254     void load(FragmentScale& scale_frag)
255     {
256 
257         CUTLASS_PRAGMA_UNROLL
258         for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
259             scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
260         }
261     }
262 
263     CUTLASS_DEVICE
dequantize(FragmentDequantizedOperand & operand_frag,const FragmentScale & scale_frag)264     void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
265     {
266         using _MmaOperandB        = typename ArchMmaOperator::FragmentB;
267         using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
268         static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
269                           == FragmentDequantizedOperand::kElements,
270                       "");
271 
272         multiplies<ExpandedMmaOperandB> mul_op;
273 
274         ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
275         CUTLASS_PRAGMA_UNROLL
276         for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
277             operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
278         }
279     }
280 
281 private:
282     ElementScale const* pointer_;
283 };
284 
285 ////////////////////////////////////////////////////////////////////////////////
286 
287 // Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm
288 template<
289     /// Underlying matrix multiply operator (concept: MmaTensorOp)
290     typename MmaOperator_,
291     /// Shape of the warp level matrix multiply (concept: GemmShape)
292     typename Shape_>
293 class MmaTensorOpDequantizer<
294     MmaOperator_,
295     Shape_,
296     Operand::kB,
297     half_t,
298     layout::RowMajor,
299     32,
300     typename platform::enable_if<
301         platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value
302         && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::RowMajor>::value>::type> {
303 
304 public:
305     static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, "");
306 
307     /// Mma Operator
308     using MmaOperator = MmaOperator_;
309 
310     // The architecture specific mma ooperator being used
311     using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
312 
313     // Mma Instruction Shape
314     using InstructionShape = typename ArchMmaOperator::Shape;
315 
316     /// Type of the scales
317     using ElementScale = half_t;
318 
319     /// Fragment to hold B data before Mma
320     using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
321 
322     /// Warp mma shape
323     using Shape = Shape_;
324 
325     // Fragment to hold scale data to apply to B before mma
326     // Each 32x32x4 matmul uses 8 elements from B.
327     static constexpr int ColsPerMmaTile  = 32;
328     static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
329     using FragmentScale                  = Array<ElementScale, TileNIterations * 8>;
330     using AccessType                     = Array<ElementScale, 8>;
331 
332     /// Layout of the scales in shared memory
333     using Layout = layout::RowMajor;
334 
335     /// TensorRef type for loading element from a tensor
336     using TensorRef = TensorRef<ElementScale, Layout>;
337 
338     CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales,const int warp_idx_n,const int lane_idx)339     MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
340     {
341         const int warp_offset   = warp_idx_n * Shape::kN;
342         const int base_col      = lane_idx & 0xF8;
343         const int thread_offset = warp_offset + base_col;
344         pointer_                = smem_scales.data() + thread_offset;
345     }
346 
347     CUTLASS_DEVICE
load(FragmentScale & scale_frag)348     void load(FragmentScale& scale_frag)
349     {
350         AccessType* scale_frag_ptr = reinterpret_cast<AccessType*>(&scale_frag);
351 
352         CUTLASS_PRAGMA_UNROLL
353         for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) {
354             // We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
355             scale_frag_ptr[tile_iter] = *reinterpret_cast<AccessType const*>(pointer_ + ColsPerMmaTile * tile_iter);
356         }
357     }
358 
359     CUTLASS_DEVICE
dequantize(FragmentDequantizedOperand & operand_frag,const FragmentScale & scale_frag)360     void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
361     {
362         static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, "");
363 
364         multiplies<FragmentDequantizedOperand> mul_op;
365         operand_frag = mul_op(operand_frag, scale_frag);
366     }
367 
368 private:
369     ElementScale const* pointer_;
370 };
371 
372 ////////////////////////////////////////////////////////////////////////////////
373 
374 // Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm
375 template<
376     /// Underlying matrix multiply operator (concept: MmaTensorOp)
377     typename MmaOperator_,
378     /// Shape of the warp level matrix multiply (concept: GemmShape)
379     typename Shape_>
380 class MmaTensorOpDequantizer<
381     MmaOperator_,
382     Shape_,
383     Operand::kB,
384     half_t,
385     layout::RowMajor,
386     32,
387     typename platform::enable_if<
388         platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value
389         && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
390 
391 public:
392     static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, "");
393 
394     /// Mma Operator
395     using MmaOperator = MmaOperator_;
396 
397     // The architecture specific mma ooperator being used
398     using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
399 
400     // Mma Instruction Shape
401     using InstructionShape = typename ArchMmaOperator::Shape;
402 
403     /// Type of the scales
404     using ElementScale = half_t;
405 
406     /// Fragment to hold B data before Mma
407     using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
408 
409     /// Warp mma shape
410     using Shape = Shape_;
411 
412     // Fragment to hold scale data to apply to B before mma
413     // Each 32x32x4 matmul uses 8 elements from B.
414     static constexpr int ColsPerMmaTile  = 32;
415     static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
416     using FragmentScale                  = Array<ElementScale, TileNIterations * 2>;
417 
418     /// Layout of the scales in shared memory
419     using Layout = layout::RowMajor;
420 
421     /// TensorRef type for loading element from a tensor
422     using TensorRef = TensorRef<ElementScale, Layout>;
423 
424     CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales,const int warp_idx_n,const int lane_idx)425     MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
426     {
427         const int warp_offset   = warp_idx_n * Shape::kN;
428         const int base_col      = lane_idx & 0xF8 + lane_idx % 4;
429         const int thread_offset = warp_offset + base_col;
430         pointer_                = smem_scales.data() + thread_offset;
431     }
432 
433     CUTLASS_DEVICE
load(FragmentScale & scale_frag)434     void load(FragmentScale& scale_frag)
435     {
436         CUTLASS_PRAGMA_UNROLL
437         for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) {
438             // We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
439             // For col major B, each thread will jump 4 cols to get its next value inside
440             // of the super mma.
441             CUTLASS_PRAGMA_UNROLL
442             for (int mma_iter = 0; mma_iter < 2; ++mma_iter) {
443                 scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter];
444             }
445         }
446     }
447 
448     CUTLASS_DEVICE
dequantize(FragmentDequantizedOperand & operand_frag,const FragmentScale & scale_frag)449     void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
450     {
451         using MmaOperandB                 = typename ArchMmaOperator::FragmentB;
452         static constexpr int total_n_mmas = 2 * TileNIterations;
453         static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, "");
454 
455         multiplies<MmaOperandB> mul_op;
456 
457         MmaOperandB* operand_frag_ptr = reinterpret_cast<MmaOperandB*>(&operand_frag);
458         CUTLASS_PRAGMA_UNROLL
459         for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) {
460             operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
461         }
462     }
463 
464 private:
465     ElementScale const* pointer_;
466 };
467 
468 ////////////////////////////////////////////////////////////////////////////////
469 
470 }  // namespace warp
471 }  // namespace gemm
472 }  // namespace cutlass
473 
474 ////////////////////////////////////////////////////////////////////////////////
475