1 #pragma once
2 
3 #include <cutlass/gemm/threadblock/default_mma.h>
4 #include <ATen/native/cuda/cutlass_extensions/arch/mma.h>
5 
6 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h>
7 #include <ATen/native/cuda/cutlass_extensions/gemm/warp/default_mma_tensor_op.h>
8 #include <ATen/native/cuda/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h>
9 #include <ATen/native/cuda/cutlass_extensions/tile_interleaved_layout.h>
10 
11 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_dq_mma.h>
12 
13 namespace cutlass {
14 namespace gemm {
15 namespace threadblock {
16 
17 ////////////////////////////////////////////////////////////////////////////////
18 
19 template<
20     /// Type for element A
21     typename ElementA,
22     /// Layout type for A matrix operand
23     typename LayoutA,
24     /// Access granularity of A matrix in units of elements
25     int kAlignmentA,
26     /// Type for element B
27     typename ElementB,
28     /// Layout type for B matrix operand
29     typename LayoutB,
30     /// Access granularity of B matrix in units of elements
31     int kAlignmentB,
32     /// Element type for the input scale
33     typename ElementScale,
34     /// Layout for the scale operand
35     typename LayoutScale,
36     /// Access granularity of Scales in unit of elements
37     int kAlignmentScale,
38     /// Element type for internal accumulation
39     typename ElementAccumulator,
40     /// Operator class tag
41     typename OperatorClass,
42     /// Tag indicating architecture to tune for
43     typename ArchTag,
44     /// Threadblock-level tile size (concept: GemmShape)
45     typename ThreadblockShape,
46     /// Warp-level tile size (concept: GemmShape)
47     typename WarpShape,
48     /// Instruction-level tile size (concept: GemmShape)
49     typename InstructionShape,
50     /// Operation performed by GEMM
51     typename Operator>
52 struct DqMma<ElementA,
53              LayoutA,
54              kAlignmentA,
55              ElementB,
56              LayoutB,
57              kAlignmentB,
58              ElementScale,
59              LayoutScale,
60              kAlignmentScale,
61              ElementAccumulator,
62              layout::RowMajor,
63              OperatorClass,
64              ArchTag,
65              ThreadblockShape,
66              WarpShape,
67              InstructionShape,
68              2,
69              Operator,
70              SharedMemoryClearOption::kNone,
71              typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
72 
73     static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
74                   "Element A must be fp16 or bf16");
75 
76     static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
77                   "Element B must be uint8 or uint4");
78 
79     static constexpr bool DqAfterLDG        = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
80     static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
81     using MmaCoreElementA                   = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
82     using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
83 
84     // Define the MmaCore components
85     using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
86                                                                         WarpShape,
87                                                                         InstructionShape,
88                                                                         MmaCoreElementA,
89                                                                         LayoutA,
90                                                                         MmaCoreElementB,
91                                                                         LayoutB,
92                                                                         ElementAccumulator,
93                                                                         layout::RowMajor,
94                                                                         OperatorClass,
95                                                                         2,
96                                                                         Operator>;
97 
98     // Define iterators over tiles from the A operand
99     using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
100         cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
101         ElementA,
102         LayoutA,
103         1,
104         typename MmaCore::IteratorThreadMapA,
105         kAlignmentA>;
106 
107     // Define iterators over tiles from the B operand
108     using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
109         cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
110         ElementB,
111         LayoutB,
112         0,
113         typename MmaCore::IteratorThreadMapB,
114         kAlignmentB>;
115 
116     // ThreadMap for scale iterator
117     static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
118     using IteratorScaleThreadMap =
119         transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
120                                                   MmaCore::Shape::kN / kAlignmentScale,
121                                                   kAlignmentScale>;
122 
123     // Define iterators over tiles from the scale operand
124     using IteratorScale =
125         cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
126                                                                 ElementScale,
127                                                                 LayoutScale,
128                                                                 0,
129                                                                 IteratorScaleThreadMap,
130                                                                 kAlignmentScale>;
131 
132     using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
133     using SmemIteratorScale =
134         cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
135                                                                 SmemScaleType,
136                                                                 LayoutScale,
137                                                                 0,
138                                                                 IteratorScaleThreadMap,
139                                                                 kAlignmentScale>;
140 
141     using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
142 
143     // Define the threadblock-scoped pipelined matrix multiply
144     using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape,
145                                                                       IteratorA,
146                                                                       typename MmaCore::SmemIteratorA,
147                                                                       IteratorB,
148                                                                       typename MmaCore::SmemIteratorB,
149                                                                       IteratorScale,
150                                                                       SmemIteratorScale,
151                                                                       ElementAccumulator,
152                                                                       layout::RowMajor,
153                                                                       typename MmaCore::MmaPolicy,
154                                                                       typename Converters::TransformAfterLDG,
155                                                                       typename Converters::TransformAfterLDS>;
156 };
157 
158 // Specialization to handle column major interleave B
159 template<
160     /// Type for element A
161     typename ElementA,
162     /// Layout type for A matrix operand
163     typename LayoutA,
164     /// Access granularity of A matrix in units of elements
165     int kAlignmentA,
166     /// Type for element B
167     typename ElementB,
168     /// Access granularity of B matrix in units of elements
169     int kAlignmentB,
170     /// Element type for the input scale
171     typename ElementScale,
172     /// Layout for the scale operand
173     typename LayoutScale,
174     /// Access granularity of Scales in unit of elements
175     int kAlignmentScale,
176     /// Element type for internal accumulation
177     typename ElementAccumulator,
178     /// Operator class tag
179     typename OperatorClass,
180     /// Tag indicating architecture to tune for
181     typename ArchTag,
182     /// Threadblock-level tile size (concept: GemmShape)
183     typename ThreadblockShape,
184     /// Warp-level tile size (concept: GemmShape)
185     typename WarpShape,
186     /// Instruction-level tile size (concept: GemmShape)
187     typename InstructionShape,
188     /// Operation performed by GEMM
189     typename Operator,
190     ///
191     int RowsPerTile,
192     ///
193     int ColumnsInterleaved>
194 struct DqMma<ElementA,
195              LayoutA,
196              kAlignmentA,
197              ElementB,
198              layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>,
199              kAlignmentB,
200              ElementScale,
201              LayoutScale,
202              kAlignmentScale,
203              ElementAccumulator,
204              layout::RowMajor,
205              OperatorClass,
206              ArchTag,
207              ThreadblockShape,
208              WarpShape,
209              InstructionShape,
210              2,
211              Operator,
212              SharedMemoryClearOption::kNone,
213              typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
214 
215     static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
216                   "Element A must be fp16 or bf16");
217 
218     static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
219                   "Element B must be uint8 or uint4");
220 
221     static constexpr bool DqAfterLDG        = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
222     static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
223     using MmaCoreElementA                   = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
224     using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
225 
226     // Define the MmaCore components
227     using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
228                                                                         WarpShape,
229                                                                         InstructionShape,
230                                                                         MmaCoreElementA,
231                                                                         LayoutA,
232                                                                         MmaCoreElementB,
233                                                                         layout::ColumnMajor,
234                                                                         ElementAccumulator,
235                                                                         layout::RowMajor,
236                                                                         OperatorClass,
237                                                                         2,
238                                                                         Operator>;
239 
240     // Define iterators over tiles from the A operand
241     using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
242         cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
243         ElementA,
244         LayoutA,
245         1,
246         typename MmaCore::IteratorThreadMapA,
247         kAlignmentA>;
248 
249 private:
250     static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
251     static_assert(RowsPerTile == MmaCore::Shape::kK, "");
252 
253     using OriginalThreadMap       = typename MmaCore::IteratorThreadMapB;
254     using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
255     static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
256 
257     using GmemIteratorShape =
258         MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
259     using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
260         layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>,
261         OriginalThreadMap::kThreads,
262         layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
263                                  OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
264         MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
265 
266 public:
267     // Define iterators over tiles from the B operand
268     using IteratorB = cutlass::transform::threadblock::
269         PredicatedTileIterator<GmemIteratorShape, ElementB, layout::ColumnMajor, 0, GmemThreadMapB, kAlignmentB>;
270 
271     // ThreadMap for scale iterator
272     static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
273     using IteratorScaleThreadMap =
274         transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
275                                                   MmaCore::Shape::kN / kAlignmentScale,
276                                                   kAlignmentScale>;
277 
278     // Define iterators over tiles from the scale operand
279     using IteratorScale =
280         cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
281                                                                 ElementScale,
282                                                                 LayoutScale,
283                                                                 0,
284                                                                 IteratorScaleThreadMap,
285                                                                 kAlignmentScale>;
286 
287     using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
288     using SmemIteratorScale =
289         cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
290                                                                 SmemScaleType,
291                                                                 LayoutScale,
292                                                                 0,
293                                                                 IteratorScaleThreadMap,
294                                                                 kAlignmentScale>;
295 
296     using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
297 
298     // Define the threadblock-scoped pipelined matrix multiply
299     using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape,
300                                                                       IteratorA,
301                                                                       typename MmaCore::SmemIteratorA,
302                                                                       IteratorB,
303                                                                       typename MmaCore::SmemIteratorB,
304                                                                       IteratorScale,
305                                                                       SmemIteratorScale,
306                                                                       ElementAccumulator,
307                                                                       layout::RowMajor,
308                                                                       typename MmaCore::MmaPolicy,
309                                                                       typename Converters::TransformAfterLDG,
310                                                                       typename Converters::TransformAfterLDS>;
311 };
312 
313 }  // namespace threadblock
314 }  // namespace gemm
315 }  // namespace cutlass
316