1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #pragma once
9 
10 #include <cutlass/functional.h>
11 #include <cutlass/gemm/warp/mma_simt_tile_iterator.h>
12 #include <cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h>
13 #include <cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h>
14 #include <cutlass/matrix_shape.h>
15 
16 /*
17 TensorCores have different accumulator layouts.
18 This file provides a class to easily map the accumulator
19 i-th element with the corresponding matrix row/col.
20 */
21 
22 template <typename T, typename accum_t, int kWarpSize>
23 struct AccumLambdaIteratorSm80 {
24   static_assert(
25       cutlass::platform::
26           is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
27       "only RowMajor is supported");
28 
29   using Policy = typename T::Policy;
30   using InstructionShape = typename T::InstructionShape;
31   using OpDelta = typename T::OpDelta;
32   using Shape = typename T::Shape;
33   static int const kElementsPerAccess = InstructionShape::kN / 4;
34   static int const kRowsPerTile = 8;
35   static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile;
36 
get_lane_offsetAccumLambdaIteratorSm8037   static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(
38       int8_t lane_id,
39       int8_t warp_id,
40       typename T::TensorCoord const& tile_offset) {
41     int quad = (lane_id >> 2);
42     int lane_in_quad = (lane_id & 3);
43     return cutlass::MatrixCoord(
44         quad + tile_offset.row() * Shape::kRow,
45         lane_in_quad * kElementsPerAccess +
46             tile_offset.column() * Shape::kColumn);
47   }
48 
49   template <typename FA, typename FB, typename FC>
iterateRowsAccumLambdaIteratorSm8050   CUTLASS_DEVICE static void iterateRows(
51       cutlass::MatrixCoord& lane_offset,
52       FA beginRow,
53       FB op,
54       FC endRow) {
55     // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h
56     CUTLASS_PRAGMA_UNROLL
57     for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
58       CUTLASS_PRAGMA_UNROLL
59       for (int row = 0; row < kAccumulatorRows; ++row) {
60         int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow +
61             row * kRowsPerTile + lane_offset.row();
62         beginRow(accum_m);
63 
64         CUTLASS_PRAGMA_UNROLL
65         for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
66           int mma_accum_start = kAccumulatorRows * kElementsPerAccess *
67               (mma_n * Policy::MmaIterations::kRow + mma_m);
68           CUTLASS_PRAGMA_UNROLL
69           for (int col = 0; col < kElementsPerAccess; ++col) {
70             int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn +
71                 col + lane_offset.column();
72             int idx = mma_accum_start + row * kElementsPerAccess + col;
73             op(accum_m, accum_n, idx);
74           }
75         }
76 
77         endRow(accum_m);
78       }
79     }
80   }
81 
82   template <typename DT, typename F>
reduceSameRowAccumLambdaIteratorSm8083   CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) {
84     // In each warp, 4 threads will work on the same row
85     // - the ones with the same `quad`
86     auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1);
87     myValue = fn(myValue, otherV);
88     otherV = __shfl_xor_sync(0xffffffff, myValue, 2);
89     myValue = fn(myValue, otherV);
90     int lane_in_quad = (lane_id & 3);
91     return lane_in_quad == 0;
92   }
93 };
94 
95 template <typename T, typename accum_t, int kWarpSize>
96 struct AccumLambdaIteratorSm70 {
97   static_assert(
98       cutlass::platform::
99           is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
100       "only RowMajor is supported");
101 
102   using Policy = typename T::Policy;
103   using InstructionShape = typename T::InstructionShape;
104   using OpDelta = typename T::OpDelta;
105   using Shape = typename T::Shape;
106   using Element = accum_t;
107 
108   static int const kElementsPerPartial = 4;
109   using EleShapePerPatial = typename cutlass::platform::conditional<
110       cutlass::platform::is_same<Element, float>::value,
111       cutlass::MatrixShape<2, 2>,
112       cutlass::MatrixShape<1, 4>>::type;
113   static int const kElementsPerMma = 8;
114   static int const kAccumulatorPatials = 2;
115   using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>;
116 
get_lane_offsetAccumLambdaIteratorSm70117   static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(
118       int8_t lane_id,
119       int8_t warp_id,
120       typename T::TensorCoord const& tile_offset) {
121     int quad = (lane_id >> 2);
122     int lane_in_quad = (lane_id & 3);
123     int accum_m, accum_n;
124 
125     if (cutlass::platform::is_same<Element, float>::value) {
126       // (quad[2],quad[0])+lane_in_quad[0]
127       accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1);
128       // (quad[1])+lane_in_quad[1]
129       accum_n =
130           ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials +
131           (lane_in_quad & 2);
132     } else {
133       accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 +
134           lane_in_quad; // (quad[2],quad[0])
135       accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials;
136     }
137     return cutlass::MatrixCoord(
138         accum_m + tile_offset.row() * Shape::kRow,
139         accum_n + tile_offset.column() * Shape::kColumn);
140   }
141 
142   template <typename DT, typename F>
reduceSameRowAccumLambdaIteratorSm70143   CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) {
144     static_assert(
145         cutlass::platform::is_same<Element, float>::value,
146         "update to support non-float accum");
147     // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
148     // T0 & T2 share same line within a quad
149     auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1);
150     myValue = fn(myValue, otherV);
151     // quad 0 and quad 2 are on the same lines
152     otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3);
153     myValue = fn(myValue, otherV);
154     return (lane_id & ((1 << 1) | (1 << 3))) == 0;
155   }
156 
157   template <typename FA, typename FB, typename FC>
iterateRowsAccumLambdaIteratorSm70158   CUTLASS_DEVICE static void iterateRows(
159       cutlass::MatrixCoord& lane_offset,
160       FA beginRow,
161       FB op,
162       FC endRow) {
163     CUTLASS_PRAGMA_UNROLL
164     for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) {
165       CUTLASS_PRAGMA_UNROLL
166       for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
167         CUTLASS_PRAGMA_UNROLL
168         for (int m = 0; m < EleShapePerPatial::kRow; ++m) {
169           int accum_m = tile_m * Policy::InterleavedTile::kRow +
170               mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row();
171           beginRow(accum_m);
172 
173           CUTLASS_PRAGMA_UNROLL
174           for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn;
175                ++tile_n) {
176             CUTLASS_PRAGMA_UNROLL
177             for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn;
178                  ++mma_n) {
179               CUTLASS_PRAGMA_UNROLL
180               for (int p = 0; p < kAccumulatorPatials; ++p) {
181                 CUTLASS_PRAGMA_UNROLL
182                 for (int n = 0; n < EleShapePerPatial::kColumn; ++n) {
183                   int mma_accum_start =
184                       (((tile_n * Policy::TileIterations::kRow + tile_m) *
185                             Policy::MmaIterations::kColumn +
186                         mma_n) *
187                            Policy::MmaIterations::kRow +
188                        mma_m) *
189                       kElementsPerMma;
190                   int accum_n = tile_n * Policy::InterleavedTile::kColumn +
191                       mma_n * QuadShapePerPatialMma::kColumn +
192                       p * Policy::InterleavedTile::kColumn / 2 + n +
193                       lane_offset.column();
194                   int idx = mma_accum_start + p * kElementsPerPartial +
195                       m * EleShapePerPatial::kColumn + n;
196                   op(accum_m, accum_n, idx);
197                 }
198               }
199             }
200           }
201           endRow(accum_m);
202         }
203       }
204     }
205   }
206 };
207 
208 template <typename T, typename accum_t, int kWarpSize>
209 struct AccumLambdaIteratorSimt {
210   using Policy = typename T::Policy;
211   using Iterations = typename T::Iterations;
212   using Element = typename T::Element;
213   using Delta = typename T::Delta;
214   using Shape = typename T::Shape;
215   static_assert(
216       cutlass::platform::
217           is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
218       "only RowMajor is supported");
219 
220   template <typename DT, typename F>
reduceSameRowAccumLambdaIteratorSimt221   CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) {
222     CUTLASS_PRAGMA_UNROLL
223     for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) {
224       auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit);
225       myValue = fn(myValue, otherV);
226     }
227     return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0;
228   }
229 
230   template <typename FA, typename FB, typename FC>
iterateRowsAccumLambdaIteratorSimt231   CUTLASS_DEVICE static void iterateRows(
232       cutlass::MatrixCoord& lane_offset,
233       FA beginRow,
234       FB op,
235       FC endRow) {
236     CUTLASS_PRAGMA_UNROLL
237     for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
238       CUTLASS_PRAGMA_UNROLL
239       for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
240         int accum_m = mma_m * Delta::kRow + m + lane_offset.row();
241         beginRow(accum_m);
242 
243         CUTLASS_PRAGMA_UNROLL
244         for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
245           int accum_n =
246               mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN +
247               lane_offset.column();
248           CUTLASS_PRAGMA_UNROLL
249           for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
250             int idx = n +
251                 Policy::LaneMmaShape::kN *
252                     (mma_n +
253                      Iterations::kColumn *
254                          (m + mma_m * Policy::LaneMmaShape::kM));
255             op(accum_m, accum_n + n, idx);
256           }
257         }
258         endRow(accum_m);
259       }
260     }
261   }
262 
get_lane_offsetAccumLambdaIteratorSimt263   static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(
264       int8_t lane_id,
265       int8_t warp_id,
266       typename T::TensorCoord const& tile_offset) {
267     static_assert(
268         cutlass::platform::is_same<
269             typename Policy::LaneLayout,
270             cutlass::layout::RowMajorInterleaved<1>>::value,
271         "");
272     typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
273 
274     cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
275         cutlass::MatrixCoord(Policy::LaneMmaShape::kM,
276                              Policy::LaneMmaShape::kN);
277     return lane_offset +
278         tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn);
279   }
280 };
281 
282 template <typename T, typename accum_t, int kWarpSize>
283 struct DefaultMmaAccumLambdaIterator;
284 
285 // Simt
286 template <typename S, typename P, typename accum_t, int kWarpSize>
287 struct DefaultMmaAccumLambdaIterator<
288     cutlass::gemm::warp::MmaSimtTileIterator<
289         S,
290         cutlass::gemm::Operand::kC,
291         accum_t,
292         cutlass::layout::RowMajor,
293         P,
294         1,
295         1>,
296     accum_t,
297     kWarpSize> {
298   using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator<
299       S,
300       cutlass::gemm::Operand::kC,
301       accum_t,
302       cutlass::layout::RowMajor,
303       P,
304       1,
305       1>;
306   using Iterator = AccumLambdaIteratorSimt<WarpIterator, accum_t, kWarpSize>;
307 };
308 
309 // TensorOp - Volta
310 template <typename S1, typename S2, typename accum_t, int kWarpSize>
311 struct DefaultMmaAccumLambdaIterator<
312     cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
313         S1,
314         accum_t,
315         cutlass::layout::RowMajor,
316         S2,
317         cutlass::MatrixShape<1, 1>>,
318     accum_t,
319     kWarpSize> {
320   using WarpIterator =
321       typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
322           S1,
323           accum_t,
324           cutlass::layout::RowMajor,
325           S2,
326           cutlass::MatrixShape<1, 1>>;
327   using Iterator = AccumLambdaIteratorSm70<WarpIterator, accum_t, kWarpSize>;
328 };
329 
330 // TensorOp - Sm75+
331 template <
332     typename S1,
333     typename S2,
334     typename S3,
335     typename accum_t,
336     int kWarpSize>
337 struct DefaultMmaAccumLambdaIterator<
338     cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
339         S1,
340         accum_t,
341         cutlass::layout::RowMajor,
342         S2,
343         S3>,
344     accum_t,
345     kWarpSize> {
346   using WarpIterator =
347       typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
348           S1,
349           accum_t,
350           cutlass::layout::RowMajor,
351           S2,
352           S3>;
353   using Iterator = AccumLambdaIteratorSm80<WarpIterator, accum_t, kWarpSize>;
354 };
355