1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights 3 *reserved. SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 *this list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 *POSSIBILITY OF SUCH DAMAGE. 30 * 31 **************************************************************************************************/ 32 /*! \file 33 \brief Template for a double-buffered threadblock-scoped GEMM kernel. 34 */ 35 36 #pragma once 37 38 #include <cutlass/aligned_buffer.h> 39 #include <cutlass/arch/memory.h> 40 #include <cutlass/array.h> 41 #include <cutlass/cutlass.h> 42 #include <cutlass/epilogue/thread/linear_combination.h> 43 #include <cutlass/epilogue/threadblock/default_epilogue_simt.h> 44 #include <cutlass/epilogue/threadblock/default_epilogue_tensor_op.h> 45 #include <cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h> 46 #include <cutlass/functional.h> 47 #include <cutlass/gemm/gemm.h> 48 #include <cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h> 49 #include <cutlass/matrix_shape.h> 50 #include <cutlass/numeric_conversion.h> 51 #include <cutlass/numeric_types.h> 52 #include <cutlass/platform/platform.h> 53 #include <cutlass/transform/threadblock/vector_iterator.h> 54 55 #include <cutlass/epilogue/threadblock/epilogue_smem_accumulator.h> 56 #include <cutlass/gemm/threadblock/mma_base.h> 57 #include <cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h> 58 #include <cutlass/gemm/threadblock/mma_pipelined.h> 59 #include <cutlass/gemm/threadblock/mma_multistage.h> 60 61 #include <ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_thread_apply_logsumexp.h> 62 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h> 63 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h> 64 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_accum_lambda_iterator.h> 65 66 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/default_warp_iterator_from_smem.h> 67 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h> 68 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/transpose_warp_iterator.h> 69 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h> 70 71 ///////////////////////////////////////////////////////////////////////////////////////////////// 72 73 namespace cutlass { 74 namespace gemm { 75 namespace threadblock { 76 77 /// Shared storage object needed by accumulator 78 /// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h 79 template < 80 typename Shape_, 81 typename Element_, 82 typename Layout_, 83 typename Padding_> 84 class AccumulatorSharedStorage { 85 public: 86 // 87 // Type definitions 88 // 89 using Shape = Shape_; 90 using Element = Element_; 91 using Layout = Layout_; 92 using Padding = Padding_; 93 94 /// Tensor reference to the accumulator 95 using TensorRefAccum = cutlass::TensorRef<Element, Layout>; 96 97 /// Shape of the accumulator matrix in shared memory 98 using ShapeAccum = cutlass:: 99 MatrixShape<Shape::kM + Padding::kRow, Shape::kN + Padding::kColumn>; 100 101 public: 102 // 103 // Data members 104 // 105 106 /// Buffer for accumulator 107 cutlass::AlignedBuffer<Element, ShapeAccum::kCount> accum; 108 109 public: 110 // 111 // Methods 112 // 113 114 /// Returns a layout object for the Accum matrix 115 CUTLASS_DEVICE LayoutAccum()116 static Layout LayoutAccum() { 117 return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); 118 } 119 120 /// Returns a TensorRef to the Accumulator 121 CUTLASS_HOST_DEVICE accum_ref()122 TensorRefAccum accum_ref() { 123 return TensorRefAccum{accum.data(), LayoutAccum()}; 124 } 125 }; 126 127 //////////////////////////////////////////////////////////////////////////////// 128 // Taken from 129 // https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h 130 //////////////////////////////////////////////////////////////////////////////// 131 132 /// Structure to compute the matrix product targeting CUDA cores and SIMT math 133 /// instructions. 134 template < 135 /// Size of the Gemm problem - concept: gemm::GemmShape<> 136 typename Shape_, 137 // Maximum K dimension - also the dimension of the shared-memory 138 // holding `OperandA` 139 int kMaxK_, 140 /// Policy describing tuning details (concept: MmaPolicy) 141 typename Policy_, 142 /// Number of stages, 143 int Stages, 144 /// Layout in shared-memory of operand A 145 typename SmemLayoutA, 146 /// Used for partial specialization 147 typename Enable = bool> 148 class MmaBaseFromSharedMemory { 149 public: 150 ///< Size of the Gemm problem - concept: gemm::GemmShape<> 151 using Shape = Shape_; 152 static constexpr int kMaxK = kMaxK_; 153 154 ///< Policy describing tuning details 155 using Policy = Policy_; 156 157 // 158 // Dependent types 159 // 160 161 /// Warp-level Mma 162 using Operator = typename Policy::Operator; 163 164 /// Shape describing the overall GEMM computed from shared memory 165 /// by each warp. 166 using WarpGemm = typename Policy::Operator::Shape; 167 168 /// Shape describing the number of warps filling the CTA 169 using WarpCount = GemmShape< 170 Shape::kM / WarpGemm::kM, 171 Shape::kN / WarpGemm::kN, 172 Shape::kK / WarpGemm::kK>; 173 using WarpCount1 = WarpCount; 174 175 /// Number of warp-level GEMM operations 176 static int const kWarpGemmIterations = 177 (WarpGemm::kK / Operator::Policy::MmaShape::kK); 178 static int const kWarpGemmIterations1 = kWarpGemmIterations; 179 180 /// Number of stages 181 static int const kStages = Stages; 182 183 /// If this is true, we fill the entire shmem buffer at start 184 /// and don't need to iterate through it in a circular fashion 185 static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; 186 187 /// Tensor reference to the A operand 188 using TensorRefA = TensorRef<typename Operator::ElementA, SmemLayoutA>; 189 190 /// Tensor reference to the B operand 191 using TensorRefB = 192 TensorRef<typename Operator::ElementB, typename Operator::LayoutB>; 193 194 // 195 // Nested structs 196 // 197 198 /// Shared storage object needed by threadblock-scoped GEMM 199 class SharedStorage { 200 public: 201 // 202 // Type definitions 203 // 204 205 /// Shape of the B matrix operand in shared memory 206 using ShapeB = MatrixShape< 207 Shape::kK * kStages + Policy::SmemPaddingB::kRow, 208 Shape::kN + Policy::SmemPaddingB::kColumn>; 209 210 public: 211 // 212 // Data members 213 // 214 215 /// Buffer for B operand 216 AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B; 217 218 public: 219 // 220 // Methods 221 // 222 223 /// Returns a layout object for the B matrix 224 CUTLASS_HOST_DEVICE LayoutB()225 static typename Operator::LayoutB LayoutB() { 226 return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); 227 } 228 229 /// Returns a TensorRef to the B operand 230 CUTLASS_HOST_DEVICE operand_B_ref()231 TensorRefB operand_B_ref() { 232 return TensorRefB{operand_B.data(), LayoutB()}; 233 } 234 }; 235 236 protected: 237 // 238 // Data members 239 // 240 241 // /// Iterator to load a warp-scoped tile of A operand from shared memory 242 // typename Operator::IteratorA warp_tile_iterator_A_; 243 244 /// Iterator to load a warp-scoped tile of B operand from shared memory 245 typename Operator::IteratorB warp_tile_iterator_B_; 246 247 public: 248 /// Construct from tensor references 249 CUTLASS_DEVICE MmaBaseFromSharedMemory(TensorRefB & b_tile,int thread_idx,int warp_idx,int lane_idx)250 MmaBaseFromSharedMemory( 251 ///< Shared storage needed for internal use by threadblock-scoped GEMM 252 TensorRefB& b_tile, 253 ///< ID within the threadblock 254 int thread_idx, 255 ///< ID of warp 256 int warp_idx, 257 ///< ID of each thread within a warp 258 int lane_idx) 259 : warp_tile_iterator_B_(b_tile, lane_idx) {} 260 }; 261 262 namespace { 263 264 // has necessary trait compliance with WarpIteratorFromSmem but doesn't do 265 // anything, can be default initialized, and uses fragment that takes up 266 // (almost) no space. this warp iterator is selected at compile time when 267 // elementwise on-the-fly scaling for operand A is disabled, in which case 268 // operations related to loading scale factors for operand A get wiped out by 269 // the compiler. 270 template <typename TensorRef> 271 class NoOpWarpIteratorScale { 272 public: 273 // in pipelined+multistage MMA implementations we keep an array of fragments. 274 // if we aren't using scaling we don't want to waste registers on fragments 275 // of scale elements, so ideally this would be sized 0. 276 // Since arrays of zero-sized objects are not allowed, using size as 1. 277 // The compiler will most likely wipe it out anyways. 278 using Fragment = cutlass::Array<char, 1>; 279 280 CUTLASS_HOST_DEVICE NoOpWarpIteratorScale()281 NoOpWarpIteratorScale() {} 282 283 CUTLASS_HOST_DEVICE NoOpWarpIteratorScale(TensorRef const &,int)284 NoOpWarpIteratorScale(TensorRef const&, int) {} 285 286 CUTLASS_HOST_DEVICE add_tile_offset(typename TensorRef::TensorCoord const &)287 NoOpWarpIteratorScale& add_tile_offset( 288 typename TensorRef::TensorCoord const&) { 289 return *this; 290 } 291 292 CUTLASS_HOST_DEVICE 293 NoOpWarpIteratorScale& operator++() { 294 return *this; 295 } 296 297 CUTLASS_DEVICE load(Fragment &)298 void load(Fragment&) const {} 299 }; 300 301 // if scaling is enabled, performs fragment elementwise multiplication between 302 // fragment and its scaling factor. 303 template <typename Fragment, typename FragmentScale, bool ScalingEnabled> 304 class FragmentElementwiseScaler; 305 306 // specialization for scaling being enabled. 307 template <typename Fragment, typename FragmentScale> 308 class FragmentElementwiseScaler<Fragment, FragmentScale, true> { 309 public: 310 // cast scale_frag to correct type then apply elementwise to fragment 311 CUTLASS_DEVICE apply(Fragment frag,FragmentScale const & scale_frag)312 static Fragment apply(Fragment frag, FragmentScale const& scale_frag) { 313 Fragment converted_scale_frag = cutlass::NumericArrayConverter< 314 typename Fragment::Element, 315 typename FragmentScale::Element, 316 FragmentScale::kElements>()(scale_frag); 317 return cutlass::multiplies<Fragment>()(frag, converted_scale_frag); 318 } 319 }; 320 321 // specialization for scaling being disabled. doesn't do anything and should 322 // just get wiped out by the compiler. 323 template <typename Fragment, typename FragmentScale> 324 class FragmentElementwiseScaler<Fragment, FragmentScale, false> { 325 public: 326 CUTLASS_DEVICE apply(Fragment frag,FragmentScale const &)327 static Fragment apply(Fragment frag, FragmentScale const&) { 328 return frag; 329 } 330 }; 331 } // namespace 332 333 //////////////////////////////////////////////////////////////////////////////// 334 // Taken from 335 // https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h 336 //////////////////////////////////////////////////////////////////////////////// 337 338 /// Structure to compute the matrix product targeting CUDA cores and SIMT math 339 /// instructions. 340 template < 341 /// Size of the Gemm problem - concept: gemm::GemmShape<> 342 typename Shape_, 343 // BEGIN smem 344 /// Iterates over the intermediate accumulator tile in shared memory 345 typename WarpIteratorA_, 346 /// whether or not to perform elementwise multiplication of A 347 // by another matrix (A_scale) that is also kept in shared memory prior 348 // to matmul A @ B 349 bool ScaleOperandA_, 350 /// Max GEMM problem size in K dimension 351 int MaxK, 352 /// Iterates over tiles of B operand in global memory 353 // (concept: ReadableTileIterator | ForwardTileIterator | 354 // MaskedTileIterator) 355 typename IteratorB_, 356 /// Iterates over tiles of B operand in shared memory 357 /// (concept: WriteableTileIterator | RandomAccessTileIterator) 358 typename SmemIteratorB_, 359 /// Data type of accumulator matrix 360 typename ElementC_, 361 /// Data type of accumulator matrix 362 typename LayoutC_, 363 /// Policy describing tuning details (concept: MmaPolicy) 364 typename Policy_, 365 /// Transformation applied to B operand 366 typename TransformB_ = NumericArrayConverter< 367 typename SmemIteratorB_::Element, 368 typename IteratorB_::Element, 369 IteratorB_::Fragment::kElements>, 370 /// Used for partial specialization 371 typename Enable = bool> 372 class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< 373 Shape_, 374 MaxK, 375 Policy_, 376 2, 377 typename WarpIteratorA_::Layout> { 378 public: 379 ///< Base class 380 using Base = MmaBaseFromSharedMemory< 381 Shape_, 382 MaxK, 383 Policy_, 384 2, 385 typename WarpIteratorA_::Layout>; 386 387 using Shape = 388 Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> 389 static constexpr bool ScaleOperandA = ScaleOperandA_; 390 391 using WarpIteratorA = WarpIteratorA_; 392 ///< loads fragments of A_scale from shared memory if operand A scaling is 393 ///< enabled. otherwise no-op. 394 using WarpIteratorAScale = typename cutlass::platform::conditional< 395 ScaleOperandA, 396 WarpIteratorA, 397 NoOpWarpIteratorScale<typename WarpIteratorA::TensorRef>>::type; 398 399 using IteratorB = 400 IteratorB_; ///< Iterates over tiles of B operand in global memory 401 using ElementC = ElementC_; ///< Data type of accumulator matrix 402 using LayoutC = LayoutC_; ///< Layout of accumulator matrix 403 using Policy = Policy_; ///< Policy describing tuning details 404 405 using SmemIteratorB = SmemIteratorB_; 406 407 using TransformB = TransformB_; 408 409 // 410 // Dependent types 411 // 412 413 /// Fragment of operand B loaded from global memory 414 using FragmentB = typename IteratorB::Fragment; 415 416 /// Fragment of accumulator tile 417 using FragmentC = typename Policy::Operator::FragmentC; 418 419 /// Warp-level Mma 420 using Operator = typename Policy::Operator; 421 422 /// Obtain the arch tag from the warp-level operator 423 using ArchTag = typename Policy::Operator::ArchTag; 424 425 /// Complex transform on B operand 426 static ComplexTransform const kTransformB = Operator::kTransformB; 427 428 // statically assert kStages for MmaPipelined is two (Double-buffered pipeline) 429 static_assert( 430 (Base::kStages == 2), 431 "MmaPipelined requires kStages set to value 2"); 432 433 private: 434 using WarpFragmentA = typename Operator::FragmentA; 435 436 /// fragment type of OperandA elementwise scaling matrix. (almost) empty 437 /// if operand A scaling is disabled. 438 using WarpFragmentAScale = typename WarpIteratorAScale::Fragment; 439 440 using WarpFragmentB = typename Operator::FragmentB; 441 442 /// applies scaling factor to operand A fragment if operand A scaling is 443 /// enabled. otherwise no-op. 444 using FragmentAScaler = FragmentElementwiseScaler< 445 WarpFragmentA, 446 WarpFragmentAScale, 447 ScaleOperandA>; 448 449 protected: 450 // /// Iterator to write threadblock-scoped tile of A operand to shared memory 451 // SmemIteratorA smem_iterator_A_; 452 453 /// Iterator to write threadblock-scoped tile of B operand to shared memory 454 SmemIteratorB smem_iterator_B_; 455 456 /// Iterator to load a warp-scoped tile of A operand from intermediate 457 /// accumulator tile 458 WarpIteratorA warp_tile_iterator_A_; 459 460 /// Iterator to load a warp-scoped tile of A_scale from intermediate 461 /// accumulator tile (only used if ScaleOperandA_ is true) 462 WarpIteratorAScale warp_tile_iterator_A_scale_; 463 464 public: 465 /// constructor for MMA with operand A scaling enabled. 466 CUTLASS_DEVICE MmaPipelinedFromSharedMemory(typename Base::TensorRefA a,typename Base::TensorRefA a_scale,typename Base::TensorRefB b_staging,int thread_idx,int warp_idx,int lane_idx)467 MmaPipelinedFromSharedMemory( 468 typename Base::TensorRefA a, // Operand A in shared memory 469 typename Base::TensorRefA a_scale, // Operand A_scale in shared memory 470 typename Base::TensorRefB 471 b_staging, // staging memory for loading tiles of B 472 int thread_idx, 473 int warp_idx, 474 int lane_idx) 475 : Base(b_staging, thread_idx, warp_idx, lane_idx), 476 warp_tile_iterator_A_(a, lane_idx), 477 warp_tile_iterator_A_scale_(a_scale, lane_idx), 478 smem_iterator_B_(b_staging, thread_idx) { 479 // Compute warp location within threadblock tile by mapping the warp_id to 480 // three coordinates: 481 // _m: the warp's position within the threadblock along the M dimension 482 // _n: the warp's position within the threadblock along the N dimension 483 // _k: the warp's position within the threadblock along the K dimension 484 int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); 485 int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); 486 int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; 487 int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; 488 489 // Add per-warp offsets in units of warp-level tiles 490 this->warp_tile_iterator_A_.add_tile_offset( 491 {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); 492 this->warp_tile_iterator_A_scale_.add_tile_offset( 493 {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); 494 this->warp_tile_iterator_B_.add_tile_offset( 495 {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); 496 } 497 498 /// Construct from tensor references 499 CUTLASS_DEVICE MmaPipelinedFromSharedMemory(typename Base::TensorRefA a,typename Base::TensorRefB b_staging,int thread_idx,int warp_idx,int lane_idx)500 MmaPipelinedFromSharedMemory( 501 typename Base::TensorRefA a, ///< Operand A in shared memory 502 typename Base::TensorRefB b_staging, ///< staging memory for loading B 503 int thread_idx, ///< ID within the threadblock 504 int warp_idx, ///< ID of warp 505 int lane_idx) ///< ID of each thread within a warp 506 : Base(b_staging, thread_idx, warp_idx, lane_idx), 507 warp_tile_iterator_A_(a, lane_idx), 508 smem_iterator_B_(b_staging, thread_idx) { 509 // Compute warp location within threadblock tile by mapping the warp_id to 510 // three coordinates: 511 // _m: the warp's position within the threadblock along the M dimension 512 // _n: the warp's position within the threadblock along the N dimension 513 // _k: the warp's position within the threadblock along the K dimension 514 515 int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); 516 int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); 517 518 int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; 519 int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; 520 521 // Add per-warp offsets in units of warp-level tiles 522 this->warp_tile_iterator_A_.add_tile_offset( 523 {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); 524 this->warp_tile_iterator_B_.add_tile_offset( 525 {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); 526 } 527 528 // For API compatibility with MmaMultistageFromSharedMemory 529 // but not supported as it worsens perf: older gpus < sm80 don't 530 // support async transfers and have to waste registers 531 CUTLASS_DEVICE set_prologue_done(bool value)532 void set_prologue_done(bool value) {} 533 CUTLASS_DEVICE prologue(typename Base::SharedStorage & shared_storage,IteratorB iterator_B1,int thread_idx,int problem_size_0_n)534 static void prologue( 535 typename Base::SharedStorage& shared_storage, 536 IteratorB iterator_B1, 537 int thread_idx, 538 int problem_size_0_n) {} 539 540 /// Perform a threadblock-scoped matrix multiply-accumulate 541 CUTLASS_DEVICE operator()542 void operator()( 543 int gemm_k_iterations, ///< number of iterations of the mainloop 544 FragmentC& accum, ///< destination accumulator tile 545 // IteratorA iterator_A, ///< iterator over A 546 // operand in global memory 547 IteratorB iterator_B, ///< iterator over B operand in global memory 548 FragmentC const& src_accum, ///< source accumulator tile 549 // TransformA transform_A = TransformA(), ///< transformation 550 // applied to A fragment 551 TransformB transform_B = 552 TransformB()) { ///< transformation applied to B fragment 553 554 // 555 // Prologue 556 // 557 558 // Perform accumulation in the 'd' output operand 559 accum = src_accum; 560 561 FragmentB tb_frag_B; 562 563 tb_frag_B.clear(); 564 565 // The last kblock is loaded in the prolog 566 iterator_B.set_residual_tile(gemm_k_iterations == 1); 567 iterator_B.load(tb_frag_B); 568 569 ++iterator_B; 570 571 this->smem_iterator_B_.store(transform_B(tb_frag_B)); 572 573 ++this->smem_iterator_B_; 574 575 __syncthreads(); 576 577 // remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op 578 // if scaling is disabled. 579 580 // Pair of fragments used to overlap shared memory loads and math 581 // instructions 582 WarpFragmentA warp_frag_A[2]; 583 WarpFragmentAScale warp_frag_A_scale[2]; 584 WarpFragmentB warp_frag_B[2]; 585 warp_frag_A[0].clear(); 586 warp_frag_A_scale[0].clear(); 587 warp_frag_B[0].clear(); 588 589 this->warp_tile_iterator_B_.set_kgroup_index(0); 590 591 this->warp_tile_iterator_A_.load(warp_frag_A[0]); 592 this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]); 593 this->warp_tile_iterator_B_.load(warp_frag_B[0]); 594 595 ++this->warp_tile_iterator_A_; 596 ++this->warp_tile_iterator_A_scale_; 597 ++this->warp_tile_iterator_B_; 598 599 Operator warp_mma; 600 601 int smem_write_stage_idx = 1; 602 603 // Avoid reading out of bounds 604 iterator_B.set_residual_tile(gemm_k_iterations == 2); 605 iterator_B.clear_mask(gemm_k_iterations <= 1); 606 607 // Issue loads during the first warp-level matrix multiply-add *AFTER* 608 // issuing shared memory loads (which have the tightest latency 609 // requirement). 610 611 // 612 // Mainloop 613 // 614 615 // Note: The main loop does not support Base::kWarpGemmIterations == 2. 616 CUTLASS_GEMM_LOOP 617 for (; gemm_k_iterations > 0; --gemm_k_iterations) { 618 // 619 // Loop over GEMM K dimension 620 // 621 622 CUTLASS_PRAGMA_UNROLL 623 for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; 624 ++warp_mma_k) { 625 // Load warp-level tiles from shared memory, wrapping to k offset if 626 // this is the last group as the case may be. 627 bool hasNext = true; 628 629 if (warp_mma_k == Base::kWarpGemmIterations - 1) { 630 if (gemm_k_iterations > 1) { 631 // Write fragments to shared memory 632 this->smem_iterator_B_.store(transform_B(tb_frag_B)); 633 } 634 635 __syncthreads(); 636 637 ++this->smem_iterator_B_; 638 639 // Add negative offsets to return iterators to the 'start' of the 640 // circular buffer in shared memory SMEM: Don't reset iterator A, as 641 // we are continuing our iteration at this point 642 if (smem_write_stage_idx == 1) { 643 this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); 644 } else { 645 this->warp_tile_iterator_B_.add_tile_offset( 646 {-Base::kStages * Policy::kPartitionsK * 647 Base::kWarpGemmIterations, 648 0}); 649 } 650 651 smem_write_stage_idx ^= 1; 652 hasNext = gemm_k_iterations > 1; 653 } 654 655 // Only read the next if we need to 656 if (hasNext) { 657 this->warp_tile_iterator_B_.set_kgroup_index( 658 (warp_mma_k + 1) % Base::kWarpGemmIterations); 659 660 this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); 661 this->warp_tile_iterator_A_scale_.load( 662 warp_frag_A_scale[(warp_mma_k + 1) % 2]); 663 this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); 664 665 ++this->warp_tile_iterator_A_; 666 ++this->warp_tile_iterator_A_scale_; 667 ++this->warp_tile_iterator_B_; 668 669 if (warp_mma_k == 0) { 670 iterator_B.load(tb_frag_B); 671 672 ++iterator_B; 673 674 // Avoid reading out of bounds if this was the last loop iteration 675 iterator_B.set_residual_tile(gemm_k_iterations == 3); 676 iterator_B.clear_mask(gemm_k_iterations <= 2); 677 } 678 } 679 680 warp_mma( 681 accum, 682 FragmentAScaler::apply( 683 warp_frag_A[warp_mma_k % 2], warp_frag_A_scale[warp_mma_k % 2]), 684 warp_frag_B[warp_mma_k % 2], 685 accum); 686 } 687 } 688 } 689 }; 690 691 //////////////////////////////////////////////////////////////////////////////// 692 // Taken from 693 // https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h 694 //////////////////////////////////////////////////////////////////////////////// 695 696 /// Structure to compute the matrix product targeting CUDA cores and SIMT math 697 /// instructions. 698 template < 699 /// Size of the Gemm problem - concept: gemm::GemmShape<> 700 typename Shape1_, 701 /// Iterates over the intermediate accumulator tile in shared memory 702 typename WarpIteratorA1_, 703 /// whether or not to perform elementwise multiplication of A 704 // by another matrix (A_scale) that is also kept in shared memory prior 705 // to matmul A @ B 706 bool ScaleOperandA_, 707 /// Iterates over tiles of B operand in global memory 708 // (concept: ReadableTileIterator | ForwardTileIterator | 709 // MaskedTileIterator) 710 typename IteratorB1_, 711 /// Iterates over tiles of B operand in shared memory 712 /// (concept: WriteableTileIterator | RandomAccessTileIterator) 713 typename SmemIteratorB1_, 714 /// Cache operation for operand B 715 cutlass::arch::CacheOperation::Kind CacheOpB1, 716 /// Data type of accumulator matrix 717 typename ElementC_, 718 /// Data type of accumulator matrix 719 typename LayoutC_, 720 /// Policy describing tuning details (concept: MmaPolicy) 721 typename Policy1_, 722 /// Number of stages, 723 int Stages_, 724 int kMaxK_, 725 /// Used for partial specialization 726 typename Enable = bool> 727 class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory< 728 Shape1_, 729 kMaxK_, 730 Policy1_, 731 Stages_, 732 typename WarpIteratorA1_::Layout> { 733 public: 734 ///< Base class 735 using Base = MmaBaseFromSharedMemory< 736 Shape1_, 737 kMaxK_, 738 Policy1_, 739 Stages_, 740 typename WarpIteratorA1_::Layout>; 741 742 ///< Size of the Gemm problem - concept: gemm::GemmShape<> 743 using Shape1 = Shape1_; 744 ///< Iterates over tiles of B operand in global memory 745 using IteratorB1 = IteratorB1_; 746 using IteratorB = IteratorB1; 747 ///< Policy describing tuning details 748 using Policy1 = Policy1_; 749 750 using SmemIteratorB1 = SmemIteratorB1_; 751 using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate 752 ///< accumulator tile in shared memory 753 static constexpr bool ScaleOperandA = ScaleOperandA_; 754 755 ///< warp level iterator over A_scale matrix tile kept in shared memory. 756 ///< if elementwise A scaling is disabled then everything this does is no-op. 757 using WarpIteratorAScale = typename cutlass::platform::conditional< 758 ScaleOperandA, 759 WarpIteratorA1, 760 NoOpWarpIteratorScale<typename WarpIteratorA1::TensorRef>>::type; 761 ///< Data type of accumulator matrix 762 using ElementC = ElementC_; 763 ///< Layout of accumulator matrix 764 using LayoutC = LayoutC_; 765 766 static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; 767 static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB; 768 769 // 770 // Dependent types 771 // 772 773 /// Fragment of accumulator tile 774 using FragmentC1 = typename Policy1::Operator::FragmentC; 775 using FragmentC = FragmentC1; 776 777 /// Warp-level Mma 778 using Operator1 = typename Policy1::Operator; 779 780 /// Minimum architecture is Sm80 to support cp.async 781 using ArchTag = arch::Sm80; 782 783 /// Complex transform on B operand 784 static ComplexTransform const kTransformB1 = Operator1::kTransformB; 785 786 /// Internal structure exposed for introspection. 787 struct Detail { 788 static_assert( 789 Base::kWarpGemmIterations1 > 1, 790 "The pipelined structure requires at least two warp-level " 791 "GEMM operations."); 792 793 /// Number of cp.async instructions to load one stage of operand B 794 static int const TBLoadIterationsB1 = 795 IteratorB1::ThreadMap::Iterations::kCount; 796 797 /// Number of cp.async instructions to load on group of operand B 798 static int const kAccessesPerGroupB1 = 799 (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / 800 Base::kWarpGemmIterations1; 801 }; 802 803 static constexpr int kNumStagesConcurrentLoad = 804 kSmemContainsEntireB ? Base::kStages : Base::kStages - 1; 805 806 private: 807 using WarpLoadedFragmentA1 = typename Operator1::FragmentA; 808 /// fragment of OperandA scale matrix. if operand A scaling is disabled this 809 /// is (almost) empty. 810 using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment; 811 using WarpLoadedFragmentB1 = typename Operator1::FragmentB; 812 using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; 813 using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; 814 815 /// applies elementwise scaling to fragment of A. if operand A scaling is 816 /// disabled this is a no-op. 817 using FragmentAScaler = FragmentElementwiseScaler< 818 WarpLoadedFragmentA1, 819 WarpLoadedFragmentA1Scale, 820 ScaleOperandA>; 821 822 private: 823 // 824 // Data members 825 // 826 827 /// Iterator to load a warp-scoped tile of A1 operand from intermediate 828 /// accumulator tile 829 WarpIteratorA1 warp_tile_iterator_A1_; 830 831 /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory 832 /// if operand A scaling is disabled everything this does is a no-op. 833 WarpIteratorAScale warp_tile_iterator_A1_scale_; 834 835 /// Iterator to write threadblock-scoped tile of B operand to shared memory 836 SmemIteratorB1 smem_iterator_B1_; 837 838 bool prologue_done_; 839 840 public: 841 /// constructor for MMA with operand A scaling enabled. 842 CUTLASS_DEVICE MmaMultistageFromSharedMemory(typename Base::TensorRefA a,typename Base::TensorRefA a_scale,typename Base::TensorRefB b_tile,int thread_idx,int warp_idx,int lane_idx)843 MmaMultistageFromSharedMemory( 844 typename Base::TensorRefA a, 845 typename Base::TensorRefA a_scale, 846 typename Base::TensorRefB b_tile, 847 int thread_idx, 848 int warp_idx, 849 int lane_idx) 850 : Base(b_tile, thread_idx, warp_idx, lane_idx), 851 warp_tile_iterator_A1_(a, lane_idx), 852 warp_tile_iterator_A1_scale_(a_scale, lane_idx), 853 smem_iterator_B1_(b_tile, thread_idx), 854 prologue_done_(false) { 855 // Compute warp location within threadblock tile by mapping the warp_id to 856 // three coordinates: 857 // _m: the warp's position within the threadblock along the M dimension 858 // _n: the warp's position within the threadblock along the N dimension 859 // _k: the warp's position within the threadblock along the K dimension 860 int warp_idx_mn_1 = 861 warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); 862 int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); 863 int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; 864 int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; 865 866 // Add per-warp offsets in units of warp-level tiles 867 warp_tile_iterator_A1_.add_tile_offset( 868 {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); 869 warp_tile_iterator_A1_scale_.add_tile_offset( 870 {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); 871 this->warp_tile_iterator_B_.add_tile_offset( 872 {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); 873 } 874 875 /// Construct from tensor references 876 CUTLASS_DEVICE MmaMultistageFromSharedMemory(typename Base::TensorRefA a,typename Base::TensorRefB b_tile,int thread_idx,int warp_idx,int lane_idx)877 MmaMultistageFromSharedMemory( 878 typename Base::TensorRefA a, 879 typename Base::TensorRefB b_tile, 880 ///< ID within the threadblock 881 int thread_idx, 882 ///< ID of warp 883 int warp_idx, 884 ///< ID of each thread within a warp 885 int lane_idx) 886 : Base(b_tile, thread_idx, warp_idx, lane_idx), 887 warp_tile_iterator_A1_(a, lane_idx), 888 smem_iterator_B1_(b_tile, thread_idx), 889 prologue_done_(false) { 890 // Compute warp location within threadblock tile by mapping the warp_id to 891 // three coordinates: 892 // _m: the warp's position within the threadblock along the M dimension 893 // _n: the warp's position within the threadblock along the N dimension 894 // _k: the warp's position within the threadblock along the K dimension 895 896 int warp_idx_mn_1 = 897 warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); 898 int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); 899 900 int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; 901 int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; 902 903 // Add per-warp offsets in units of warp-level tiles 904 warp_tile_iterator_A1_.add_tile_offset( 905 {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); 906 this->warp_tile_iterator_B_.add_tile_offset( 907 {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); 908 } 909 910 CUTLASS_DEVICE set_prologue_done(bool value)911 void set_prologue_done(bool value) { 912 prologue_done_ = value; 913 } 914 915 CUTLASS_DEVICE prologue(typename Base::SharedStorage & shared_storage,IteratorB iterator_B1,int thread_idx,int problem_size_0_n)916 static void prologue( 917 typename Base::SharedStorage& shared_storage, 918 IteratorB iterator_B1, 919 int thread_idx, 920 int problem_size_0_n) { 921 SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx); 922 _prologue( 923 iterator_B1, 924 (problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK, 925 smem_iterator_B1); 926 } 927 928 CUTLASS_DEVICE 929 void copy_tiles_and_advance_1( 930 IteratorB1& iterator_B1, 931 int group_start_B1 = 0) { 932 iterator_B1.set_iteration_index( 933 group_start_B1 * IteratorB1::kAccessesPerVector); 934 this->smem_iterator_B1_.set_iteration_index(group_start_B1); 935 936 // Load for operand B 937 CUTLASS_PRAGMA_UNROLL 938 for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { 939 if (group_start_B1 + j < Detail::TBLoadIterationsB1) { 940 typename IteratorB1::AccessType* dst_ptr = 941 reinterpret_cast<typename IteratorB1::AccessType*>( 942 this->smem_iterator_B1_.get()); 943 944 int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value * 945 IteratorB1::ThreadMap::kElementsPerAccess / 946 IteratorB1::kAccessesPerVector / 8; 947 948 CUTLASS_PRAGMA_UNROLL 949 for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { 950 auto gmem_ptr = iterator_B1.get(); 951 952 cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>( 953 dst_ptr + v, gmem_ptr, iterator_B1.valid()); 954 955 ++iterator_B1; 956 } 957 ++this->smem_iterator_B1_; 958 } 959 } 960 } 961 962 CUTLASS_DEVICE _prologue(IteratorB & iterator_B1,int32_t gemm_k_iterations_1,SmemIteratorB1 & smem_iterator_B1_)963 static void _prologue( 964 IteratorB& iterator_B1, 965 int32_t gemm_k_iterations_1, 966 SmemIteratorB1& smem_iterator_B1_) { 967 // Issue several complete stages 968 CUTLASS_PRAGMA_UNROLL 969 for (int stage = 0; stage < kNumStagesConcurrentLoad; 970 ++stage, --gemm_k_iterations_1) { 971 iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); 972 iterator_B1.clear_mask(gemm_k_iterations_1 == 0); 973 974 iterator_B1.set_iteration_index(0); 975 smem_iterator_B1_.set_iteration_index(0); 976 977 // Load for operand B 978 CUTLASS_PRAGMA_UNROLL 979 for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { 980 typename IteratorB1::AccessType* dst_ptr = 981 reinterpret_cast<typename IteratorB1::AccessType*>( 982 smem_iterator_B1_.get()); 983 984 CUTLASS_PRAGMA_UNROLL 985 for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { 986 int const kSrcBytes = 987 sizeof_bits<typename IteratorB1::Element>::value * 988 IteratorB1::ThreadMap::kElementsPerAccess / 989 IteratorB1::kAccessesPerVector / 8; 990 991 cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>( 992 dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); 993 994 ++iterator_B1; 995 } 996 997 ++smem_iterator_B1_; 998 } 999 1000 // Move to the next stage 1001 iterator_B1.add_tile_offset({1, 0}); 1002 1003 smem_iterator_B1_.add_tile_offset({1, 0}); 1004 1005 // Defines the boundary of a stage of cp.async. 1006 cutlass::arch::cp_async_fence(); 1007 } 1008 iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); 1009 iterator_B1.clear_mask(gemm_k_iterations_1 == 0); 1010 } 1011 1012 /// Perform a threadblock-scoped matrix multiply-accumulate 1013 CUTLASS_DEVICE operator()1014 void operator()( 1015 ///< problem size of GEMM 1016 int gemm_k_iterations_1_, 1017 ///< destination accumulator tile 1018 FragmentC1& accum, 1019 ///< iterator over B1 operand in global memory 1020 IteratorB1 iterator_B1, 1021 ///< initial value of accumulator 1022 FragmentC1 const& src_accum) { 1023 // 2nd Gemm 1024 1025 // 1026 // Prologue 1027 // 1028 // Perform accumulation in the 'd' output operand 1029 accum = src_accum; 1030 1031 if (!prologue_done_) { 1032 _prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_); 1033 } else if (!kSmemContainsEntireB) { 1034 // Restore the iterators increments 1035 1036 int gemm_k_iterations_1 = gemm_k_iterations_1_; 1037 // Issue several complete stages 1038 CUTLASS_PRAGMA_UNROLL 1039 for (int stage = 0; stage < kNumStagesConcurrentLoad; 1040 ++stage, --gemm_k_iterations_1) { 1041 iterator_B1.set_iteration_index(0); 1042 this->smem_iterator_B1_.set_iteration_index(0); 1043 1044 // Load for operand B 1045 CUTLASS_PRAGMA_UNROLL 1046 for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { 1047 CUTLASS_PRAGMA_UNROLL 1048 for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { 1049 ++iterator_B1; 1050 } 1051 ++this->smem_iterator_B1_; 1052 } 1053 iterator_B1.add_tile_offset({1, 0}); 1054 this->smem_iterator_B1_.add_tile_offset({1, 0}); 1055 } 1056 iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1); 1057 iterator_B1.clear_mask(gemm_k_iterations_1 <= 0); 1058 } 1059 1060 // DEPBAR+SYNC 1061 cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>(); 1062 __syncthreads(); 1063 1064 // remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty 1065 // if scaling is disabled. 1066 1067 // Pair of fragments used to overlap shared memory loads and math 1068 // instructions 1069 WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; 1070 WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2]; 1071 WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; 1072 WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; 1073 WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; 1074 1075 Operator1 warp_mma1; 1076 1077 warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); 1078 ++warp_tile_iterator_A1_; 1079 1080 warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); 1081 ++warp_tile_iterator_A1_scale_; 1082 1083 this->warp_tile_iterator_B_.set_kgroup_index(0); 1084 this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); 1085 ++this->warp_tile_iterator_B_; 1086 1087 int smem_write_stage_idx = Base::kStages - 1; 1088 int smem_read_stage_idx = 0; 1089 1090 warp_mma1.transform( 1091 warp_transformed_frag_A1[0], 1092 warp_transformed_frag_B1[0], 1093 FragmentAScaler::apply( 1094 warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]), 1095 warp_loaded_frag_B1[0]); 1096 1097 // tf32x3 kernels use staging accumulation. warp_mma uses a temporary 1098 // accumulator and this temporary accumulator is added to the final 1099 // accumulator once in every mainloop iteration. 1100 plus<FragmentC1> plus_accum; 1101 1102 FragmentC1 tmp_accum; 1103 1104 if (platform::is_same< 1105 typename Operator1::MathOperator, 1106 arch::OpMultiplyAddFastF32>::value || 1107 platform::is_same< 1108 typename Operator1::MathOperator, 1109 arch::OpMultiplyAddComplexFastF32>::value) { 1110 tmp_accum.clear(); 1111 } 1112 1113 // 1114 // Mainloop 1115 // 1116 1117 CUTLASS_PRAGMA_UNROLL 1118 for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1); 1119 gemm_k_iterations_1 > (-Base::kStages + 1); 1120 gemm_k_iterations_1--) { 1121 // 1122 // Loop over GEMM K dimension 1123 // 1124 1125 // Computes a warp-level GEMM on data held in shared memory 1126 // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate 1127 CUTLASS_PRAGMA_UNROLL 1128 for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; 1129 ++warp_mma_k) { 1130 // Load warp-level tile from accumulator fragment (A) 1131 // or shared memory (operand B) 1132 this->warp_tile_iterator_B_.set_kgroup_index( 1133 (warp_mma_k + 1) % Base::kWarpGemmIterations1); 1134 // skip warp tile loading for the last kgroup (we are out of the buf) 1135 if (gemm_k_iterations_1 > (-Base::kStages + 2) || 1136 warp_mma_k < Base::kWarpGemmIterations1 - 1) { 1137 warp_tile_iterator_A1_.load( 1138 warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); 1139 warp_tile_iterator_A1_scale_.load( 1140 warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); 1141 this->warp_tile_iterator_B_.load( 1142 warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); 1143 } 1144 ++warp_tile_iterator_A1_; 1145 ++warp_tile_iterator_A1_scale_; 1146 ++this->warp_tile_iterator_B_; 1147 1148 if (warp_mma_k > 0) 1149 warp_mma1.transform( 1150 warp_transformed_frag_A1[warp_mma_k % 2], 1151 warp_transformed_frag_B1[warp_mma_k % 2], 1152 FragmentAScaler::apply( 1153 warp_loaded_frag_A1[warp_mma_k % 2], 1154 warp_loaded_frag_A1_scale[warp_mma_k % 2]), 1155 warp_loaded_frag_B1[warp_mma_k % 2]); 1156 1157 if (platform::is_same< 1158 typename Operator1::MathOperator, 1159 arch::OpMultiplyAddFastF32>::value || 1160 platform::is_same< 1161 typename Operator1::MathOperator, 1162 arch::OpMultiplyAddComplexFastF32>::value) { 1163 warp_mma1( 1164 tmp_accum, 1165 warp_transformed_frag_A1[warp_mma_k % 2], 1166 warp_transformed_frag_B1[warp_mma_k % 2], 1167 tmp_accum); 1168 1169 if (warp_mma_k == 0) { 1170 accum = plus_accum(accum, tmp_accum); 1171 tmp_accum.clear(); 1172 } 1173 } else { 1174 warp_mma1( 1175 accum, 1176 warp_transformed_frag_A1[warp_mma_k % 2], 1177 warp_transformed_frag_B1[warp_mma_k % 2], 1178 accum); 1179 } 1180 1181 // Issue global->shared copies for the this stage 1182 if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { 1183 int group_start_iteration_B1; 1184 1185 group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; 1186 1187 if (!kSmemContainsEntireB) { 1188 copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); 1189 } 1190 } 1191 1192 if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { 1193 int group_start_iteration_B1; 1194 group_start_iteration_B1 = 1195 (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; 1196 1197 if (!kSmemContainsEntireB) { 1198 copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); 1199 } 1200 1201 // Inserts a memory fence between stages of cp.async instructions. 1202 cutlass::arch::cp_async_fence(); 1203 1204 // Waits until kStages-2 stages have committed. 1205 arch::cp_async_wait<kNumStagesConcurrentLoad - 1>(); 1206 __syncthreads(); 1207 1208 // Move to the next stage 1209 iterator_B1.add_tile_offset({1, 0}); 1210 1211 this->smem_iterator_B1_.add_tile_offset({1, 0}); 1212 1213 // Add negative offsets to return iterators to the 'start' of the 1214 // circular buffer in shared memory 1215 if (!kSmemContainsEntireB) { 1216 if (smem_write_stage_idx == (Base::kStages - 1)) { 1217 this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); 1218 smem_write_stage_idx = 0; 1219 } else { 1220 ++smem_write_stage_idx; 1221 } 1222 1223 if (smem_read_stage_idx == (Base::kStages - 1)) { 1224 this->warp_tile_iterator_B_.add_tile_offset( 1225 {-Base::kStages * Policy1::kPartitionsK * 1226 Base::kWarpGemmIterations1, 1227 0}); 1228 smem_read_stage_idx = 0; 1229 } else { 1230 ++smem_read_stage_idx; 1231 } 1232 } 1233 1234 iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2); 1235 iterator_B1.clear_mask(gemm_k_iterations_1 == 1); 1236 } 1237 1238 // Do any conversions feeding the first stage at the end of the loop so 1239 // we can start right away on mma instructions 1240 if (warp_mma_k + 1 == Base::kWarpGemmIterations1) 1241 warp_mma1.transform( 1242 warp_transformed_frag_A1[(warp_mma_k + 1) % 2], 1243 warp_transformed_frag_B1[(warp_mma_k + 1) % 2], 1244 FragmentAScaler::apply( 1245 warp_loaded_frag_A1[(warp_mma_k + 1) % 2], 1246 warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]), 1247 warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); 1248 } 1249 } 1250 1251 if (platform::is_same< 1252 typename Operator1::MathOperator, 1253 arch::OpMultiplyAddFastF32>::value || 1254 platform::is_same< 1255 typename Operator1::MathOperator, 1256 arch::OpMultiplyAddComplexFastF32>::value) { 1257 accum = plus_accum(accum, tmp_accum); 1258 } 1259 } 1260 }; 1261 1262 // Converts a "regular" Mma into their counterpart from shared memory 1263 template < 1264 typename Mma_, 1265 int kMaxK, 1266 typename WarpIteratorA_, 1267 /// whether or not to apply elementwise multiplication of operand A by 1268 /// another matrix in shared memory before usage in A @ B 1269 bool kScaleOperandA, 1270 bool kTransposeA = false> 1271 struct DefaultMmaFromSharedMemory; 1272 1273 // Mma pipelined 1274 template < 1275 /// Size of the Gemm problem - concept: gemm::GemmShape<> 1276 typename Shape_, 1277 /// Iterates over tiles of A operand in global memory 1278 // (concept: ReadableTileIterator | ForwardTileIterator | 1279 // MaskedTileIterator) 1280 typename IteratorA_, 1281 /// Iterates over tiles of A operand in shared memory 1282 /// (concept: WriteableTileIterator | RandomAccessTileIterator) 1283 typename SmemIteratorA_, 1284 typename WarpIteratorA_, 1285 /// Iterates over tiles of B operand in global memory 1286 // (concept: ReadableTileIterator | ForwardTileIterator | 1287 // MaskedTileIterator) 1288 typename IteratorB_, 1289 /// Iterates over tiles of B operand in shared memory 1290 /// (concept: WriteableTileIterator | RandomAccessTileIterator) 1291 typename SmemIteratorB_, 1292 /// Data type of accumulator matrix 1293 typename ElementC_, 1294 /// Data type of accumulator matrix 1295 typename LayoutC_, 1296 /// Policy describing tuning details (concept: MmaPolicy) 1297 typename Policy_, 1298 /// Transformation applied to A operand 1299 typename TransformA_, 1300 /// Transformation applied to B operand 1301 typename TransformB_, 1302 // Max MMA problem size K 1303 int kMaxK, 1304 /// whether or not to apply elementwise multiplication of operand A by 1305 /// another matrix in shared memory before usage in A @ B 1306 bool kScaleOperandA, 1307 bool kTransposeA> 1308 struct DefaultMmaFromSharedMemory< 1309 MmaPipelined< 1310 Shape_, 1311 IteratorA_, 1312 SmemIteratorA_, 1313 IteratorB_, 1314 SmemIteratorB_, 1315 ElementC_, 1316 LayoutC_, 1317 Policy_, 1318 TransformA_, 1319 TransformB_>, 1320 kMaxK, 1321 WarpIteratorA_, 1322 kScaleOperandA, 1323 kTransposeA> { 1324 using RegularMma = MmaPipelined< 1325 Shape_, 1326 IteratorA_, 1327 SmemIteratorA_, 1328 IteratorB_, 1329 SmemIteratorB_, 1330 ElementC_, 1331 LayoutC_, 1332 Policy_, 1333 TransformA_, 1334 TransformB_>; 1335 1336 using WarpShape = typename Policy_::Operator::Shape; 1337 using InstructionShape = typename Policy_::Operator::InstructionShape; 1338 using ArchMmaOperator = typename Policy_::Operator; 1339 1340 static constexpr bool kIsTransposedA = false; 1341 using WarpIteratorA = WarpIteratorA_; 1342 using IteratorB = 1343 typename cutlass::transform::threadblock::MakeIteratorResidualLast< 1344 IteratorB_>::Iterator; 1345 1346 using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory< 1347 Shape_, 1348 WarpIteratorA, 1349 kScaleOperandA, 1350 kMaxK, 1351 IteratorB, 1352 SmemIteratorB_, 1353 ElementC_, 1354 LayoutC_, 1355 Policy_>; 1356 }; 1357 1358 template < 1359 /// Size of the Gemm problem - concept: gemm::GemmShape<> 1360 typename Shape_, 1361 /// Iterates over tiles of A operand in global memory 1362 // (concept: ReadableTileIterator | ForwardTileIterator | 1363 // MaskedTileIterator) 1364 typename IteratorA_, 1365 /// Iterates over tiles of A operand in shared memory 1366 /// (concept: WriteableTileIterator | RandomAccessTileIterator) 1367 typename SmemIteratorA_, 1368 typename WarpIteratorA_, 1369 /// Cache operation for operand A 1370 cutlass::arch::CacheOperation::Kind CacheOpA, 1371 /// Iterates over tiles of B operand in global memory 1372 // (concept: ReadableTileIterator | ForwardTileIterator | 1373 // MaskedTileIterator) 1374 typename IteratorB_, 1375 /// Iterates over tiles of B operand in shared memory 1376 /// (concept: WriteableTileIterator | RandomAccessTileIterator) 1377 typename SmemIteratorB_, 1378 /// Cache operation for operand B 1379 cutlass::arch::CacheOperation::Kind CacheOpB, 1380 /// Data type of accumulator matrix 1381 typename ElementC_, 1382 /// Data type of accumulator matrix 1383 typename LayoutC_, 1384 /// Policy describing tuning details (concept: MmaPolicy) 1385 typename Policy_, 1386 /// Number of stages, 1387 int Stages, 1388 /// Use zfill or predicate for out-of-bound cp.async 1389 SharedMemoryClearOption SharedMemoryClear, 1390 int kMaxK, 1391 /// whether or not to apply elementwise multiplication of operand A by 1392 /// another matrix in shared memory before usage in A @ B 1393 bool kScaleOperandA, 1394 bool kTransposeA> 1395 struct DefaultMmaFromSharedMemory< 1396 MmaMultistage< 1397 Shape_, 1398 IteratorA_, 1399 SmemIteratorA_, 1400 CacheOpA, 1401 IteratorB_, 1402 SmemIteratorB_, 1403 CacheOpB, 1404 ElementC_, 1405 LayoutC_, 1406 Policy_, 1407 Stages, 1408 SharedMemoryClear>, 1409 kMaxK, 1410 WarpIteratorA_, 1411 kScaleOperandA, 1412 kTransposeA> { 1413 using RegularMma = MmaMultistage< 1414 Shape_, 1415 IteratorA_, 1416 SmemIteratorA_, 1417 CacheOpA, 1418 IteratorB_, 1419 SmemIteratorB_, 1420 CacheOpB, 1421 ElementC_, 1422 LayoutC_, 1423 Policy_, 1424 Stages, 1425 SharedMemoryClear>; 1426 1427 using WarpShape = typename Policy_::Operator::Shape; 1428 using InstructionShape = typename Policy_::Operator::InstructionShape; 1429 using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>; 1430 static constexpr bool kIsTransposedA = 1431 WarpIteratorTranspose::kSupportsTranspose && kTransposeA; 1432 using WarpIteratorA = typename platform::conditional< 1433 kIsTransposedA, 1434 typename WarpIteratorTranspose::Iterator, 1435 WarpIteratorA_>::type; 1436 1437 // Reduce the number of stages if we don't need that many 1438 static int constexpr kStagesMax = 1439 (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); 1440 static int constexpr kStages = cutlass::const_min(Stages, kStagesMax); 1441 1442 using IteratorB = 1443 typename cutlass::transform::threadblock::MakeIteratorResidualLast< 1444 IteratorB_>::Iterator; 1445 using Mma = 1446 typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< 1447 Shape_, 1448 WarpIteratorA, 1449 kScaleOperandA, 1450 IteratorB, 1451 SmemIteratorB_, 1452 RegularMma::kCacheOpB, 1453 ElementC_, 1454 LayoutC_, 1455 Policy_, 1456 kStages, 1457 kMaxK>; 1458 }; 1459 1460 ///////////////////////////////////////////////////////////////////////////////////////////////// 1461 1462 template < 1463 typename IteratorC, 1464 typename Operator, 1465 typename scalar_t, 1466 typename WarpShape_, 1467 typename ThreadblockShape_> 1468 struct B2bGemm; 1469 1470 // Tensor Cores >= Sm75 specialization (Ampere ...) 1471 template < /// Size of the matrix to load (concept: MatrixShape) 1472 typename Shape_, 1473 /// Element type 1474 typename Element_, 1475 /// Layout of operand in memory 1476 typename Layout_, 1477 /// Shape of one matrix product operation (concept: MatrixShape) 1478 typename InstructionShape_, 1479 /// Interval between adjacent *MMA instructions (in units of MMA 1480 /// instructions, concept: MatrixShape) 1481 typename OpDelta_, 1482 typename Operator, 1483 typename scalar_t, 1484 typename WarpShape_, 1485 typename ThreadblockShape_> 1486 struct B2bGemm< 1487 cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< 1488 Shape_, 1489 Element_, 1490 Layout_, 1491 InstructionShape_, 1492 OpDelta_>, 1493 Operator, 1494 scalar_t, 1495 WarpShape_, 1496 ThreadblockShape_> { 1497 using IteratorC = 1498 typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< 1499 Shape_, 1500 Element_, 1501 Layout_, 1502 InstructionShape_, 1503 OpDelta_>; 1504 using FragmentC = typename IteratorC::Fragment; 1505 using InstructionShape = InstructionShape_; 1506 using WarpShape = WarpShape_; 1507 using ThreadblockShape = ThreadblockShape_; 1508 using accum_t = Element_; 1509 using lse_scalar_t = float; 1510 1511 using SmemAccumulatorLayout = cutlass::layout::RowMajor; 1512 1513 // Iterator to load accumulators (results of matmul in registers) 1514 using FragmentIteratorAccumulator = 1515 cutlass::epilogue::warp::FragmentIteratorTensorOp< 1516 WarpShape, 1517 InstructionShape, 1518 accum_t, 1519 typename Operator::Policy::Operator::FragmentC, 1520 cutlass::layout::RowMajor>; 1521 1522 // Iterator to store to shared-memory 1523 using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp< 1524 WarpShape, 1525 InstructionShape, 1526 scalar_t, // accum_t, 1527 SmemAccumulatorLayout>; 1528 using AccumulatorSharedStorage = 1529 cutlass::gemm::threadblock::AccumulatorSharedStorage< 1530 ThreadblockShape, 1531 typename SmemIteratorD0::Element, 1532 typename SmemIteratorD0::TensorLayout, 1533 typename SmemIteratorD0::Padding>; 1534 // We need to provide an operation for the epilogue. Let's create an 1535 // operation that does nothing (ScaleType::Nothing), just converts 1536 // from accum_t (float) -> scalar_t (can be half) 1537 using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination< 1538 typename SmemIteratorD0::Element, // ElementOutput 1539 FragmentIteratorAccumulator::Fragment::kElements, 1540 accum_t, // ElementAccumulator 1541 typename SmemIteratorD0::Element, // ElementCompute 1542 cutlass::epilogue::thread::ScaleType::Nothing>; 1543 using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator< 1544 SmemIteratorD0, 1545 FragmentIteratorAccumulator, 1546 SmemIteratorD0, // ScaleBiasIterator - not used 1547 OutputOpNoOp>; 1548 1549 // Epilogue 2: with LSE (for backwards pass) 1550 static int const kElementsPerAccess = 2; // TODO: Why 2? 1551 using IteratorAccumulatorLSE = 1552 cutlass::transform::threadblock::VectorIterator< 1553 cutlass::transform::threadblock::PredicatedVectorAccessIterator< 1554 // Shape 1555 cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kN>, 1556 // WarpShape 1557 cutlass::MatrixShape<WarpShape::kM, WarpShape::kN>, 1558 lse_scalar_t, 1559 cutlass::layout::RowMajor, 1560 kElementsPerAccess>>; 1561 using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp< 1562 scalar_t, // ElementOutput_ 1563 lse_scalar_t, // ElementLSE_ 1564 accum_t, // ElementAccumulator_ 1565 accum_t, // ElementCompute_ 1566 128 / cutlass::sizeof_bits<scalar_t>::value 1567 // FragmentIteratorAccumulator::Fragment::kElements 1568 // InstructionShape::kM * InstructionShape::kN / 32 1569 >; 1570 using EpilogueWithLSE = 1571 cutlass::epilogue::threadblock::EpilogueSmemAccumulator< 1572 SmemIteratorD0, 1573 FragmentIteratorAccumulator, 1574 IteratorAccumulatorLSE, 1575 EpilogueOpApplyLSE>; 1576 1577 static void CUTLASS_DEVICE accumToSmem( 1578 AccumulatorSharedStorage& shared_storage, 1579 FragmentC const& accum, 1580 int lane_id, 1581 cutlass::MatrixCoord const& tile_coords) { 1582 SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); 1583 smem_iterator_attn.add_tile_offset( 1584 tile_coords * 1585 cutlass::MatrixCoord{ 1586 SmemIteratorD0::TileIterations::kRow, 1587 SmemIteratorD0::TileIterations::kColumn}); 1588 Epilogue epilogue; 1589 epilogue(OutputOpNoOp({}), smem_iterator_attn, accum); 1590 } 1591 1592 static void CUTLASS_DEVICE accumApplyLSEToSmem( 1593 AccumulatorSharedStorage& shared_storage, 1594 FragmentC& accum, 1595 lse_scalar_t const* lse, 1596 int32_t lse_extents, 1597 int thread_id, 1598 int warp_id, 1599 int lane_id, 1600 cutlass::MatrixCoord const& tile_coords) { 1601 constexpr int32_t kAlignLSE = 32; 1602 IteratorAccumulatorLSE iterator_lse( 1603 lse, 1604 {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE}, 1605 thread_id, 1606 warp_id, 1607 cutlass::MatrixCoord{0, 0} // offset 1608 ); 1609 1610 SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); 1611 smem_iterator_attn.add_tile_offset( 1612 tile_coords * 1613 cutlass::MatrixCoord{ 1614 SmemIteratorD0::TileIterations::kRow, 1615 SmemIteratorD0::TileIterations::kColumn}); 1616 EpilogueWithLSE epilogue; 1617 EpilogueOpApplyLSE minus_lse_exp({}); 1618 epilogue( 1619 minus_lse_exp, 1620 smem_iterator_attn, 1621 accum, 1622 // scale - unused 1623 iterator_lse, 1624 // bias 1625 iterator_lse); 1626 } 1627 }; 1628 1629 // Volta Specialization 1630 // only supported for f16 1631 template <typename Operator, typename WarpShape_, typename ThreadblockShape_> 1632 struct B2bGemm< 1633 cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< 1634 cutlass::MatrixShape<32, 32>, 1635 float, 1636 cutlass::layout::RowMajor, 1637 cutlass::gemm::GemmShape<16, 16, 4>, 1638 cutlass::MatrixShape<1, 1>>, 1639 Operator, 1640 cutlass::half_t, 1641 WarpShape_, 1642 ThreadblockShape_> { 1643 using IteratorC = 1644 cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< 1645 cutlass::MatrixShape<32, 32>, 1646 float, 1647 cutlass::layout::RowMajor, 1648 cutlass::gemm::GemmShape<16, 16, 4>, 1649 cutlass::MatrixShape<1, 1>>; 1650 using scalar_t = cutlass::half_t; 1651 using accum_t = IteratorC::Element; 1652 using WarpShape = WarpShape_; 1653 using ThreadblockShape = ThreadblockShape_; 1654 using FragmentC = IteratorC::Fragment; 1655 using lse_scalar_t = float; 1656 1657 // Storage in shared-memory for Q.Kt 1658 using SmemAccumulatorLayout = 1659 cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; 1660 using AccumulatorSharedStorage = 1661 cutlass::gemm::threadblock::AccumulatorSharedStorage< 1662 ThreadblockShape, 1663 scalar_t, 1664 SmemAccumulatorLayout, 1665 cutlass::MatrixShape<0, 0> // Padding 1666 >; 1667 using TensorRef = cutlass::TensorRef<scalar_t, SmemAccumulatorLayout>; 1668 using Policy = typename IteratorC::Policy; 1669 using Element = accum_t; 1670 // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields 1671 // Let's copy their values 1672 static int const kElementsPerPartial = 4; 1673 using EleShapePerPatial = typename cutlass::platform::conditional< 1674 cutlass::platform::is_same<Element, float>::value, 1675 cutlass::MatrixShape<2, 2>, 1676 cutlass::MatrixShape<1, 4>>::type; 1677 static int const kElementsPerMma = 8; 1678 static int const kAccumulatorPatials = 2; 1679 using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; 1680 1681 static void CUTLASS_DEVICE accumToSmem( 1682 AccumulatorSharedStorage& shared_storage, 1683 FragmentC const& accum, 1684 int lane_id, 1685 cutlass::MatrixCoord const& tile_coords) { 1686 // ctor - from MmaVoltaTensorOpAccumulatorTileIterator 1687 TensorRef ref_(shared_storage.accum_ref()); 1688 int quad = (lane_id >> 2); 1689 int lane_in_quad = (lane_id & 3); 1690 int accum_m, accum_n; 1691 1692 if (cutlass::platform::is_same<Element, float>::value) { 1693 // (quad[2],quad[0])+lane_in_quad[0] 1694 accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); 1695 // (quad[1])+lane_in_quad[1] 1696 accum_n = 1697 ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + 1698 (lane_in_quad & 2); 1699 } else { 1700 accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + 1701 lane_in_quad; // (quad[2],quad[0]) 1702 accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; 1703 } 1704 cutlass::MatrixCoord lane_offset(accum_m, accum_n); 1705 1706 // Tile offset 1707 ref_.add_coord_offset( 1708 tile_coords * 1709 cutlass::MatrixCoord( 1710 {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); 1711 1712 using AccessType = cutlass::Array<scalar_t, EleShapePerPatial::kColumn>; 1713 1714 // store - from MmaVoltaTensorOpAccumulatorTileIterator 1715 CUTLASS_PRAGMA_UNROLL 1716 for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { 1717 CUTLASS_PRAGMA_UNROLL 1718 for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { 1719 CUTLASS_PRAGMA_UNROLL 1720 for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { 1721 CUTLASS_PRAGMA_UNROLL 1722 for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { 1723 int mma_accum_start = 1724 (((tile_n * Policy::TileIterations::kRow + tile_m) * 1725 Policy::MmaIterations::kColumn + 1726 mma_n) * 1727 Policy::MmaIterations::kRow + 1728 mma_m) * 1729 kElementsPerMma; 1730 1731 CUTLASS_PRAGMA_UNROLL 1732 for (int p = 0; p < kAccumulatorPatials; ++p) { 1733 CUTLASS_PRAGMA_UNROLL 1734 for (int m = 0; m < EleShapePerPatial::kRow; ++m) { 1735 int accum_m = tile_m * Policy::InterleavedTile::kRow + 1736 mma_m * QuadShapePerPatialMma::kRow + m * 2; 1737 int accum_n = tile_n * Policy::InterleavedTile::kColumn + 1738 mma_n * QuadShapePerPatialMma::kColumn + 1739 p * Policy::InterleavedTile::kColumn / 2; 1740 int r = (accum_m + lane_offset.row()); 1741 AccessType to_store; 1742 CUTLASS_PRAGMA_UNROLL 1743 for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { 1744 int idx = mma_accum_start + p * kElementsPerPartial + 1745 m * EleShapePerPatial::kColumn + n; 1746 int c = (accum_n + n + lane_offset.column()); 1747 to_store[n] = scalar_t(accum[idx]); 1748 } 1749 int c = (accum_n + lane_offset.column()); 1750 assert(r < 32); 1751 assert(c < 32); 1752 *reinterpret_cast<AccessType*>( 1753 ref_.data() + ref_.offset({r, c})) = to_store; 1754 } 1755 } 1756 } 1757 } 1758 } 1759 } 1760 } 1761 1762 static void CUTLASS_DEVICE accumApplyLSEToSmem( 1763 AccumulatorSharedStorage& shared_storage, 1764 typename IteratorC::Fragment& accum, 1765 lse_scalar_t const* lse, 1766 int lse_extent, 1767 int thread_id, 1768 int warp_id, 1769 int lane_id, 1770 cutlass::MatrixCoord const& tile_coords) { 1771 // Non-optimized way to apply LSE to registers 1772 // NOTE: accum is attn.T 1773 // TODO: Optimize for each architecture 1774 static constexpr int WarpSize = 32; 1775 using AccumLambdaIterator = 1776 typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>:: 1777 Iterator; 1778 auto lane_offset = 1779 AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); 1780 1781 cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched; 1782 lse_prefetched.clear(); 1783 int rowIdx = 0; 1784 int colIdx = 0; 1785 AccumLambdaIterator::iterateRows( 1786 lane_offset, 1787 [&](int accum_m) { 1788 ++rowIdx; 1789 colIdx = 0; 1790 }, 1791 [&](int accum_m, int accum_n, int idx) { 1792 if (rowIdx == 1) { 1793 lse_prefetched[colIdx] = accum_n < lse_extent 1794 ? lse[accum_n] 1795 : platform::numeric_limits<accum_t>::infinity(); 1796 } 1797 accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); 1798 ++colIdx; 1799 }, 1800 [&](int accum_m) {}); 1801 accumToSmem(shared_storage, accum, lane_id, tile_coords); 1802 } 1803 }; 1804 1805 // Simt Specialization 1806 // for f32 on Sm70-Sm75 and f16/f32 below 1807 1808 template < 1809 typename Operator, 1810 typename OperatorPolicy, 1811 typename scalar_t, 1812 typename WarpShape_, 1813 typename ThreadblockShape_> 1814 struct B2bGemm< 1815 cutlass::gemm::warp::MmaSimtTileIterator< 1816 cutlass::MatrixShape<32, 32>, 1817 cutlass::gemm::Operand::kC, 1818 float, 1819 cutlass::layout::RowMajor, 1820 OperatorPolicy, 1821 1, 1822 1>, 1823 Operator, 1824 scalar_t, 1825 WarpShape_, 1826 ThreadblockShape_> { 1827 using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< 1828 cutlass::MatrixShape<32, 32>, 1829 cutlass::gemm::Operand::kC, 1830 float, 1831 cutlass::layout::RowMajor, 1832 OperatorPolicy, 1833 1, 1834 1>; 1835 using accum_t = typename IteratorC::Element; 1836 using WarpShape = WarpShape_; 1837 using ThreadblockShape = ThreadblockShape_; 1838 using FragmentC = typename IteratorC::Fragment; 1839 using lse_scalar_t = float; 1840 1841 // Storage in shared-memory for Q.Kt 1842 using AccumulatorSharedStorage = 1843 cutlass::gemm::threadblock::AccumulatorSharedStorage< 1844 ThreadblockShape, 1845 scalar_t, 1846 cutlass::layout::ColumnMajor, 1847 cutlass::MatrixShape<0, 0> // Padding 1848 >; 1849 1850 static void CUTLASS_DEVICE accumToSmem( 1851 AccumulatorSharedStorage& shared_storage, 1852 FragmentC const& accum, 1853 int lane_id, 1854 cutlass::MatrixCoord const& tile_coords) { 1855 using Policy = typename IteratorC::Policy; 1856 using Element = typename IteratorC::Element; 1857 using Iterations = typename IteratorC::Iterations; 1858 using Delta = typename IteratorC::Delta; 1859 1860 auto ref_ = shared_storage.accum_ref(); 1861 // ctor - MmaSimtTileIterator 1862 // compute offset based on thread ID and lane layout 1863 typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); 1864 1865 MatrixCoord lane_offset = lane_layout.inverse(lane_id) * 1866 MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); 1867 1868 ref_.add_coord_offset(lane_offset); 1869 1870 // Tile offset 1871 ref_.add_coord_offset( 1872 tile_coords * 1873 cutlass::MatrixCoord( 1874 {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); 1875 1876 // store - MmaSimtTileIterator 1877 CUTLASS_PRAGMA_UNROLL 1878 for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { 1879 CUTLASS_PRAGMA_UNROLL 1880 for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { 1881 CUTLASS_PRAGMA_UNROLL 1882 for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { 1883 CUTLASS_PRAGMA_UNROLL 1884 for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { 1885 int r = 1886 Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) + 1887 m; 1888 int c = mma_n * Delta::kColumn + n; 1889 int idx = n + 1890 Policy::LaneMmaShape::kN * 1891 (mma_n + 1892 Iterations::kColumn * 1893 (m + mma_m * Policy::LaneMmaShape::kM)); 1894 ref_.at({r, c}) = scalar_t(accum[idx]); 1895 } 1896 } 1897 } 1898 } 1899 } 1900 1901 static void CUTLASS_DEVICE accumApplyLSEToSmem( 1902 AccumulatorSharedStorage& shared_storage, 1903 typename IteratorC::Fragment& accum, 1904 lse_scalar_t const* lse, 1905 int lse_extent, 1906 int thread_id, 1907 int warp_id, 1908 int lane_id, 1909 cutlass::MatrixCoord const& tile_coords) { 1910 // Non-optimized way to apply LSE to registers 1911 // NOTE: accum is attn.T 1912 // TODO: Optimize for each architecture 1913 static constexpr int WarpSize = 32; 1914 using AccumLambdaIterator = 1915 typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>:: 1916 Iterator; 1917 auto lane_offset = 1918 AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); 1919 1920 cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched; 1921 lse_prefetched.clear(); 1922 int rowIdx = 0; 1923 int colIdx = 0; 1924 AccumLambdaIterator::iterateRows( 1925 lane_offset, 1926 [&](int accum_m) { 1927 ++rowIdx; 1928 colIdx = 0; 1929 }, 1930 [&](int accum_m, int accum_n, int idx) { 1931 if (rowIdx == 1) { 1932 lse_prefetched[colIdx] = accum_n < lse_extent 1933 ? lse[accum_n] 1934 : platform::numeric_limits<accum_t>::infinity(); 1935 } 1936 accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); 1937 ++colIdx; 1938 }, 1939 [&](int accum_m) {}); 1940 accumToSmem(shared_storage, accum, lane_id, tile_coords); 1941 } 1942 }; 1943 1944 } // namespace threadblock 1945 } // namespace gemm 1946 } // namespace cutlass 1947 1948 ///////////////////////////////////////////////////////////////////////////////////////////////// 1949