1 #pragma once 2 3 #include <cutlass/gemm/threadblock/default_mma.h> 4 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h> 5 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h> 6 7 namespace cutlass { 8 namespace gemm { 9 namespace threadblock { 10 11 //////////////////////////////////////////////////////////////////////////////// 12 13 /// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight 14 template< 15 /// Layout type for A matrix operand 16 typename LayoutA, 17 /// Access granularity of A matrix in units of elements 18 int kAlignmentA, 19 /// Layout type for B matrix operand 20 typename LayoutB, 21 /// Access granularity of B matrix in units of elements 22 int kAlignmentB, 23 /// Element type for internal accumulation 24 typename ElementAccumulator, 25 /// Tag indicating architecture to tune for 26 typename ArchTag, 27 /// Threadblock-level tile size (concept: GemmShape) 28 typename ThreadblockShape, 29 /// Warp-level tile size (concept: GemmShape) 30 typename WarpShape, 31 /// Instruction-level tile size (concept: GemmShape) 32 typename InstructionShape, 33 /// Operation performed by GEMM 34 typename Operator, 35 /// Use zfill or predicate for out-of-bound cp.async 36 SharedMemoryClearOption SharedMemoryClear, 37 /// Gather operand A by using an index array 38 bool GatherA, 39 /// Gather operand B by using an index array 40 bool GatherB> 41 struct DefaultMma<bfloat16_t, 42 LayoutA, 43 kAlignmentA, 44 bfloat16_t, 45 LayoutB, 46 kAlignmentB, 47 ElementAccumulator, 48 layout::RowMajor, 49 arch::OpClassTensorOp, 50 ArchTag, 51 ThreadblockShape, 52 WarpShape, 53 InstructionShape, 54 2, 55 Operator, 56 false, 57 SharedMemoryClear, 58 GatherA, 59 GatherB> { 60 61 private: 62 // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS. 63 static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; 64 using MmaElementA = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type; 65 using MmaElementB = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type; 66 67 public: 68 // Define the MmaCore components 69 using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, 70 WarpShape, 71 InstructionShape, 72 MmaElementA, 73 LayoutA, 74 MmaElementB, 75 LayoutB, 76 ElementAccumulator, 77 layout::RowMajor, 78 arch::OpClassTensorOp, 79 2, 80 Operator>; 81 82 using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< 83 cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, 84 bfloat16_t, 85 LayoutA, 86 1, 87 typename MmaCore::IteratorThreadMapA, 88 kAlignmentA, 89 GatherA>; 90 91 // Define iterators over tiles from the B operand 92 using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< 93 cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, 94 bfloat16_t, 95 LayoutB, 96 0, 97 typename MmaCore::IteratorThreadMapB, 98 kAlignmentB, 99 GatherB>; 100 101 // Define the threadblock-scoped pipelined matrix multiply 102 using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape, 103 IteratorA, 104 typename MmaCore::SmemIteratorA, 105 IteratorB, 106 typename MmaCore::SmemIteratorB, 107 ElementAccumulator, 108 layout::RowMajor, 109 typename MmaCore::MmaPolicy>; 110 }; 111 112 // bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on 113 // large tile when not enough shared mem is present to do 3+ stage 114 template< 115 /// Layout type for A matrix operand 116 typename LayoutA, 117 /// Access granularity of A matrix in units of elements 118 int kAlignmentA, 119 /// Layout type for B matrix operand 120 typename LayoutB, 121 /// Access granularity of B matrix in units of elements 122 int kAlignmentB, 123 /// Element type for internal accumulation 124 typename ElementAccumulator, 125 /// Threadblock-level tile size (concept: GemmShape) 126 typename ThreadblockShape, 127 /// Warp-level tile size (concept: GemmShape) 128 typename WarpShape, 129 /// Instruction-level tile size (concept: GemmShape) 130 typename InstructionShape, 131 /// Operation performed by GEMM 132 typename Operator, 133 /// Use zfill or predicate for out-of-bound cp.async 134 SharedMemoryClearOption SharedMemoryClear, 135 /// Gather operand A by using an index array 136 bool GatherA, 137 /// Gather operand B by using an index array 138 bool GatherB> 139 struct DefaultMma<bfloat16_t, 140 LayoutA, 141 kAlignmentA, 142 bfloat16_t, 143 LayoutB, 144 kAlignmentB, 145 ElementAccumulator, 146 layout::RowMajor, 147 arch::OpClassTensorOp, 148 arch::Sm80, 149 ThreadblockShape, 150 WarpShape, 151 InstructionShape, 152 2, 153 Operator, 154 false, 155 SharedMemoryClear, 156 GatherA, 157 GatherB> { 158 159 // Define the MmaCore components 160 // 3 is used on purpose here to trigger components for mma multistage 161 using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, 162 WarpShape, 163 InstructionShape, 164 bfloat16_t, 165 LayoutA, 166 bfloat16_t, 167 LayoutB, 168 ElementAccumulator, 169 layout::RowMajor, 170 arch::OpClassTensorOp, 171 3, 172 Operator>; 173 174 // Define iterators over tiles from the A operand 175 using ThreadMapA = typename MmaCore::IteratorThreadMapA; 176 using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>; 177 using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< 178 cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, 179 bfloat16_t, 180 LayoutA, 181 1, 182 ThreadMapA, 183 AccessTypeA, 184 GatherA>; 185 186 // Define iterators over tiles from the B operand 187 using ThreadMapB = typename MmaCore::IteratorThreadMapB; 188 using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>; 189 using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< 190 cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, 191 bfloat16_t, 192 LayoutB, 193 0, 194 ThreadMapB, 195 AccessTypeB, 196 GatherB>; 197 198 // Define the threadblock-scoped multistage matrix multiply 199 using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, 200 IteratorA, 201 typename MmaCore::SmemIteratorA, 202 MmaCore::kCacheOpA, 203 IteratorB, 204 typename MmaCore::SmemIteratorB, 205 MmaCore::kCacheOpB, 206 ElementAccumulator, 207 layout::RowMajor, 208 typename MmaCore::MmaPolicy, 209 2>; 210 }; 211 212 //////////////////////////////////////////////////////////////////////////////// 213 214 /// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight 215 template< 216 /// Layout type for A matrix operand 217 typename LayoutA, 218 /// Access granularity of A matrix in units of elements 219 int kAlignmentA, 220 /// Layout type for B matrix operand 221 typename LayoutB, 222 /// Access granularity of B matrix in units of elements 223 int kAlignmentB, 224 /// Element type for internal accumulation 225 typename ElementAccumulator, 226 /// Tag indicating architecture to tune for 227 typename ArchTag, 228 /// Threadblock-level tile size (concept: GemmShape) 229 typename ThreadblockShape, 230 /// Warp-level tile size (concept: GemmShape) 231 typename WarpShape, 232 /// Instruction-level tile size (concept: GemmShape) 233 typename InstructionShape, 234 /// Operation performed by GEMM 235 typename Operator> 236 struct DefaultMma<cutlass::bfloat16_t, 237 LayoutA, 238 kAlignmentA, 239 uint8_t, 240 LayoutB, 241 kAlignmentB, 242 ElementAccumulator, 243 layout::RowMajor, 244 arch::OpClassTensorOp, 245 ArchTag, 246 ThreadblockShape, 247 WarpShape, 248 InstructionShape, 249 2, 250 Operator> { 251 252 private: 253 static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value; 254 255 using Mma = DqMma<bfloat16_t, 256 LayoutA, 257 kAlignmentA, 258 uint8_t, 259 LayoutB, 260 kAlignmentB, 261 bfloat16_t, 262 layout::RowMajor, 263 kAlignmentScale, 264 ElementAccumulator, 265 layout::RowMajor, 266 arch::OpClassTensorOp, 267 ArchTag, 268 ThreadblockShape, 269 WarpShape, 270 InstructionShape, 271 2, 272 Operator>; 273 274 public: 275 // Define the MmaCore components 276 using MmaCore = typename Mma::MmaCore; 277 278 // Define iterators over tiles from the A operand 279 using IteratorA = typename Mma::IteratorA; 280 281 // Define iterators over tiles from the B operand 282 using IteratorB = typename Mma::IteratorB; 283 284 // Define the threadblock-scoped pipelined matrix multiply 285 using ThreadblockMma = typename Mma::ThreadblockMma; 286 }; 287 288 //////////////////////////////////////////////////////////////////////////////// 289 /// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight 290 template< 291 /// Layout type for A matrix operand 292 typename LayoutA, 293 /// Access granularity of A matrix in units of elements 294 int kAlignmentA, 295 /// Layout type for B matrix operand 296 typename LayoutB, 297 /// Access granularity of B matrix in units of elements 298 int kAlignmentB, 299 /// Element type for internal accumulation 300 typename ElementAccumulator, 301 /// Tag indicating architecture to tune for 302 typename ArchTag, 303 /// Threadblock-level tile size (concept: GemmShape) 304 typename ThreadblockShape, 305 /// Warp-level tile size (concept: GemmShape) 306 typename WarpShape, 307 /// Instruction-level tile size (concept: GemmShape) 308 typename InstructionShape, 309 /// Operation performed by GEMM 310 typename Operator> 311 struct DefaultMma<cutlass::bfloat16_t, 312 LayoutA, 313 kAlignmentA, 314 uint4b_t, 315 LayoutB, 316 kAlignmentB, 317 ElementAccumulator, 318 layout::RowMajor, 319 arch::OpClassTensorOp, 320 ArchTag, 321 ThreadblockShape, 322 WarpShape, 323 InstructionShape, 324 2, 325 Operator> { 326 327 private: 328 static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value; 329 330 using Mma = DqMma<bfloat16_t, 331 LayoutA, 332 kAlignmentA, 333 uint4b_t, 334 LayoutB, 335 kAlignmentB, 336 bfloat16_t, 337 layout::RowMajor, 338 kAlignmentScale, 339 ElementAccumulator, 340 layout::RowMajor, 341 arch::OpClassTensorOp, 342 ArchTag, 343 ThreadblockShape, 344 WarpShape, 345 InstructionShape, 346 2, 347 Operator>; 348 349 public: 350 // Define the MmaCore components 351 using MmaCore = typename Mma::MmaCore; 352 353 // Define iterators over tiles from the A operand 354 using IteratorA = typename Mma::IteratorA; 355 356 // Define iterators over tiles from the B operand 357 using IteratorB = typename Mma::IteratorB; 358 359 // Define the threadblock-scoped pipelined matrix multiply 360 using ThreadblockMma = typename Mma::ThreadblockMma; 361 }; 362 363 template< 364 /// Layout type for A matrix operand 365 typename LayoutA, 366 /// Access granularity of A matrix in units of elements 367 int kAlignmentA, 368 /// Layout type for B matrix operand 369 typename LayoutB, 370 /// Access granularity of B matrix in units of elements 371 int kAlignmentB, 372 /// Element type for internal accumulation 373 typename ElementAccumulator, 374 /// Tag indicating architecture to tune for 375 typename ArchTag, 376 /// Threadblock-level tile size (concept: GemmShape) 377 typename ThreadblockShape, 378 /// Warp-level tile size (concept: GemmShape) 379 typename WarpShape, 380 /// Instruction-level tile size (concept: GemmShape) 381 typename InstructionShape, 382 /// Operation performed by GEMM 383 typename Operator, 384 /// 385 int kStages, 386 /// Shared memory clear option 387 SharedMemoryClearOption SharedMemoryClear> 388 struct DefaultMma<cutlass::bfloat16_t, 389 LayoutA, 390 kAlignmentA, 391 uint8_t, 392 LayoutB, 393 kAlignmentB, 394 ElementAccumulator, 395 layout::RowMajor, 396 arch::OpClassTensorOp, 397 ArchTag, 398 ThreadblockShape, 399 WarpShape, 400 InstructionShape, 401 kStages, 402 Operator, 403 false, 404 SharedMemoryClear> { 405 406 private: 407 static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value; 408 409 using Mma = DqMma<bfloat16_t, 410 LayoutA, 411 kAlignmentA, 412 uint8_t, 413 LayoutB, 414 kAlignmentB, 415 bfloat16_t, 416 layout::RowMajor, 417 kAlignmentScale, 418 ElementAccumulator, 419 layout::RowMajor, 420 arch::OpClassTensorOp, 421 ArchTag, 422 ThreadblockShape, 423 WarpShape, 424 InstructionShape, 425 kStages, 426 Operator, 427 SharedMemoryClear>; 428 429 public: 430 // Define the MmaCore components 431 using MmaCore = typename Mma::MmaCore; 432 433 // Define iterators over tiles from the A operand 434 using IteratorA = typename Mma::IteratorA; 435 436 // Define iterators over tiles from the B operand 437 using IteratorB = typename Mma::IteratorB; 438 439 // Define the threadblock-scoped pipelined matrix multiply 440 using ThreadblockMma = typename Mma::ThreadblockMma; 441 }; 442 443 //////////////////////////////////////////////////////////////////////////////// 444 /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight 445 template< 446 /// Layout type for A matrix operand 447 typename LayoutA, 448 /// Access granularity of A matrix in units of elements 449 int kAlignmentA, 450 /// Layout type for B matrix operand 451 typename LayoutB, 452 /// Access granularity of B matrix in units of elements 453 int kAlignmentB, 454 /// Element type for internal accumulation 455 typename ElementAccumulator, 456 /// Tag indicating architecture to tune for 457 typename ArchTag, 458 /// Threadblock-level tile size (concept: GemmShape) 459 typename ThreadblockShape, 460 /// Warp-level tile size (concept: GemmShape) 461 typename WarpShape, 462 /// Instruction-level tile size (concept: GemmShape) 463 typename InstructionShape, 464 /// Operation performed by GEMM 465 typename Operator, 466 /// 467 int kStages, 468 /// Shared memory clear option 469 SharedMemoryClearOption SharedMemoryClear> 470 struct DefaultMma<cutlass::bfloat16_t, 471 LayoutA, 472 kAlignmentA, 473 uint4b_t, 474 LayoutB, 475 kAlignmentB, 476 ElementAccumulator, 477 layout::RowMajor, 478 arch::OpClassTensorOp, 479 ArchTag, 480 ThreadblockShape, 481 WarpShape, 482 InstructionShape, 483 kStages, 484 Operator, 485 false, 486 SharedMemoryClear> { 487 488 private: 489 static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value; 490 491 using Mma = DqMma<bfloat16_t, 492 LayoutA, 493 kAlignmentA, 494 uint4b_t, 495 LayoutB, 496 kAlignmentB, 497 bfloat16_t, 498 layout::RowMajor, 499 kAlignmentScale, 500 ElementAccumulator, 501 layout::RowMajor, 502 arch::OpClassTensorOp, 503 ArchTag, 504 ThreadblockShape, 505 WarpShape, 506 InstructionShape, 507 kStages, 508 Operator, 509 SharedMemoryClear>; 510 511 public: 512 // Define the MmaCore components 513 using MmaCore = typename Mma::MmaCore; 514 515 // Define iterators over tiles from the A operand 516 using IteratorA = typename Mma::IteratorA; 517 518 // Define iterators over tiles from the B operand 519 using IteratorB = typename Mma::IteratorB; 520 521 // Define the threadblock-scoped pipelined matrix multiply 522 using ThreadblockMma = typename Mma::ThreadblockMma; 523 }; 524 525 } // namespace threadblock 526 } // namespace gemm 527 } // namespace cutlass 528