xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2   This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
3   quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
4   to be consumed by CUTLASS.
5 
6   Note that for int4, ThreadBlockK MUST be 64.
7 
8  */
9 
10 #pragma once
11 
12 #include <cutlass/layout/matrix.h>
13 #include <cutlass/numeric_types.h>
14 
15 #include <cutlass/arch/arch.h>
16 #include <cutlass/arch/mma.h>
17 #include <cutlass/platform/platform.h>
18 
19 #include <ATen/native/cuda/cutlass_extensions/arch/mma.h>
20 #include <ATen/native/cuda/cutlass_extensions/tile_interleaved_layout.h>
21 
22 namespace cutlass {
23 namespace gemm {
24 namespace kernel {
25 
26 template<typename TypeB, typename Arch, typename Enable = void>
27 struct LayoutDetailsB {
28 };
29 
30 // Volta specialiations. Volta will dequantize before STS, so we need a different operator
31 template<typename TypeB>
32 struct LayoutDetailsB<TypeB, arch::Sm70> {
33     static constexpr int ThreadblockK      = 64;
34     using Layout                           = layout::RowMajor;
35     static constexpr int ElementsPerAccess = 8;
36     using Operator                         = cutlass::arch::OpMultiplyAdd;
37 };
38 
39 // Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
40 // TODO - Switch this to column major for weights since gemms should be more performant.
41 template<typename Arch>
42 struct LayoutDetailsB<half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
43     static constexpr int ThreadblockK      = 64;
44     using Layout                           = layout::RowMajor;
45     static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
46     using Operator                         = cutlass::arch::OpMultiplyAdd;
47 };
48 
49 template<typename Arch>
50 struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
51     static constexpr int ThreadblockK      = 64;
52     using Layout                           = layout::RowMajor;
53     static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
54     using Operator                         = cutlass::arch::OpMultiplyAdd;
55 };
56 
57 // Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
58 // which signals that we want to dequantize after loading from smem.
59 template<typename Arch>
60 struct LayoutDetailsB<uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
61     static constexpr int ThreadblockK = 64;
62 
63 private:
64     static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
65     static constexpr int ColumnsInterleaved   = ElementsPerCacheLine / ThreadblockK;
66 
67 public:
68     using Layout                           = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
69     static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint8_t>::value;
70     using Operator                         = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
71 };
72 
73 template<typename Arch>
74 struct LayoutDetailsB<uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
75     static constexpr int ThreadblockK = 64;
76 
77 private:
78     static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
79     static constexpr int ColumnsInterleaved   = ElementsPerCacheLine / ThreadblockK;
80 
81 public:
82     using Layout                           = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
83     static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint4b_t>::value;
84     using Operator                         = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
85 };
86 
87 }  // namespace kernel
88 }  // namespace gemm
89 }  // namespace cutlass
90