1 /* 2 This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is 3 quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices 4 to be consumed by CUTLASS. 5 6 Note that for int4, ThreadBlockK MUST be 64. 7 8 */ 9 10 #pragma once 11 12 #include <cutlass/layout/matrix.h> 13 #include <cutlass/numeric_types.h> 14 15 #include <cutlass/arch/arch.h> 16 #include <cutlass/arch/mma.h> 17 #include <cutlass/platform/platform.h> 18 19 #include <ATen/native/cuda/cutlass_extensions/arch/mma.h> 20 #include <ATen/native/cuda/cutlass_extensions/tile_interleaved_layout.h> 21 22 namespace cutlass { 23 namespace gemm { 24 namespace kernel { 25 26 template<typename TypeB, typename Arch, typename Enable = void> 27 struct LayoutDetailsB { 28 }; 29 30 // Volta specialiations. Volta will dequantize before STS, so we need a different operator 31 template<typename TypeB> 32 struct LayoutDetailsB<TypeB, arch::Sm70> { 33 static constexpr int ThreadblockK = 64; 34 using Layout = layout::RowMajor; 35 static constexpr int ElementsPerAccess = 8; 36 using Operator = cutlass::arch::OpMultiplyAdd; 37 }; 38 39 // Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. 40 // TODO - Switch this to column major for weights since gemms should be more performant. 41 template<typename Arch> 42 struct LayoutDetailsB<half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> { 43 static constexpr int ThreadblockK = 64; 44 using Layout = layout::RowMajor; 45 static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value; 46 using Operator = cutlass::arch::OpMultiplyAdd; 47 }; 48 49 template<typename Arch> 50 struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> { 51 static constexpr int ThreadblockK = 64; 52 using Layout = layout::RowMajor; 53 static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value; 54 using Operator = cutlass::arch::OpMultiplyAdd; 55 }; 56 57 // Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, 58 // which signals that we want to dequantize after loading from smem. 59 template<typename Arch> 60 struct LayoutDetailsB<uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> { 61 static constexpr int ThreadblockK = 64; 62 63 private: 64 static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value; 65 static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; 66 67 public: 68 using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>; 69 static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint8_t>::value; 70 using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; 71 }; 72 73 template<typename Arch> 74 struct LayoutDetailsB<uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> { 75 static constexpr int ThreadblockK = 64; 76 77 private: 78 static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value; 79 static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; 80 81 public: 82 using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>; 83 static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint4b_t>::value; 84 using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; 85 }; 86 87 } // namespace kernel 88 } // namespace gemm 89 } // namespace cutlass 90