1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 #pragma once 9 10 #include <cutlass/functional.h> 11 #include <cutlass/gemm/warp/mma_simt_tile_iterator.h> 12 #include <cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h> 13 #include <cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h> 14 #include <cutlass/matrix_shape.h> 15 16 /* 17 TensorCores have different accumulator layouts. 18 This file provides a class to easily map the accumulator 19 i-th element with the corresponding matrix row/col. 20 */ 21 22 template <typename T, typename accum_t, int kWarpSize> 23 struct AccumLambdaIteratorSm80 { 24 static_assert( 25 cutlass::platform:: 26 is_same<typename T::Layout, cutlass::layout::RowMajor>::value, 27 "only RowMajor is supported"); 28 29 using Policy = typename T::Policy; 30 using InstructionShape = typename T::InstructionShape; 31 using OpDelta = typename T::OpDelta; 32 using Shape = typename T::Shape; 33 static int const kElementsPerAccess = InstructionShape::kN / 4; 34 static int const kRowsPerTile = 8; 35 static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; 36 get_lane_offsetAccumLambdaIteratorSm8037 static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( 38 int8_t lane_id, 39 int8_t warp_id, 40 typename T::TensorCoord const& tile_offset) { 41 int quad = (lane_id >> 2); 42 int lane_in_quad = (lane_id & 3); 43 return cutlass::MatrixCoord( 44 quad + tile_offset.row() * Shape::kRow, 45 lane_in_quad * kElementsPerAccess + 46 tile_offset.column() * Shape::kColumn); 47 } 48 49 template <typename FA, typename FB, typename FC> iterateRowsAccumLambdaIteratorSm8050 CUTLASS_DEVICE static void iterateRows( 51 cutlass::MatrixCoord& lane_offset, 52 FA beginRow, 53 FB op, 54 FC endRow) { 55 // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h 56 CUTLASS_PRAGMA_UNROLL 57 for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { 58 CUTLASS_PRAGMA_UNROLL 59 for (int row = 0; row < kAccumulatorRows; ++row) { 60 int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + 61 row * kRowsPerTile + lane_offset.row(); 62 beginRow(accum_m); 63 64 CUTLASS_PRAGMA_UNROLL 65 for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { 66 int mma_accum_start = kAccumulatorRows * kElementsPerAccess * 67 (mma_n * Policy::MmaIterations::kRow + mma_m); 68 CUTLASS_PRAGMA_UNROLL 69 for (int col = 0; col < kElementsPerAccess; ++col) { 70 int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + 71 col + lane_offset.column(); 72 int idx = mma_accum_start + row * kElementsPerAccess + col; 73 op(accum_m, accum_n, idx); 74 } 75 } 76 77 endRow(accum_m); 78 } 79 } 80 } 81 82 template <typename DT, typename F> reduceSameRowAccumLambdaIteratorSm8083 CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { 84 // In each warp, 4 threads will work on the same row 85 // - the ones with the same `quad` 86 auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); 87 myValue = fn(myValue, otherV); 88 otherV = __shfl_xor_sync(0xffffffff, myValue, 2); 89 myValue = fn(myValue, otherV); 90 int lane_in_quad = (lane_id & 3); 91 return lane_in_quad == 0; 92 } 93 }; 94 95 template <typename T, typename accum_t, int kWarpSize> 96 struct AccumLambdaIteratorSm70 { 97 static_assert( 98 cutlass::platform:: 99 is_same<typename T::Layout, cutlass::layout::RowMajor>::value, 100 "only RowMajor is supported"); 101 102 using Policy = typename T::Policy; 103 using InstructionShape = typename T::InstructionShape; 104 using OpDelta = typename T::OpDelta; 105 using Shape = typename T::Shape; 106 using Element = accum_t; 107 108 static int const kElementsPerPartial = 4; 109 using EleShapePerPatial = typename cutlass::platform::conditional< 110 cutlass::platform::is_same<Element, float>::value, 111 cutlass::MatrixShape<2, 2>, 112 cutlass::MatrixShape<1, 4>>::type; 113 static int const kElementsPerMma = 8; 114 static int const kAccumulatorPatials = 2; 115 using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; 116 get_lane_offsetAccumLambdaIteratorSm70117 static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( 118 int8_t lane_id, 119 int8_t warp_id, 120 typename T::TensorCoord const& tile_offset) { 121 int quad = (lane_id >> 2); 122 int lane_in_quad = (lane_id & 3); 123 int accum_m, accum_n; 124 125 if (cutlass::platform::is_same<Element, float>::value) { 126 // (quad[2],quad[0])+lane_in_quad[0] 127 accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); 128 // (quad[1])+lane_in_quad[1] 129 accum_n = 130 ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + 131 (lane_in_quad & 2); 132 } else { 133 accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + 134 lane_in_quad; // (quad[2],quad[0]) 135 accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; 136 } 137 return cutlass::MatrixCoord( 138 accum_m + tile_offset.row() * Shape::kRow, 139 accum_n + tile_offset.column() * Shape::kColumn); 140 } 141 142 template <typename DT, typename F> reduceSameRowAccumLambdaIteratorSm70143 CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { 144 static_assert( 145 cutlass::platform::is_same<Element, float>::value, 146 "update to support non-float accum"); 147 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 148 // T0 & T2 share same line within a quad 149 auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); 150 myValue = fn(myValue, otherV); 151 // quad 0 and quad 2 are on the same lines 152 otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); 153 myValue = fn(myValue, otherV); 154 return (lane_id & ((1 << 1) | (1 << 3))) == 0; 155 } 156 157 template <typename FA, typename FB, typename FC> iterateRowsAccumLambdaIteratorSm70158 CUTLASS_DEVICE static void iterateRows( 159 cutlass::MatrixCoord& lane_offset, 160 FA beginRow, 161 FB op, 162 FC endRow) { 163 CUTLASS_PRAGMA_UNROLL 164 for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { 165 CUTLASS_PRAGMA_UNROLL 166 for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { 167 CUTLASS_PRAGMA_UNROLL 168 for (int m = 0; m < EleShapePerPatial::kRow; ++m) { 169 int accum_m = tile_m * Policy::InterleavedTile::kRow + 170 mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row(); 171 beginRow(accum_m); 172 173 CUTLASS_PRAGMA_UNROLL 174 for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; 175 ++tile_n) { 176 CUTLASS_PRAGMA_UNROLL 177 for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; 178 ++mma_n) { 179 CUTLASS_PRAGMA_UNROLL 180 for (int p = 0; p < kAccumulatorPatials; ++p) { 181 CUTLASS_PRAGMA_UNROLL 182 for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { 183 int mma_accum_start = 184 (((tile_n * Policy::TileIterations::kRow + tile_m) * 185 Policy::MmaIterations::kColumn + 186 mma_n) * 187 Policy::MmaIterations::kRow + 188 mma_m) * 189 kElementsPerMma; 190 int accum_n = tile_n * Policy::InterleavedTile::kColumn + 191 mma_n * QuadShapePerPatialMma::kColumn + 192 p * Policy::InterleavedTile::kColumn / 2 + n + 193 lane_offset.column(); 194 int idx = mma_accum_start + p * kElementsPerPartial + 195 m * EleShapePerPatial::kColumn + n; 196 op(accum_m, accum_n, idx); 197 } 198 } 199 } 200 } 201 endRow(accum_m); 202 } 203 } 204 } 205 } 206 }; 207 208 template <typename T, typename accum_t, int kWarpSize> 209 struct AccumLambdaIteratorSimt { 210 using Policy = typename T::Policy; 211 using Iterations = typename T::Iterations; 212 using Element = typename T::Element; 213 using Delta = typename T::Delta; 214 using Shape = typename T::Shape; 215 static_assert( 216 cutlass::platform:: 217 is_same<typename T::Layout, cutlass::layout::RowMajor>::value, 218 "only RowMajor is supported"); 219 220 template <typename DT, typename F> reduceSameRowAccumLambdaIteratorSimt221 CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { 222 CUTLASS_PRAGMA_UNROLL 223 for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { 224 auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); 225 myValue = fn(myValue, otherV); 226 } 227 return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; 228 } 229 230 template <typename FA, typename FB, typename FC> iterateRowsAccumLambdaIteratorSimt231 CUTLASS_DEVICE static void iterateRows( 232 cutlass::MatrixCoord& lane_offset, 233 FA beginRow, 234 FB op, 235 FC endRow) { 236 CUTLASS_PRAGMA_UNROLL 237 for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { 238 CUTLASS_PRAGMA_UNROLL 239 for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { 240 int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); 241 beginRow(accum_m); 242 243 CUTLASS_PRAGMA_UNROLL 244 for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { 245 int accum_n = 246 mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + 247 lane_offset.column(); 248 CUTLASS_PRAGMA_UNROLL 249 for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { 250 int idx = n + 251 Policy::LaneMmaShape::kN * 252 (mma_n + 253 Iterations::kColumn * 254 (m + mma_m * Policy::LaneMmaShape::kM)); 255 op(accum_m, accum_n + n, idx); 256 } 257 } 258 endRow(accum_m); 259 } 260 } 261 } 262 get_lane_offsetAccumLambdaIteratorSimt263 static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( 264 int8_t lane_id, 265 int8_t warp_id, 266 typename T::TensorCoord const& tile_offset) { 267 static_assert( 268 cutlass::platform::is_same< 269 typename Policy::LaneLayout, 270 cutlass::layout::RowMajorInterleaved<1>>::value, 271 ""); 272 typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); 273 274 cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) * 275 cutlass::MatrixCoord(Policy::LaneMmaShape::kM, 276 Policy::LaneMmaShape::kN); 277 return lane_offset + 278 tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); 279 } 280 }; 281 282 template <typename T, typename accum_t, int kWarpSize> 283 struct DefaultMmaAccumLambdaIterator; 284 285 // Simt 286 template <typename S, typename P, typename accum_t, int kWarpSize> 287 struct DefaultMmaAccumLambdaIterator< 288 cutlass::gemm::warp::MmaSimtTileIterator< 289 S, 290 cutlass::gemm::Operand::kC, 291 accum_t, 292 cutlass::layout::RowMajor, 293 P, 294 1, 295 1>, 296 accum_t, 297 kWarpSize> { 298 using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator< 299 S, 300 cutlass::gemm::Operand::kC, 301 accum_t, 302 cutlass::layout::RowMajor, 303 P, 304 1, 305 1>; 306 using Iterator = AccumLambdaIteratorSimt<WarpIterator, accum_t, kWarpSize>; 307 }; 308 309 // TensorOp - Volta 310 template <typename S1, typename S2, typename accum_t, int kWarpSize> 311 struct DefaultMmaAccumLambdaIterator< 312 cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< 313 S1, 314 accum_t, 315 cutlass::layout::RowMajor, 316 S2, 317 cutlass::MatrixShape<1, 1>>, 318 accum_t, 319 kWarpSize> { 320 using WarpIterator = 321 typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< 322 S1, 323 accum_t, 324 cutlass::layout::RowMajor, 325 S2, 326 cutlass::MatrixShape<1, 1>>; 327 using Iterator = AccumLambdaIteratorSm70<WarpIterator, accum_t, kWarpSize>; 328 }; 329 330 // TensorOp - Sm75+ 331 template < 332 typename S1, 333 typename S2, 334 typename S3, 335 typename accum_t, 336 int kWarpSize> 337 struct DefaultMmaAccumLambdaIterator< 338 cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< 339 S1, 340 accum_t, 341 cutlass::layout::RowMajor, 342 S2, 343 S3>, 344 accum_t, 345 kWarpSize> { 346 using WarpIterator = 347 typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< 348 S1, 349 accum_t, 350 cutlass::layout::RowMajor, 351 S2, 352 S3>; 353 using Iterator = AccumLambdaIteratorSm80<WarpIterator, accum_t, kWarpSize>; 354 }; 355