1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 * SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, this 9 * list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 * 30 **************************************************************************************************/ 31 32 /*! \file 33 \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. 34 */ 35 36 #pragma once 37 38 #include <cutlass/cutlass.h> 39 40 #include <cutlass/arch/arch.h> 41 #include <cutlass/gemm/gemm.h> 42 #include <cutlass/matrix_coord.h> 43 #include <cutlass/semaphore.h> 44 45 ///////////////////////////////////////////////////////////////////////////////////////////////// 46 47 namespace cutlass { 48 namespace gemm { 49 namespace kernel { 50 51 ///////////////////////////////////////////////////////////////////////////////////////////////// 52 53 template<typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate 54 typename Epilogue_, ///! Epilogue 55 typename ThreadblockSwizzle_, ///! Threadblock swizzling function 56 typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level 57 /// arch. 58 bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. 59 > 60 struct GemmFpAIntB { 61 62 using Mma = Mma_; 63 using Epilogue = Epilogue_; 64 using EpilogueOutputOp = typename Epilogue::OutputOp; 65 using ThreadblockSwizzle = ThreadblockSwizzle_; 66 static bool const kSplitKSerial = SplitKSerial; 67 68 using ElementA = typename Mma::IteratorA::Element; 69 using LayoutA = typename Mma::IteratorA::Layout; 70 using ElementB = typename Mma::IteratorB::Element; 71 using LayoutB = typename Mma::IteratorB::Element; 72 using ElementC = typename Epilogue::OutputTileIterator::Element; 73 using LayoutC = typename Mma::LayoutC; 74 using ElementScale = ElementC; 75 76 static ComplexTransform const kTransformA = Mma::kTransformA; 77 static ComplexTransform const kTransformB = Mma::kTransformA; 78 79 // Type definitions about the mainloop. 80 using Operator = typename Mma::Operator; 81 using OperatorClass = typename Mma::Operator::OperatorClass; 82 using ThreadblockShape = typename Mma::Shape; 83 using WarpShape = typename Mma::Operator::Shape; 84 using InstructionShape = typename Mma::Policy::Operator::InstructionShape; 85 using ArchTag = typename Mma::ArchTag; 86 87 static int const kStages = Mma::kStages; 88 static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; 89 static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; 90 static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; 91 92 /// Warp count (concept: GemmShape) 93 using WarpCount = typename Mma::WarpCount; 94 static int const kThreadCount = 32 * WarpCount::kCount; 95 96 static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; 97 98 /// Parameters structure 99 struct Arguments { 100 GemmUniversalMode mode = GemmUniversalMode::kGemm; 101 102 cutlass::gemm::GemmCoord problem_size; 103 typename Mma::IteratorA::TensorRef ref_A; 104 typename Mma::IteratorB::TensorRef ref_B; 105 typename Mma::IteratorScale::TensorRef ref_scale; 106 typename Epilogue::OutputTileIterator::TensorRef ref_C; 107 typename Epilogue::OutputTileIterator::TensorRef ref_D; 108 109 // Control serial split-k 110 int batch_count; 111 112 typename EpilogueOutputOp::Params output_op; 113 114 // For gather+scatter operations 115 int const* gather_A_indices; 116 int const* gather_B_indices; 117 int const* scatter_D_indices; 118 119 // Included so we can use Gemm Universal 120 int batch_stride_D = 0; 121 122 // 123 // Methods 124 // 125 126 CUTLASS_HOST_DEVICE ArgumentsGemmFpAIntB::Arguments127 Arguments() {} 128 129 CUTLASS_HOST_DEVICE 130 Arguments(cutlass::gemm::GemmCoord const& problem_size, 131 typename Mma::IteratorA::TensorRef ref_A, 132 typename Mma::IteratorB::TensorRef ref_B, 133 typename Mma::IteratorScale::TensorRef ref_scale, 134 typename Epilogue::OutputTileIterator::TensorRef ref_C, 135 typename Epilogue::OutputTileIterator::TensorRef ref_D, 136 int serial_split_k_factor, 137 typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), 138 int const* gather_A_indices = nullptr, 139 int const* gather_B_indices = nullptr, 140 int const* scatter_D_indices = nullptr): problem_sizeGemmFpAIntB::Arguments141 problem_size(problem_size), 142 ref_A(ref_A), 143 ref_B(ref_B), 144 ref_scale(ref_scale), 145 ref_C(ref_C), 146 ref_D(ref_D), 147 batch_count(serial_split_k_factor), 148 output_op(output_op), 149 gather_A_indices(gather_A_indices), 150 gather_B_indices(gather_B_indices), 151 scatter_D_indices(scatter_D_indices) 152 { 153 } 154 }; 155 156 /// Parameters structure 157 struct Params 158 { 159 cutlass::gemm::GemmCoord problem_size; 160 cutlass::gemm::GemmCoord grid_tiled_shape; 161 int swizzle_log_tile; 162 typename Mma::IteratorA::Params params_A; 163 typename Mma::IteratorA::TensorRef ref_A; 164 typename Mma::IteratorB::Params params_B; 165 typename Mma::IteratorB::TensorRef ref_B; 166 typename Mma::IteratorScale::Params params_scale; 167 typename Mma::IteratorScale::TensorRef ref_scale; 168 typename Epilogue::OutputTileIterator::Params params_C; 169 typename Epilogue::OutputTileIterator::TensorRef ref_C; 170 typename Epilogue::OutputTileIterator::Params params_D; 171 typename Epilogue::OutputTileIterator::TensorRef ref_D; 172 typename EpilogueOutputOp::Params output_op; 173 int* semaphore; 174 int gemm_k_size; 175 // For gather+scatter operations 176 int const* gather_A_indices; 177 int const* gather_B_indices; 178 int const* scatter_D_indices; 179 180 // 181 // Methods 182 // 183 ParamsGemmFpAIntB::Params184 Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {} 185 ParamsGemmFpAIntB::Params186 Params(Arguments const& args, 187 int device_sms, 188 int sm_occupancy): 189 problem_size(args.problem_size), 190 swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), 191 params_A(args.ref_A.layout()), 192 ref_A(args.ref_A), 193 params_B(args.ref_B.layout()), 194 ref_B(args.ref_B), 195 params_scale(args.ref_scale.layout()), 196 ref_scale(args.ref_scale), 197 params_C(args.ref_C.layout()), 198 ref_C(args.ref_C), 199 params_D(args.ref_D.layout()), 200 ref_D(args.ref_D), 201 output_op(args.output_op), 202 gather_A_indices(args.gather_A_indices), 203 gather_B_indices(args.gather_B_indices), 204 scatter_D_indices(args.scatter_D_indices) 205 { 206 ThreadblockSwizzle swizzle; 207 grid_tiled_shape = swizzle.get_tiled_shape( 208 args.problem_size, 209 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 210 args.batch_count); 211 212 gemm_k_size = args.problem_size.k(); 213 } 214 get_workspace_sizeGemmFpAIntB::Params215 size_t get_workspace_size() const 216 { 217 return 0; 218 } 219 220 Status init_workspace(void *workspace,cudaStream_t stream = nullptr) 221 { 222 return Status::kSuccess; 223 } 224 get_grid_dimsGemmFpAIntB::Params225 dim3 get_grid_dims() const 226 { 227 return ThreadblockSwizzle().get_grid_shape(grid_tiled_shape); 228 } 229 }; 230 231 /// Shared memory storage structure 232 union SharedStorage { 233 typename Mma::SharedStorage main_loop; 234 typename Epilogue::SharedStorage epilogue; 235 }; 236 237 // 238 // Methods 239 // 240 241 CUTLASS_HOST_DEVICE GemmFpAIntBGemmFpAIntB242 GemmFpAIntB() {} 243 244 /// Determines whether kernel satisfies alignment 245 CUTLASS_HOST_DEVICE can_implementGemmFpAIntB246 static Status can_implement(Arguments const& args) 247 { 248 249 static int const kAlignmentA = 250 (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ? 251 32 : 252 (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value) ? 253 64 : 254 Mma::IteratorA::AccessType::kElements; 255 static int const kAlignmentB = 256 (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ? 257 32 : 258 (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value) ? 259 64 : 260 Mma::IteratorB::AccessType::kElements; 261 262 static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; 263 264 static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout, 265 layout::ColumnMajorInterleaved<32>>::value) ? 266 32 : 267 (platform::is_same<typename Epilogue::OutputTileIterator::Layout, 268 layout::ColumnMajorInterleaved<64>>::value) ? 269 64 : 270 Epilogue::OutputTileIterator::kElementsPerAccess; 271 272 if (!TensorRef_aligned(args.ref_A, kAlignmentA)) { 273 return Status::kErrorMisalignedOperand; 274 } 275 276 if (!TensorRef_aligned(args.ref_B, kAlignmentB)) { 277 return Status::kErrorMisalignedOperand; 278 } 279 280 if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) { 281 return Status::kErrorMisalignedOperand; 282 } 283 284 if (!TensorRef_aligned(args.ref_C, kAlignmentC)) { 285 return Status::kErrorMisalignedOperand; 286 } 287 288 if (!TensorRef_aligned(args.ref_D, kAlignmentC)) { 289 return Status::kErrorMisalignedOperand; 290 } 291 292 return Status::kSuccess; 293 } 294 295 // The dummy template parameter is not used and exists so that we can compile this code using 296 // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in 297 // a namespace 298 template<bool B, typename dummy = void> 299 struct KernelRunner { 300 CUTLASS_DEVICE run_kernelGemmFpAIntB::KernelRunner301 static void run_kernel(Params const& params, SharedStorage& shared_storage) 302 { 303 CUTLASS_NOT_IMPLEMENTED(); 304 } 305 }; 306 307 template<typename dummy> 308 struct KernelRunner<true, dummy> { 309 CUTLASS_DEVICE 310 static void run_kernel(Params const& params, SharedStorage& shared_storage) 311 { 312 using LayoutB = typename Mma::IteratorB::Layout; 313 static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1 314 || platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1, 315 "B must be row major/col major OR col major interleaved."); 316 317 // Compute threadblock location 318 ThreadblockSwizzle threadblock_swizzle; 319 320 cutlass::gemm::GemmCoord threadblock_tile_offset = 321 threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); 322 323 // Early exit if CTA is out of range 324 if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() 325 || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { 326 327 return; 328 } 329 330 // Compute initial location in logical coordinates 331 cutlass::MatrixCoord tb_offset_A{ 332 threadblock_tile_offset.m() * Mma::Shape::kM, 333 threadblock_tile_offset.k() * params.gemm_k_size, 334 }; 335 336 cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, 337 threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; 338 339 cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() * Mma::Shape::kN}; 340 341 // Problem size is a function of threadblock index in the K dimension 342 int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); 343 344 // Compute threadblock-scoped matrix multiply-add 345 int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; 346 347 // Compute position within threadblock 348 int thread_idx = threadIdx.x; 349 350 // Construct iterators to A and B operands 351 typename Mma::IteratorA iterator_A(params.params_A, 352 params.ref_A.data(), 353 {params.problem_size.m(), problem_size_k}, 354 thread_idx, 355 tb_offset_A, 356 params.gather_A_indices); 357 358 typename Mma::IteratorB iterator_B(params.params_B, 359 params.ref_B.data(), 360 {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, 361 thread_idx, 362 tb_offset_B, 363 params.gather_B_indices); 364 365 typename Mma::IteratorScale iterator_scale(params.params_scale, 366 params.ref_scale.data(), 367 {1, params.problem_size.n()}, 368 thread_idx, 369 tb_offset_scale); 370 371 // Broadcast the warp_id computed by lane 0 to ensure dependent code 372 // is compiled as warp-uniform. 373 int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); 374 int lane_idx = threadIdx.x % 32; 375 376 // 377 // Main loop 378 // 379 // Construct thread-scoped matrix multiply 380 Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); 381 382 typename Mma::FragmentC accumulators; 383 384 accumulators.clear(); 385 386 if (!kSplitKSerial || gemm_k_iterations > 0) { 387 // Compute threadblock-scoped matrix multiply-add 388 mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); 389 } 390 391 // 392 // Epilogue 393 // 394 395 EpilogueOutputOp output_op(params.output_op); 396 397 // 398 // Masked tile iterators constructed from members 399 // 400 401 threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); 402 403 // assume identity swizzle 404 MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, 405 threadblock_tile_offset.n() * Mma::Shape::kN); 406 407 int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); 408 409 // Construct the semaphore. 410 Semaphore semaphore(params.semaphore + block_idx, thread_idx); 411 412 // If performing a reduction via split-K, fetch the initial synchronization 413 if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { 414 415 // Fetch the synchronization lock initially but do not block. 416 semaphore.fetch(); 417 418 // Indicate which position in a serial reduction the output operator is currently updating 419 output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); 420 } 421 422 // Tile iterator loading from source tensor. 423 typename Epilogue::OutputTileIterator iterator_C(params.params_C, 424 params.ref_C.data(), 425 params.problem_size.mn(), 426 thread_idx, 427 threadblock_offset, 428 params.scatter_D_indices); 429 430 // Tile iterator writing to destination tensor. 431 typename Epilogue::OutputTileIterator iterator_D(params.params_D, 432 params.ref_D.data(), 433 params.problem_size.mn(), 434 thread_idx, 435 threadblock_offset, 436 params.scatter_D_indices); 437 438 Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); 439 440 // Wait on the semaphore - this latency may have been covered by iterator construction 441 if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { 442 443 // For subsequent threadblocks, the source matrix is held in the 'D' tensor. 444 if (threadblock_tile_offset.k()) { 445 iterator_C = iterator_D; 446 } 447 448 semaphore.wait(threadblock_tile_offset.k()); 449 } 450 451 // Execute the epilogue operator to update the destination tensor. 452 epilogue(output_op, iterator_D, accumulators, iterator_C); 453 454 // 455 // Release the semaphore 456 // 457 458 if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { 459 460 int lock = 0; 461 if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { 462 463 // The final threadblock resets the semaphore for subsequent grids. 464 lock = 0; 465 } 466 else { 467 // Otherwise, the semaphore is incremented 468 lock = threadblock_tile_offset.k() + 1; 469 } 470 471 semaphore.release(lock); 472 } 473 } 474 }; 475 476 CUTLASS_DEVICE 477 static void invoke(Params const ¶ms, SharedStorage &shared_storage) 478 { 479 GemmFpAIntB op; 480 op(params, shared_storage); 481 } 482 483 /* 484 To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond 485 to the ArchTag of the cutlass kernel operator. 486 */ 487 /// Executes one GEMM 488 CUTLASS_DEVICE 489 void operator()(Params const& params, SharedStorage& shared_storage) 490 { 491 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) 492 static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm70>::value; 493 KernelRunner<compile_needed>::run_kernel(params, shared_storage); 494 #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) 495 static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm75>::value; 496 KernelRunner<compile_needed>::run_kernel(params, shared_storage); 497 #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) 498 static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value; 499 KernelRunner<compile_needed>::run_kernel(params, shared_storage); 500 #else 501 CUTLASS_NOT_IMPLEMENTED(); 502 #endif 503 } 504 }; 505 506 ///////////////////////////////////////////////////////////////////////////////////////////////// 507 508 } // namespace kernel 509 } // namespace gemm 510 } // namespace cutlass 511