1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 * SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, this 9 * list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 * 30 **************************************************************************************************/ 31 /*! \file 32 \brief Template for a double-buffered threadblock-scoped GEMM kernel. 33 */ 34 35 #pragma once 36 37 #include <cutlass/aligned_buffer.h> 38 #include <cutlass/array.h> 39 #include <cutlass/cutlass.h> 40 #include <cutlass/numeric_conversion.h> 41 42 #include <cutlass/matrix_shape.h> 43 #include <cutlass/numeric_types.h> 44 45 #include <cutlass/gemm/gemm.h> 46 47 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/dq_mma_base.h> 48 #include <ATen/native/cuda/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h> 49 #include <ATen/native/cuda/cutlass_extensions/interleaved_numeric_conversion.h> 50 51 #include <ATen/native/cuda/cutlass_extensions/ft_gemm_configs.h> 52 #include <ATen/native/cuda/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h> 53 54 ///////////////////////////////////////////////////////////////////////////////////////////////// 55 56 namespace cutlass { 57 namespace gemm { 58 namespace threadblock { 59 60 ///////////////////////////////////////////////////////////////////////////////////////////////// 61 62 /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. 63 template< 64 /// Size of the Gemm problem - concept: gemm::GemmShape<> 65 typename Shape_, 66 /// Iterates over tiles of A operand in global memory 67 // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) 68 typename IteratorA_, 69 /// Iterates over tiles of A operand in shared memory 70 /// (concept: WriteableTileIterator | RandomAccessTileIterator) 71 typename SmemIteratorA_, 72 /// Iterates over tiles of B operand in global memory 73 // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) 74 typename IteratorB_, 75 /// Iterates over tiles of B operand in shared memory 76 /// (concept: WriteableTileIterator | RandomAccessTileIterator) 77 typename SmemIteratorB_, 78 /// Data type for the scales 79 typename IteratorScale_, 80 /// Iterators over scales in shared memory 81 typename SmemIteratorScale_, 82 /// Data type of accumulator matrix 83 typename ElementC_, 84 /// Data type of accumulator matrix 85 typename LayoutC_, 86 /// Policy describing tuning details (concept: MmaPolicy) 87 typename Policy_, 88 /// Converter for B matrix applied immediately after the LDG (before STS) 89 typename TransformBAfterLDG_, 90 /// Converter for B matrix applited immediately after the LDS 91 typename TransformBAfterLDS_, 92 /// Used for partial specialization 93 typename Enable = bool> 94 class DqMmaPipelined: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2> { 95 public: 96 ///< Base class 97 using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2>; 98 99 using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> 100 using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory 101 using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory 102 using ElementC = ElementC_; ///< Data type of accumulator matrix 103 using LayoutC = LayoutC_; ///< Layout of accumulator matrix 104 using Policy = Policy_; ///< Policy describing tuning details 105 106 using IteratorScale = IteratorScale_; 107 using ElementScale = typename IteratorScale::Element; 108 using LayoutScale = typename IteratorScale::Layout; 109 110 using SmemIteratorA = SmemIteratorA_; 111 using SmemIteratorB = SmemIteratorB_; 112 using SmemIteratorScale = SmemIteratorScale_; 113 114 using TransformBAfterLDG = TransformBAfterLDG_; 115 using TransformBAfterLDS = TransformBAfterLDS_; 116 117 // 118 // Dependent types 119 // 120 121 /// Fragment of operand A loaded from global memory 122 using FragmentA = typename IteratorA::Fragment; 123 124 /// Fragment of operand B loaded from global memory 125 using FragmentB = typename IteratorB::Fragment; 126 127 /// Fragment of operand Scale loaded from global memory; 128 using FragmentScale = typename IteratorScale::Fragment; 129 130 /// Fragment of accumulator tile 131 using FragmentC = typename Policy::Operator::FragmentC; 132 133 /// Warp-level Mma 134 using Operator = typename Policy::Operator; 135 136 /// Obtain the arch tag from the warp-level operator 137 using ArchTag = typename Policy::Operator::ArchTag; 138 139 using Dequantizer = warp::MmaTensorOpDequantizer<Operator, 140 typename Base::WarpGemm, 141 Operand::kB, 142 typename SmemIteratorScale::Fragment::Element, 143 LayoutScale, 144 32>; 145 146 /// Complex transform on A operand 147 static ComplexTransform const kTransformA = Operator::kTransformA; 148 149 /// Complex transform on B operand 150 static ComplexTransform const kTransformB = Operator::kTransformB; 151 152 // statically assert kStages for DqMmaPipelined is two (Double-buffered pipeline) 153 static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); 154 155 private: 156 using WarpFragmentA = typename Operator::FragmentA; 157 using WarpFragmentB = typename Operator::FragmentB; 158 Dequantizer warp_dequantizer_; 159 160 using ElementB = typename IteratorB::Element; 161 using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>; 162 163 static constexpr bool RequiresTileInterleave = 164 layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value; 165 static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), 166 "Layout K must match threadblockK"); 167 168 protected: 169 /// Iterator to write threadblock-scoped tile of A operand to shared memory 170 SmemIteratorA smem_iterator_A_; 171 172 /// Iterator to write threadblock-scoped tile of B operand to shared memory 173 SmemIteratorB smem_iterator_B_; 174 175 /// Iterator to write threadblock-scoped tile of scale operand to shared memory 176 SmemIteratorScale smem_iterator_scale_; 177 178 public: 179 /// Construct from tensor references 180 CUTLASS_DEVICE DqMmaPipelined(typename Base::SharedStorage & shared_storage,int thread_idx,int warp_idx,int lane_idx)181 DqMmaPipelined(typename Base::SharedStorage& 182 shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM 183 int thread_idx, ///< ID within the threadblock 184 int warp_idx, ///< ID of warp 185 int lane_idx ///< ID of each thread within a warp 186 ): 187 Base(shared_storage, thread_idx, warp_idx, lane_idx), 188 warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, 189 (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, 190 lane_idx), 191 smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), 192 smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), 193 smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) 194 { 195 196 // Compute warp location within threadblock tile by mapping the warp_id to 197 // three coordinates: 198 // _m: the warp's position within the threadblock along the M dimension 199 // _n: the warp's position within the threadblock along the N dimension 200 // _k: the warp's position within the threadblock along the K dimension 201 202 int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); 203 int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); 204 205 int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; 206 int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; 207 208 // Add per-warp offsets in units of warp-level tiles 209 this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); 210 this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); 211 } 212 213 /// Perform a threadblock-scoped matrix multiply-accumulate 214 CUTLASS_DEVICE operator()215 void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop 216 FragmentC& accum, ///< destination accumulator tile 217 IteratorA iterator_A, ///< iterator over A operand in global memory 218 IteratorB iterator_B, ///< iterator over B operand in global memory 219 IteratorScale iterator_scale, ///< iterator over scale operand in global memory 220 FragmentC const& src_accum) 221 { ///< source accumulator tile 222 223 // 224 // Prologue 225 // 226 TransformBAfterLDG ldg_converter; 227 TransformBAfterLDS lds_converter; 228 229 using TransformA = 230 NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>; 231 232 using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element, 233 typename FragmentScale::Element, 234 FragmentScale::kElements>; 235 236 // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want 237 // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. 238 TransformA transformA; 239 TransformScale transformScale; 240 241 // Perform accumulation in the 'd' output operand 242 accum = src_accum; 243 244 FragmentA tb_frag_A; 245 FragmentB tb_frag_B; 246 FragmentScale tb_frag_scales; 247 248 using WarpFragmentScale = typename Dequantizer::FragmentScale; 249 WarpFragmentScale warp_frag_scales; 250 251 tb_frag_A.clear(); 252 tb_frag_B.clear(); 253 tb_frag_scales.clear(); 254 255 // The last kblock is loaded in the prolog 256 iterator_A.load(tb_frag_A); 257 iterator_B.load(tb_frag_B); 258 iterator_scale.load(tb_frag_scales); 259 260 ++iterator_A; 261 ++iterator_B; 262 263 this->smem_iterator_A_.store(transformA(tb_frag_A)); 264 this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); 265 this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); 266 267 ++this->smem_iterator_A_; 268 ++this->smem_iterator_B_; 269 270 __syncthreads(); 271 272 warp_dequantizer_.load(warp_frag_scales); 273 274 // Pair of fragments used to overlap shared memory loads and math instructions 275 WarpFragmentA warp_frag_A[2]; 276 WarpFragmentB warp_frag_B[2]; 277 278 this->warp_tile_iterator_A_.set_kgroup_index(0); 279 this->warp_tile_iterator_B_.set_kgroup_index(0); 280 281 this->warp_tile_iterator_A_.load(warp_frag_A[0]); 282 this->warp_tile_iterator_B_.load(warp_frag_B[0]); 283 284 ++this->warp_tile_iterator_A_; 285 ++this->warp_tile_iterator_B_; 286 287 Operator warp_mma; 288 289 int smem_write_stage_idx = 1; 290 291 // Avoid reading out of bounds 292 iterator_A.clear_mask(gemm_k_iterations <= 1); 293 iterator_B.clear_mask(gemm_k_iterations <= 1); 294 295 // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing 296 // shared memory loads (which have the tighest latency requirement). 297 298 // 299 // Mainloop 300 // 301 302 // Note: The main loop does not support Base::kWarpGemmIterations == 2. 303 CUTLASS_GEMM_LOOP 304 for (; gemm_k_iterations > 0; --gemm_k_iterations) { 305 // 306 // Loop over GEMM K dimension 307 // 308 309 CUTLASS_PRAGMA_UNROLL 310 for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { 311 312 // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group 313 // as the case may be. 314 315 if (warp_mma_k == Base::kWarpGemmIterations - 1) { 316 317 // Write fragments to shared memory 318 this->smem_iterator_A_.store(transformA(tb_frag_A)); 319 320 this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); 321 322 __syncthreads(); 323 324 ++this->smem_iterator_A_; 325 ++this->smem_iterator_B_; 326 327 // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory 328 if (smem_write_stage_idx == 1) { 329 this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); 330 this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); 331 } 332 else { 333 this->warp_tile_iterator_A_.add_tile_offset( 334 {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); 335 this->warp_tile_iterator_B_.add_tile_offset( 336 {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); 337 } 338 339 smem_write_stage_idx ^= 1; 340 } 341 342 this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); 343 this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); 344 ++this->warp_tile_iterator_A_; 345 346 const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; 347 const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; 348 // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. 349 if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { 350 this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) 351 % Base::kWarpGemmIterationsForB); 352 this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); 353 ++this->warp_tile_iterator_B_; 354 } 355 356 if (warp_mma_k == 0) { 357 358 iterator_A.load(tb_frag_A); 359 iterator_B.load(tb_frag_B); 360 361 ++iterator_A; 362 ++iterator_B; 363 364 // Avoid reading out of bounds if this was the last loop iteration 365 iterator_A.clear_mask(gemm_k_iterations <= 2); 366 iterator_B.clear_mask(gemm_k_iterations <= 2); 367 } 368 369 typename TransformBAfterLDS::result_type converted_frag_B = 370 lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); 371 warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); 372 run_warp_mma( 373 warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); 374 } 375 } 376 } 377 }; 378 379 ///////////////////////////////////////////////////////////////////////////////////////////////// 380 381 } // namespace threadblock 382 } // namespace gemm 383 } // namespace cutlass 384 385 ///////////////////////////////////////////////////////////////////////////////////////////////// 386