xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/int4mm.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
2 #include <cuda_bf16.h>
3 #include <cuda_fp16.h>
4 #include <cuda_runtime.h>
5 #if !defined(USE_ROCM)
6 #include <mma.h>
7 #endif
8 #endif
9 #include <ATen/ATen.h>
10 #include <ATen/core/Tensor.h>
11 #include <ATen/cuda/CUDAContext.h>
12 #include <ATen/DeviceGuard.h>
13 #include <c10/cuda/CUDAGuard.h>
14 
15 
16 namespace at::native {
17 
18 template <typename U, typename V>
divDown(U a,V b)19 constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) {
20   static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
21   return (a / b);
22 }
23 
24 template <typename U, typename V>
divUp(U a,V b)25 constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
26   static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
27   // Overflow safe variant of (a + b - 1) / b
28   const uint64_t blocks = a / b + (a % b != 0);
29   return blocks;
30 }
31 
32 template <typename U, typename V>
roundDown(U a,V b)33 constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) {
34   static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
35   return divDown(a, b) * b;
36 }
37 
38 template <typename U, typename V>
roundUp(U a,V b)39 constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) {
40   static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
41   return divUp(a, b) * b;
42 }
43 
44 template <typename U, typename V>
isEvenDivisor(U a,V b)45 constexpr __host__ __device__ bool isEvenDivisor(U a, V b) {
46   static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
47   return (a % V(b) == 0) && ((a / V(b)) >= 1);
48 }
49 
50 template <class T>
pow(T n,int power)51 constexpr __host__ __device__ T pow(T n, int power) {
52   return (power > 0 ? n * pow(n, power - 1) : 1);
53 }
54 
55 template <class T>
pow2(int power)56 constexpr __host__ __device__ T pow2(int power) {
57   return pow(2, power);
58 }
59 
60 static_assert(pow2<int>(8) == 256, "pow2");
61 
62 template <typename T>
log2(T n,int p=0)63 constexpr __host__ __device__ int log2(T n, int p = 0) {
64   return (n <= 1) ? p : log2(n / 2, p + 1);
65 }
66 
67 static_assert(log2(2) == 1, "log2");
68 static_assert(log2(3) == 1, "log2");
69 static_assert(log2(4) == 2, "log2");
70 
71 template <typename T>
isPowerOf2(T v)72 constexpr __host__ __device__ bool isPowerOf2(T v) {
73   static_assert(std::is_integral<T>::value, "");
74   return (v && !(v & (v - 1)));
75 }
76 
77 static_assert(isPowerOf2(2048), "isPowerOf2");
78 static_assert(!isPowerOf2(3333), "isPowerOf2");
79 
80 template <typename T>
nextHighestPowerOf2(T v)81 constexpr __host__ __device__ T nextHighestPowerOf2(T v) {
82   static_assert(std::is_integral<T>::value, "");
83   return (isPowerOf2(v) ? (T)2 * v : ((T)1 << (log2(v) + 1)));
84 }
85 
86 static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2");
87 static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2");
88 static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2");
89 static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2");
90 
91 static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2");
92 static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2");
93 static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2");
94 
95 static_assert(
96     nextHighestPowerOf2(1536000000u) == 2147483648u,
97     "nextHighestPowerOf2");
98 static_assert(
99     nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL,
100     "nextHighestPowerOf2");
101 
102 template <typename T>
nextLowestPowerOf2(T v)103 constexpr __host__ __device__ T nextLowestPowerOf2(T v) {
104   static_assert(std::is_integral<T>::value, "");
105   return (isPowerOf2(v) ? v / (T)2 : ((T)1 << (log2(v))));
106 }
107 
108 static_assert(nextLowestPowerOf2(1) == 0, "nextLowestPowerOf2");
109 static_assert(nextLowestPowerOf2(2) == 1, "nextLowestPowerOf2");
110 static_assert(nextLowestPowerOf2(3) == 2, "nextLowestPowerOf2");
111 static_assert(nextLowestPowerOf2(4) == 2, "nextLowestPowerOf2");
112 
113 static_assert(nextLowestPowerOf2(15) == 8, "nextLowestPowerOf2");
114 static_assert(nextLowestPowerOf2(16) == 8, "nextLowestPowerOf2");
115 static_assert(nextLowestPowerOf2(17) == 16, "nextLowestPowerOf2");
116 
isPointerAligned(const void * p,int align)117 inline __host__ __device__ bool isPointerAligned(const void* p, int align) {
118   return reinterpret_cast<uintptr_t>(p) % align == 0;
119 }
120 
121 // Returns the increment needed to aligned the pointer to the next highest
122 // aligned address
123 template <int Align>
getAlignmentRoundUp(const void * p)124 inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
125   static_assert(isPowerOf2(Align), "");
126   const uint32_t diff = uint32_t(uintptr_t(p) & uintptr_t(Align - 1));
127   return diff == 0 ? 0 : uint32_t(Align) - diff;
128 }
129 
130 #if defined(USE_ROCM)
131 // TODO: Support RDNA
132 constexpr int32_t kWarpSize = 64;
133 
134 template<typename T, uint32_t Rank>
135 using VecT = T __attribute__((ext_vector_type(Rank)));
136 
isCDNA2orLater(int index)137 static bool isCDNA2orLater(int index) {
138     hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
139     std::string device_arch = prop->gcnArchName;
140     static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
141     for (std::string arch : archs) {
142         size_t substring = device_arch.find(arch);
143         if (substring != std::string::npos) {
144             return true;
145         }
146     }
147     return false;
148 }
149 
150 #else
151 constexpr int32_t kWarpSize = 32;
152 #endif
153 
154 #if defined (__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
155 #define CDNA2_OR_LATER 1
156 #else
157 #define CDNA2_OR_LATER 0
158 #endif
159 
160 #if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
161 
162 // f16 vector types
163 struct __align__(2) f16x1 {
164   __half vals[1];
165 };
166 
167 struct __align__(4) f16x2 {
168   __half vals[2];
169 };
170 
171 struct __align__(8) f16x4 {
172   __half vals[4];
173 };
174 
175 struct __align__(16) f16x8 {
176   __half vals[8];
177 };
178 
179 // bf16 vector types
180 struct __align__(2) bf16x1 {
181   __nv_bfloat16 vals[1];
182 };
183 
184 struct __align__(4) bf16x2 {
185   __nv_bfloat16 vals[2];
186 };
187 
188 struct __align__(8) bf16x4 {
189   __nv_bfloat16 vals[4];
190 };
191 
192 struct __align__(16) bf16x8 {
193   __nv_bfloat16 vals[8];
194 };
195 
196 // bf162 vector types
197 struct __align__(4) bf16x2x1 {
198   __nv_bfloat162 vals[1];
199 };
200 
201 struct __align__(8) bf16x2x2 {
202   __nv_bfloat162 vals[2];
203 };
204 
205 struct __align__(16) bf16x2x4 {
206   __nv_bfloat162 vals[4];
207 };
208 
209 struct __align__(16) bf16x2x4_u32 {
210 #if defined(USE_ROCM)
211   VecT<short, 4> val[2];
212 #else
213   uint32_t vals[4];
214 #endif
215 };
216 
217 struct __align__(8) bf16x2x2_u32 {
218 #if defined(USE_ROCM)
219   VecT<short, 4> val;
220 #else
221   uint32_t vals[2];
222 #endif
223 };
224 
225 struct __align__(4) bf16x2x1_u32 {
226   uint32_t vals[1];
227 };
228 
229 template <typename T, int N>
__align__(sizeof (T)* N)230 struct __align__(sizeof(T) * N) VectorType {
231   T vals[N];
232 };
233 
234 // from
235 // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
convert_i4x8_to_bf16x2x4(uint32_t source)236 inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
237   bf16x2x4 result;
238   constexpr int kElements = 8;
239 
240   uint32_t* h = reinterpret_cast<uint32_t*>(&result);
241   uint32_t const source_i4s = source;
242 
243   // First, we extract the i4s and construct an intermediate fp16 number.
244 #if !defined(USE_ROCM)
245   static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
246 #endif
247   static constexpr uint32_t MASK = 0x000f000f;
248   static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
249 
250   // We don't have enough mantissa to remove as much shift overhead as FP16, so
251   // we must loop. No shift needed for first item.
252   uint32_t i4s = source_i4s;
253 
254 #if defined(USE_ROCM)
255   asm volatile("v_and_or_b32 %0, %1, %2, %3"
256                : "=v"(h[0])
257                : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
258 #else
259   asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
260                : "=r"(h[0])
261                : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
262 #endif
263 
264 #pragma unroll
265   for (int ii = 1; ii < kElements / 2; ++ii) {
266     i4s >>= 4; // or is it 8?
267     // (i4s & 0x000f000f) | 0x43004300
268 #if defined(USE_ROCM)
269     asm volatile("v_and_or_b32 %0, %1, %2, %3"
270         : "=v"(h[ii])
271         : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
272 #else
273     asm volatile(
274         "lop3.b32 %0, %1, %2, %3, %4;\n"
275         : "=r"(h[ii])
276         : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
277 #endif
278   }
279 
280   // This is the BF16 {-136, -136} represented as an integer.
281 #if defined(USE_ROCM)
282 #if ROCM_VERSION >= 60200
283   auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308}));
284   auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80}));
285 #else
286   auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308});
287   auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80});
288 #endif
289 #else
290   static constexpr uint32_t BF16_BIAS = 0xC308C308;
291   static constexpr uint32_t BF16_ONE = 0x3F803F80;
292 #endif
293 
294 // Finally, we construct the output numbers.
295 #pragma unroll
296   for (int ii = 0; ii < kElements / 2; ++ii) {
297     // Since this section is for Ampere+, we use bf16 fma to do the bias
298     // subtraction
299 #if defined(USE_ROCM)
300      result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS);
301 #else
302     asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
303         : "=r"(h[ii])
304         : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
305 #endif
306   }
307 
308   return result;
309 }
310 
311 
312 
313 enum class KReductionType {
314   // No k-reduction is needed between blocks as the number of k-tiles processed
315   // per block are exact and we can directly write the output
316   None,
317 };
318 
319 // Loads the A matrix in 16-bit standard m x k row major layout, and writes
320 // the C matrix in 16-bit standard m x n row major layout:
321 //
322 // size [m][k]
323 template <KReductionType ReduceType>
324 struct ALayout_RM {
325   static constexpr int32_t kMTileSize = 16;
326 #if defined(USE_ROCM)
327   static constexpr int32_t kNTileSize = 16;
328 #else
329   static constexpr int32_t kNTileSize = 8;
330 #endif
331   static constexpr int32_t kKTileSize = 16;
332 
333   template <int KTilesToLoad>
loadat::native::ALayout_RM334   static __device__ void load(
335       const void* A,
336       int32_t m,
337       int32_t k,
338       int32_t mTiles,
339       int32_t mTile,
340       int32_t kTiles,
341       int32_t kTileStart,
342       int32_t laneId,
343 #if defined(USE_ROCM)
344       bf16x2x2_u32 out[KTilesToLoad]
345 #else
346       bf16x2x4_u32 out[KTilesToLoad]
347 #endif
348   ) {
349 #if defined(USE_ROCM)
350     const auto mLane = mTile * kMTileSize + (laneId % kMTileSize);
351     const auto kLane = kTileStart * kKTileSize + (laneId / kMTileSize) * 4;
352 #else
353     const auto mLane = mTile * kMTileSize + (laneId / 4);
354     const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2;
355 #endif
356 
357     // access
358     // [mTile * kMTileSize + (laneId / 4)]
359     // [kTileStart * kKTileSize + (laneId % 4) * 2]
360     auto aPtr = reinterpret_cast<const __nv_bfloat16*>(A) + mLane * k + kLane;
361     bool m0InBounds = mLane < m;
362 
363 #if !defined(USE_ROCM)
364     auto aPtrPlus8Rows = aPtr + 8 * k;
365 
366     bool m1InBounds = (mLane + 8) < m;
367 #endif
368 
369 #pragma unroll
370     for (int i = 0; i < KTilesToLoad; ++i) {
371 #if defined(USE_ROCM)
372       out[i].val = m0InBounds ? *((VecT<short, 4> *)(aPtr + i * kKTileSize)) : VecT<short, 4>{0, 0, 0, 0};
373 #else
374       out[i].vals[0] = m0InBounds
375           ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize)
376           : uint32_t(0);
377       out[i].vals[1] = m1InBounds
378           ? *reinterpret_cast<const uint32_t*>(aPtrPlus8Rows + i * kKTileSize)
379           : uint32_t(0);
380 
381       out[i].vals[2] = m0InBounds
382           ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize + 8)
383           : uint32_t(0);
384       out[i].vals[3] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
385                                         aPtrPlus8Rows + i * kKTileSize + 8)
386                                   : uint32_t(0);
387 #endif
388     }
389   }
390 
storeat::native::ALayout_RM391   static __device__ void store(
392       void* C,
393       int32_t m,
394       int32_t n,
395       int32_t mOutTiles,
396       int32_t mTile,
397       int32_t nOutTiles,
398       int32_t nTile,
399       int32_t laneId,
400       const float4& out) {
401     static_assert(ReduceType == KReductionType::None, "");
402 
403     if constexpr (ReduceType == KReductionType::None) {
404 #if defined(USE_ROCM)
405       const int outRow = mTile * kMTileSize + (laneId / kNTileSize) * 4;
406       const int outCol = nTile * kNTileSize + (laneId % kNTileSize);
407 #else
408       // sum.x / sum.y are written at
409       // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
410       // sum.z / sum.w are written at
411       // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
412       // i.e., same columns, different row.
413       const int outRow = mTile * kMTileSize + (laneId / 4);
414       const int outCol = nTile * kNTileSize + (laneId % 4) * 2;
415 #endif
416 
417       // Pointer where sum.x / sum.y is written
418       auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol;
419 
420 #if defined(USE_ROCM)
421       if (outRow < m)
422         cPtr[0] = __float2bfloat16(out.x);
423       if ((outRow + 1) < m)
424         cPtr[n] = __float2bfloat16(out.y);
425       if ((outRow + 2) < m)
426         cPtr[2*n] = __float2bfloat16(out.z);
427       if ((outRow + 3) < m)
428         cPtr[3*n] = __float2bfloat16(out.w);
429 #else
430       auto v01 = __float22bfloat162_rn(float2{out.x, out.y});
431       auto v23 = __float22bfloat162_rn(float2{out.z, out.w});
432 
433       if (outRow < m) {
434         *reinterpret_cast<__nv_bfloat162*>(cPtr) = v01;
435       }
436 
437       // sum.z, sum.w at +8 rows from cPtr
438       if (outRow + 8 < m) {
439         *reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23;
440       }
441 #endif
442     }
443   }
444 };
445 
446 template <int InnerKTiles, int QGroupSize>
447 struct BLayout_TC_int4 {
448   static constexpr int32_t kInnerKTiles = InnerKTiles;
449   static constexpr int32_t kMTileSize = 16;
450 #if defined(USE_ROCM)
451   static constexpr int32_t kNTileSize = 16;
452 #else
453   static constexpr int32_t kNTileSize = 8;
454 #endif
455   static constexpr int32_t kKTileSize = 16;
456 
457   template <int KTilesToLoad>
loadat::native::BLayout_TC_int4458   static __device__ void load(
459       // type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2]
460       // n-tiles: n / 8 for NV, n /16 for AMD
461       // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16 for NV, m16n16k16 for AMD)
462       // value per warp lane: 32 for NV, 64 for AMD
463       // (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile.
464       // 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest
465       // value) 4 k-tiles packed is a uint32x2 (64 bits) 8 k-tiles packed is a
466       // uint32x4 (128 bits)
467       const void* __restrict__ B,
468       // size [k / qGroupSize][n][2]
469       // Contains the scale and zero point of each of the quantized int4 values
470       // within B
471       // v_reconstructed = (bf16(B_int4_val) * scale) - zero
472       const void* __restrict__ quantizationInfo,
473       int32_t n,
474       int32_t k,
475       int32_t nTiles,
476       int32_t nTile,
477       int32_t kTiles,
478       int32_t kTileStart,
479       int32_t laneId,
480       bf16x2x4_u32 out[KTilesToLoad / InnerKTiles][InnerKTiles / 2]) {
481     // offset [nTile][kTileStart / InnerKTiles][laneId][0]
482     auto bPtr = reinterpret_cast<const int32_t*>(B) +
483         (((nTile * (kTiles / InnerKTiles) + (kTileStart / InnerKTiles)) *
484           kWarpSize) +
485          laneId) *
486             (InnerKTiles / 2);
487 
488     int32_t b_int4[KTilesToLoad / InnerKTiles][InnerKTiles / 2];
489 
490 #pragma unroll
491     for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) {
492       auto bPtrCur = bPtr + i * kWarpSize * (InnerKTiles / 2);
493 
494       if constexpr (InnerKTiles == 2) {
495         b_int4[i][0] = bPtrCur[0];
496       }
497 
498       if constexpr (InnerKTiles == 4) {
499         // asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n"
500         //              : "=r"(b_int4[i][0]), "=r"(b_int4[i][1])
501         //              : "l"(bPtrCur));
502 
503         int2 load8 = reinterpret_cast<const int2*>(bPtrCur)[0];
504         b_int4[i][0] = load8.x;
505         b_int4[i][1] = load8.y;
506       }
507 
508       if constexpr (InnerKTiles == 8) {
509         // asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n"
510         //              : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]),
511         //              "=r"(b_int4[i][2]), "=r"(b_int4[i][3]) : "l"(bPtrCur));
512 
513         int4 load16 = reinterpret_cast<const int4*>(bPtrCur)[0];
514         b_int4[i][0] = load16.x;
515         b_int4[i][1] = load16.y;
516         b_int4[i][2] = load16.z;
517         b_int4[i][3] = load16.w;
518       }
519     }
520 
521     // Load needed info for dequantization
522 
523     static_assert(isPowerOf2(QGroupSize), "");
524     static_assert(isEvenDivisor(QGroupSize, kKTileSize), "");
525     // smallest quantization group size is 32 (2 k-tiles are packed in an int32)
526     static_assert(QGroupSize >= kKTileSize * 2, "");
527     constexpr int kKTilesPerQGroup = (QGroupSize / kKTileSize);
528     // a q-group could be larger than what we are handling in a single warp
529     constexpr int kNumQGroups = (KTilesToLoad / kKTilesPerQGroup) < 1
530         ? 1
531         : (KTilesToLoad / kKTilesPerQGroup);
532 
533     __nv_bfloat162 qScaleAndZero[kNumQGroups];
534     {
535 #if defined(USE_ROCM)
536       int32_t laneN = nTile * kNTileSize + (laneId % kNTileSize);
537 #else
538       int32_t laneN = nTile * kNTileSize + (laneId / 4);
539 #endif
540       int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize;
541 
542       int32_t n = nTiles * kNTileSize;
543 
544       // offset [qScale_kGroup][qScale_n][0]
545       auto qInfoPtr = reinterpret_cast<const __nv_bfloat16*>(quantizationInfo) +
546           (groupStart * n + laneN) * 2;
547 
548 #pragma unroll
549       for (int i = 0; i < kNumQGroups; ++i) {
550         qScaleAndZero[i] =
551             *reinterpret_cast<const __nv_bfloat162*>(qInfoPtr + i * n * 2);
552       }
553     }
554 
555     //
556     // De-quantize int4 values to bf16. Values are dequantized as truly int4
557     // [-8, 7] range; dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
558     //
559     {
560       // FIXME: does this negatively affect register counts, or will nvcc
561       // move this expansion (and data loads above) closer to the point of use?
562       __nv_bfloat162 qScale[kNumQGroups];
563       __nv_bfloat162 qZero[kNumQGroups];
564 
565 #pragma unroll
566       for (int i = 0; i < kNumQGroups; ++i) {
567         qScale[i] = __bfloat162bfloat162(qScaleAndZero[i].x);
568         qZero[i] = __bfloat162bfloat162(qScaleAndZero[i].y);
569       }
570 
571 #pragma unroll
572       for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) {
573 #pragma unroll
574         for (int j = 0; j < InnerKTiles / 2; ++j) {
575           bf16x2x4 v = convert_i4x8_to_bf16x2x4(b_int4[i][j]);
576 
577           int curKTile = i * InnerKTiles + j * 2;
578           int curQGroup = (curKTile * kKTileSize) / QGroupSize;
579 
580           // The dequantized values in `v` for a given lane have the same n
581           // dimension (the B tensor core layout has all values in the same
582           // thread along the same n) but different k dimension, but all are
583           // guaranteed to occur within the same quantization group, so we need
584           // only load a single scale + zero to cover what this lane has
585 #pragma unroll
586           for (int k = 0; k < 4; ++k) {
587             v.vals[k] = __hfma2(v.vals[k], qScale[curQGroup], qZero[curQGroup]);
588           }
589 
590           // type pun, the __nv_bfloat162 value in bf16x2x4 is a struct and
591           // can't be used as a 32-bit asm register argument for `mma`
592           static_assert(sizeof(bf16x2x4) == sizeof(out[0][0]), "");
593           std::memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32));
594         }
595       }
596     }
597   }
598 };
599 
600 template <
601     typename ALayout,
602     typename BLayout,
603     typename CLayout,
604     int Warps,
605     int KTilesPerIteration>
606 __global__
__launch_bounds__(Warps * kWarpSize)607 __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
608     // Data for the A matrix, loaded as per ALayout
609     const void* const __restrict__ A,
610 
611     // Data for the B matrix, loaded as per BLayout
612     const void* const __restrict__ B,
613 
614     // Optional quantization data for dequantizing B, loaded as per BLayout
615     const void* const __restrict__ B_quantizationInfo,
616 
617     // Output data for the C matrix, stored as per CLayout
618     void* __restrict__ C,
619 
620     // The size of the matrix multiplication
621     int32_t m,
622     int32_t n,
623     int32_t k,
624 
625     // The size of the matrix multiplication, in multiples of our TC tile size
626     int32_t mTiles,
627     int32_t nTiles,
628     int32_t kTiles) {
629   constexpr int32_t kMTileSize = 16;
630 #if defined(USE_ROCM)
631   constexpr int32_t kNTileSize = 16;
632 #else
633   constexpr int32_t kNTileSize = 8;
634 #endif
635   constexpr int32_t kKTileSize = 16;
636 
637 #if !defined(USE_ROCM) || CDNA2_OR_LATER
638 
639   static_assert(
640       ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize &&
641           ALayout::kKTileSize == kKTileSize,
642       "");
643 
644   static_assert(
645       BLayout::kMTileSize == kMTileSize && BLayout::kNTileSize == kNTileSize &&
646           BLayout::kKTileSize == kKTileSize,
647       "");
648 
649   static_assert(
650       CLayout::kMTileSize == kMTileSize && CLayout::kNTileSize == kNTileSize &&
651           CLayout::kKTileSize == kKTileSize,
652       "");
653 
654   constexpr int kInnerKTiles = BLayout::kInnerKTiles;
655 
656   // 2/4/8 inner k-tiles correspond to 4, 8 and 16 byte innermost loads
657   static_assert(
658       kInnerKTiles == 2 || kInnerKTiles == 4 || kInnerKTiles == 8, "");
659 
660   // We always process at least kInnerKTiles k-tiles back to back in a warp
661   static_assert(
662       KTilesPerIteration >= kInnerKTiles &&
663           isEvenDivisor(KTilesPerIteration, kInnerKTiles),
664       "");
665 
666   auto warpId = threadIdx.y;
667   auto laneId = threadIdx.x;
668 
669   int32_t mTile = blockIdx.z;
670   int32_t nTile = blockIdx.y;
671 
672 #if defined(USE_ROCM)
673   VecT<float, 4> c{0.0f, 0.0f, 0.0f, 0.0f};
674 #else
675   float4 c{0.0f, 0.0f, 0.0f, 0.0f};
676 #endif
677 
678   // First, handle whole multiples of KTilesPerIteration
679   auto kTilesLimit = roundDown(kTiles, KTilesPerIteration);
680 
681   // Each warp handles a set of KTilesPerIteration under the above limit
682   for (int32_t kTileBase = (blockIdx.x * Warps + warpId) * KTilesPerIteration;
683        kTileBase < kTilesLimit;
684        kTileBase += Warps * KTilesPerIteration) {
685     //
686     // Load data from A
687     //
688 #if defined(USE_ROCM)
689     bf16x2x2_u32 a[KTilesPerIteration];
690 #else
691     bf16x2x4_u32 a[KTilesPerIteration];
692 #endif
693     ALayout::template load<KTilesPerIteration>(
694         A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a);
695 
696     //
697     // Load data from B and de-quantize as needed
698     // Each k-tile is bf16x2x2
699     //
700     bf16x2x4_u32 b[KTilesPerIteration / kInnerKTiles][kInnerKTiles / 2];
701     BLayout::template load<KTilesPerIteration>(
702         B,
703         B_quantizationInfo,
704         n,
705         k,
706         nTiles,
707         nTile,
708         kTiles,
709         kTileBase,
710         laneId,
711         b);
712 
713     //
714     // Now, perform the matrix multiplication
715     //
716 
717     // We accumulate across k-tiles here
718 #pragma unroll
719     for (int i = 0; i < KTilesPerIteration / kInnerKTiles; ++i) {
720       static_assert(isEvenDivisor(kInnerKTiles, 2) && kInnerKTiles >= 2, "");
721 #pragma unroll
722       for (int j = 0; j < kInnerKTiles / 2; ++j) {
723         // We don't simply accumulate into `c` as this creates a too-strong
724         // execution dependency. Instead, we only periodically accumulate into
725         // `c`
726 #if defined(USE_ROCM)
727         VecT<float, 4> cTmp[2];
728 #else
729         float4 cTmp[2];
730 #endif
731 
732 #pragma unroll
733         for (int k = 0; k < 2; ++k) {
734 #if defined(USE_ROCM)
735           cTmp[k] = VecT<float, 4>{0.0f, 0.0f, 0.0f, 0.0f};
736 #else
737           cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
738 #endif
739         }
740 
741 #pragma unroll
742         for (int k = 0; k < 2; ++k) {
743 #if defined(USE_ROCM)
744           cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
745               a[i * kInnerKTiles + j * 2 + k].val,
746               b[i][(j * 2 + k) / 2].val[((j * 2 + k) % 2)],
747               cTmp[k], 0, 0, 0);
748 #else
749           asm volatile(
750               "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
751               "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
752               : "=f"(cTmp[k].x),
753                 "=f"(cTmp[k].y),
754                 "=f"(cTmp[k].z),
755                 "=f"(cTmp[k].w)
756               : "r"(a[i * kInnerKTiles + j * 2 + k].vals[0]),
757                 "r"(a[i * kInnerKTiles + j * 2 + k].vals[1]),
758                 "r"(a[i * kInnerKTiles + j * 2 + k].vals[2]),
759                 "r"(a[i * kInnerKTiles + j * 2 + k].vals[3]),
760                 "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]),
761                 "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]),
762                 "f"(cTmp[k].x),
763                 "f"(cTmp[k].y),
764                 "f"(cTmp[k].z),
765                 "f"(cTmp[k].w));
766 #endif
767         }
768 
769 #pragma unroll
770         for (int k = 0; k < 2; ++k) {
771 #if defined(USE_ROCM)
772           c[0] += cTmp[k][0];
773           c[1] += cTmp[k][1];
774           c[2] += cTmp[k][2];
775           c[3] += cTmp[k][3];
776 #else
777           c.x += cTmp[k].x;
778           c.y += cTmp[k].y;
779           c.z += cTmp[k].z;
780           c.w += cTmp[k].w;
781 #endif
782         }
783       }
784     }
785   } // for all tiles under kTilesLimit
786 
787   // Now, there could be a remainder of 1 to KTilesPerIteration - 1 k-tiles
788   // remaining. We guarantee that the number of warps is >= KTilesPerIteration /
789   // kInnerKTiles, so that each warp can simply load kInnerKTiles and do its
790   // thing without needing more warps
791   static_assert(Warps >= KTilesPerIteration / kInnerKTiles, "");
792 
793   auto kTileBaseRemaining = kTilesLimit + warpId * kInnerKTiles;
794 
795   // If we have any remainder k-tiles, some warps will handle them, processing
796   // kInnerKTiles k-tiles at a time
797   if (kTileBaseRemaining < kTiles) {
798 #if defined(USE_ROCM)
799     bf16x2x2_u32 a[kInnerKTiles];
800 #else
801     bf16x2x4_u32 a[kInnerKTiles];
802 #endif
803     ALayout::template load<kInnerKTiles>(
804         A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, a);
805 
806     bf16x2x4_u32 b[1][kInnerKTiles / 2];
807     BLayout::template load<kInnerKTiles>(
808         B,
809         B_quantizationInfo,
810         n,
811         k,
812         nTiles,
813         nTile,
814         kTiles,
815         kTileBaseRemaining,
816         laneId,
817         b);
818 
819 #pragma unroll
820     for (int j = 0; j < kInnerKTiles / 2; ++j) {
821       // We don't simply accumulate into `c` as this creates a too-strong
822       // execution dependency. Instead, we only periodically accumulate into
823       // `c`
824 #if defined(USE_ROCM)
825       VecT<float, 4> cTmp[2];
826 #else
827       float4 cTmp[2];
828 #endif
829 
830 #pragma unroll
831       for (int k = 0; k < 2; ++k) {
832 #if defined(USE_ROCM)
833         cTmp[k] = VecT<float, 4>{0.0f, 0.0f, 0.0f, 0.0f};
834 #else
835         cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
836 #endif
837       }
838 
839 #pragma unroll
840       for (int k = 0; k < 2; ++k) {
841 #if defined(USE_ROCM)
842         cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
843           a[j * 2 + k].val,
844           b[0][(j * 2 + k) / 2].val[((j * 2 + k) % 2)],
845           cTmp[k], 0, 0, 0);
846 #else
847         asm volatile(
848             "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
849             "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
850             : "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w)
851             : "r"(a[j * 2 + k].vals[0]),
852               "r"(a[j * 2 + k].vals[1]),
853               "r"(a[j * 2 + k].vals[2]),
854               "r"(a[j * 2 + k].vals[3]),
855               "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]),
856               "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]),
857               "f"(cTmp[k].x),
858               "f"(cTmp[k].y),
859               "f"(cTmp[k].z),
860               "f"(cTmp[k].w));
861 #endif
862       }
863 
864 #pragma unroll
865       for (int k = 0; k < 2; ++k) {
866 #if defined(USE_ROCM)
867         c[0] += cTmp[k][0];
868         c[1] += cTmp[k][1];
869         c[2] += cTmp[k][2];
870         c[3] += cTmp[k][3];
871 #else
872         c.x += cTmp[k].x;
873         c.y += cTmp[k].y;
874         c.z += cTmp[k].z;
875         c.w += cTmp[k].w;
876 #endif
877       }
878     }
879   }
880 
881   //
882   // Reduce independent k-tiles (same m/n) across warps
883   //
884   __shared__ float4 smem_sum[Warps][kWarpSize];
885 
886   // FIXME: this likely doesn't need to be a true reduction tree, can just be a
887   // serial sum, maybe (unless nvcc/ptxas goes back to its old ways)
888   // smem_sum[warpId][laneId] = TreeReduce4<KTilesPerIteration>::reduce(c);
889 #if defined(USE_ROCM)
890   smem_sum[warpId][laneId].x = c[0];
891   smem_sum[warpId][laneId].y = c[1];
892   smem_sum[warpId][laneId].z = c[2];
893   smem_sum[warpId][laneId].w = c[3];
894 #else
895   smem_sum[warpId][laneId] = c;
896 #endif
897 
898   __syncthreads();
899 
900   if (warpId == 0) {
901     float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f};
902 
903     // Reduce across the block in the first warp
904     for (int i = 0; i < Warps; ++i) {
905       float4 v = smem_sum[i][laneId];
906       sum_f32.x += v.x;
907       sum_f32.y += v.y;
908       sum_f32.z += v.z;
909       sum_f32.w += v.w;
910     }
911 
912     // Write the reduced result (in the first warp) into the output
913     CLayout::store(
914         C,
915         m,
916         n,
917         mTiles,
918         mTile,
919         // n for C output becomes k for A input, so for m16n8k16,
920         // we need to halve the tiles
921         nTiles / 2,
922         nTile,
923         laneId,
924         sum_f32);
925   }
926 #else
927     printf("__builtin_amdgcn_mfma_f32_16x16x16bf16_1k is only supported on AMD gpu arch greater than or equal to CDNA2\n");
928 #endif
929 }
930 
931 
932 template <
933     typename ALayout,
934     typename BLayout,
935     typename CLayout,
936     int Warps,
937     int KTilesPerWarp>
launch_tinygemm_kernel(const at::Tensor & A,const at::Tensor & B,const at::Tensor * qScaleAndZeros,at::Tensor & C_final,int32_t mTiles,int32_t nTiles,int32_t kTiles,int32_t m,int32_t n,int32_t k,cudaStream_t stream)938 void launch_tinygemm_kernel(
939     const at::Tensor& A,
940     const at::Tensor& B,
941     const at::Tensor* qScaleAndZeros, /* optional */
942     at::Tensor& C_final,
943     int32_t mTiles,
944     int32_t nTiles,
945     int32_t kTiles,
946     int32_t m,
947     int32_t n,
948     int32_t k,
949     cudaStream_t stream) {
950   // The chunking kernel requires that kTiles is a multiple of kInnerKTiles
951   TORCH_CHECK(
952       kTiles >= BLayout::kInnerKTiles &&
953       isEvenDivisor(kTiles, BLayout::kInnerKTiles));
954 
955   TORCH_CHECK(
956       KTilesPerWarp >= BLayout::kInnerKTiles &&
957       isEvenDivisor(KTilesPerWarp, BLayout::kInnerKTiles));
958 
959   // After intra-block reduction across the k dimension, we are left with this
960   // many tiles
961   //  int32_t postKernelKTiles = kTiles / (Warps * KTilesPerWarp);
962   int32_t postKernelKTiles = 1; // we loop
963 
964   auto grid = dim3(postKernelKTiles, nTiles, mTiles);
965   auto block = dim3(kWarpSize, Warps);
966 
967   auto func =
968       tinygemm_m16n8k16_chunk_kernel<ALayout, BLayout, CLayout, Warps, KTilesPerWarp>;
969 
970   func<<<grid, block, 0, stream>>>(
971       A.data_ptr(),
972       B.data_ptr(),
973       qScaleAndZeros ? qScaleAndZeros->data_ptr() : nullptr,
974       C_final.data_ptr(),
975       m,
976       n,
977       k,
978       mTiles,
979       nTiles,
980       kTiles);
981   C10_CUDA_KERNEL_LAUNCH_CHECK();
982 
983   cudaFuncAttributes funcAttr;
984 #if defined(USE_ROCM)
985   C10_CUDA_CHECK(cudaFuncGetAttributes(
986       &funcAttr,
987       (void *)func
988   ));
989 #else
990   C10_CUDA_CHECK(cudaFuncGetAttributes(
991       &funcAttr,
992       func
993   ));
994 #endif
995 }
996 
997 // FIXME: parallelize better, smem staging etc?
998 template <int InnerKTiles>
matrix_to_m16n8k16_Bint4_layout(const at::PackedTensorAccessor32<uint8_t,2,at::RestrictPtrTraits> in,at::PackedTensorAccessor32<int32_t,4,at::RestrictPtrTraits> out)999 __global__ void matrix_to_m16n8k16_Bint4_layout(
1000     // size [n][k / 2]
1001     const at::PackedTensorAccessor32<uint8_t, 2, at::RestrictPtrTraits> in,
1002     // size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2]
1003     at::PackedTensorAccessor32<int32_t, 4, at::RestrictPtrTraits> out) {
1004   // int4 values are packed into int32 values, which require at least 8. Given
1005   // m16n8k16 B layout requires 4 scalar values/lane, the minimum number of
1006   // innermost k-tiles that we can use is 2.
1007   static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), "");
1008 
1009 #if defined(USE_ROCM)
1010   constexpr int32_t kNTileSize = 16;
1011 #else
1012   constexpr int32_t kNTileSize = 8;
1013 #endif
1014   constexpr int32_t kKTileSize = 16;
1015 
1016   // gridDim.x corresponds to the number of k-tiles divided by InnerKTiles
1017   auto kOuterTile = blockIdx.x;
1018   auto nTile = blockIdx.y;
1019   auto t = threadIdx.x;
1020 
1021   // Two k-tiles are packed into an int32 at a time
1022 #pragma unroll
1023   for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) {
1024     // n dimension that this lane loads from
1025 #if defined(USE_ROCM)
1026     auto n0 = nTile * kNTileSize + (t % kNTileSize);
1027 #else
1028     auto n0 = nTile * kNTileSize + (t / 4);
1029 #endif
1030 
1031     bool n0Valid = n0 < in.size(0);
1032 
1033     // Four uint8 are packed into an int32
1034     int32_t ks[4];
1035 
1036     auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize / 2;
1037 
1038 #if defined(USE_ROCM)
1039     ks[0] = kBase0 + (t / kNTileSize) * 2;
1040     ks[1] = ks[0] + 1;
1041 
1042     auto kBase1 = kBase0 + kKTileSize / 2;
1043     ks[2] = kBase1 + (t / kNTileSize) * 2;
1044     ks[3] = ks[2] + 1;
1045 #else
1046     ks[0] = kBase0 + t % 4;
1047     ks[1] = ks[0] + 4;
1048 
1049     auto kBase1 = kBase0 + kKTileSize / 2;
1050     ks[2] = kBase1 + t % 4;
1051     ks[3] = ks[2] + 4;
1052 #endif
1053 
1054     auto pIn = &in[n0][0];
1055 
1056     uint8_t v[4];
1057 #pragma unroll
1058     for (int i = 0; i < 4; ++i) {
1059       v[i] = (n0Valid && ks[i] < in.size(1)) ? pIn[ks[i]] : uint8_t(0);
1060     }
1061 
1062     // To clearly explain the packed result with 8 int4 values (4 uint8)
1063     // into one int32, we use the follow figure:
1064     // [n][k]     int32: v[0] v[1] v[2] v[3] v[4] v[5] v[6] v[7]
1065     // [n][k / 2] uint8:    v[0]     v[1]      v[2]      v[3]
1066     // When using int32 weight as input, the packed result is consisted of
1067     // v[7] | v[5] | v[3] | v[1] | v[6] | v[4] | v[2] | v[0],
1068     // which epuals to
1069     // v[3]L | v[2]L | v[1]L | v[0]L | v[3]H | v[2]H | v[1]H | v[0]H
1070     // when using uint8 weight as input.
1071     int32_t pack = ((uint32_t)(v[3] & 0xF) << 28) |
1072         ((uint32_t)(v[2] & 0xF) << 24) | ((uint32_t)(v[1] & 0xF) << 20) |
1073         ((uint32_t)(v[0] & 0xF) << 16) | ((uint32_t)(v[3] & 0xF0) << 8) |
1074         ((uint32_t)(v[2] & 0xF0) << 4) | ((uint32_t)(v[1] & 0xF0)) |
1075         ((uint32_t)(v[0] & 0xF0) >> 4);
1076 
1077     // inner k-tiles pack two at a time
1078 #if defined(USE_ROCM)
1079     // The output tensor shape is [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2], which is specific to Nvidia
1080     // But AMD needs [ceil(n / 16)][ceil(k / (InnerKTiles * 16))][64][InnerKTiles / 2]
1081     // So construct the pointer accordingly
1082     auto bPtr = out.data() +
1083       ((nTile * out.size(1) * kWarpSize * (InnerKTiles / 2)) +
1084         (kOuterTile * kWarpSize * (InnerKTiles / 2)) +
1085           (t * (InnerKTiles / 2)) +
1086             (innerKTile / 2));
1087     *bPtr = pack;
1088 #else
1089     out[nTile][kOuterTile][t][innerKTile / 2] = pack;
1090 #endif
1091   }
1092 }
1093 
1094 #endif
1095 
1096 
_weight_int4pack_mm_cuda(const at::Tensor & A,const at::Tensor & B,int64_t qGroupSize,const at::Tensor & qScaleAndZeros)1097 at::Tensor _weight_int4pack_mm_cuda(
1098     const at::Tensor& A,
1099     const at::Tensor& B,
1100     int64_t qGroupSize,
1101     const at::Tensor& qScaleAndZeros) {
1102   c10::cuda::CUDAGuard g(A.device());
1103 
1104   TORCH_CHECK(
1105       A.device() == B.device() && A.device() == qScaleAndZeros.device());
1106 
1107 #if defined(USE_ROCM)
1108   if (!isCDNA2orLater(A.device().index())) {
1109     TORCH_CHECK(false, "_weight_int4pack_mm_cuda is only supported on AMD gpu arch greater than or equal to CDNA2");
1110   }
1111 #endif
1112 
1113   constexpr int32_t kMTileSize = 16;
1114 #if defined(USE_ROCM)
1115   constexpr int32_t kNTileSize = 16;
1116 #else
1117   constexpr int32_t kNTileSize = 8;
1118 #endif
1119   constexpr int32_t kKTileSize = 16;
1120 
1121   // row major layout
1122   auto m = A.size(0);
1123   auto mTiles = divUp(m, kMTileSize);
1124 
1125   // To convert the nTiles from tensor storage layout to the actual matrix core layout
1126   constexpr int32_t kNTileSizeTensor = 8;
1127   auto nTileScaleFactor = (kNTileSize / kNTileSizeTensor);
1128 
1129   // tensor core layout
1130   auto nTiles = (B.size(0) / nTileScaleFactor);
1131   auto n = nTiles * kNTileSize;
1132 
1133   // row major layout
1134   auto k = A.size(1);
1135   auto kTiles = divUp(k, kKTileSize);
1136 
1137   // The number of inner k tiles is the innermost dimension of  times 2
1138   // 2 k-tiles (4 values per lane per tile, 8 values total) quantized to int4
1139   // packed into 1 int32 for int4 B
1140   auto B_innerKTiles = B.size(3) * 2;
1141   TORCH_CHECK(B_innerKTiles == 2 || B_innerKTiles == 4 || B_innerKTiles == 8);
1142 
1143   // A is standard row major
1144   TORCH_CHECK(A.dtype() == at::kBFloat16);
1145   TORCH_CHECK(A.is_contiguous());
1146   TORCH_CHECK(A.dim() == 2);
1147 
1148   // B has B_innerKTiles k-tiles in the innermost dimension
1149   TORCH_CHECK(B.dtype() == at::kInt);
1150   TORCH_CHECK(B.is_contiguous());
1151   TORCH_CHECK(B.dim() == 4);
1152   TORCH_CHECK(B.size(1) == k / (B_innerKTiles * kKTileSize));
1153   TORCH_CHECK(B.size(2) == 32);
1154 
1155   // Validate the scale and zero point tensor for dequantization
1156   // These are the only versions handled at the moment
1157   TORCH_CHECK(
1158       qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 ||
1159       qGroupSize == 256);
1160 
1161   TORCH_CHECK(qScaleAndZeros.dim() == 3);
1162   auto numQGroups = qScaleAndZeros.size(0);
1163   TORCH_CHECK(
1164       kTiles * kKTileSize >= qGroupSize &&
1165       isEvenDivisor(kTiles * kKTileSize, qGroupSize));
1166   TORCH_CHECK(qScaleAndZeros.size(1) == n);
1167   TORCH_CHECK(qScaleAndZeros.size(2) == 2);
1168 
1169   // Output is a standard row-major matrix
1170   auto C_final = at::empty(
1171       {m, n}, at::TensorOptions().dtype(at::kBFloat16).device(A.device()));
1172 
1173 #if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
1174   auto stream = at::cuda::getCurrentCUDAStream();
1175 #define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \
1176   do {                                                               \
1177     using ACLayout = ALayout_RM<REDUCE_TYPE>;                        \
1178                                                                      \
1179     TORCH_CHECK(                                                     \
1180         K_TILES_PER_WARP >= B_innerKTiles &&                         \
1181         isEvenDivisor(K_TILES_PER_WARP, B_innerKTiles));             \
1182                                                                      \
1183     switch (B_innerKTiles) {                                         \
1184       case 2:                                                        \
1185         if constexpr (K_TILES_PER_WARP >= 2) {                       \
1186           using BLayout = BLayout_TC_int4<2, Q_GROUP_SIZE>;          \
1187           launch_tinygemm_kernel<                                    \
1188               ACLayout,                                              \
1189               BLayout,                                               \
1190               ACLayout,                                              \
1191               WARPS,                                                 \
1192               K_TILES_PER_WARP>(                                     \
1193               A,                                                     \
1194               B,                                                     \
1195               &qScaleAndZeros,                                       \
1196               C_final,                                               \
1197               mTiles,                                                \
1198               nTiles,                                                \
1199               kTiles,                                                \
1200               m,                                                     \
1201               n,                                                     \
1202               k,                                                     \
1203               stream);                                               \
1204         }                                                            \
1205         break;                                                       \
1206       case 4:                                                        \
1207         if constexpr (K_TILES_PER_WARP >= 4) {                       \
1208           using BLayout = BLayout_TC_int4<4, Q_GROUP_SIZE>;          \
1209           launch_tinygemm_kernel<                                    \
1210               ACLayout,                                              \
1211               BLayout,                                               \
1212               ACLayout,                                              \
1213               WARPS,                                                 \
1214               K_TILES_PER_WARP>(                                     \
1215               A,                                                     \
1216               B,                                                     \
1217               &qScaleAndZeros,                                       \
1218               C_final,                                               \
1219               mTiles,                                                \
1220               nTiles,                                                \
1221               kTiles,                                                \
1222               m,                                                     \
1223               n,                                                     \
1224               k,                                                     \
1225               stream);                                               \
1226         }                                                            \
1227         break;                                                       \
1228       case 8:                                                        \
1229         if constexpr (K_TILES_PER_WARP >= 8) {                       \
1230           using BLayout = BLayout_TC_int4<8, Q_GROUP_SIZE>;          \
1231           launch_tinygemm_kernel<                                    \
1232               ACLayout,                                              \
1233               BLayout,                                               \
1234               ACLayout,                                              \
1235               WARPS,                                                 \
1236               K_TILES_PER_WARP>(                                     \
1237               A,                                                     \
1238               B,                                                     \
1239               &qScaleAndZeros,                                       \
1240               C_final,                                               \
1241               mTiles,                                                \
1242               nTiles,                                                \
1243               kTiles,                                                \
1244               m,                                                     \
1245               n,                                                     \
1246               k,                                                     \
1247               stream);                                               \
1248         }                                                            \
1249         break;                                                       \
1250       default:                                                       \
1251         break;                                                       \
1252     }                                                                \
1253   } while (false)
1254 
1255 #define HANDLE_Q_GROUP(WARPS, K_TILES_PER_WARP, REDUCE_TYPE) \
1256   do {                                                       \
1257     switch (qGroupSize) {                                    \
1258       case 32:                                               \
1259         RUN_GEMM(WARPS, K_TILES_PER_WARP, 32, REDUCE_TYPE);  \
1260         break;                                               \
1261       case 64:                                               \
1262         RUN_GEMM(WARPS, K_TILES_PER_WARP, 64, REDUCE_TYPE);  \
1263         break;                                               \
1264       case 128:                                              \
1265         RUN_GEMM(WARPS, K_TILES_PER_WARP, 128, REDUCE_TYPE); \
1266         break;                                               \
1267       case 256:                                              \
1268         RUN_GEMM(WARPS, K_TILES_PER_WARP, 256, REDUCE_TYPE); \
1269         break;                                               \
1270     }                                                        \
1271   } while (false)
1272 
1273   HANDLE_Q_GROUP(8, 8, KReductionType::None);
1274 
1275 #undef HANDLE_Q_GROUP
1276 #undef RUN_GEMM
1277 
1278   return C_final;
1279 #endif
1280   TORCH_CHECK(false, "_weight_int4pack_mm_cuda is not available for build.")
1281   return C_final;
1282 }
1283 
1284 // input is [n][k / 2] (uint8 dtype)
1285 // output is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] (int32 dtype)
_convert_weight_to_int4pack_cuda(const at::Tensor & in,int64_t innerKTiles)1286 at::Tensor _convert_weight_to_int4pack_cuda(
1287     const at::Tensor& in,
1288     int64_t innerKTiles) {
1289   c10::cuda::CUDAGuard g(in.device());
1290 
1291   TORCH_CHECK(in.dim() == 2);
1292   TORCH_CHECK(in.dtype() == at::kByte);
1293   TORCH_CHECK(in.is_contiguous());
1294 
1295   // At least 2 k-tiles need to be packed back to back in the innermost
1296   // dimension, as the m16n8k16 tensor core tile presents 4 scalar values for
1297   // the B matrix, but the minimum word size for the packed format is 4 bytes
1298   // (int32). 4 inner K-tiles = 8 byte load, 8 inner k-tiles = 16 byte load
1299   // which is the maximum vectorized load/store size
1300   TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8);
1301 
1302 #if defined(USE_ROCM)
1303   if (!isCDNA2orLater(in.device().index())) {
1304     TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is only supported on AMD gpu arch greater than or equal to CDNA2");
1305   }
1306 #endif
1307 
1308 #if defined(USE_ROCM)
1309   constexpr int32_t kNTileSize = 16;
1310 #else
1311   constexpr int32_t kNTileSize = 8;
1312 #endif
1313   constexpr int32_t kKTileSize = 16;
1314 
1315   // GPT-FAST assumes nTileSize of 8 for quantized weight tensor.
1316   // See https://github.com/pytorch-labs/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510
1317   // Torch dynamo also requires the torch ops has the same output shape for each device.
1318   // See https://github.com/pytorch/pytorch/blob/ec284d3a74ec1863685febd53687d491fd99a161/torch/_meta_registrations.py#L3263
1319   constexpr int32_t kNTileSizeTensor = 8;
1320 
1321   auto nTiles = divUp(in.size(0), kNTileSize);
1322   auto nTilesTensor = divUp(in.size(0), kNTileSizeTensor);
1323 
1324   // k-tiles are packed back to back in the innermost dimension in order to
1325   // allow for 4/8/16 byte loads
1326   TORCH_CHECK(isEvenDivisor(in.size(1) * 2, innerKTiles * kKTileSize));
1327   // kSuperTiles is the number of k-tiles assuming k is innerKTiles * kKTileSize
1328   auto kSuperTiles = divUp(in.size(1) * 2, innerKTiles * kKTileSize);
1329 
1330   // each block handles `innerKTiles` k-tiles.
1331   // 2 k-tiles are a single int32
1332   //
1333   // We use the same shape for AMD gpus also to match the GPT-FAST spec.
1334   // Will index it correctly when dereferencing the quantized weight tensor pointer.
1335   auto out = at::empty(
1336       {nTilesTensor, kSuperTiles, 32, innerKTiles / 2},
1337       at::TensorOptions().dtype(at::kInt).device(in.device()));
1338 
1339 #if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
1340   auto stream = at::cuda::getCurrentCUDAStream();
1341   dim3 grid(kSuperTiles, nTiles);
1342 
1343   if (innerKTiles == 2) {
1344     matrix_to_m16n8k16_Bint4_layout<2><<<grid, kWarpSize, 0, stream>>>(
1345         in.packed_accessor32<uint8_t, 2, at::RestrictPtrTraits>(),
1346         out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
1347   } else if (innerKTiles == 4) {
1348     matrix_to_m16n8k16_Bint4_layout<4><<<grid, kWarpSize, 0, stream>>>(
1349         in.packed_accessor32<uint8_t, 2, at::RestrictPtrTraits>(),
1350         out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
1351   } else if (innerKTiles == 8) {
1352     matrix_to_m16n8k16_Bint4_layout<8><<<grid, kWarpSize, 0, stream>>>(
1353         in.packed_accessor32<uint8_t, 2, at::RestrictPtrTraits>(),
1354         out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
1355   }
1356 
1357   return out;
1358 #endif
1359   TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is not available for build.")
1360   return out;
1361 }
1362 
1363 
1364 } // namespace at::native
1365