1 #pragma once 2 3 #include <cutlass/gemm/threadblock/default_mma.h> 4 #include <ATen/native/cuda/cutlass_extensions/arch/mma.h> 5 6 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h> 7 #include <ATen/native/cuda/cutlass_extensions/gemm/warp/default_mma_tensor_op.h> 8 #include <ATen/native/cuda/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h> 9 #include <ATen/native/cuda/cutlass_extensions/tile_interleaved_layout.h> 10 11 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_dq_mma.h> 12 13 namespace cutlass { 14 namespace gemm { 15 namespace threadblock { 16 17 //////////////////////////////////////////////////////////////////////////////// 18 19 template< 20 /// Type for element A 21 typename ElementA, 22 /// Layout type for A matrix operand 23 typename LayoutA, 24 /// Access granularity of A matrix in units of elements 25 int kAlignmentA, 26 /// Type for element B 27 typename ElementB, 28 /// Layout type for B matrix operand 29 typename LayoutB, 30 /// Access granularity of B matrix in units of elements 31 int kAlignmentB, 32 /// Element type for the input scale 33 typename ElementScale, 34 /// Layout for the scale operand 35 typename LayoutScale, 36 /// Access granularity of Scales in unit of elements 37 int kAlignmentScale, 38 /// Element type for internal accumulation 39 typename ElementAccumulator, 40 /// Operator class tag 41 typename OperatorClass, 42 /// Tag indicating architecture to tune for 43 typename ArchTag, 44 /// Threadblock-level tile size (concept: GemmShape) 45 typename ThreadblockShape, 46 /// Warp-level tile size (concept: GemmShape) 47 typename WarpShape, 48 /// Instruction-level tile size (concept: GemmShape) 49 typename InstructionShape, 50 /// Operation performed by GEMM 51 typename Operator> 52 struct DqMma<ElementA, 53 LayoutA, 54 kAlignmentA, 55 ElementB, 56 LayoutB, 57 kAlignmentB, 58 ElementScale, 59 LayoutScale, 60 kAlignmentScale, 61 ElementAccumulator, 62 layout::RowMajor, 63 OperatorClass, 64 ArchTag, 65 ThreadblockShape, 66 WarpShape, 67 InstructionShape, 68 2, 69 Operator, 70 SharedMemoryClearOption::kNone, 71 typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> { 72 73 static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value, 74 "Element A must be fp16 or bf16"); 75 76 static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value, 77 "Element B must be uint8 or uint4"); 78 79 static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value; 80 static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; 81 using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type; 82 using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type; 83 84 // Define the MmaCore components 85 using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, 86 WarpShape, 87 InstructionShape, 88 MmaCoreElementA, 89 LayoutA, 90 MmaCoreElementB, 91 LayoutB, 92 ElementAccumulator, 93 layout::RowMajor, 94 OperatorClass, 95 2, 96 Operator>; 97 98 // Define iterators over tiles from the A operand 99 using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< 100 cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, 101 ElementA, 102 LayoutA, 103 1, 104 typename MmaCore::IteratorThreadMapA, 105 kAlignmentA>; 106 107 // Define iterators over tiles from the B operand 108 using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< 109 cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, 110 ElementB, 111 LayoutB, 112 0, 113 typename MmaCore::IteratorThreadMapB, 114 kAlignmentB>; 115 116 // ThreadMap for scale iterator 117 static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); 118 using IteratorScaleThreadMap = 119 transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>, 120 MmaCore::Shape::kN / kAlignmentScale, 121 kAlignmentScale>; 122 123 // Define iterators over tiles from the scale operand 124 using IteratorScale = 125 cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>, 126 ElementScale, 127 LayoutScale, 128 0, 129 IteratorScaleThreadMap, 130 kAlignmentScale>; 131 132 using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type; 133 using SmemIteratorScale = 134 cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>, 135 SmemScaleType, 136 LayoutScale, 137 0, 138 IteratorScaleThreadMap, 139 kAlignmentScale>; 140 141 using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>; 142 143 // Define the threadblock-scoped pipelined matrix multiply 144 using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, 145 IteratorA, 146 typename MmaCore::SmemIteratorA, 147 IteratorB, 148 typename MmaCore::SmemIteratorB, 149 IteratorScale, 150 SmemIteratorScale, 151 ElementAccumulator, 152 layout::RowMajor, 153 typename MmaCore::MmaPolicy, 154 typename Converters::TransformAfterLDG, 155 typename Converters::TransformAfterLDS>; 156 }; 157 158 // Specialization to handle column major interleave B 159 template< 160 /// Type for element A 161 typename ElementA, 162 /// Layout type for A matrix operand 163 typename LayoutA, 164 /// Access granularity of A matrix in units of elements 165 int kAlignmentA, 166 /// Type for element B 167 typename ElementB, 168 /// Access granularity of B matrix in units of elements 169 int kAlignmentB, 170 /// Element type for the input scale 171 typename ElementScale, 172 /// Layout for the scale operand 173 typename LayoutScale, 174 /// Access granularity of Scales in unit of elements 175 int kAlignmentScale, 176 /// Element type for internal accumulation 177 typename ElementAccumulator, 178 /// Operator class tag 179 typename OperatorClass, 180 /// Tag indicating architecture to tune for 181 typename ArchTag, 182 /// Threadblock-level tile size (concept: GemmShape) 183 typename ThreadblockShape, 184 /// Warp-level tile size (concept: GemmShape) 185 typename WarpShape, 186 /// Instruction-level tile size (concept: GemmShape) 187 typename InstructionShape, 188 /// Operation performed by GEMM 189 typename Operator, 190 /// 191 int RowsPerTile, 192 /// 193 int ColumnsInterleaved> 194 struct DqMma<ElementA, 195 LayoutA, 196 kAlignmentA, 197 ElementB, 198 layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>, 199 kAlignmentB, 200 ElementScale, 201 LayoutScale, 202 kAlignmentScale, 203 ElementAccumulator, 204 layout::RowMajor, 205 OperatorClass, 206 ArchTag, 207 ThreadblockShape, 208 WarpShape, 209 InstructionShape, 210 2, 211 Operator, 212 SharedMemoryClearOption::kNone, 213 typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> { 214 215 static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value, 216 "Element A must be fp16 or bf16"); 217 218 static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value, 219 "Element B must be uint8 or uint4"); 220 221 static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value; 222 static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; 223 using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type; 224 using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type; 225 226 // Define the MmaCore components 227 using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, 228 WarpShape, 229 InstructionShape, 230 MmaCoreElementA, 231 LayoutA, 232 MmaCoreElementB, 233 layout::ColumnMajor, 234 ElementAccumulator, 235 layout::RowMajor, 236 OperatorClass, 237 2, 238 Operator>; 239 240 // Define iterators over tiles from the A operand 241 using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< 242 cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, 243 ElementA, 244 LayoutA, 245 1, 246 typename MmaCore::IteratorThreadMapA, 247 kAlignmentA>; 248 249 private: 250 static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); 251 static_assert(RowsPerTile == MmaCore::Shape::kK, ""); 252 253 using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; 254 using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; 255 static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); 256 257 using GmemIteratorShape = 258 MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>; 259 using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< 260 layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, 261 OriginalThreadMap::kThreads, 262 layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved, 263 OriginalWarpArrangement::kStrided / ColumnsInterleaved>, 264 MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>; 265 266 public: 267 // Define iterators over tiles from the B operand 268 using IteratorB = cutlass::transform::threadblock:: 269 PredicatedTileIterator<GmemIteratorShape, ElementB, layout::ColumnMajor, 0, GmemThreadMapB, kAlignmentB>; 270 271 // ThreadMap for scale iterator 272 static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); 273 using IteratorScaleThreadMap = 274 transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>, 275 MmaCore::Shape::kN / kAlignmentScale, 276 kAlignmentScale>; 277 278 // Define iterators over tiles from the scale operand 279 using IteratorScale = 280 cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>, 281 ElementScale, 282 LayoutScale, 283 0, 284 IteratorScaleThreadMap, 285 kAlignmentScale>; 286 287 using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type; 288 using SmemIteratorScale = 289 cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>, 290 SmemScaleType, 291 LayoutScale, 292 0, 293 IteratorScaleThreadMap, 294 kAlignmentScale>; 295 296 using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>; 297 298 // Define the threadblock-scoped pipelined matrix multiply 299 using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, 300 IteratorA, 301 typename MmaCore::SmemIteratorA, 302 IteratorB, 303 typename MmaCore::SmemIteratorB, 304 IteratorScale, 305 SmemIteratorScale, 306 ElementAccumulator, 307 layout::RowMajor, 308 typename MmaCore::MmaPolicy, 309 typename Converters::TransformAfterLDG, 310 typename Converters::TransformAfterLDS>; 311 }; 312 313 } // namespace threadblock 314 } // namespace gemm 315 } // namespace cutlass 316