xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_mma_bf16.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cutlass/gemm/threadblock/default_mma.h>
4 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h>
5 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h>
6 
7 namespace cutlass {
8 namespace gemm {
9 namespace threadblock {
10 
11 ////////////////////////////////////////////////////////////////////////////////
12 
13 /// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight
14 template<
15     /// Layout type for A matrix operand
16     typename LayoutA,
17     /// Access granularity of A matrix in units of elements
18     int kAlignmentA,
19     /// Layout type for B matrix operand
20     typename LayoutB,
21     /// Access granularity of B matrix in units of elements
22     int kAlignmentB,
23     /// Element type for internal accumulation
24     typename ElementAccumulator,
25     /// Tag indicating architecture to tune for
26     typename ArchTag,
27     /// Threadblock-level tile size (concept: GemmShape)
28     typename ThreadblockShape,
29     /// Warp-level tile size (concept: GemmShape)
30     typename WarpShape,
31     /// Instruction-level tile size (concept: GemmShape)
32     typename InstructionShape,
33     /// Operation performed by GEMM
34     typename Operator,
35     /// Use zfill or predicate for out-of-bound cp.async
36     SharedMemoryClearOption SharedMemoryClear,
37     /// Gather operand A by using an index array
38     bool GatherA,
39     /// Gather operand B by using an index array
40     bool GatherB>
41 struct DefaultMma<bfloat16_t,
42                   LayoutA,
43                   kAlignmentA,
44                   bfloat16_t,
45                   LayoutB,
46                   kAlignmentB,
47                   ElementAccumulator,
48                   layout::RowMajor,
49                   arch::OpClassTensorOp,
50                   ArchTag,
51                   ThreadblockShape,
52                   WarpShape,
53                   InstructionShape,
54                   2,
55                   Operator,
56                   false,
57                   SharedMemoryClear,
58                   GatherA,
59                   GatherB> {
60 
61 private:
62     // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS.
63     static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
64     using MmaElementA = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
65     using MmaElementB = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
66 
67 public:
68     // Define the MmaCore components
69     using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
70                                                                         WarpShape,
71                                                                         InstructionShape,
72                                                                         MmaElementA,
73                                                                         LayoutA,
74                                                                         MmaElementB,
75                                                                         LayoutB,
76                                                                         ElementAccumulator,
77                                                                         layout::RowMajor,
78                                                                         arch::OpClassTensorOp,
79                                                                         2,
80                                                                         Operator>;
81 
82     using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
83         cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
84         bfloat16_t,
85         LayoutA,
86         1,
87         typename MmaCore::IteratorThreadMapA,
88         kAlignmentA,
89         GatherA>;
90 
91     // Define iterators over tiles from the B operand
92     using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
93         cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
94         bfloat16_t,
95         LayoutB,
96         0,
97         typename MmaCore::IteratorThreadMapB,
98         kAlignmentB,
99         GatherB>;
100 
101     // Define the threadblock-scoped pipelined matrix multiply
102     using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape,
103                                                                     IteratorA,
104                                                                     typename MmaCore::SmemIteratorA,
105                                                                     IteratorB,
106                                                                     typename MmaCore::SmemIteratorB,
107                                                                     ElementAccumulator,
108                                                                     layout::RowMajor,
109                                                                     typename MmaCore::MmaPolicy>;
110 };
111 
112 // bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
113 // large tile when not enough shared mem is present to do 3+ stage
114 template<
115     /// Layout type for A matrix operand
116     typename LayoutA,
117     /// Access granularity of A matrix in units of elements
118     int kAlignmentA,
119     /// Layout type for B matrix operand
120     typename LayoutB,
121     /// Access granularity of B matrix in units of elements
122     int kAlignmentB,
123     /// Element type for internal accumulation
124     typename ElementAccumulator,
125     /// Threadblock-level tile size (concept: GemmShape)
126     typename ThreadblockShape,
127     /// Warp-level tile size (concept: GemmShape)
128     typename WarpShape,
129     /// Instruction-level tile size (concept: GemmShape)
130     typename InstructionShape,
131     /// Operation performed by GEMM
132     typename Operator,
133     /// Use zfill or predicate for out-of-bound cp.async
134     SharedMemoryClearOption SharedMemoryClear,
135     /// Gather operand A by using an index array
136     bool GatherA,
137     /// Gather operand B by using an index array
138     bool GatherB>
139 struct DefaultMma<bfloat16_t,
140                   LayoutA,
141                   kAlignmentA,
142                   bfloat16_t,
143                   LayoutB,
144                   kAlignmentB,
145                   ElementAccumulator,
146                   layout::RowMajor,
147                   arch::OpClassTensorOp,
148                   arch::Sm80,
149                   ThreadblockShape,
150                   WarpShape,
151                   InstructionShape,
152                   2,
153                   Operator,
154                   false,
155                   SharedMemoryClear,
156                   GatherA,
157                   GatherB> {
158 
159     // Define the MmaCore components
160     // 3 is used on purpose here to trigger components for mma multistage
161     using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
162                                                                         WarpShape,
163                                                                         InstructionShape,
164                                                                         bfloat16_t,
165                                                                         LayoutA,
166                                                                         bfloat16_t,
167                                                                         LayoutB,
168                                                                         ElementAccumulator,
169                                                                         layout::RowMajor,
170                                                                         arch::OpClassTensorOp,
171                                                                         3,
172                                                                         Operator>;
173 
174     // Define iterators over tiles from the A operand
175     using ThreadMapA  = typename MmaCore::IteratorThreadMapA;
176     using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
177     using IteratorA   = cutlass::transform::threadblock::PredicatedTileAccessIterator<
178         cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
179         bfloat16_t,
180         LayoutA,
181         1,
182         ThreadMapA,
183         AccessTypeA,
184         GatherA>;
185 
186     // Define iterators over tiles from the B operand
187     using ThreadMapB  = typename MmaCore::IteratorThreadMapB;
188     using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
189     using IteratorB   = cutlass::transform::threadblock::PredicatedTileAccessIterator<
190         cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
191         bfloat16_t,
192         LayoutB,
193         0,
194         ThreadMapB,
195         AccessTypeB,
196         GatherB>;
197 
198     // Define the threadblock-scoped multistage matrix multiply
199     using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
200                                                                      IteratorA,
201                                                                      typename MmaCore::SmemIteratorA,
202                                                                      MmaCore::kCacheOpA,
203                                                                      IteratorB,
204                                                                      typename MmaCore::SmemIteratorB,
205                                                                      MmaCore::kCacheOpB,
206                                                                      ElementAccumulator,
207                                                                      layout::RowMajor,
208                                                                      typename MmaCore::MmaPolicy,
209                                                                      2>;
210 };
211 
212 ////////////////////////////////////////////////////////////////////////////////
213 
214 /// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
215 template<
216     /// Layout type for A matrix operand
217     typename LayoutA,
218     /// Access granularity of A matrix in units of elements
219     int kAlignmentA,
220     /// Layout type for B matrix operand
221     typename LayoutB,
222     /// Access granularity of B matrix in units of elements
223     int kAlignmentB,
224     /// Element type for internal accumulation
225     typename ElementAccumulator,
226     /// Tag indicating architecture to tune for
227     typename ArchTag,
228     /// Threadblock-level tile size (concept: GemmShape)
229     typename ThreadblockShape,
230     /// Warp-level tile size (concept: GemmShape)
231     typename WarpShape,
232     /// Instruction-level tile size (concept: GemmShape)
233     typename InstructionShape,
234     /// Operation performed by GEMM
235     typename Operator>
236 struct DefaultMma<cutlass::bfloat16_t,
237                   LayoutA,
238                   kAlignmentA,
239                   uint8_t,
240                   LayoutB,
241                   kAlignmentB,
242                   ElementAccumulator,
243                   layout::RowMajor,
244                   arch::OpClassTensorOp,
245                   ArchTag,
246                   ThreadblockShape,
247                   WarpShape,
248                   InstructionShape,
249                   2,
250                   Operator> {
251 
252 private:
253     static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
254 
255     using Mma = DqMma<bfloat16_t,
256                       LayoutA,
257                       kAlignmentA,
258                       uint8_t,
259                       LayoutB,
260                       kAlignmentB,
261                       bfloat16_t,
262                       layout::RowMajor,
263                       kAlignmentScale,
264                       ElementAccumulator,
265                       layout::RowMajor,
266                       arch::OpClassTensorOp,
267                       ArchTag,
268                       ThreadblockShape,
269                       WarpShape,
270                       InstructionShape,
271                       2,
272                       Operator>;
273 
274 public:
275     // Define the MmaCore components
276     using MmaCore = typename Mma::MmaCore;
277 
278     // Define iterators over tiles from the A operand
279     using IteratorA = typename Mma::IteratorA;
280 
281     // Define iterators over tiles from the B operand
282     using IteratorB = typename Mma::IteratorB;
283 
284     // Define the threadblock-scoped pipelined matrix multiply
285     using ThreadblockMma = typename Mma::ThreadblockMma;
286 };
287 
288 ////////////////////////////////////////////////////////////////////////////////
289 /// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
290 template<
291     /// Layout type for A matrix operand
292     typename LayoutA,
293     /// Access granularity of A matrix in units of elements
294     int kAlignmentA,
295     /// Layout type for B matrix operand
296     typename LayoutB,
297     /// Access granularity of B matrix in units of elements
298     int kAlignmentB,
299     /// Element type for internal accumulation
300     typename ElementAccumulator,
301     /// Tag indicating architecture to tune for
302     typename ArchTag,
303     /// Threadblock-level tile size (concept: GemmShape)
304     typename ThreadblockShape,
305     /// Warp-level tile size (concept: GemmShape)
306     typename WarpShape,
307     /// Instruction-level tile size (concept: GemmShape)
308     typename InstructionShape,
309     /// Operation performed by GEMM
310     typename Operator>
311 struct DefaultMma<cutlass::bfloat16_t,
312                   LayoutA,
313                   kAlignmentA,
314                   uint4b_t,
315                   LayoutB,
316                   kAlignmentB,
317                   ElementAccumulator,
318                   layout::RowMajor,
319                   arch::OpClassTensorOp,
320                   ArchTag,
321                   ThreadblockShape,
322                   WarpShape,
323                   InstructionShape,
324                   2,
325                   Operator> {
326 
327 private:
328     static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
329 
330     using Mma = DqMma<bfloat16_t,
331                       LayoutA,
332                       kAlignmentA,
333                       uint4b_t,
334                       LayoutB,
335                       kAlignmentB,
336                       bfloat16_t,
337                       layout::RowMajor,
338                       kAlignmentScale,
339                       ElementAccumulator,
340                       layout::RowMajor,
341                       arch::OpClassTensorOp,
342                       ArchTag,
343                       ThreadblockShape,
344                       WarpShape,
345                       InstructionShape,
346                       2,
347                       Operator>;
348 
349 public:
350     // Define the MmaCore components
351     using MmaCore = typename Mma::MmaCore;
352 
353     // Define iterators over tiles from the A operand
354     using IteratorA = typename Mma::IteratorA;
355 
356     // Define iterators over tiles from the B operand
357     using IteratorB = typename Mma::IteratorB;
358 
359     // Define the threadblock-scoped pipelined matrix multiply
360     using ThreadblockMma = typename Mma::ThreadblockMma;
361 };
362 
363 template<
364     /// Layout type for A matrix operand
365     typename LayoutA,
366     /// Access granularity of A matrix in units of elements
367     int kAlignmentA,
368     /// Layout type for B matrix operand
369     typename LayoutB,
370     /// Access granularity of B matrix in units of elements
371     int kAlignmentB,
372     /// Element type for internal accumulation
373     typename ElementAccumulator,
374     /// Tag indicating architecture to tune for
375     typename ArchTag,
376     /// Threadblock-level tile size (concept: GemmShape)
377     typename ThreadblockShape,
378     /// Warp-level tile size (concept: GemmShape)
379     typename WarpShape,
380     /// Instruction-level tile size (concept: GemmShape)
381     typename InstructionShape,
382     /// Operation performed by GEMM
383     typename Operator,
384     ///
385     int kStages,
386     /// Shared memory clear option
387     SharedMemoryClearOption SharedMemoryClear>
388 struct DefaultMma<cutlass::bfloat16_t,
389                   LayoutA,
390                   kAlignmentA,
391                   uint8_t,
392                   LayoutB,
393                   kAlignmentB,
394                   ElementAccumulator,
395                   layout::RowMajor,
396                   arch::OpClassTensorOp,
397                   ArchTag,
398                   ThreadblockShape,
399                   WarpShape,
400                   InstructionShape,
401                   kStages,
402                   Operator,
403                   false,
404                   SharedMemoryClear> {
405 
406 private:
407     static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
408 
409     using Mma = DqMma<bfloat16_t,
410                       LayoutA,
411                       kAlignmentA,
412                       uint8_t,
413                       LayoutB,
414                       kAlignmentB,
415                       bfloat16_t,
416                       layout::RowMajor,
417                       kAlignmentScale,
418                       ElementAccumulator,
419                       layout::RowMajor,
420                       arch::OpClassTensorOp,
421                       ArchTag,
422                       ThreadblockShape,
423                       WarpShape,
424                       InstructionShape,
425                       kStages,
426                       Operator,
427                       SharedMemoryClear>;
428 
429 public:
430     // Define the MmaCore components
431     using MmaCore = typename Mma::MmaCore;
432 
433     // Define iterators over tiles from the A operand
434     using IteratorA = typename Mma::IteratorA;
435 
436     // Define iterators over tiles from the B operand
437     using IteratorB = typename Mma::IteratorB;
438 
439     // Define the threadblock-scoped pipelined matrix multiply
440     using ThreadblockMma = typename Mma::ThreadblockMma;
441 };
442 
443 ////////////////////////////////////////////////////////////////////////////////
444 /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
445 template<
446     /// Layout type for A matrix operand
447     typename LayoutA,
448     /// Access granularity of A matrix in units of elements
449     int kAlignmentA,
450     /// Layout type for B matrix operand
451     typename LayoutB,
452     /// Access granularity of B matrix in units of elements
453     int kAlignmentB,
454     /// Element type for internal accumulation
455     typename ElementAccumulator,
456     /// Tag indicating architecture to tune for
457     typename ArchTag,
458     /// Threadblock-level tile size (concept: GemmShape)
459     typename ThreadblockShape,
460     /// Warp-level tile size (concept: GemmShape)
461     typename WarpShape,
462     /// Instruction-level tile size (concept: GemmShape)
463     typename InstructionShape,
464     /// Operation performed by GEMM
465     typename Operator,
466     ///
467     int kStages,
468     /// Shared memory clear option
469     SharedMemoryClearOption SharedMemoryClear>
470 struct DefaultMma<cutlass::bfloat16_t,
471                   LayoutA,
472                   kAlignmentA,
473                   uint4b_t,
474                   LayoutB,
475                   kAlignmentB,
476                   ElementAccumulator,
477                   layout::RowMajor,
478                   arch::OpClassTensorOp,
479                   ArchTag,
480                   ThreadblockShape,
481                   WarpShape,
482                   InstructionShape,
483                   kStages,
484                   Operator,
485                   false,
486                   SharedMemoryClear> {
487 
488 private:
489     static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
490 
491     using Mma = DqMma<bfloat16_t,
492                       LayoutA,
493                       kAlignmentA,
494                       uint4b_t,
495                       LayoutB,
496                       kAlignmentB,
497                       bfloat16_t,
498                       layout::RowMajor,
499                       kAlignmentScale,
500                       ElementAccumulator,
501                       layout::RowMajor,
502                       arch::OpClassTensorOp,
503                       ArchTag,
504                       ThreadblockShape,
505                       WarpShape,
506                       InstructionShape,
507                       kStages,
508                       Operator,
509                       SharedMemoryClear>;
510 
511 public:
512     // Define the MmaCore components
513     using MmaCore = typename Mma::MmaCore;
514 
515     // Define iterators over tiles from the A operand
516     using IteratorA = typename Mma::IteratorA;
517 
518     // Define iterators over tiles from the B operand
519     using IteratorB = typename Mma::IteratorB;
520 
521     // Define the threadblock-scoped pipelined matrix multiply
522     using ThreadblockMma = typename Mma::ThreadblockMma;
523 };
524 
525 }  // namespace threadblock
526 }  // namespace gemm
527 }  // namespace cutlass
528