xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_mma.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h>
4 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h>
5 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_mma_bf16.h>
6 
7 namespace cutlass {
8 namespace gemm {
9 namespace threadblock {
10 
11 ////////////////////////////////////////////////////////////////////////////////
12 
13 /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight
14 template<
15     /// Layout type for A matrix operand
16     typename LayoutA,
17     /// Access granularity of A matrix in units of elements
18     int kAlignmentA,
19     /// Layout type for B matrix operand
20     typename LayoutB,
21     /// Access granularity of B matrix in units of elements
22     int kAlignmentB,
23     /// Element type for internal accumulation
24     typename ElementAccumulator,
25     /// Tag indicating architecture to tune for
26     typename ArchTag,
27     /// Threadblock-level tile size (concept: GemmShape)
28     typename ThreadblockShape,
29     /// Warp-level tile size (concept: GemmShape)
30     typename WarpShape,
31     /// Instruction-level tile size (concept: GemmShape)
32     typename InstructionShape,
33     /// Operation performed by GEMM
34     typename Operator>
35 struct DefaultMma<cutlass::half_t,
36                   LayoutA,
37                   kAlignmentA,
38                   uint8_t,
39                   LayoutB,
40                   kAlignmentB,
41                   ElementAccumulator,
42                   layout::RowMajor,
43                   arch::OpClassTensorOp,
44                   ArchTag,
45                   ThreadblockShape,
46                   WarpShape,
47                   InstructionShape,
48                   2,
49                   Operator> {
50 
51 private:
52     static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
53 
54     using Mma = DqMma<half_t,
55                       LayoutA,
56                       kAlignmentA,
57                       uint8_t,
58                       LayoutB,
59                       kAlignmentB,
60                       half_t,
61                       layout::RowMajor,
62                       kAlignmentScale,
63                       ElementAccumulator,
64                       layout::RowMajor,
65                       arch::OpClassTensorOp,
66                       ArchTag,
67                       ThreadblockShape,
68                       WarpShape,
69                       InstructionShape,
70                       2,
71                       Operator>;
72 
73 public:
74     // Define the MmaCore components
75     using MmaCore = typename Mma::MmaCore;
76 
77     // Define iterators over tiles from the A operand
78     using IteratorA = typename Mma::IteratorA;
79 
80     // Define iterators over tiles from the B operand
81     using IteratorB = typename Mma::IteratorB;
82 
83     // Define the threadblock-scoped pipelined matrix multiply
84     using ThreadblockMma = typename Mma::ThreadblockMma;
85 };
86 
87 ////////////////////////////////////////////////////////////////////////////////
88 /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
89 template<
90     /// Layout type for A matrix operand
91     typename LayoutA,
92     /// Access granularity of A matrix in units of elements
93     int kAlignmentA,
94     /// Layout type for B matrix operand
95     typename LayoutB,
96     /// Access granularity of B matrix in units of elements
97     int kAlignmentB,
98     /// Element type for internal accumulation
99     typename ElementAccumulator,
100     /// Tag indicating architecture to tune for
101     typename ArchTag,
102     /// Threadblock-level tile size (concept: GemmShape)
103     typename ThreadblockShape,
104     /// Warp-level tile size (concept: GemmShape)
105     typename WarpShape,
106     /// Instruction-level tile size (concept: GemmShape)
107     typename InstructionShape,
108     /// Operation performed by GEMM
109     typename Operator>
110 struct DefaultMma<cutlass::half_t,
111                   LayoutA,
112                   kAlignmentA,
113                   uint4b_t,
114                   LayoutB,
115                   kAlignmentB,
116                   ElementAccumulator,
117                   layout::RowMajor,
118                   arch::OpClassTensorOp,
119                   ArchTag,
120                   ThreadblockShape,
121                   WarpShape,
122                   InstructionShape,
123                   2,
124                   Operator> {
125 
126 private:
127     static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
128 
129     using Mma = DqMma<half_t,
130                       LayoutA,
131                       kAlignmentA,
132                       uint4b_t,
133                       LayoutB,
134                       kAlignmentB,
135                       half_t,
136                       layout::RowMajor,
137                       kAlignmentScale,
138                       ElementAccumulator,
139                       layout::RowMajor,
140                       arch::OpClassTensorOp,
141                       ArchTag,
142                       ThreadblockShape,
143                       WarpShape,
144                       InstructionShape,
145                       2,
146                       Operator>;
147 
148 public:
149     // Define the MmaCore components
150     using MmaCore = typename Mma::MmaCore;
151 
152     // Define iterators over tiles from the A operand
153     using IteratorA = typename Mma::IteratorA;
154 
155     // Define iterators over tiles from the B operand
156     using IteratorB = typename Mma::IteratorB;
157 
158     // Define the threadblock-scoped pipelined matrix multiply
159     using ThreadblockMma = typename Mma::ThreadblockMma;
160 };
161 
162 template<
163     /// Layout type for A matrix operand
164     typename LayoutA,
165     /// Access granularity of A matrix in units of elements
166     int kAlignmentA,
167     /// Layout type for B matrix operand
168     typename LayoutB,
169     /// Access granularity of B matrix in units of elements
170     int kAlignmentB,
171     /// Element type for internal accumulation
172     typename ElementAccumulator,
173     /// Tag indicating architecture to tune for
174     typename ArchTag,
175     /// Threadblock-level tile size (concept: GemmShape)
176     typename ThreadblockShape,
177     /// Warp-level tile size (concept: GemmShape)
178     typename WarpShape,
179     /// Instruction-level tile size (concept: GemmShape)
180     typename InstructionShape,
181     /// Operation performed by GEMM
182     typename Operator,
183     ///
184     int kStages,
185     /// Shared memory clear option
186     SharedMemoryClearOption SharedMemoryClear>
187 struct DefaultMma<cutlass::half_t,
188                   LayoutA,
189                   kAlignmentA,
190                   uint8_t,
191                   LayoutB,
192                   kAlignmentB,
193                   ElementAccumulator,
194                   layout::RowMajor,
195                   arch::OpClassTensorOp,
196                   ArchTag,
197                   ThreadblockShape,
198                   WarpShape,
199                   InstructionShape,
200                   kStages,
201                   Operator,
202                   false,
203                   SharedMemoryClear> {
204 
205 private:
206     static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
207 
208     using Mma = DqMma<half_t,
209                       LayoutA,
210                       kAlignmentA,
211                       uint8_t,
212                       LayoutB,
213                       kAlignmentB,
214                       half_t,
215                       layout::RowMajor,
216                       kAlignmentScale,
217                       ElementAccumulator,
218                       layout::RowMajor,
219                       arch::OpClassTensorOp,
220                       ArchTag,
221                       ThreadblockShape,
222                       WarpShape,
223                       InstructionShape,
224                       kStages,
225                       Operator,
226                       SharedMemoryClear>;
227 
228 public:
229     // Define the MmaCore components
230     using MmaCore = typename Mma::MmaCore;
231 
232     // Define iterators over tiles from the A operand
233     using IteratorA = typename Mma::IteratorA;
234 
235     // Define iterators over tiles from the B operand
236     using IteratorB = typename Mma::IteratorB;
237 
238     // Define the threadblock-scoped pipelined matrix multiply
239     using ThreadblockMma = typename Mma::ThreadblockMma;
240 };
241 
242 ////////////////////////////////////////////////////////////////////////////////
243 /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
244 template<
245     /// Layout type for A matrix operand
246     typename LayoutA,
247     /// Access granularity of A matrix in units of elements
248     int kAlignmentA,
249     /// Layout type for B matrix operand
250     typename LayoutB,
251     /// Access granularity of B matrix in units of elements
252     int kAlignmentB,
253     /// Element type for internal accumulation
254     typename ElementAccumulator,
255     /// Tag indicating architecture to tune for
256     typename ArchTag,
257     /// Threadblock-level tile size (concept: GemmShape)
258     typename ThreadblockShape,
259     /// Warp-level tile size (concept: GemmShape)
260     typename WarpShape,
261     /// Instruction-level tile size (concept: GemmShape)
262     typename InstructionShape,
263     /// Operation performed by GEMM
264     typename Operator,
265     ///
266     int kStages,
267     /// Shared memory clear option
268     SharedMemoryClearOption SharedMemoryClear>
269 struct DefaultMma<cutlass::half_t,
270                   LayoutA,
271                   kAlignmentA,
272                   uint4b_t,
273                   LayoutB,
274                   kAlignmentB,
275                   ElementAccumulator,
276                   layout::RowMajor,
277                   arch::OpClassTensorOp,
278                   ArchTag,
279                   ThreadblockShape,
280                   WarpShape,
281                   InstructionShape,
282                   kStages,
283                   Operator,
284                   false,
285                   SharedMemoryClear> {
286 
287 private:
288     static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
289 
290     using Mma = DqMma<half_t,
291                       LayoutA,
292                       kAlignmentA,
293                       uint4b_t,
294                       LayoutB,
295                       kAlignmentB,
296                       half_t,
297                       layout::RowMajor,
298                       kAlignmentScale,
299                       ElementAccumulator,
300                       layout::RowMajor,
301                       arch::OpClassTensorOp,
302                       ArchTag,
303                       ThreadblockShape,
304                       WarpShape,
305                       InstructionShape,
306                       kStages,
307                       Operator,
308                       SharedMemoryClear>;
309 
310 public:
311     // Define the MmaCore components
312     using MmaCore = typename Mma::MmaCore;
313 
314     // Define iterators over tiles from the A operand
315     using IteratorA = typename Mma::IteratorA;
316 
317     // Define iterators over tiles from the B operand
318     using IteratorB = typename Mma::IteratorB;
319 
320     // Define the threadblock-scoped pipelined matrix multiply
321     using ThreadblockMma = typename Mma::ThreadblockMma;
322 };
323 
324 // fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
325 // large tile when not enough shared mem is present to do 3+ stage
326 template<
327     /// Layout type for A matrix operand
328     typename LayoutA,
329     /// Access granularity of A matrix in units of elements
330     int kAlignmentA,
331     /// Layout type for B matrix operand
332     typename LayoutB,
333     /// Access granularity of B matrix in units of elements
334     int kAlignmentB,
335     /// Element type for internal accumulation
336     typename ElementAccumulator,
337     /// Threadblock-level tile size (concept: GemmShape)
338     typename ThreadblockShape,
339     /// Warp-level tile size (concept: GemmShape)
340     typename WarpShape,
341     /// Instruction-level tile size (concept: GemmShape)
342     typename InstructionShape,
343     /// Operation performed by GEMM
344     typename Operator,
345     /// Use zfill or predicate for out-of-bound cp.async
346     SharedMemoryClearOption SharedMemoryClear,
347     /// Gather operand A by using an index array
348     bool GatherA,
349     /// Gather operand B by using an index array
350     bool GatherB>
351 struct DefaultMma<half_t,
352                   LayoutA,
353                   kAlignmentA,
354                   half_t,
355                   LayoutB,
356                   kAlignmentB,
357                   ElementAccumulator,
358                   layout::RowMajor,
359                   arch::OpClassTensorOp,
360                   arch::Sm80,
361                   ThreadblockShape,
362                   WarpShape,
363                   InstructionShape,
364                   2,
365                   Operator,
366                   false,
367                   SharedMemoryClear,
368                   GatherA,
369                   GatherB> {
370 
371     // Define the MmaCore components
372     // 3 is used on purpose here to trigger components for mma multistage
373     using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
374                                                                         WarpShape,
375                                                                         InstructionShape,
376                                                                         half_t,
377                                                                         LayoutA,
378                                                                         half_t,
379                                                                         LayoutB,
380                                                                         ElementAccumulator,
381                                                                         layout::RowMajor,
382                                                                         arch::OpClassTensorOp,
383                                                                         3,
384                                                                         Operator>;
385 
386     // Define iterators over tiles from the A operand
387     using ThreadMapA  = typename MmaCore::IteratorThreadMapA;
388     using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
389     using IteratorA   = cutlass::transform::threadblock::PredicatedTileAccessIterator<
390         cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
391         half_t,
392         LayoutA,
393         1,
394         ThreadMapA,
395         AccessTypeA,
396         GatherA>;
397 
398     // Define iterators over tiles from the B operand
399     using ThreadMapB  = typename MmaCore::IteratorThreadMapB;
400     using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
401     using IteratorB   = cutlass::transform::threadblock::PredicatedTileAccessIterator<
402         cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
403         half_t,
404         LayoutB,
405         0,
406         ThreadMapB,
407         AccessTypeB,
408         GatherB>;
409 
410     // Define the threadblock-scoped multistage matrix multiply
411     using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
412                                                                      IteratorA,
413                                                                      typename MmaCore::SmemIteratorA,
414                                                                      MmaCore::kCacheOpA,
415                                                                      IteratorB,
416                                                                      typename MmaCore::SmemIteratorB,
417                                                                      MmaCore::kCacheOpB,
418                                                                      ElementAccumulator,
419                                                                      layout::RowMajor,
420                                                                      typename MmaCore::MmaPolicy,
421                                                                      2>;
422 };
423 
424 }  // namespace threadblock
425 }  // namespace gemm
426 }  // namespace cutlass
427