1# mypy: allow-untyped-defs 2import torch 3 4 5# Pack pairs of int4 values into int8, in row major order; first int4 6# value goes into lower order bits, and second int4 value into higher 7# order bits of resulting int8 value. 8def pack_int4_to_int8(weight): 9 assert weight.dim() == 2 10 assert weight.shape[1] % 2 == 0 11 assert weight.dtype == torch.int8 12 return ((weight[:, 1::2] & 0xF) << 4) | (weight[:, 0::2] & 0xF) 13 14 15# Unpack quandruples of bits in int8 values into int4 values, in row 16# major order; lower 4 bits go into first int4 value goes, and upper 4 17# bits go into second int4 value. 18def unpack_int8_to_int4(weight): 19 assert weight.dim() == 2 20 assert weight.dtype == torch.int8 21 return torch.stack((weight & 0xF, (weight >> 4) & 0xF), dim=2).view( 22 weight.shape[0], 2 * weight.shape[1] 23 ) 24 25 26# Transpose the weight matrix, and then reorder its elements according 27# to underlying requirements of CUTLASS library, so that it could be 28# used for CUTLASS-based mixed datatypes linear operation. 29def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( 30 weight, dtypeq, transpose=False 31): 32 assert weight.dim() == 2 33 assert weight.dtype == torch.int8 34 assert dtypeq == torch.int8 or dtypeq == torch.quint4x2 35 assert weight.device.type == "cuda" 36 37 device = weight.device 38 39 # subbyte_transpose 40 if not transpose: 41 if dtypeq == torch.int8: 42 outp = weight.T 43 elif dtypeq == torch.quint4x2: 44 outp = pack_int4_to_int8(unpack_int8_to_int4(weight.view(torch.int8)).T) 45 else: 46 outp = weight 47 48 ncols, nrows = outp.shape # type: ignore[possibly-undefined] 49 assert nrows % (32 if dtypeq == torch.quint4x2 else 64) == 0 50 assert ncols % 64 == 0 51 52 # permute_B_rows_for_mixed_gemm 53 # (permute cols actually, as transpose is applied first here) 54 if dtypeq == torch.quint4x2: 55 cols_permuted = ( 56 torch.tensor( 57 [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15], 58 device=device, 59 ) 60 + (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand( 61 nrows // 16, 16 62 ) 63 ).view(-1) 64 else: 65 cols_permuted = ( 66 torch.tensor( 67 [0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15], 68 device=device, 69 ) 70 + (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand( 71 nrows // 16, 16 72 ) 73 ).view(-1) 74 outp = outp.index_copy(1, cols_permuted, outp) 75 76 # interleave_column_major_tensor 77 magic0 = 4 if dtypeq == torch.quint4x2 else 2 78 magic1 = 32 // magic0 79 80 tmp0 = ( 81 (torch.arange(0, ncols // magic0, device=device) * (nrows // 4 * magic0)) 82 .view(-1, 1) 83 .repeat(1, nrows // 4 * magic0) 84 .view(-1) 85 ) 86 tmp1 = ( 87 (torch.arange(0, nrows // 4 // magic1, device=device) * (magic0 * magic1)) 88 .view(-1, 1) 89 .repeat(1, magic1) 90 .view(-1) 91 .repeat(ncols) 92 ) 93 tmp2 = ( 94 (torch.arange(0, magic0, device=device) * magic1) 95 .view(-1, 1) 96 .repeat(1, nrows // 4) 97 .view(-1) 98 .repeat(ncols // magic0) 99 ) 100 tmp3 = torch.arange(0, magic1, device=device).repeat(nrows // 4 * ncols // magic1) 101 102 outp_offsets = tmp0 + tmp1 + tmp2 + tmp3 103 104 tmp = outp.view(-1).view(torch.int32) 105 outp = torch.zeros_like(tmp) 106 outp.scatter_(0, outp_offsets, tmp) 107 outp = outp.view(weight.dtype) 108 109 # add_bias_and_interleave_quantized_tensor_inplace 110 tmp = outp.view(-1) 111 112 outp = torch.empty_like(tmp) 113 if dtypeq == torch.int8: 114 tmp = (tmp.to(torch.int) + 128).to(tmp.dtype) 115 outp[0::4] = tmp[0::4] 116 outp[1::4] = tmp[2::4] 117 outp[2::4] = tmp[1::4] 118 outp[3::4] = tmp[3::4] 119 elif dtypeq == torch.quint4x2: 120 tmp0 = ((tmp & 0xF) + 8) & 0xF 121 tmp0 = (tmp0[1::2] << 4) | tmp0[0::2] 122 tmp1 = (((tmp >> 4) & 0xF) + 8) & 0xF 123 tmp1 = (tmp1[1::2] << 4) | tmp1[0::2] 124 outp[0::4] = tmp0[0::2] 125 outp[1::4] = tmp0[1::2] 126 outp[2::4] = tmp1[0::2] 127 outp[3::4] = tmp1[1::2] 128 129 if dtypeq == torch.quint4x2: 130 nrows *= 2 131 ncols //= 2 132 133 return outp.view(nrows, ncols).view(torch.uint8) 134