1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 * SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, this 9 * list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 * 30 **************************************************************************************************/ 31 /*! \file 32 \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. 33 */ 34 35 #pragma once 36 37 #include <cutlass/cutlass.h> 38 39 #include <cutlass/array.h> 40 #include <cutlass/matrix_shape.h> 41 #include <cutlass/numeric_types.h> 42 #include <cutlass/tensor_ref.h> 43 44 #include <cutlass/arch/arch.h> 45 #include <cutlass/arch/memory_sm75.h> 46 #include <cutlass/gemm/gemm.h> 47 48 #include <cutlass/layout/matrix.h> 49 #include <cutlass/layout/pitch_linear.h> 50 #include <cutlass/layout/tensor.h> 51 52 #include <cutlass/functional.h> 53 #include <cutlass/platform/platform.h> 54 55 //#include <src/fastertransformer/utils/cuda_bf16_wrapper.h> 56 //#ifdef ENABLE_BF16 57 #include <cuda_bf16.h> 58 //#endif 59 60 //////////////////////////////////////////////////////////////////////////////// 61 62 namespace cutlass { 63 namespace gemm { 64 namespace warp { 65 66 //////////////////////////////////////////////////////////////////////////////// 67 68 template< 69 /// Matrix multiply operator 70 typename MmaOperator_, 71 /// Size of the matrix to load (concept: MatrixShape) 72 typename Shape_, 73 /// Operand identity 74 Operand Operand, 75 /// Data type of Scale elements 76 typename Element_, 77 /// Layout of operand 78 typename Layout_, 79 /// Number of threads participating in one matrix operation 80 int Threads, 81 /// 82 typename Enable = void> 83 class MmaTensorOpDequantizer; 84 85 //////////////////////////////////////////////////////////////////////////////// 86 // Bfloat specialization for Ampere 87 template< 88 /// Underlying matrix multiply operator (concept: MmaTensorOp) 89 typename MmaOperator_, 90 /// Shape of the warp level matrix multiply (concept: GemmShape) 91 typename Shape_> 92 class MmaTensorOpDequantizer< 93 MmaOperator_, 94 Shape_, 95 Operand::kB, 96 bfloat16_t, 97 layout::RowMajor, 98 32, 99 typename platform::enable_if< 100 MmaOperator_::ArchTag::kMinComputeCapability >= 80 101 && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> { 102 103 public: 104 /// Mma Operator 105 using MmaOperator = MmaOperator_; 106 107 // The architecture specific mma ooperator being used 108 using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; 109 110 // Mma Instruction Shape 111 using InstructionShape = typename ArchMmaOperator::Shape; 112 113 // This is the ratio of the load instruction vs the compute instruction. 114 static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; 115 116 /// Type of the scales 117 using ElementScale = bfloat16_t; 118 119 /// Fragment to hold B data before Mma 120 using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>; 121 122 // Fragment to hold scale data to apply to B before mma 123 // We need 1 fp16 per matrix iteration in the N dimension 124 static constexpr int kColsPerMmaPerThread = 1; 125 using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>; 126 127 /// Warp mma shape 128 using Shape = Shape_; 129 130 /// Layout of the scales in shared memory 131 using Layout = layout::RowMajor; 132 133 /// TensorRef type for loading element from a tensor 134 using TensorRef = TensorRef<ElementScale, Layout>; 135 136 CUTLASS_DEVICE MmaTensorOpDequantizer(TensorRef smem_scales,const int warp_idx_n,const int lane_idx)137 MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) 138 { 139 const int warp_offset = warp_idx_n * Shape::kN; 140 const int quad = lane_idx / 4; 141 const int thread_offset = warp_offset + quad; 142 pointer_ = smem_scales.data() + thread_offset; 143 } 144 145 CUTLASS_DEVICE load(FragmentScale & scale_frag)146 void load(FragmentScale& scale_frag) 147 { 148 149 CUTLASS_PRAGMA_UNROLL 150 for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { 151 scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN]; 152 } 153 } 154 155 CUTLASS_DEVICE dequantize(FragmentDequantizedOperand & operand_frag,const FragmentScale & scale_frag)156 void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) 157 { 158 //#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) 159 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) 160 using _MmaOperandB = typename ArchMmaOperator::FragmentB; 161 using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>; 162 static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn 163 == FragmentDequantizedOperand::kElements, 164 ""); 165 166 const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag); 167 168 ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag); 169 CUTLASS_PRAGMA_UNROLL 170 for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { 171 static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); 172 173 __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); 174 __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); 175 CUTLASS_PRAGMA_UNROLL 176 for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { 177 operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); 178 } 179 } 180 #else 181 // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should 182 // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid 183 // numerous conversion instructions in GEMM main loop. 184 arch::device_breakpoint(); 185 #endif 186 } 187 188 private: 189 ElementScale const* pointer_; 190 }; 191 192 //////////////////////////////////////////////////////////////////////////////// 193 194 // Specialization for Turing & Ampere 195 template< 196 /// Underlying matrix multiply operator (concept: MmaTensorOp) 197 typename MmaOperator_, 198 /// Shape of the warp level matrix multiply (concept: GemmShape) 199 typename Shape_> 200 class MmaTensorOpDequantizer< 201 MmaOperator_, 202 Shape_, 203 Operand::kB, 204 half_t, 205 layout::RowMajor, 206 32, 207 typename platform::enable_if< 208 MmaOperator_::ArchTag::kMinComputeCapability >= 75 209 && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> { 210 211 public: 212 /// Mma Operator 213 using MmaOperator = MmaOperator_; 214 215 // The architecture specific mma ooperator being used 216 using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; 217 218 // Mma Instruction Shape 219 using InstructionShape = typename ArchMmaOperator::Shape; 220 221 // This is the ratio of the load instruction vs the compute instruction. 222 static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; 223 224 /// Type of the scales 225 using ElementScale = half_t; 226 227 /// Fragment to hold B data before Mma 228 using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>; 229 230 // Fragment to hold scale data to apply to B before mma 231 // We need 1 fp16 per matrix iteration in the N dimension 232 static constexpr int kColsPerMmaPerThread = 1; 233 using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>; 234 235 /// Warp mma shape 236 using Shape = Shape_; 237 238 /// Layout of the scales in shared memory 239 using Layout = layout::RowMajor; 240 241 /// TensorRef type for loading element from a tensor 242 using TensorRef = TensorRef<ElementScale, Layout>; 243 244 CUTLASS_DEVICE MmaTensorOpDequantizer(TensorRef smem_scales,const int warp_idx_n,const int lane_idx)245 MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) 246 { 247 const int warp_offset = warp_idx_n * Shape::kN; 248 const int quad = lane_idx / 4; 249 const int thread_offset = warp_offset + quad; 250 pointer_ = smem_scales.data() + thread_offset; 251 } 252 253 CUTLASS_DEVICE load(FragmentScale & scale_frag)254 void load(FragmentScale& scale_frag) 255 { 256 257 CUTLASS_PRAGMA_UNROLL 258 for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { 259 scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN]; 260 } 261 } 262 263 CUTLASS_DEVICE dequantize(FragmentDequantizedOperand & operand_frag,const FragmentScale & scale_frag)264 void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) 265 { 266 using _MmaOperandB = typename ArchMmaOperator::FragmentB; 267 using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>; 268 static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn 269 == FragmentDequantizedOperand::kElements, 270 ""); 271 272 multiplies<ExpandedMmaOperandB> mul_op; 273 274 ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag); 275 CUTLASS_PRAGMA_UNROLL 276 for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { 277 operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); 278 } 279 } 280 281 private: 282 ElementScale const* pointer_; 283 }; 284 285 //////////////////////////////////////////////////////////////////////////////// 286 287 // Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm 288 template< 289 /// Underlying matrix multiply operator (concept: MmaTensorOp) 290 typename MmaOperator_, 291 /// Shape of the warp level matrix multiply (concept: GemmShape) 292 typename Shape_> 293 class MmaTensorOpDequantizer< 294 MmaOperator_, 295 Shape_, 296 Operand::kB, 297 half_t, 298 layout::RowMajor, 299 32, 300 typename platform::enable_if< 301 platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value 302 && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::RowMajor>::value>::type> { 303 304 public: 305 static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, ""); 306 307 /// Mma Operator 308 using MmaOperator = MmaOperator_; 309 310 // The architecture specific mma ooperator being used 311 using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; 312 313 // Mma Instruction Shape 314 using InstructionShape = typename ArchMmaOperator::Shape; 315 316 /// Type of the scales 317 using ElementScale = half_t; 318 319 /// Fragment to hold B data before Mma 320 using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>; 321 322 /// Warp mma shape 323 using Shape = Shape_; 324 325 // Fragment to hold scale data to apply to B before mma 326 // Each 32x32x4 matmul uses 8 elements from B. 327 static constexpr int ColsPerMmaTile = 32; 328 static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; 329 using FragmentScale = Array<ElementScale, TileNIterations * 8>; 330 using AccessType = Array<ElementScale, 8>; 331 332 /// Layout of the scales in shared memory 333 using Layout = layout::RowMajor; 334 335 /// TensorRef type for loading element from a tensor 336 using TensorRef = TensorRef<ElementScale, Layout>; 337 338 CUTLASS_DEVICE MmaTensorOpDequantizer(TensorRef smem_scales,const int warp_idx_n,const int lane_idx)339 MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) 340 { 341 const int warp_offset = warp_idx_n * Shape::kN; 342 const int base_col = lane_idx & 0xF8; 343 const int thread_offset = warp_offset + base_col; 344 pointer_ = smem_scales.data() + thread_offset; 345 } 346 347 CUTLASS_DEVICE load(FragmentScale & scale_frag)348 void load(FragmentScale& scale_frag) 349 { 350 AccessType* scale_frag_ptr = reinterpret_cast<AccessType*>(&scale_frag); 351 352 CUTLASS_PRAGMA_UNROLL 353 for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { 354 // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. 355 scale_frag_ptr[tile_iter] = *reinterpret_cast<AccessType const*>(pointer_ + ColsPerMmaTile * tile_iter); 356 } 357 } 358 359 CUTLASS_DEVICE dequantize(FragmentDequantizedOperand & operand_frag,const FragmentScale & scale_frag)360 void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) 361 { 362 static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, ""); 363 364 multiplies<FragmentDequantizedOperand> mul_op; 365 operand_frag = mul_op(operand_frag, scale_frag); 366 } 367 368 private: 369 ElementScale const* pointer_; 370 }; 371 372 //////////////////////////////////////////////////////////////////////////////// 373 374 // Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm 375 template< 376 /// Underlying matrix multiply operator (concept: MmaTensorOp) 377 typename MmaOperator_, 378 /// Shape of the warp level matrix multiply (concept: GemmShape) 379 typename Shape_> 380 class MmaTensorOpDequantizer< 381 MmaOperator_, 382 Shape_, 383 Operand::kB, 384 half_t, 385 layout::RowMajor, 386 32, 387 typename platform::enable_if< 388 platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value 389 && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> { 390 391 public: 392 static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, ""); 393 394 /// Mma Operator 395 using MmaOperator = MmaOperator_; 396 397 // The architecture specific mma ooperator being used 398 using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; 399 400 // Mma Instruction Shape 401 using InstructionShape = typename ArchMmaOperator::Shape; 402 403 /// Type of the scales 404 using ElementScale = half_t; 405 406 /// Fragment to hold B data before Mma 407 using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>; 408 409 /// Warp mma shape 410 using Shape = Shape_; 411 412 // Fragment to hold scale data to apply to B before mma 413 // Each 32x32x4 matmul uses 8 elements from B. 414 static constexpr int ColsPerMmaTile = 32; 415 static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; 416 using FragmentScale = Array<ElementScale, TileNIterations * 2>; 417 418 /// Layout of the scales in shared memory 419 using Layout = layout::RowMajor; 420 421 /// TensorRef type for loading element from a tensor 422 using TensorRef = TensorRef<ElementScale, Layout>; 423 424 CUTLASS_DEVICE MmaTensorOpDequantizer(TensorRef smem_scales,const int warp_idx_n,const int lane_idx)425 MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) 426 { 427 const int warp_offset = warp_idx_n * Shape::kN; 428 const int base_col = lane_idx & 0xF8 + lane_idx % 4; 429 const int thread_offset = warp_offset + base_col; 430 pointer_ = smem_scales.data() + thread_offset; 431 } 432 433 CUTLASS_DEVICE load(FragmentScale & scale_frag)434 void load(FragmentScale& scale_frag) 435 { 436 CUTLASS_PRAGMA_UNROLL 437 for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { 438 // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. 439 // For col major B, each thread will jump 4 cols to get its next value inside 440 // of the super mma. 441 CUTLASS_PRAGMA_UNROLL 442 for (int mma_iter = 0; mma_iter < 2; ++mma_iter) { 443 scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter]; 444 } 445 } 446 } 447 448 CUTLASS_DEVICE dequantize(FragmentDequantizedOperand & operand_frag,const FragmentScale & scale_frag)449 void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) 450 { 451 using MmaOperandB = typename ArchMmaOperator::FragmentB; 452 static constexpr int total_n_mmas = 2 * TileNIterations; 453 static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, ""); 454 455 multiplies<MmaOperandB> mul_op; 456 457 MmaOperandB* operand_frag_ptr = reinterpret_cast<MmaOperandB*>(&operand_frag); 458 CUTLASS_PRAGMA_UNROLL 459 for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) { 460 operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); 461 } 462 } 463 464 private: 465 ElementScale const* pointer_; 466 }; 467 468 //////////////////////////////////////////////////////////////////////////////// 469 470 } // namespace warp 471 } // namespace gemm 472 } // namespace cutlass 473 474 //////////////////////////////////////////////////////////////////////////////// 475