1 #pragma once 2 3 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h> 4 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h> 5 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_mma_bf16.h> 6 7 namespace cutlass { 8 namespace gemm { 9 namespace threadblock { 10 11 //////////////////////////////////////////////////////////////////////////////// 12 13 /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight 14 template< 15 /// Layout type for A matrix operand 16 typename LayoutA, 17 /// Access granularity of A matrix in units of elements 18 int kAlignmentA, 19 /// Layout type for B matrix operand 20 typename LayoutB, 21 /// Access granularity of B matrix in units of elements 22 int kAlignmentB, 23 /// Element type for internal accumulation 24 typename ElementAccumulator, 25 /// Tag indicating architecture to tune for 26 typename ArchTag, 27 /// Threadblock-level tile size (concept: GemmShape) 28 typename ThreadblockShape, 29 /// Warp-level tile size (concept: GemmShape) 30 typename WarpShape, 31 /// Instruction-level tile size (concept: GemmShape) 32 typename InstructionShape, 33 /// Operation performed by GEMM 34 typename Operator> 35 struct DefaultMma<cutlass::half_t, 36 LayoutA, 37 kAlignmentA, 38 uint8_t, 39 LayoutB, 40 kAlignmentB, 41 ElementAccumulator, 42 layout::RowMajor, 43 arch::OpClassTensorOp, 44 ArchTag, 45 ThreadblockShape, 46 WarpShape, 47 InstructionShape, 48 2, 49 Operator> { 50 51 private: 52 static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value; 53 54 using Mma = DqMma<half_t, 55 LayoutA, 56 kAlignmentA, 57 uint8_t, 58 LayoutB, 59 kAlignmentB, 60 half_t, 61 layout::RowMajor, 62 kAlignmentScale, 63 ElementAccumulator, 64 layout::RowMajor, 65 arch::OpClassTensorOp, 66 ArchTag, 67 ThreadblockShape, 68 WarpShape, 69 InstructionShape, 70 2, 71 Operator>; 72 73 public: 74 // Define the MmaCore components 75 using MmaCore = typename Mma::MmaCore; 76 77 // Define iterators over tiles from the A operand 78 using IteratorA = typename Mma::IteratorA; 79 80 // Define iterators over tiles from the B operand 81 using IteratorB = typename Mma::IteratorB; 82 83 // Define the threadblock-scoped pipelined matrix multiply 84 using ThreadblockMma = typename Mma::ThreadblockMma; 85 }; 86 87 //////////////////////////////////////////////////////////////////////////////// 88 /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight 89 template< 90 /// Layout type for A matrix operand 91 typename LayoutA, 92 /// Access granularity of A matrix in units of elements 93 int kAlignmentA, 94 /// Layout type for B matrix operand 95 typename LayoutB, 96 /// Access granularity of B matrix in units of elements 97 int kAlignmentB, 98 /// Element type for internal accumulation 99 typename ElementAccumulator, 100 /// Tag indicating architecture to tune for 101 typename ArchTag, 102 /// Threadblock-level tile size (concept: GemmShape) 103 typename ThreadblockShape, 104 /// Warp-level tile size (concept: GemmShape) 105 typename WarpShape, 106 /// Instruction-level tile size (concept: GemmShape) 107 typename InstructionShape, 108 /// Operation performed by GEMM 109 typename Operator> 110 struct DefaultMma<cutlass::half_t, 111 LayoutA, 112 kAlignmentA, 113 uint4b_t, 114 LayoutB, 115 kAlignmentB, 116 ElementAccumulator, 117 layout::RowMajor, 118 arch::OpClassTensorOp, 119 ArchTag, 120 ThreadblockShape, 121 WarpShape, 122 InstructionShape, 123 2, 124 Operator> { 125 126 private: 127 static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value; 128 129 using Mma = DqMma<half_t, 130 LayoutA, 131 kAlignmentA, 132 uint4b_t, 133 LayoutB, 134 kAlignmentB, 135 half_t, 136 layout::RowMajor, 137 kAlignmentScale, 138 ElementAccumulator, 139 layout::RowMajor, 140 arch::OpClassTensorOp, 141 ArchTag, 142 ThreadblockShape, 143 WarpShape, 144 InstructionShape, 145 2, 146 Operator>; 147 148 public: 149 // Define the MmaCore components 150 using MmaCore = typename Mma::MmaCore; 151 152 // Define iterators over tiles from the A operand 153 using IteratorA = typename Mma::IteratorA; 154 155 // Define iterators over tiles from the B operand 156 using IteratorB = typename Mma::IteratorB; 157 158 // Define the threadblock-scoped pipelined matrix multiply 159 using ThreadblockMma = typename Mma::ThreadblockMma; 160 }; 161 162 template< 163 /// Layout type for A matrix operand 164 typename LayoutA, 165 /// Access granularity of A matrix in units of elements 166 int kAlignmentA, 167 /// Layout type for B matrix operand 168 typename LayoutB, 169 /// Access granularity of B matrix in units of elements 170 int kAlignmentB, 171 /// Element type for internal accumulation 172 typename ElementAccumulator, 173 /// Tag indicating architecture to tune for 174 typename ArchTag, 175 /// Threadblock-level tile size (concept: GemmShape) 176 typename ThreadblockShape, 177 /// Warp-level tile size (concept: GemmShape) 178 typename WarpShape, 179 /// Instruction-level tile size (concept: GemmShape) 180 typename InstructionShape, 181 /// Operation performed by GEMM 182 typename Operator, 183 /// 184 int kStages, 185 /// Shared memory clear option 186 SharedMemoryClearOption SharedMemoryClear> 187 struct DefaultMma<cutlass::half_t, 188 LayoutA, 189 kAlignmentA, 190 uint8_t, 191 LayoutB, 192 kAlignmentB, 193 ElementAccumulator, 194 layout::RowMajor, 195 arch::OpClassTensorOp, 196 ArchTag, 197 ThreadblockShape, 198 WarpShape, 199 InstructionShape, 200 kStages, 201 Operator, 202 false, 203 SharedMemoryClear> { 204 205 private: 206 static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value; 207 208 using Mma = DqMma<half_t, 209 LayoutA, 210 kAlignmentA, 211 uint8_t, 212 LayoutB, 213 kAlignmentB, 214 half_t, 215 layout::RowMajor, 216 kAlignmentScale, 217 ElementAccumulator, 218 layout::RowMajor, 219 arch::OpClassTensorOp, 220 ArchTag, 221 ThreadblockShape, 222 WarpShape, 223 InstructionShape, 224 kStages, 225 Operator, 226 SharedMemoryClear>; 227 228 public: 229 // Define the MmaCore components 230 using MmaCore = typename Mma::MmaCore; 231 232 // Define iterators over tiles from the A operand 233 using IteratorA = typename Mma::IteratorA; 234 235 // Define iterators over tiles from the B operand 236 using IteratorB = typename Mma::IteratorB; 237 238 // Define the threadblock-scoped pipelined matrix multiply 239 using ThreadblockMma = typename Mma::ThreadblockMma; 240 }; 241 242 //////////////////////////////////////////////////////////////////////////////// 243 /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight 244 template< 245 /// Layout type for A matrix operand 246 typename LayoutA, 247 /// Access granularity of A matrix in units of elements 248 int kAlignmentA, 249 /// Layout type for B matrix operand 250 typename LayoutB, 251 /// Access granularity of B matrix in units of elements 252 int kAlignmentB, 253 /// Element type for internal accumulation 254 typename ElementAccumulator, 255 /// Tag indicating architecture to tune for 256 typename ArchTag, 257 /// Threadblock-level tile size (concept: GemmShape) 258 typename ThreadblockShape, 259 /// Warp-level tile size (concept: GemmShape) 260 typename WarpShape, 261 /// Instruction-level tile size (concept: GemmShape) 262 typename InstructionShape, 263 /// Operation performed by GEMM 264 typename Operator, 265 /// 266 int kStages, 267 /// Shared memory clear option 268 SharedMemoryClearOption SharedMemoryClear> 269 struct DefaultMma<cutlass::half_t, 270 LayoutA, 271 kAlignmentA, 272 uint4b_t, 273 LayoutB, 274 kAlignmentB, 275 ElementAccumulator, 276 layout::RowMajor, 277 arch::OpClassTensorOp, 278 ArchTag, 279 ThreadblockShape, 280 WarpShape, 281 InstructionShape, 282 kStages, 283 Operator, 284 false, 285 SharedMemoryClear> { 286 287 private: 288 static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value; 289 290 using Mma = DqMma<half_t, 291 LayoutA, 292 kAlignmentA, 293 uint4b_t, 294 LayoutB, 295 kAlignmentB, 296 half_t, 297 layout::RowMajor, 298 kAlignmentScale, 299 ElementAccumulator, 300 layout::RowMajor, 301 arch::OpClassTensorOp, 302 ArchTag, 303 ThreadblockShape, 304 WarpShape, 305 InstructionShape, 306 kStages, 307 Operator, 308 SharedMemoryClear>; 309 310 public: 311 // Define the MmaCore components 312 using MmaCore = typename Mma::MmaCore; 313 314 // Define iterators over tiles from the A operand 315 using IteratorA = typename Mma::IteratorA; 316 317 // Define iterators over tiles from the B operand 318 using IteratorB = typename Mma::IteratorB; 319 320 // Define the threadblock-scoped pipelined matrix multiply 321 using ThreadblockMma = typename Mma::ThreadblockMma; 322 }; 323 324 // fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on 325 // large tile when not enough shared mem is present to do 3+ stage 326 template< 327 /// Layout type for A matrix operand 328 typename LayoutA, 329 /// Access granularity of A matrix in units of elements 330 int kAlignmentA, 331 /// Layout type for B matrix operand 332 typename LayoutB, 333 /// Access granularity of B matrix in units of elements 334 int kAlignmentB, 335 /// Element type for internal accumulation 336 typename ElementAccumulator, 337 /// Threadblock-level tile size (concept: GemmShape) 338 typename ThreadblockShape, 339 /// Warp-level tile size (concept: GemmShape) 340 typename WarpShape, 341 /// Instruction-level tile size (concept: GemmShape) 342 typename InstructionShape, 343 /// Operation performed by GEMM 344 typename Operator, 345 /// Use zfill or predicate for out-of-bound cp.async 346 SharedMemoryClearOption SharedMemoryClear, 347 /// Gather operand A by using an index array 348 bool GatherA, 349 /// Gather operand B by using an index array 350 bool GatherB> 351 struct DefaultMma<half_t, 352 LayoutA, 353 kAlignmentA, 354 half_t, 355 LayoutB, 356 kAlignmentB, 357 ElementAccumulator, 358 layout::RowMajor, 359 arch::OpClassTensorOp, 360 arch::Sm80, 361 ThreadblockShape, 362 WarpShape, 363 InstructionShape, 364 2, 365 Operator, 366 false, 367 SharedMemoryClear, 368 GatherA, 369 GatherB> { 370 371 // Define the MmaCore components 372 // 3 is used on purpose here to trigger components for mma multistage 373 using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, 374 WarpShape, 375 InstructionShape, 376 half_t, 377 LayoutA, 378 half_t, 379 LayoutB, 380 ElementAccumulator, 381 layout::RowMajor, 382 arch::OpClassTensorOp, 383 3, 384 Operator>; 385 386 // Define iterators over tiles from the A operand 387 using ThreadMapA = typename MmaCore::IteratorThreadMapA; 388 using AccessTypeA = cutlass::Array<half_t, kAlignmentA>; 389 using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< 390 cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, 391 half_t, 392 LayoutA, 393 1, 394 ThreadMapA, 395 AccessTypeA, 396 GatherA>; 397 398 // Define iterators over tiles from the B operand 399 using ThreadMapB = typename MmaCore::IteratorThreadMapB; 400 using AccessTypeB = cutlass::Array<half_t, kAlignmentB>; 401 using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< 402 cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, 403 half_t, 404 LayoutB, 405 0, 406 ThreadMapB, 407 AccessTypeB, 408 GatherB>; 409 410 // Define the threadblock-scoped multistage matrix multiply 411 using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, 412 IteratorA, 413 typename MmaCore::SmemIteratorA, 414 MmaCore::kCacheOpA, 415 IteratorB, 416 typename MmaCore::SmemIteratorB, 417 MmaCore::kCacheOpB, 418 ElementAccumulator, 419 layout::RowMajor, 420 typename MmaCore::MmaPolicy, 421 2>; 422 }; 423 424 } // namespace threadblock 425 } // namespace gemm 426 } // namespace cutlass 427