1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 *reserved. SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 *this list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 *POSSIBILITY OF SUCH DAMAGE. 30 * 31 **************************************************************************************************/ 32 /*! \file 33 \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. 34 35 File copied from <cutlass/epilogue/threadblock/epilogue.h> 36 then modified to: 37 (1) load 2 source fragments at the same time (pipelining) 38 (2) support reading from a different dtype 39 (3) pass the row id to the OutputOp if it takes it 40 (see MemoryEfficientAttentionNormalize) 41 Note that in general the fragment passed to the OutputOp could 42 span multiple rows but it does not happen with the configurations we have 43 */ 44 45 #pragma once 46 47 #if defined(__CUDACC_RTC__) 48 #include <cuda/std/cassert> 49 #else 50 #include <cassert> 51 #endif 52 53 #include <cutlass/aligned_buffer.h> 54 #include <cutlass/array.h> 55 #include <cutlass/cutlass.h> 56 #include <cutlass/functional.h> 57 #include <cutlass/layout/tensor.h> 58 #include <cutlass/layout/vector.h> 59 #include <cutlass/numeric_types.h> 60 #include <cutlass/tensor_coord.h> 61 62 #include <cutlass/gemm/gemm.h> 63 64 #include <cutlass/transform/pitch_linear_thread_map.h> 65 #include <cutlass/transform/threadblock/regular_tile_iterator.h> 66 67 #include <cutlass/epilogue/threadblock/epilogue_base.h> 68 #include <cutlass/epilogue/threadblock/predicated_tile_iterator.h> 69 #include <cutlass/numeric_types.h> 70 71 //////////////////////////////////////////////////////////////////////////////// 72 73 namespace cutlass { 74 namespace epilogue { 75 namespace threadblock { 76 77 template <typename Op> 78 struct ApplyEpilogueOp { applyApplyEpilogueOp79 static CUTLASS_DEVICE typename Op::FragmentOutput apply( 80 Op const& output_op, 81 int row_id, 82 typename Op::FragmentAccumulator const& accum, 83 typename Op::FragmentOutput const& source) { 84 return output_op(accum, source); 85 } applyApplyEpilogueOp86 static CUTLASS_DEVICE typename Op::FragmentOutput apply( 87 Op const& output_op, 88 int row_id, 89 typename Op::FragmentAccumulator const& accum) { 90 return output_op(accum); 91 } 92 }; 93 94 //////////////////////////////////////////////////////////////////////////////// 95 96 /// Epilogue operator 97 template < 98 typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) 99 typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: 100 ///< gemm::warp::MmaTensorOp) 101 int PartitionsK, ///< Number of partitions of the K dimension 102 typename OutputTileIterator_, ///< Tile iterator writing output tensors 103 typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting 104 ///< accumulators 105 typename WarpTileIterator_, ///< Warp-scoped tile iterator writing 106 ///< accumulators to SMEM 107 typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading 108 ///< from SMEM 109 typename OutputOp_, ///< Output operator 110 typename Padding_, ///< Padding added to SMEM allocation to avoid bank 111 ///< conflicts (concept: MatrixShape) 112 int FragmentsPerPartition = 113 1, ///< Used to coarsten the epilogue granularity 114 int IterationsUnroll = ///< Used to reduce binary size when epilogue op is 115 ///< large 116 (!IsEpilogueFunctorHeavy<OutputOp_>::value), 117 typename OutputTileSourceIterator_ = 118 OutputTileIterator_ ///< Tile iterator reading tensors 119 > 120 class EpiloguePipelined : public EpilogueBase< 121 Shape_, 122 typename WarpMmaOperator_::Shape, 123 PartitionsK, 124 AccumulatorFragmentIterator_, 125 WarpTileIterator_, 126 Padding_, 127 FragmentsPerPartition> { 128 public: 129 using Base = EpilogueBase< 130 Shape_, 131 typename WarpMmaOperator_::Shape, 132 PartitionsK, 133 AccumulatorFragmentIterator_, 134 WarpTileIterator_, 135 Padding_, 136 FragmentsPerPartition>; 137 138 using Shape = Shape_; 139 using WarpMmaOperator = WarpMmaOperator_; 140 static int const kPartitionsK = PartitionsK; 141 using OutputTileIterator = OutputTileIterator_; 142 using OutputTileSourceIterator = OutputTileSourceIterator_; 143 using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; 144 using WarpTileIterator = WarpTileIterator_; 145 using SharedLoadIterator = SharedLoadIterator_; 146 using OutputOp = OutputOp_; 147 using Padding = Padding_; 148 149 using Layout = layout::RowMajor; 150 using LongIndex = typename Layout::LongIndex; 151 152 /// The complete warp-level accumulator tile 153 using AccumulatorTile = typename Base::AccumulatorTile; 154 155 /// Accumulator element 156 using ElementAccumulator = typename WarpTileIterator::Element; 157 158 /// Output element 159 using ElementOutput = typename OutputTileIterator::Element; 160 using ElementSource = typename OutputTileSourceIterator::Element; 161 162 /// Output access size 163 static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; 164 165 /// Tensor reference to destination tensor 166 using TensorRef = typename OutputTileIterator::TensorRef; 167 168 /// Tensor reference to sync tensor 169 using SyncTensorRef = 170 typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>; 171 172 /// Const tensor reference to source tensor 173 using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; 174 175 /// Array type used to output 176 using OutputAccessType = Array< 177 typename OutputTileIterator::Element, 178 OutputTileIterator::kElementsPerAccess>; 179 using SourceAccessType = Array< 180 typename OutputTileSourceIterator::Element, 181 OutputTileSourceIterator::kElementsPerAccess>; 182 183 /// Array type used by output functor 184 using AccumulatorAccessType = Array< 185 typename WarpTileIterator::Element, 186 OutputTileIterator::kElementsPerAccess>; 187 188 /// Number of warps 189 using WarpCount = typename Base::WarpCount; 190 191 static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 192 ? Base::kFragmentsPerIteration 193 : kPartitionsK; 194 static int constexpr kSmemPointerOffset = 195 Base::SharedStorage::StorageShape::kCount / kSmemTiles; 196 197 public: 198 static_assert( 199 OutputTileSourceIterator::Fragment::kElements == 200 OutputTileIterator::Fragment::kElements, 201 "Mismatch between input tile and output tile iterator (kElements)"); 202 static_assert( 203 OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations, 204 "Mismatch between input tile and output tile iterator (kIterations)"); 205 static_assert( 206 SharedLoadIterator::Fragment::kElements == 207 OutputTileIterator::Fragment::kElements, 208 "Mismatch between shared load iterator and output tile iterator."); 209 210 static_assert( 211 OutputTileIterator::kElementsPerAccess, 212 "OutputTileIterator::kElementsPerAccess must not be zero."); 213 214 static_assert( 215 !(OutputTileIterator::Fragment::kElements % 216 OutputTileIterator::kElementsPerAccess), 217 "Divisibility"); 218 219 private: 220 /// Loads fragment from shared memory aligned with output tensor 221 SharedLoadIterator shared_load_iterator_; 222 223 public: 224 /// Constructor 225 CUTLASS_DEVICE EpiloguePipelined(typename Base::SharedStorage & shared_storage,int thread_idx,int warp_idx,int lane_idx)226 EpiloguePipelined( 227 typename Base::SharedStorage& shared_storage, ///< Shared storage object 228 int thread_idx, ///< ID of a thread within the threadblock 229 int warp_idx, ///< ID of warp within threadblock 230 int lane_idx ///< Id of thread within warp 231 ) 232 : Base(shared_storage, thread_idx, warp_idx, lane_idx), 233 shared_load_iterator_(shared_storage.reference(), thread_idx) {} 234 235 /// Streams the result to global memory 236 CUTLASS_DEVICE operator()237 void operator()( 238 OutputOp const& output_op, ///< Output operator 239 OutputTileIterator 240 destination_iterator, ///< Tile iterator for destination 241 AccumulatorTile const& 242 accumulators, ///< Complete warp-level accumulator tile 243 OutputTileSourceIterator 244 source_iterator) { ///< Threadblock tile coordinate in GEMM (in units 245 ///< of threadblock tiles) 246 247 if (!output_op.is_source_needed()) { 248 compute_source_not_needed_(output_op, destination_iterator, accumulators); 249 } else { 250 compute_source_needed_( 251 output_op, destination_iterator, accumulators, source_iterator); 252 } 253 } 254 CUTLASS_DEVICE operator()255 void operator()( 256 OutputOp const& output_op, ///< Output operator 257 OutputTileIterator 258 destination_iterator, ///< Tile iterator for destination 259 AccumulatorTile const& 260 accumulators) { ///< Complete warp-level accumulator tile 261 compute_source_not_needed_(output_op, destination_iterator, accumulators); 262 } 263 264 private: 265 template <class Seq> 266 struct acc2smem_source_not_needed; 267 268 template <size_t... Seq> 269 struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> { 270 template <int Advance> 271 CUTLASS_DEVICE static void helper( 272 AccumulatorFragmentIterator accum_fragment_iterator, 273 WarpTileIterator& warp_tile_iterator) { 274 CUTLASS_PRAGMA_UNROLL 275 for (int i = 0; i < Advance; i++) { 276 ++accum_fragment_iterator; 277 } 278 279 CUTLASS_PRAGMA_UNROLL 280 for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { 281 typename AccumulatorFragmentIterator::Fragment accum_fragment; 282 283 accum_fragment_iterator.load(accum_fragment); 284 ++accum_fragment_iterator; 285 286 warp_tile_iterator.store(accum_fragment); 287 if (p < Base::kFragmentsPerIteration - 1) { 288 warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); 289 } 290 } 291 292 if (Base::kFragmentsPerIteration > 1) { 293 warp_tile_iterator.add_pointer_offset( 294 kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); 295 } 296 } 297 298 CUTLASS_DEVICE 299 static void push( 300 size_t pos, 301 AccumulatorFragmentIterator const& iterator_begin, 302 WarpTileIterator& warp_tile_iterator) { 303 int dummy[] = { 304 (pos == (Seq * Base::kFragmentsPerIteration)) && 305 (helper<Seq * Base::kFragmentsPerIteration>( 306 iterator_begin, warp_tile_iterator), 307 0)...}; 308 309 CUTLASS_UNUSED(dummy[0]); 310 } 311 }; 312 313 static_assert( 314 kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, 315 "One of these must be exactly 1."); 316 317 /// Streams the result to global memory 318 CUTLASS_DEVICE 319 void compute_source_not_needed_( 320 OutputOp const& output_op, ///< Output operator 321 OutputTileIterator 322 destination_iterator, ///< Tile iterator for destination 323 AccumulatorTile const& 324 accumulators ///< Complete warp-level accumulator tile 325 ) { 326 // 327 // Iterator over warp-level accumulator fragment 328 // 329 330 AccumulatorFragmentIterator accum_fragment_iterator(accumulators); 331 332 // 333 // Iterate over accumulator tile 334 // 335 336 #pragma unroll( \ 337 IterationsUnroll \ 338 ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ 339 : 1) 340 for (int iter = 0; iter < OutputTileIterator::kIterations; 341 iter += Base::kFragmentsPerIteration) { 342 // 343 // Convert and store fragment 344 // 345 346 __syncthreads(); 347 348 acc2smem_source_not_needed<cutlass::make_index_sequence< 349 OutputTileIterator::kIterations / Base::kFragmentsPerIteration>>:: 350 push(iter, accum_fragment_iterator, this->warp_tile_iterator_); 351 352 __syncthreads(); 353 354 // 355 // Load fragments from shared memory 356 // 357 358 CUTLASS_PRAGMA_UNROLL 359 for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { 360 typename SharedLoadIterator::Fragment 361 aligned_accum_fragment[kPartitionsK]; 362 363 shared_load_iterator_.load(aligned_accum_fragment[0]); 364 365 if (p < Base::kFragmentsPerIteration - 1) { 366 shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); 367 } else if (kPartitionsK > 1) { 368 plus<typename SharedLoadIterator::Fragment> add_fragments; 369 370 CUTLASS_PRAGMA_UNROLL 371 for (int i = 1; i < kPartitionsK; ++i) { 372 shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); 373 shared_load_iterator_.load(aligned_accum_fragment[i]); 374 aligned_accum_fragment[0] = add_fragments( 375 aligned_accum_fragment[0], aligned_accum_fragment[i]); 376 } 377 378 shared_load_iterator_.add_pointer_offset( 379 (1 - kPartitionsK) * kSmemPointerOffset); 380 } 381 382 // 383 // Compute the output result 384 // 385 386 typename OutputTileIterator::Fragment output_fragment; 387 388 apply_output_operator_source_not_needed_( 389 destination_iterator.thread_start_row(), 390 output_fragment, 391 output_op, 392 aligned_accum_fragment[0]); 393 394 // 395 // Store the final result 396 // 397 398 destination_iterator.store(output_fragment); 399 ++destination_iterator; 400 } 401 402 if (Base::kFragmentsPerIteration > 1) { 403 shared_load_iterator_.add_pointer_offset( 404 kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); 405 } 406 } 407 } 408 409 template <class Seq> 410 struct acc2smem_source_needed; 411 412 template <size_t... Seq> 413 struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> { 414 template <int Advance> 415 CUTLASS_DEVICE static void helper( 416 AccumulatorFragmentIterator accum_fragment_iterator, 417 WarpTileIterator& warp_tile_iterator) { 418 CUTLASS_PRAGMA_UNROLL 419 for (int i = 0; i < Advance; i++) { 420 ++accum_fragment_iterator; 421 } 422 423 typename AccumulatorFragmentIterator::Fragment accum_fragment; 424 accum_fragment_iterator.load(accum_fragment); 425 warp_tile_iterator.store(accum_fragment); 426 } 427 428 CUTLASS_DEVICE 429 static void push( 430 size_t pos, 431 AccumulatorFragmentIterator const& iterator_begin, 432 WarpTileIterator& warp_tile_iterator) { 433 int dummy[] = { 434 (pos == Seq) && 435 (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...}; 436 } 437 }; 438 439 /// Streams the result to global memory 440 CUTLASS_DEVICE 441 void compute_source_needed_( 442 OutputOp const& output_op, ///< Output operator 443 OutputTileIterator 444 destination_iterator, ///< Tile iterator for destination 445 AccumulatorTile const& 446 accumulators, ///< Complete warp-level accumulator tile 447 OutputTileSourceIterator 448 source_iterator ///< Threadblock tile coordinate in GEMM (in units of 449 ///< threadblock tiles) 450 ) { 451 typename OutputTileSourceIterator::Fragment source_fragment[2]; 452 453 source_fragment[0].clear(); 454 source_iterator.load(source_fragment[0]); 455 ++source_iterator; 456 source_fragment[1].clear(); 457 458 // 459 // Iterator over warp-level accumulator fragment 460 // 461 462 AccumulatorFragmentIterator accum_fragment_iterator(accumulators); 463 464 // 465 // Iterate over accumulator tile 466 // 467 468 #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) 469 for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { 470 if (iter > 0) { 471 __syncthreads(); 472 } 473 // 474 // Load the source for next iteration (pipelining) 475 // 476 477 if (iter + 1 < OutputTileIterator::kIterations) { 478 source_iterator.load(source_fragment[(iter + 1) % 2]); 479 } 480 ++source_iterator; 481 acc2smem_source_needed< 482 cutlass::make_index_sequence<OutputTileIterator::kIterations>>:: 483 push(iter, accum_fragment_iterator, this->warp_tile_iterator_); 484 485 __syncthreads(); 486 487 // 488 // Load fragments from shared memory 489 // 490 491 typename SharedLoadIterator::Fragment 492 aligned_accum_fragment[kPartitionsK]; 493 494 shared_load_iterator_.load(aligned_accum_fragment[0]); 495 496 // If the number of k-slices is > 1 - perform a reduction amongst the 497 // k-slices 498 if (kPartitionsK > 1) { 499 plus<typename SharedLoadIterator::Fragment> add_fragments; 500 501 CUTLASS_PRAGMA_UNROLL 502 for (int i = 1; i < kPartitionsK; ++i) { 503 shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); 504 shared_load_iterator_.load(aligned_accum_fragment[i]); 505 aligned_accum_fragment[0] = add_fragments( 506 aligned_accum_fragment[0], aligned_accum_fragment[i]); 507 } 508 509 shared_load_iterator_.add_pointer_offset( 510 (1 - kPartitionsK) * kSmemPointerOffset); 511 } 512 513 // 514 // Compute the output result 515 // 516 517 typename OutputTileIterator::Fragment output_fragment; 518 519 apply_output_operator_( 520 destination_iterator.thread_start_row(), 521 output_fragment, 522 output_op, 523 aligned_accum_fragment[0], 524 source_fragment[iter % 2]); 525 526 // 527 // Store the final result 528 // 529 530 destination_iterator.store(output_fragment); 531 ++destination_iterator; 532 } 533 } 534 535 /// Helper to invoke the output functor over each vector of output 536 CUTLASS_DEVICE 537 void apply_output_operator_( 538 int begin_row, 539 typename OutputTileIterator::Fragment& output_fragment, 540 OutputOp const& output_op, ///< Output operator 541 typename SharedLoadIterator::Fragment const& aligned_accum_fragment, 542 typename OutputTileSourceIterator::Fragment const& source_fragment) { 543 OutputAccessType* output_frag_ptr = 544 reinterpret_cast<OutputAccessType*>(&output_fragment); 545 546 AccumulatorAccessType const* compute_frag_ptr = 547 reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment); 548 549 SourceAccessType const* source_frag_ptr = 550 reinterpret_cast<SourceAccessType const*>(&source_fragment); 551 552 int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / 553 OutputTileIterator::kElementsPerAccess; 554 555 CUTLASS_PRAGMA_UNROLL 556 for (int i = 0; i < kOutputOpIterations; ++i) { 557 // Call the output operator 558 output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply( 559 output_op, 560 begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), 561 compute_frag_ptr[i], 562 source_frag_ptr[i]); 563 } 564 } 565 566 /// Helper to invoke the output functor over each vector of output 567 CUTLASS_DEVICE 568 void apply_output_operator_source_not_needed_( 569 int begin_row, 570 typename OutputTileIterator::Fragment& output_fragment, 571 OutputOp const& output_op, ///< Output operator 572 typename SharedLoadIterator::Fragment const& aligned_accum_fragment) { 573 OutputAccessType* output_frag_ptr = 574 reinterpret_cast<OutputAccessType*>(&output_fragment); 575 576 AccumulatorAccessType const* compute_frag_ptr = 577 reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment); 578 579 int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / 580 OutputTileIterator::kElementsPerAccess; 581 582 CUTLASS_PRAGMA_UNROLL 583 for (int i = 0; i < kOutputOpIterations; ++i) { 584 // Call the output operator 585 output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply( 586 output_op, 587 begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), 588 compute_frag_ptr[i]); 589 } 590 } 591 592 // This should be constexpr, but it's only supported on c++14 593 constexpr int CUTLASS_HOST_DEVICE getRowOffset(int i) { 594 using ThreadMap = typename OutputTileIterator::ThreadMap; 595 596 CUTLASS_PRAGMA_UNROLL 597 for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; 598 ++cluster) { 599 CUTLASS_PRAGMA_UNROLL 600 for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { 601 CUTLASS_PRAGMA_UNROLL 602 for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { 603 int row_offset = row * ThreadMap::Delta::kRow + 604 group * ThreadMap::Delta::kGroup + 605 cluster * ThreadMap::Delta::kCluster; 606 int frag_row_idx = 607 (row + 608 ThreadMap::Iterations::kRow * 609 (group + ThreadMap::Iterations::kGroup * cluster)); 610 CUTLASS_PRAGMA_UNROLL 611 for (int column = 0; column < ThreadMap::Iterations::kColumn; 612 ++column) { 613 int frag_idx = ThreadMap::kElementsPerAccess * 614 (frag_row_idx * ThreadMap::Iterations::kColumn + column); 615 if (i < frag_idx + ThreadMap::kElementsPerAccess) { 616 return row_offset; 617 } 618 } 619 } 620 } 621 } 622 return -1; 623 } 624 }; 625 626 //////////////////////////////////////////////////////////////////////////////// 627 628 } // namespace threadblock 629 } // namespace epilogue 630 } // namespace cutlass 631 632 //////////////////////////////////////////////////////////////////////////////// 633