1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 *reserved. SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 *this list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 *POSSIBILITY OF SUCH DAMAGE. 30 * 31 **************************************************************************************************/ 32 /*! \file 33 \brief Template for a double-buffered threadblock-scoped GEMM kernel. 34 */ 35 36 #pragma once 37 38 #include <cutlass/aligned_buffer.h> 39 #include <cutlass/array.h> 40 #include <cutlass/cutlass.h> 41 #include <cutlass/numeric_conversion.h> 42 43 #include <cutlass/matrix_shape.h> 44 #include <cutlass/numeric_types.h> 45 46 #include <cutlass/gemm/gemm.h> 47 48 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h> 49 50 ///////////////////////////////////////////////////////////////////////////////////////////////// 51 52 namespace cutlass { 53 namespace gemm { 54 namespace threadblock { 55 56 ///////////////////////////////////////////////////////////////////////////////////////////////// 57 58 /// Structure to compute the matrix product targeting CUDA cores and SIMT math 59 /// instructions. 60 template < 61 /// Size of the Gemm problem - concept: gemm::GemmShape<> 62 typename Shape_, 63 /// Iterates over tiles of A operand in global memory 64 // (concept: ReadableTileIterator | ForwardTileIterator | 65 // MaskedTileIterator) 66 typename IteratorA_, 67 /// Iterates over tiles of A operand in shared memory 68 /// (concept: WriteableTileIterator | RandomAccessTileIterator) 69 typename SmemIteratorA_, 70 /// Iterates over tiles of B operand in global memory 71 // (concept: ReadableTileIterator | ForwardTileIterator | 72 // MaskedTileIterator) 73 typename IteratorB_, 74 /// Iterates over tiles of B operand in shared memory 75 /// (concept: WriteableTileIterator | RandomAccessTileIterator) 76 typename SmemIteratorB_, 77 /// Data type of accumulator matrix 78 typename ElementC_, 79 /// Data type of accumulator matrix 80 typename LayoutC_, 81 /// Policy describing tuning details (concept: MmaPolicy) 82 typename Policy_, 83 /// Transformation applied to A operand 84 typename TransformA_ = NumericArrayConverter< 85 typename SmemIteratorA_::Element, 86 typename IteratorA_::Element, 87 IteratorA_::Fragment::kElements>, 88 /// 89 /// Transformation applied to B operand 90 typename TransformB_ = NumericArrayConverter< 91 typename SmemIteratorB_::Element, 92 typename IteratorB_::Element, 93 IteratorB_::Fragment::kElements>, 94 /// Used for partial specialization 95 typename Enable = bool> 96 class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> { 97 public: 98 ///< Base class 99 using Base = CustomMmaBase<Shape_, Policy_, 2>; 100 101 using Shape = 102 Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> 103 using IteratorA = 104 IteratorA_; ///< Iterates over tiles of A operand in global memory 105 using IteratorB = 106 IteratorB_; ///< Iterates over tiles of B operand in global memory 107 using ElementC = ElementC_; ///< Data type of accumulator matrix 108 using LayoutC = LayoutC_; ///< Layout of accumulator matrix 109 using Policy = Policy_; ///< Policy describing tuning details 110 111 using SmemIteratorA = SmemIteratorA_; 112 using SmemIteratorB = SmemIteratorB_; 113 114 using TransformA = TransformA_; 115 using TransformB = TransformB_; 116 117 // 118 // Dependent types 119 // 120 121 /// Fragment of operand A loaded from global memory 122 using FragmentA = typename IteratorA::Fragment; 123 124 /// Fragment of operand B loaded from global memory 125 using FragmentB = typename IteratorB::Fragment; 126 127 /// Fragment of accumulator tile 128 using FragmentC = typename Policy::Operator::FragmentC; 129 130 /// Warp-level Mma 131 using Operator = typename Policy::Operator; 132 133 /// Obtain the arch tag from the warp-level operator 134 using ArchTag = typename Policy::Operator::ArchTag; 135 136 /// Complex transform on A operand 137 static ComplexTransform const kTransformA = Operator::kTransformA; 138 139 /// Complex transform on B operand 140 static ComplexTransform const kTransformB = Operator::kTransformB; 141 142 // statically assert kStages for MmaPipelined is two (Double-buffered pipeline) 143 static_assert( 144 (Base::kStages == 2), 145 "MmaPipelined requires kStages set to value 2"); 146 147 static bool const kSmemContainsEntireMat = false; 148 149 private: 150 using WarpFragmentA = typename Operator::FragmentA; 151 using WarpFragmentB = typename Operator::FragmentB; 152 153 protected: 154 /// Iterator to write threadblock-scoped tile of A operand to shared memory 155 SmemIteratorA smem_iterator_A_; 156 157 /// Iterator to write threadblock-scoped tile of B operand to shared memory 158 SmemIteratorB smem_iterator_B_; 159 160 public: 161 /// Construct from tensor references 162 CUTLASS_DEVICE CustomMmaPipelined(typename Base::SharedStorageA & shared_storageA,typename Base::SharedStorageB & shared_storageB,int thread_idx,int warp_idx,int lane_idx)163 CustomMmaPipelined( 164 typename Base::SharedStorageA& shared_storageA, 165 typename Base::SharedStorageB& shared_storageB, 166 int thread_idx, ///< ID within the threadblock 167 int warp_idx, ///< ID of warp 168 int lane_idx ///< ID of each thread within a warp 169 ) 170 : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), 171 smem_iterator_A_(shared_storageA.ref(), thread_idx), 172 smem_iterator_B_(shared_storageB.ref(), thread_idx) { 173 // Compute warp location within threadblock tile by mapping the warp_id to 174 // three coordinates: 175 // _m: the warp's position within the threadblock along the M dimension 176 // _n: the warp's position within the threadblock along the N dimension 177 // _k: the warp's position within the threadblock along the K dimension 178 179 int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); 180 int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); 181 182 int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; 183 int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; 184 185 // Add per-warp offsets in units of warp-level tiles 186 this->warp_tile_iterator_A_.add_tile_offset( 187 {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); 188 this->warp_tile_iterator_B_.add_tile_offset( 189 {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); 190 } 191 CUTLASS_DEVICE CustomMmaPipelined(typename Base::SharedStorage & st,int thread_idx,int warp_idx,int lane_idx)192 CustomMmaPipelined( 193 ///< Shared storage needed for internal use by threadblock-scoped GEMM 194 typename Base::SharedStorage& st, 195 ///< ID within the threadblock 196 int thread_idx, 197 ///< ID of warp 198 int warp_idx, 199 ///< ID of each thread within a warp 200 int lane_idx) 201 : CustomMmaPipelined( 202 st.operand_A, 203 st.operand_B, 204 thread_idx, 205 warp_idx, 206 lane_idx) {} 207 208 CUTLASS_DEVICE set_prologue_done(bool value)209 void set_prologue_done(bool value) { 210 // NOT IMPLEMENTED FOR PIPELINED 211 } 212 213 CUTLASS_DEVICE set_zero_outside_bounds(bool value)214 void set_zero_outside_bounds(bool value) { 215 // NOT NEEDED FOR PIPELINED 216 // shared memory will always be zero-filled 217 } 218 219 template <bool kLoadA = true, bool kLoadB = true> prologue(typename Base::SharedStorage & shared_storage,IteratorA iterator_A,IteratorB iterator_B,int thread_idx,int problem_size_k)220 CUTLASS_DEVICE static void prologue( 221 typename Base::SharedStorage& shared_storage, 222 ///< iterator over A operand in global memory 223 IteratorA iterator_A, 224 ///< iterator over B operand in global memory 225 IteratorB iterator_B, 226 int thread_idx, 227 int problem_size_k) { 228 prologue<kLoadA, kLoadB>( 229 shared_storage.operand_A, 230 shared_storage.operand_B, 231 iterator_A, 232 iterator_B, 233 thread_idx, 234 problem_size_k); 235 } 236 237 template <bool kLoadA = true, bool kLoadB = true> prologue(typename Base::SharedStorageA & shared_storageA,typename Base::SharedStorageB & shared_storageB,IteratorA iterator_A,IteratorB iterator_B,int thread_idx,int problem_size_k)238 CUTLASS_DEVICE static void prologue( 239 typename Base::SharedStorageA& shared_storageA, 240 typename Base::SharedStorageB& shared_storageB, 241 ///< iterator over A operand in global memory 242 IteratorA iterator_A, 243 ///< iterator over B operand in global memory 244 IteratorB iterator_B, 245 int thread_idx, 246 int problem_size_k) { 247 // NOT IMPLEMENTED FOR PIPELINED 248 } 249 250 /// Perform a threadblock-scoped matrix multiply-accumulate 251 CUTLASS_DEVICE operator()252 void operator()( 253 int gemm_k_iterations, ///< number of iterations of the mainloop 254 FragmentC& accum, ///< destination accumulator tile 255 IteratorA iterator_A, ///< iterator over A operand in global memory 256 IteratorB iterator_B, ///< iterator over B operand in global memory 257 FragmentC const& src_accum, ///< source accumulator tile 258 TransformA transform_A = 259 TransformA(), ///< transformation applied to A fragment 260 TransformB transform_B = 261 TransformB()) { ///< transformation applied to B fragment 262 263 // 264 // Prologue 265 // 266 267 // Perform accumulation in the 'd' output operand 268 accum = src_accum; 269 270 FragmentA tb_frag_A; 271 FragmentB tb_frag_B; 272 273 tb_frag_A.clear(); 274 tb_frag_B.clear(); 275 276 // The last kblock is loaded in the prolog 277 iterator_A.load(tb_frag_A); 278 iterator_B.load(tb_frag_B); 279 280 ++iterator_A; 281 ++iterator_B; 282 283 this->smem_iterator_A_.store(transform_A(tb_frag_A)); 284 this->smem_iterator_B_.store(transform_B(tb_frag_B)); 285 286 ++this->smem_iterator_A_; 287 ++this->smem_iterator_B_; 288 289 __syncthreads(); 290 291 // Pair of fragments used to overlap shared memory loads and math 292 // instructions 293 WarpFragmentA warp_frag_A[2]; 294 WarpFragmentB warp_frag_B[2]; 295 296 this->warp_tile_iterator_A_.set_kgroup_index(0); 297 this->warp_tile_iterator_B_.set_kgroup_index(0); 298 299 this->warp_tile_iterator_A_.load(warp_frag_A[0]); 300 this->warp_tile_iterator_B_.load(warp_frag_B[0]); 301 302 ++this->warp_tile_iterator_A_; 303 ++this->warp_tile_iterator_B_; 304 305 Operator warp_mma; 306 307 int smem_write_stage_idx = 1; 308 309 // Avoid reading out of bounds 310 iterator_A.clear_mask(gemm_k_iterations <= 1); 311 iterator_B.clear_mask(gemm_k_iterations <= 1); 312 313 // Issue loads during the first warp-level matrix multiply-add *AFTER* 314 // issuing shared memory loads (which have the tighest latency requirement). 315 316 // 317 // Mainloop 318 // 319 320 // Note: The main loop does not support Base::kWarpGemmIterations == 2. 321 CUTLASS_GEMM_LOOP 322 for (; gemm_k_iterations > 0; --gemm_k_iterations) { 323 // 324 // Loop over GEMM K dimension 325 // 326 327 CUTLASS_PRAGMA_UNROLL 328 for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; 329 ++warp_mma_k) { 330 // Load warp-level tiles from shared memory, wrapping to k offset if 331 // this is the last group as the case may be. 332 333 if (warp_mma_k == Base::kWarpGemmIterations - 1) { 334 // Write fragments to shared memory 335 this->smem_iterator_A_.store(transform_A(tb_frag_A)); 336 337 this->smem_iterator_B_.store(transform_B(tb_frag_B)); 338 339 __syncthreads(); 340 341 ++this->smem_iterator_A_; 342 ++this->smem_iterator_B_; 343 344 // Add negative offsets to return iterators to the 'start' of the 345 // circular buffer in shared memory 346 if (smem_write_stage_idx == 1) { 347 this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); 348 this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); 349 } else { 350 this->warp_tile_iterator_A_.add_tile_offset( 351 {0, 352 -Base::kStages * Policy::kPartitionsK * 353 Base::kWarpGemmIterations}); 354 this->warp_tile_iterator_B_.add_tile_offset( 355 {-Base::kStages * Policy::kPartitionsK * 356 Base::kWarpGemmIterations, 357 0}); 358 } 359 360 smem_write_stage_idx ^= 1; 361 } 362 363 this->warp_tile_iterator_A_.set_kgroup_index( 364 (warp_mma_k + 1) % Base::kWarpGemmIterations); 365 this->warp_tile_iterator_B_.set_kgroup_index( 366 (warp_mma_k + 1) % Base::kWarpGemmIterations); 367 368 this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); 369 this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); 370 371 ++this->warp_tile_iterator_A_; 372 ++this->warp_tile_iterator_B_; 373 374 if (warp_mma_k == 0) { 375 iterator_A.load(tb_frag_A); 376 iterator_B.load(tb_frag_B); 377 378 ++iterator_A; 379 ++iterator_B; 380 381 // Avoid reading out of bounds if this was the last loop iteration 382 iterator_A.clear_mask(gemm_k_iterations <= 2); 383 iterator_B.clear_mask(gemm_k_iterations <= 2); 384 } 385 386 warp_mma( 387 accum, 388 warp_frag_A[warp_mma_k % 2], 389 warp_frag_B[warp_mma_k % 2], 390 accum); 391 } 392 } 393 } 394 }; 395 396 ///////////////////////////////////////////////////////////////////////////////////////////////// 397 398 } // namespace threadblock 399 } // namespace gemm 400 } // namespace cutlass 401 402 ///////////////////////////////////////////////////////////////////////////////////////////////// 403