1 #if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
2 #include <cuda_bf16.h>
3 #include <cuda_fp16.h>
4 #include <cuda_runtime.h>
5 #if !defined(USE_ROCM)
6 #include <mma.h>
7 #endif
8 #endif
9 #include <ATen/ATen.h>
10 #include <ATen/core/Tensor.h>
11 #include <ATen/cuda/CUDAContext.h>
12 #include <ATen/DeviceGuard.h>
13 #include <c10/cuda/CUDAGuard.h>
14
15
16 namespace at::native {
17
18 template <typename U, typename V>
divDown(U a,V b)19 constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) {
20 static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
21 return (a / b);
22 }
23
24 template <typename U, typename V>
divUp(U a,V b)25 constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
26 static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
27 // Overflow safe variant of (a + b - 1) / b
28 const uint64_t blocks = a / b + (a % b != 0);
29 return blocks;
30 }
31
32 template <typename U, typename V>
roundDown(U a,V b)33 constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) {
34 static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
35 return divDown(a, b) * b;
36 }
37
38 template <typename U, typename V>
roundUp(U a,V b)39 constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) {
40 static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
41 return divUp(a, b) * b;
42 }
43
44 template <typename U, typename V>
isEvenDivisor(U a,V b)45 constexpr __host__ __device__ bool isEvenDivisor(U a, V b) {
46 static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
47 return (a % V(b) == 0) && ((a / V(b)) >= 1);
48 }
49
50 template <class T>
pow(T n,int power)51 constexpr __host__ __device__ T pow(T n, int power) {
52 return (power > 0 ? n * pow(n, power - 1) : 1);
53 }
54
55 template <class T>
pow2(int power)56 constexpr __host__ __device__ T pow2(int power) {
57 return pow(2, power);
58 }
59
60 static_assert(pow2<int>(8) == 256, "pow2");
61
62 template <typename T>
log2(T n,int p=0)63 constexpr __host__ __device__ int log2(T n, int p = 0) {
64 return (n <= 1) ? p : log2(n / 2, p + 1);
65 }
66
67 static_assert(log2(2) == 1, "log2");
68 static_assert(log2(3) == 1, "log2");
69 static_assert(log2(4) == 2, "log2");
70
71 template <typename T>
isPowerOf2(T v)72 constexpr __host__ __device__ bool isPowerOf2(T v) {
73 static_assert(std::is_integral<T>::value, "");
74 return (v && !(v & (v - 1)));
75 }
76
77 static_assert(isPowerOf2(2048), "isPowerOf2");
78 static_assert(!isPowerOf2(3333), "isPowerOf2");
79
80 template <typename T>
nextHighestPowerOf2(T v)81 constexpr __host__ __device__ T nextHighestPowerOf2(T v) {
82 static_assert(std::is_integral<T>::value, "");
83 return (isPowerOf2(v) ? (T)2 * v : ((T)1 << (log2(v) + 1)));
84 }
85
86 static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2");
87 static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2");
88 static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2");
89 static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2");
90
91 static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2");
92 static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2");
93 static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2");
94
95 static_assert(
96 nextHighestPowerOf2(1536000000u) == 2147483648u,
97 "nextHighestPowerOf2");
98 static_assert(
99 nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL,
100 "nextHighestPowerOf2");
101
102 template <typename T>
nextLowestPowerOf2(T v)103 constexpr __host__ __device__ T nextLowestPowerOf2(T v) {
104 static_assert(std::is_integral<T>::value, "");
105 return (isPowerOf2(v) ? v / (T)2 : ((T)1 << (log2(v))));
106 }
107
108 static_assert(nextLowestPowerOf2(1) == 0, "nextLowestPowerOf2");
109 static_assert(nextLowestPowerOf2(2) == 1, "nextLowestPowerOf2");
110 static_assert(nextLowestPowerOf2(3) == 2, "nextLowestPowerOf2");
111 static_assert(nextLowestPowerOf2(4) == 2, "nextLowestPowerOf2");
112
113 static_assert(nextLowestPowerOf2(15) == 8, "nextLowestPowerOf2");
114 static_assert(nextLowestPowerOf2(16) == 8, "nextLowestPowerOf2");
115 static_assert(nextLowestPowerOf2(17) == 16, "nextLowestPowerOf2");
116
isPointerAligned(const void * p,int align)117 inline __host__ __device__ bool isPointerAligned(const void* p, int align) {
118 return reinterpret_cast<uintptr_t>(p) % align == 0;
119 }
120
121 // Returns the increment needed to aligned the pointer to the next highest
122 // aligned address
123 template <int Align>
getAlignmentRoundUp(const void * p)124 inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
125 static_assert(isPowerOf2(Align), "");
126 const uint32_t diff = uint32_t(uintptr_t(p) & uintptr_t(Align - 1));
127 return diff == 0 ? 0 : uint32_t(Align) - diff;
128 }
129
130 #if defined(USE_ROCM)
131 // TODO: Support RDNA
132 constexpr int32_t kWarpSize = 64;
133
134 template<typename T, uint32_t Rank>
135 using VecT = T __attribute__((ext_vector_type(Rank)));
136
isCDNA2orLater(int index)137 static bool isCDNA2orLater(int index) {
138 hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
139 std::string device_arch = prop->gcnArchName;
140 static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
141 for (std::string arch : archs) {
142 size_t substring = device_arch.find(arch);
143 if (substring != std::string::npos) {
144 return true;
145 }
146 }
147 return false;
148 }
149
150 #else
151 constexpr int32_t kWarpSize = 32;
152 #endif
153
154 #if defined (__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
155 #define CDNA2_OR_LATER 1
156 #else
157 #define CDNA2_OR_LATER 0
158 #endif
159
160 #if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
161
162 // f16 vector types
163 struct __align__(2) f16x1 {
164 __half vals[1];
165 };
166
167 struct __align__(4) f16x2 {
168 __half vals[2];
169 };
170
171 struct __align__(8) f16x4 {
172 __half vals[4];
173 };
174
175 struct __align__(16) f16x8 {
176 __half vals[8];
177 };
178
179 // bf16 vector types
180 struct __align__(2) bf16x1 {
181 __nv_bfloat16 vals[1];
182 };
183
184 struct __align__(4) bf16x2 {
185 __nv_bfloat16 vals[2];
186 };
187
188 struct __align__(8) bf16x4 {
189 __nv_bfloat16 vals[4];
190 };
191
192 struct __align__(16) bf16x8 {
193 __nv_bfloat16 vals[8];
194 };
195
196 // bf162 vector types
197 struct __align__(4) bf16x2x1 {
198 __nv_bfloat162 vals[1];
199 };
200
201 struct __align__(8) bf16x2x2 {
202 __nv_bfloat162 vals[2];
203 };
204
205 struct __align__(16) bf16x2x4 {
206 __nv_bfloat162 vals[4];
207 };
208
209 struct __align__(16) bf16x2x4_u32 {
210 #if defined(USE_ROCM)
211 VecT<short, 4> val[2];
212 #else
213 uint32_t vals[4];
214 #endif
215 };
216
217 struct __align__(8) bf16x2x2_u32 {
218 #if defined(USE_ROCM)
219 VecT<short, 4> val;
220 #else
221 uint32_t vals[2];
222 #endif
223 };
224
225 struct __align__(4) bf16x2x1_u32 {
226 uint32_t vals[1];
227 };
228
229 template <typename T, int N>
__align__(sizeof (T)* N)230 struct __align__(sizeof(T) * N) VectorType {
231 T vals[N];
232 };
233
234 // from
235 // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
convert_i4x8_to_bf16x2x4(uint32_t source)236 inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
237 bf16x2x4 result;
238 constexpr int kElements = 8;
239
240 uint32_t* h = reinterpret_cast<uint32_t*>(&result);
241 uint32_t const source_i4s = source;
242
243 // First, we extract the i4s and construct an intermediate fp16 number.
244 #if !defined(USE_ROCM)
245 static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
246 #endif
247 static constexpr uint32_t MASK = 0x000f000f;
248 static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
249
250 // We don't have enough mantissa to remove as much shift overhead as FP16, so
251 // we must loop. No shift needed for first item.
252 uint32_t i4s = source_i4s;
253
254 #if defined(USE_ROCM)
255 asm volatile("v_and_or_b32 %0, %1, %2, %3"
256 : "=v"(h[0])
257 : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
258 #else
259 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
260 : "=r"(h[0])
261 : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
262 #endif
263
264 #pragma unroll
265 for (int ii = 1; ii < kElements / 2; ++ii) {
266 i4s >>= 4; // or is it 8?
267 // (i4s & 0x000f000f) | 0x43004300
268 #if defined(USE_ROCM)
269 asm volatile("v_and_or_b32 %0, %1, %2, %3"
270 : "=v"(h[ii])
271 : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
272 #else
273 asm volatile(
274 "lop3.b32 %0, %1, %2, %3, %4;\n"
275 : "=r"(h[ii])
276 : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
277 #endif
278 }
279
280 // This is the BF16 {-136, -136} represented as an integer.
281 #if defined(USE_ROCM)
282 #if ROCM_VERSION >= 60200
283 auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308}));
284 auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80}));
285 #else
286 auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308});
287 auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80});
288 #endif
289 #else
290 static constexpr uint32_t BF16_BIAS = 0xC308C308;
291 static constexpr uint32_t BF16_ONE = 0x3F803F80;
292 #endif
293
294 // Finally, we construct the output numbers.
295 #pragma unroll
296 for (int ii = 0; ii < kElements / 2; ++ii) {
297 // Since this section is for Ampere+, we use bf16 fma to do the bias
298 // subtraction
299 #if defined(USE_ROCM)
300 result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS);
301 #else
302 asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
303 : "=r"(h[ii])
304 : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
305 #endif
306 }
307
308 return result;
309 }
310
311
312
313 enum class KReductionType {
314 // No k-reduction is needed between blocks as the number of k-tiles processed
315 // per block are exact and we can directly write the output
316 None,
317 };
318
319 // Loads the A matrix in 16-bit standard m x k row major layout, and writes
320 // the C matrix in 16-bit standard m x n row major layout:
321 //
322 // size [m][k]
323 template <KReductionType ReduceType>
324 struct ALayout_RM {
325 static constexpr int32_t kMTileSize = 16;
326 #if defined(USE_ROCM)
327 static constexpr int32_t kNTileSize = 16;
328 #else
329 static constexpr int32_t kNTileSize = 8;
330 #endif
331 static constexpr int32_t kKTileSize = 16;
332
333 template <int KTilesToLoad>
loadat::native::ALayout_RM334 static __device__ void load(
335 const void* A,
336 int32_t m,
337 int32_t k,
338 int32_t mTiles,
339 int32_t mTile,
340 int32_t kTiles,
341 int32_t kTileStart,
342 int32_t laneId,
343 #if defined(USE_ROCM)
344 bf16x2x2_u32 out[KTilesToLoad]
345 #else
346 bf16x2x4_u32 out[KTilesToLoad]
347 #endif
348 ) {
349 #if defined(USE_ROCM)
350 const auto mLane = mTile * kMTileSize + (laneId % kMTileSize);
351 const auto kLane = kTileStart * kKTileSize + (laneId / kMTileSize) * 4;
352 #else
353 const auto mLane = mTile * kMTileSize + (laneId / 4);
354 const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2;
355 #endif
356
357 // access
358 // [mTile * kMTileSize + (laneId / 4)]
359 // [kTileStart * kKTileSize + (laneId % 4) * 2]
360 auto aPtr = reinterpret_cast<const __nv_bfloat16*>(A) + mLane * k + kLane;
361 bool m0InBounds = mLane < m;
362
363 #if !defined(USE_ROCM)
364 auto aPtrPlus8Rows = aPtr + 8 * k;
365
366 bool m1InBounds = (mLane + 8) < m;
367 #endif
368
369 #pragma unroll
370 for (int i = 0; i < KTilesToLoad; ++i) {
371 #if defined(USE_ROCM)
372 out[i].val = m0InBounds ? *((VecT<short, 4> *)(aPtr + i * kKTileSize)) : VecT<short, 4>{0, 0, 0, 0};
373 #else
374 out[i].vals[0] = m0InBounds
375 ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize)
376 : uint32_t(0);
377 out[i].vals[1] = m1InBounds
378 ? *reinterpret_cast<const uint32_t*>(aPtrPlus8Rows + i * kKTileSize)
379 : uint32_t(0);
380
381 out[i].vals[2] = m0InBounds
382 ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize + 8)
383 : uint32_t(0);
384 out[i].vals[3] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
385 aPtrPlus8Rows + i * kKTileSize + 8)
386 : uint32_t(0);
387 #endif
388 }
389 }
390
storeat::native::ALayout_RM391 static __device__ void store(
392 void* C,
393 int32_t m,
394 int32_t n,
395 int32_t mOutTiles,
396 int32_t mTile,
397 int32_t nOutTiles,
398 int32_t nTile,
399 int32_t laneId,
400 const float4& out) {
401 static_assert(ReduceType == KReductionType::None, "");
402
403 if constexpr (ReduceType == KReductionType::None) {
404 #if defined(USE_ROCM)
405 const int outRow = mTile * kMTileSize + (laneId / kNTileSize) * 4;
406 const int outCol = nTile * kNTileSize + (laneId % kNTileSize);
407 #else
408 // sum.x / sum.y are written at
409 // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
410 // sum.z / sum.w are written at
411 // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
412 // i.e., same columns, different row.
413 const int outRow = mTile * kMTileSize + (laneId / 4);
414 const int outCol = nTile * kNTileSize + (laneId % 4) * 2;
415 #endif
416
417 // Pointer where sum.x / sum.y is written
418 auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol;
419
420 #if defined(USE_ROCM)
421 if (outRow < m)
422 cPtr[0] = __float2bfloat16(out.x);
423 if ((outRow + 1) < m)
424 cPtr[n] = __float2bfloat16(out.y);
425 if ((outRow + 2) < m)
426 cPtr[2*n] = __float2bfloat16(out.z);
427 if ((outRow + 3) < m)
428 cPtr[3*n] = __float2bfloat16(out.w);
429 #else
430 auto v01 = __float22bfloat162_rn(float2{out.x, out.y});
431 auto v23 = __float22bfloat162_rn(float2{out.z, out.w});
432
433 if (outRow < m) {
434 *reinterpret_cast<__nv_bfloat162*>(cPtr) = v01;
435 }
436
437 // sum.z, sum.w at +8 rows from cPtr
438 if (outRow + 8 < m) {
439 *reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23;
440 }
441 #endif
442 }
443 }
444 };
445
446 template <int InnerKTiles, int QGroupSize>
447 struct BLayout_TC_int4 {
448 static constexpr int32_t kInnerKTiles = InnerKTiles;
449 static constexpr int32_t kMTileSize = 16;
450 #if defined(USE_ROCM)
451 static constexpr int32_t kNTileSize = 16;
452 #else
453 static constexpr int32_t kNTileSize = 8;
454 #endif
455 static constexpr int32_t kKTileSize = 16;
456
457 template <int KTilesToLoad>
loadat::native::BLayout_TC_int4458 static __device__ void load(
459 // type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2]
460 // n-tiles: n / 8 for NV, n /16 for AMD
461 // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16 for NV, m16n16k16 for AMD)
462 // value per warp lane: 32 for NV, 64 for AMD
463 // (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile.
464 // 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest
465 // value) 4 k-tiles packed is a uint32x2 (64 bits) 8 k-tiles packed is a
466 // uint32x4 (128 bits)
467 const void* __restrict__ B,
468 // size [k / qGroupSize][n][2]
469 // Contains the scale and zero point of each of the quantized int4 values
470 // within B
471 // v_reconstructed = (bf16(B_int4_val) * scale) - zero
472 const void* __restrict__ quantizationInfo,
473 int32_t n,
474 int32_t k,
475 int32_t nTiles,
476 int32_t nTile,
477 int32_t kTiles,
478 int32_t kTileStart,
479 int32_t laneId,
480 bf16x2x4_u32 out[KTilesToLoad / InnerKTiles][InnerKTiles / 2]) {
481 // offset [nTile][kTileStart / InnerKTiles][laneId][0]
482 auto bPtr = reinterpret_cast<const int32_t*>(B) +
483 (((nTile * (kTiles / InnerKTiles) + (kTileStart / InnerKTiles)) *
484 kWarpSize) +
485 laneId) *
486 (InnerKTiles / 2);
487
488 int32_t b_int4[KTilesToLoad / InnerKTiles][InnerKTiles / 2];
489
490 #pragma unroll
491 for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) {
492 auto bPtrCur = bPtr + i * kWarpSize * (InnerKTiles / 2);
493
494 if constexpr (InnerKTiles == 2) {
495 b_int4[i][0] = bPtrCur[0];
496 }
497
498 if constexpr (InnerKTiles == 4) {
499 // asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n"
500 // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1])
501 // : "l"(bPtrCur));
502
503 int2 load8 = reinterpret_cast<const int2*>(bPtrCur)[0];
504 b_int4[i][0] = load8.x;
505 b_int4[i][1] = load8.y;
506 }
507
508 if constexpr (InnerKTiles == 8) {
509 // asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n"
510 // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]),
511 // "=r"(b_int4[i][2]), "=r"(b_int4[i][3]) : "l"(bPtrCur));
512
513 int4 load16 = reinterpret_cast<const int4*>(bPtrCur)[0];
514 b_int4[i][0] = load16.x;
515 b_int4[i][1] = load16.y;
516 b_int4[i][2] = load16.z;
517 b_int4[i][3] = load16.w;
518 }
519 }
520
521 // Load needed info for dequantization
522
523 static_assert(isPowerOf2(QGroupSize), "");
524 static_assert(isEvenDivisor(QGroupSize, kKTileSize), "");
525 // smallest quantization group size is 32 (2 k-tiles are packed in an int32)
526 static_assert(QGroupSize >= kKTileSize * 2, "");
527 constexpr int kKTilesPerQGroup = (QGroupSize / kKTileSize);
528 // a q-group could be larger than what we are handling in a single warp
529 constexpr int kNumQGroups = (KTilesToLoad / kKTilesPerQGroup) < 1
530 ? 1
531 : (KTilesToLoad / kKTilesPerQGroup);
532
533 __nv_bfloat162 qScaleAndZero[kNumQGroups];
534 {
535 #if defined(USE_ROCM)
536 int32_t laneN = nTile * kNTileSize + (laneId % kNTileSize);
537 #else
538 int32_t laneN = nTile * kNTileSize + (laneId / 4);
539 #endif
540 int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize;
541
542 int32_t n = nTiles * kNTileSize;
543
544 // offset [qScale_kGroup][qScale_n][0]
545 auto qInfoPtr = reinterpret_cast<const __nv_bfloat16*>(quantizationInfo) +
546 (groupStart * n + laneN) * 2;
547
548 #pragma unroll
549 for (int i = 0; i < kNumQGroups; ++i) {
550 qScaleAndZero[i] =
551 *reinterpret_cast<const __nv_bfloat162*>(qInfoPtr + i * n * 2);
552 }
553 }
554
555 //
556 // De-quantize int4 values to bf16. Values are dequantized as truly int4
557 // [-8, 7] range; dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
558 //
559 {
560 // FIXME: does this negatively affect register counts, or will nvcc
561 // move this expansion (and data loads above) closer to the point of use?
562 __nv_bfloat162 qScale[kNumQGroups];
563 __nv_bfloat162 qZero[kNumQGroups];
564
565 #pragma unroll
566 for (int i = 0; i < kNumQGroups; ++i) {
567 qScale[i] = __bfloat162bfloat162(qScaleAndZero[i].x);
568 qZero[i] = __bfloat162bfloat162(qScaleAndZero[i].y);
569 }
570
571 #pragma unroll
572 for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) {
573 #pragma unroll
574 for (int j = 0; j < InnerKTiles / 2; ++j) {
575 bf16x2x4 v = convert_i4x8_to_bf16x2x4(b_int4[i][j]);
576
577 int curKTile = i * InnerKTiles + j * 2;
578 int curQGroup = (curKTile * kKTileSize) / QGroupSize;
579
580 // The dequantized values in `v` for a given lane have the same n
581 // dimension (the B tensor core layout has all values in the same
582 // thread along the same n) but different k dimension, but all are
583 // guaranteed to occur within the same quantization group, so we need
584 // only load a single scale + zero to cover what this lane has
585 #pragma unroll
586 for (int k = 0; k < 4; ++k) {
587 v.vals[k] = __hfma2(v.vals[k], qScale[curQGroup], qZero[curQGroup]);
588 }
589
590 // type pun, the __nv_bfloat162 value in bf16x2x4 is a struct and
591 // can't be used as a 32-bit asm register argument for `mma`
592 static_assert(sizeof(bf16x2x4) == sizeof(out[0][0]), "");
593 std::memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32));
594 }
595 }
596 }
597 }
598 };
599
600 template <
601 typename ALayout,
602 typename BLayout,
603 typename CLayout,
604 int Warps,
605 int KTilesPerIteration>
606 __global__
__launch_bounds__(Warps * kWarpSize)607 __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
608 // Data for the A matrix, loaded as per ALayout
609 const void* const __restrict__ A,
610
611 // Data for the B matrix, loaded as per BLayout
612 const void* const __restrict__ B,
613
614 // Optional quantization data for dequantizing B, loaded as per BLayout
615 const void* const __restrict__ B_quantizationInfo,
616
617 // Output data for the C matrix, stored as per CLayout
618 void* __restrict__ C,
619
620 // The size of the matrix multiplication
621 int32_t m,
622 int32_t n,
623 int32_t k,
624
625 // The size of the matrix multiplication, in multiples of our TC tile size
626 int32_t mTiles,
627 int32_t nTiles,
628 int32_t kTiles) {
629 constexpr int32_t kMTileSize = 16;
630 #if defined(USE_ROCM)
631 constexpr int32_t kNTileSize = 16;
632 #else
633 constexpr int32_t kNTileSize = 8;
634 #endif
635 constexpr int32_t kKTileSize = 16;
636
637 #if !defined(USE_ROCM) || CDNA2_OR_LATER
638
639 static_assert(
640 ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize &&
641 ALayout::kKTileSize == kKTileSize,
642 "");
643
644 static_assert(
645 BLayout::kMTileSize == kMTileSize && BLayout::kNTileSize == kNTileSize &&
646 BLayout::kKTileSize == kKTileSize,
647 "");
648
649 static_assert(
650 CLayout::kMTileSize == kMTileSize && CLayout::kNTileSize == kNTileSize &&
651 CLayout::kKTileSize == kKTileSize,
652 "");
653
654 constexpr int kInnerKTiles = BLayout::kInnerKTiles;
655
656 // 2/4/8 inner k-tiles correspond to 4, 8 and 16 byte innermost loads
657 static_assert(
658 kInnerKTiles == 2 || kInnerKTiles == 4 || kInnerKTiles == 8, "");
659
660 // We always process at least kInnerKTiles k-tiles back to back in a warp
661 static_assert(
662 KTilesPerIteration >= kInnerKTiles &&
663 isEvenDivisor(KTilesPerIteration, kInnerKTiles),
664 "");
665
666 auto warpId = threadIdx.y;
667 auto laneId = threadIdx.x;
668
669 int32_t mTile = blockIdx.z;
670 int32_t nTile = blockIdx.y;
671
672 #if defined(USE_ROCM)
673 VecT<float, 4> c{0.0f, 0.0f, 0.0f, 0.0f};
674 #else
675 float4 c{0.0f, 0.0f, 0.0f, 0.0f};
676 #endif
677
678 // First, handle whole multiples of KTilesPerIteration
679 auto kTilesLimit = roundDown(kTiles, KTilesPerIteration);
680
681 // Each warp handles a set of KTilesPerIteration under the above limit
682 for (int32_t kTileBase = (blockIdx.x * Warps + warpId) * KTilesPerIteration;
683 kTileBase < kTilesLimit;
684 kTileBase += Warps * KTilesPerIteration) {
685 //
686 // Load data from A
687 //
688 #if defined(USE_ROCM)
689 bf16x2x2_u32 a[KTilesPerIteration];
690 #else
691 bf16x2x4_u32 a[KTilesPerIteration];
692 #endif
693 ALayout::template load<KTilesPerIteration>(
694 A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a);
695
696 //
697 // Load data from B and de-quantize as needed
698 // Each k-tile is bf16x2x2
699 //
700 bf16x2x4_u32 b[KTilesPerIteration / kInnerKTiles][kInnerKTiles / 2];
701 BLayout::template load<KTilesPerIteration>(
702 B,
703 B_quantizationInfo,
704 n,
705 k,
706 nTiles,
707 nTile,
708 kTiles,
709 kTileBase,
710 laneId,
711 b);
712
713 //
714 // Now, perform the matrix multiplication
715 //
716
717 // We accumulate across k-tiles here
718 #pragma unroll
719 for (int i = 0; i < KTilesPerIteration / kInnerKTiles; ++i) {
720 static_assert(isEvenDivisor(kInnerKTiles, 2) && kInnerKTiles >= 2, "");
721 #pragma unroll
722 for (int j = 0; j < kInnerKTiles / 2; ++j) {
723 // We don't simply accumulate into `c` as this creates a too-strong
724 // execution dependency. Instead, we only periodically accumulate into
725 // `c`
726 #if defined(USE_ROCM)
727 VecT<float, 4> cTmp[2];
728 #else
729 float4 cTmp[2];
730 #endif
731
732 #pragma unroll
733 for (int k = 0; k < 2; ++k) {
734 #if defined(USE_ROCM)
735 cTmp[k] = VecT<float, 4>{0.0f, 0.0f, 0.0f, 0.0f};
736 #else
737 cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
738 #endif
739 }
740
741 #pragma unroll
742 for (int k = 0; k < 2; ++k) {
743 #if defined(USE_ROCM)
744 cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
745 a[i * kInnerKTiles + j * 2 + k].val,
746 b[i][(j * 2 + k) / 2].val[((j * 2 + k) % 2)],
747 cTmp[k], 0, 0, 0);
748 #else
749 asm volatile(
750 "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
751 "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
752 : "=f"(cTmp[k].x),
753 "=f"(cTmp[k].y),
754 "=f"(cTmp[k].z),
755 "=f"(cTmp[k].w)
756 : "r"(a[i * kInnerKTiles + j * 2 + k].vals[0]),
757 "r"(a[i * kInnerKTiles + j * 2 + k].vals[1]),
758 "r"(a[i * kInnerKTiles + j * 2 + k].vals[2]),
759 "r"(a[i * kInnerKTiles + j * 2 + k].vals[3]),
760 "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]),
761 "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]),
762 "f"(cTmp[k].x),
763 "f"(cTmp[k].y),
764 "f"(cTmp[k].z),
765 "f"(cTmp[k].w));
766 #endif
767 }
768
769 #pragma unroll
770 for (int k = 0; k < 2; ++k) {
771 #if defined(USE_ROCM)
772 c[0] += cTmp[k][0];
773 c[1] += cTmp[k][1];
774 c[2] += cTmp[k][2];
775 c[3] += cTmp[k][3];
776 #else
777 c.x += cTmp[k].x;
778 c.y += cTmp[k].y;
779 c.z += cTmp[k].z;
780 c.w += cTmp[k].w;
781 #endif
782 }
783 }
784 }
785 } // for all tiles under kTilesLimit
786
787 // Now, there could be a remainder of 1 to KTilesPerIteration - 1 k-tiles
788 // remaining. We guarantee that the number of warps is >= KTilesPerIteration /
789 // kInnerKTiles, so that each warp can simply load kInnerKTiles and do its
790 // thing without needing more warps
791 static_assert(Warps >= KTilesPerIteration / kInnerKTiles, "");
792
793 auto kTileBaseRemaining = kTilesLimit + warpId * kInnerKTiles;
794
795 // If we have any remainder k-tiles, some warps will handle them, processing
796 // kInnerKTiles k-tiles at a time
797 if (kTileBaseRemaining < kTiles) {
798 #if defined(USE_ROCM)
799 bf16x2x2_u32 a[kInnerKTiles];
800 #else
801 bf16x2x4_u32 a[kInnerKTiles];
802 #endif
803 ALayout::template load<kInnerKTiles>(
804 A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, a);
805
806 bf16x2x4_u32 b[1][kInnerKTiles / 2];
807 BLayout::template load<kInnerKTiles>(
808 B,
809 B_quantizationInfo,
810 n,
811 k,
812 nTiles,
813 nTile,
814 kTiles,
815 kTileBaseRemaining,
816 laneId,
817 b);
818
819 #pragma unroll
820 for (int j = 0; j < kInnerKTiles / 2; ++j) {
821 // We don't simply accumulate into `c` as this creates a too-strong
822 // execution dependency. Instead, we only periodically accumulate into
823 // `c`
824 #if defined(USE_ROCM)
825 VecT<float, 4> cTmp[2];
826 #else
827 float4 cTmp[2];
828 #endif
829
830 #pragma unroll
831 for (int k = 0; k < 2; ++k) {
832 #if defined(USE_ROCM)
833 cTmp[k] = VecT<float, 4>{0.0f, 0.0f, 0.0f, 0.0f};
834 #else
835 cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
836 #endif
837 }
838
839 #pragma unroll
840 for (int k = 0; k < 2; ++k) {
841 #if defined(USE_ROCM)
842 cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
843 a[j * 2 + k].val,
844 b[0][(j * 2 + k) / 2].val[((j * 2 + k) % 2)],
845 cTmp[k], 0, 0, 0);
846 #else
847 asm volatile(
848 "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
849 "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
850 : "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w)
851 : "r"(a[j * 2 + k].vals[0]),
852 "r"(a[j * 2 + k].vals[1]),
853 "r"(a[j * 2 + k].vals[2]),
854 "r"(a[j * 2 + k].vals[3]),
855 "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]),
856 "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]),
857 "f"(cTmp[k].x),
858 "f"(cTmp[k].y),
859 "f"(cTmp[k].z),
860 "f"(cTmp[k].w));
861 #endif
862 }
863
864 #pragma unroll
865 for (int k = 0; k < 2; ++k) {
866 #if defined(USE_ROCM)
867 c[0] += cTmp[k][0];
868 c[1] += cTmp[k][1];
869 c[2] += cTmp[k][2];
870 c[3] += cTmp[k][3];
871 #else
872 c.x += cTmp[k].x;
873 c.y += cTmp[k].y;
874 c.z += cTmp[k].z;
875 c.w += cTmp[k].w;
876 #endif
877 }
878 }
879 }
880
881 //
882 // Reduce independent k-tiles (same m/n) across warps
883 //
884 __shared__ float4 smem_sum[Warps][kWarpSize];
885
886 // FIXME: this likely doesn't need to be a true reduction tree, can just be a
887 // serial sum, maybe (unless nvcc/ptxas goes back to its old ways)
888 // smem_sum[warpId][laneId] = TreeReduce4<KTilesPerIteration>::reduce(c);
889 #if defined(USE_ROCM)
890 smem_sum[warpId][laneId].x = c[0];
891 smem_sum[warpId][laneId].y = c[1];
892 smem_sum[warpId][laneId].z = c[2];
893 smem_sum[warpId][laneId].w = c[3];
894 #else
895 smem_sum[warpId][laneId] = c;
896 #endif
897
898 __syncthreads();
899
900 if (warpId == 0) {
901 float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f};
902
903 // Reduce across the block in the first warp
904 for (int i = 0; i < Warps; ++i) {
905 float4 v = smem_sum[i][laneId];
906 sum_f32.x += v.x;
907 sum_f32.y += v.y;
908 sum_f32.z += v.z;
909 sum_f32.w += v.w;
910 }
911
912 // Write the reduced result (in the first warp) into the output
913 CLayout::store(
914 C,
915 m,
916 n,
917 mTiles,
918 mTile,
919 // n for C output becomes k for A input, so for m16n8k16,
920 // we need to halve the tiles
921 nTiles / 2,
922 nTile,
923 laneId,
924 sum_f32);
925 }
926 #else
927 printf("__builtin_amdgcn_mfma_f32_16x16x16bf16_1k is only supported on AMD gpu arch greater than or equal to CDNA2\n");
928 #endif
929 }
930
931
932 template <
933 typename ALayout,
934 typename BLayout,
935 typename CLayout,
936 int Warps,
937 int KTilesPerWarp>
launch_tinygemm_kernel(const at::Tensor & A,const at::Tensor & B,const at::Tensor * qScaleAndZeros,at::Tensor & C_final,int32_t mTiles,int32_t nTiles,int32_t kTiles,int32_t m,int32_t n,int32_t k,cudaStream_t stream)938 void launch_tinygemm_kernel(
939 const at::Tensor& A,
940 const at::Tensor& B,
941 const at::Tensor* qScaleAndZeros, /* optional */
942 at::Tensor& C_final,
943 int32_t mTiles,
944 int32_t nTiles,
945 int32_t kTiles,
946 int32_t m,
947 int32_t n,
948 int32_t k,
949 cudaStream_t stream) {
950 // The chunking kernel requires that kTiles is a multiple of kInnerKTiles
951 TORCH_CHECK(
952 kTiles >= BLayout::kInnerKTiles &&
953 isEvenDivisor(kTiles, BLayout::kInnerKTiles));
954
955 TORCH_CHECK(
956 KTilesPerWarp >= BLayout::kInnerKTiles &&
957 isEvenDivisor(KTilesPerWarp, BLayout::kInnerKTiles));
958
959 // After intra-block reduction across the k dimension, we are left with this
960 // many tiles
961 // int32_t postKernelKTiles = kTiles / (Warps * KTilesPerWarp);
962 int32_t postKernelKTiles = 1; // we loop
963
964 auto grid = dim3(postKernelKTiles, nTiles, mTiles);
965 auto block = dim3(kWarpSize, Warps);
966
967 auto func =
968 tinygemm_m16n8k16_chunk_kernel<ALayout, BLayout, CLayout, Warps, KTilesPerWarp>;
969
970 func<<<grid, block, 0, stream>>>(
971 A.data_ptr(),
972 B.data_ptr(),
973 qScaleAndZeros ? qScaleAndZeros->data_ptr() : nullptr,
974 C_final.data_ptr(),
975 m,
976 n,
977 k,
978 mTiles,
979 nTiles,
980 kTiles);
981 C10_CUDA_KERNEL_LAUNCH_CHECK();
982
983 cudaFuncAttributes funcAttr;
984 #if defined(USE_ROCM)
985 C10_CUDA_CHECK(cudaFuncGetAttributes(
986 &funcAttr,
987 (void *)func
988 ));
989 #else
990 C10_CUDA_CHECK(cudaFuncGetAttributes(
991 &funcAttr,
992 func
993 ));
994 #endif
995 }
996
997 // FIXME: parallelize better, smem staging etc?
998 template <int InnerKTiles>
matrix_to_m16n8k16_Bint4_layout(const at::PackedTensorAccessor32<uint8_t,2,at::RestrictPtrTraits> in,at::PackedTensorAccessor32<int32_t,4,at::RestrictPtrTraits> out)999 __global__ void matrix_to_m16n8k16_Bint4_layout(
1000 // size [n][k / 2]
1001 const at::PackedTensorAccessor32<uint8_t, 2, at::RestrictPtrTraits> in,
1002 // size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2]
1003 at::PackedTensorAccessor32<int32_t, 4, at::RestrictPtrTraits> out) {
1004 // int4 values are packed into int32 values, which require at least 8. Given
1005 // m16n8k16 B layout requires 4 scalar values/lane, the minimum number of
1006 // innermost k-tiles that we can use is 2.
1007 static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), "");
1008
1009 #if defined(USE_ROCM)
1010 constexpr int32_t kNTileSize = 16;
1011 #else
1012 constexpr int32_t kNTileSize = 8;
1013 #endif
1014 constexpr int32_t kKTileSize = 16;
1015
1016 // gridDim.x corresponds to the number of k-tiles divided by InnerKTiles
1017 auto kOuterTile = blockIdx.x;
1018 auto nTile = blockIdx.y;
1019 auto t = threadIdx.x;
1020
1021 // Two k-tiles are packed into an int32 at a time
1022 #pragma unroll
1023 for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) {
1024 // n dimension that this lane loads from
1025 #if defined(USE_ROCM)
1026 auto n0 = nTile * kNTileSize + (t % kNTileSize);
1027 #else
1028 auto n0 = nTile * kNTileSize + (t / 4);
1029 #endif
1030
1031 bool n0Valid = n0 < in.size(0);
1032
1033 // Four uint8 are packed into an int32
1034 int32_t ks[4];
1035
1036 auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize / 2;
1037
1038 #if defined(USE_ROCM)
1039 ks[0] = kBase0 + (t / kNTileSize) * 2;
1040 ks[1] = ks[0] + 1;
1041
1042 auto kBase1 = kBase0 + kKTileSize / 2;
1043 ks[2] = kBase1 + (t / kNTileSize) * 2;
1044 ks[3] = ks[2] + 1;
1045 #else
1046 ks[0] = kBase0 + t % 4;
1047 ks[1] = ks[0] + 4;
1048
1049 auto kBase1 = kBase0 + kKTileSize / 2;
1050 ks[2] = kBase1 + t % 4;
1051 ks[3] = ks[2] + 4;
1052 #endif
1053
1054 auto pIn = &in[n0][0];
1055
1056 uint8_t v[4];
1057 #pragma unroll
1058 for (int i = 0; i < 4; ++i) {
1059 v[i] = (n0Valid && ks[i] < in.size(1)) ? pIn[ks[i]] : uint8_t(0);
1060 }
1061
1062 // To clearly explain the packed result with 8 int4 values (4 uint8)
1063 // into one int32, we use the follow figure:
1064 // [n][k] int32: v[0] v[1] v[2] v[3] v[4] v[5] v[6] v[7]
1065 // [n][k / 2] uint8: v[0] v[1] v[2] v[3]
1066 // When using int32 weight as input, the packed result is consisted of
1067 // v[7] | v[5] | v[3] | v[1] | v[6] | v[4] | v[2] | v[0],
1068 // which epuals to
1069 // v[3]L | v[2]L | v[1]L | v[0]L | v[3]H | v[2]H | v[1]H | v[0]H
1070 // when using uint8 weight as input.
1071 int32_t pack = ((uint32_t)(v[3] & 0xF) << 28) |
1072 ((uint32_t)(v[2] & 0xF) << 24) | ((uint32_t)(v[1] & 0xF) << 20) |
1073 ((uint32_t)(v[0] & 0xF) << 16) | ((uint32_t)(v[3] & 0xF0) << 8) |
1074 ((uint32_t)(v[2] & 0xF0) << 4) | ((uint32_t)(v[1] & 0xF0)) |
1075 ((uint32_t)(v[0] & 0xF0) >> 4);
1076
1077 // inner k-tiles pack two at a time
1078 #if defined(USE_ROCM)
1079 // The output tensor shape is [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2], which is specific to Nvidia
1080 // But AMD needs [ceil(n / 16)][ceil(k / (InnerKTiles * 16))][64][InnerKTiles / 2]
1081 // So construct the pointer accordingly
1082 auto bPtr = out.data() +
1083 ((nTile * out.size(1) * kWarpSize * (InnerKTiles / 2)) +
1084 (kOuterTile * kWarpSize * (InnerKTiles / 2)) +
1085 (t * (InnerKTiles / 2)) +
1086 (innerKTile / 2));
1087 *bPtr = pack;
1088 #else
1089 out[nTile][kOuterTile][t][innerKTile / 2] = pack;
1090 #endif
1091 }
1092 }
1093
1094 #endif
1095
1096
_weight_int4pack_mm_cuda(const at::Tensor & A,const at::Tensor & B,int64_t qGroupSize,const at::Tensor & qScaleAndZeros)1097 at::Tensor _weight_int4pack_mm_cuda(
1098 const at::Tensor& A,
1099 const at::Tensor& B,
1100 int64_t qGroupSize,
1101 const at::Tensor& qScaleAndZeros) {
1102 c10::cuda::CUDAGuard g(A.device());
1103
1104 TORCH_CHECK(
1105 A.device() == B.device() && A.device() == qScaleAndZeros.device());
1106
1107 #if defined(USE_ROCM)
1108 if (!isCDNA2orLater(A.device().index())) {
1109 TORCH_CHECK(false, "_weight_int4pack_mm_cuda is only supported on AMD gpu arch greater than or equal to CDNA2");
1110 }
1111 #endif
1112
1113 constexpr int32_t kMTileSize = 16;
1114 #if defined(USE_ROCM)
1115 constexpr int32_t kNTileSize = 16;
1116 #else
1117 constexpr int32_t kNTileSize = 8;
1118 #endif
1119 constexpr int32_t kKTileSize = 16;
1120
1121 // row major layout
1122 auto m = A.size(0);
1123 auto mTiles = divUp(m, kMTileSize);
1124
1125 // To convert the nTiles from tensor storage layout to the actual matrix core layout
1126 constexpr int32_t kNTileSizeTensor = 8;
1127 auto nTileScaleFactor = (kNTileSize / kNTileSizeTensor);
1128
1129 // tensor core layout
1130 auto nTiles = (B.size(0) / nTileScaleFactor);
1131 auto n = nTiles * kNTileSize;
1132
1133 // row major layout
1134 auto k = A.size(1);
1135 auto kTiles = divUp(k, kKTileSize);
1136
1137 // The number of inner k tiles is the innermost dimension of times 2
1138 // 2 k-tiles (4 values per lane per tile, 8 values total) quantized to int4
1139 // packed into 1 int32 for int4 B
1140 auto B_innerKTiles = B.size(3) * 2;
1141 TORCH_CHECK(B_innerKTiles == 2 || B_innerKTiles == 4 || B_innerKTiles == 8);
1142
1143 // A is standard row major
1144 TORCH_CHECK(A.dtype() == at::kBFloat16);
1145 TORCH_CHECK(A.is_contiguous());
1146 TORCH_CHECK(A.dim() == 2);
1147
1148 // B has B_innerKTiles k-tiles in the innermost dimension
1149 TORCH_CHECK(B.dtype() == at::kInt);
1150 TORCH_CHECK(B.is_contiguous());
1151 TORCH_CHECK(B.dim() == 4);
1152 TORCH_CHECK(B.size(1) == k / (B_innerKTiles * kKTileSize));
1153 TORCH_CHECK(B.size(2) == 32);
1154
1155 // Validate the scale and zero point tensor for dequantization
1156 // These are the only versions handled at the moment
1157 TORCH_CHECK(
1158 qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 ||
1159 qGroupSize == 256);
1160
1161 TORCH_CHECK(qScaleAndZeros.dim() == 3);
1162 auto numQGroups = qScaleAndZeros.size(0);
1163 TORCH_CHECK(
1164 kTiles * kKTileSize >= qGroupSize &&
1165 isEvenDivisor(kTiles * kKTileSize, qGroupSize));
1166 TORCH_CHECK(qScaleAndZeros.size(1) == n);
1167 TORCH_CHECK(qScaleAndZeros.size(2) == 2);
1168
1169 // Output is a standard row-major matrix
1170 auto C_final = at::empty(
1171 {m, n}, at::TensorOptions().dtype(at::kBFloat16).device(A.device()));
1172
1173 #if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
1174 auto stream = at::cuda::getCurrentCUDAStream();
1175 #define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \
1176 do { \
1177 using ACLayout = ALayout_RM<REDUCE_TYPE>; \
1178 \
1179 TORCH_CHECK( \
1180 K_TILES_PER_WARP >= B_innerKTiles && \
1181 isEvenDivisor(K_TILES_PER_WARP, B_innerKTiles)); \
1182 \
1183 switch (B_innerKTiles) { \
1184 case 2: \
1185 if constexpr (K_TILES_PER_WARP >= 2) { \
1186 using BLayout = BLayout_TC_int4<2, Q_GROUP_SIZE>; \
1187 launch_tinygemm_kernel< \
1188 ACLayout, \
1189 BLayout, \
1190 ACLayout, \
1191 WARPS, \
1192 K_TILES_PER_WARP>( \
1193 A, \
1194 B, \
1195 &qScaleAndZeros, \
1196 C_final, \
1197 mTiles, \
1198 nTiles, \
1199 kTiles, \
1200 m, \
1201 n, \
1202 k, \
1203 stream); \
1204 } \
1205 break; \
1206 case 4: \
1207 if constexpr (K_TILES_PER_WARP >= 4) { \
1208 using BLayout = BLayout_TC_int4<4, Q_GROUP_SIZE>; \
1209 launch_tinygemm_kernel< \
1210 ACLayout, \
1211 BLayout, \
1212 ACLayout, \
1213 WARPS, \
1214 K_TILES_PER_WARP>( \
1215 A, \
1216 B, \
1217 &qScaleAndZeros, \
1218 C_final, \
1219 mTiles, \
1220 nTiles, \
1221 kTiles, \
1222 m, \
1223 n, \
1224 k, \
1225 stream); \
1226 } \
1227 break; \
1228 case 8: \
1229 if constexpr (K_TILES_PER_WARP >= 8) { \
1230 using BLayout = BLayout_TC_int4<8, Q_GROUP_SIZE>; \
1231 launch_tinygemm_kernel< \
1232 ACLayout, \
1233 BLayout, \
1234 ACLayout, \
1235 WARPS, \
1236 K_TILES_PER_WARP>( \
1237 A, \
1238 B, \
1239 &qScaleAndZeros, \
1240 C_final, \
1241 mTiles, \
1242 nTiles, \
1243 kTiles, \
1244 m, \
1245 n, \
1246 k, \
1247 stream); \
1248 } \
1249 break; \
1250 default: \
1251 break; \
1252 } \
1253 } while (false)
1254
1255 #define HANDLE_Q_GROUP(WARPS, K_TILES_PER_WARP, REDUCE_TYPE) \
1256 do { \
1257 switch (qGroupSize) { \
1258 case 32: \
1259 RUN_GEMM(WARPS, K_TILES_PER_WARP, 32, REDUCE_TYPE); \
1260 break; \
1261 case 64: \
1262 RUN_GEMM(WARPS, K_TILES_PER_WARP, 64, REDUCE_TYPE); \
1263 break; \
1264 case 128: \
1265 RUN_GEMM(WARPS, K_TILES_PER_WARP, 128, REDUCE_TYPE); \
1266 break; \
1267 case 256: \
1268 RUN_GEMM(WARPS, K_TILES_PER_WARP, 256, REDUCE_TYPE); \
1269 break; \
1270 } \
1271 } while (false)
1272
1273 HANDLE_Q_GROUP(8, 8, KReductionType::None);
1274
1275 #undef HANDLE_Q_GROUP
1276 #undef RUN_GEMM
1277
1278 return C_final;
1279 #endif
1280 TORCH_CHECK(false, "_weight_int4pack_mm_cuda is not available for build.")
1281 return C_final;
1282 }
1283
1284 // input is [n][k / 2] (uint8 dtype)
1285 // output is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] (int32 dtype)
_convert_weight_to_int4pack_cuda(const at::Tensor & in,int64_t innerKTiles)1286 at::Tensor _convert_weight_to_int4pack_cuda(
1287 const at::Tensor& in,
1288 int64_t innerKTiles) {
1289 c10::cuda::CUDAGuard g(in.device());
1290
1291 TORCH_CHECK(in.dim() == 2);
1292 TORCH_CHECK(in.dtype() == at::kByte);
1293 TORCH_CHECK(in.is_contiguous());
1294
1295 // At least 2 k-tiles need to be packed back to back in the innermost
1296 // dimension, as the m16n8k16 tensor core tile presents 4 scalar values for
1297 // the B matrix, but the minimum word size for the packed format is 4 bytes
1298 // (int32). 4 inner K-tiles = 8 byte load, 8 inner k-tiles = 16 byte load
1299 // which is the maximum vectorized load/store size
1300 TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8);
1301
1302 #if defined(USE_ROCM)
1303 if (!isCDNA2orLater(in.device().index())) {
1304 TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is only supported on AMD gpu arch greater than or equal to CDNA2");
1305 }
1306 #endif
1307
1308 #if defined(USE_ROCM)
1309 constexpr int32_t kNTileSize = 16;
1310 #else
1311 constexpr int32_t kNTileSize = 8;
1312 #endif
1313 constexpr int32_t kKTileSize = 16;
1314
1315 // GPT-FAST assumes nTileSize of 8 for quantized weight tensor.
1316 // See https://github.com/pytorch-labs/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510
1317 // Torch dynamo also requires the torch ops has the same output shape for each device.
1318 // See https://github.com/pytorch/pytorch/blob/ec284d3a74ec1863685febd53687d491fd99a161/torch/_meta_registrations.py#L3263
1319 constexpr int32_t kNTileSizeTensor = 8;
1320
1321 auto nTiles = divUp(in.size(0), kNTileSize);
1322 auto nTilesTensor = divUp(in.size(0), kNTileSizeTensor);
1323
1324 // k-tiles are packed back to back in the innermost dimension in order to
1325 // allow for 4/8/16 byte loads
1326 TORCH_CHECK(isEvenDivisor(in.size(1) * 2, innerKTiles * kKTileSize));
1327 // kSuperTiles is the number of k-tiles assuming k is innerKTiles * kKTileSize
1328 auto kSuperTiles = divUp(in.size(1) * 2, innerKTiles * kKTileSize);
1329
1330 // each block handles `innerKTiles` k-tiles.
1331 // 2 k-tiles are a single int32
1332 //
1333 // We use the same shape for AMD gpus also to match the GPT-FAST spec.
1334 // Will index it correctly when dereferencing the quantized weight tensor pointer.
1335 auto out = at::empty(
1336 {nTilesTensor, kSuperTiles, 32, innerKTiles / 2},
1337 at::TensorOptions().dtype(at::kInt).device(in.device()));
1338
1339 #if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
1340 auto stream = at::cuda::getCurrentCUDAStream();
1341 dim3 grid(kSuperTiles, nTiles);
1342
1343 if (innerKTiles == 2) {
1344 matrix_to_m16n8k16_Bint4_layout<2><<<grid, kWarpSize, 0, stream>>>(
1345 in.packed_accessor32<uint8_t, 2, at::RestrictPtrTraits>(),
1346 out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
1347 } else if (innerKTiles == 4) {
1348 matrix_to_m16n8k16_Bint4_layout<4><<<grid, kWarpSize, 0, stream>>>(
1349 in.packed_accessor32<uint8_t, 2, at::RestrictPtrTraits>(),
1350 out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
1351 } else if (innerKTiles == 8) {
1352 matrix_to_m16n8k16_Bint4_layout<8><<<grid, kWarpSize, 0, stream>>>(
1353 in.packed_accessor32<uint8_t, 2, at::RestrictPtrTraits>(),
1354 out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
1355 }
1356
1357 return out;
1358 #endif
1359 TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is not available for build.")
1360 return out;
1361 }
1362
1363
1364 } // namespace at::native
1365