1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# Define a Tensor subclass to wrap around ggml q4_0 tensor layout. 8# The layout is the following: 9# ┌─────────────────────┬───────────────────────────┐ 10# │ │ │ 11# │ │ │ 12# │ 2 bytes (1xfp16) │ 16 bytes (32xint4) │ 13# │ group-wise scale │ group-wise weights │ 14# │ │ │ 15# │ │ │ 16# └─────────────────────┴───────────────────────────┘ 17# 18# Notice that the 16 bytes (32 int4) are interleved: 19# [0th value, 16th value, 1st value, 17th value, ..., 15th, 31st] 20# 21# This layout is handled internally in the tensor subclass. 22import torch 23from torchao.quantization.subclass import QuantizedLinearWeightBase 24 25 26def down_size(size): 27 assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" 28 return (*size[:-1], size[-1] // 2) 29 30 31def up_size(size): 32 return (*size[:-1], size[-1] * 2) 33 34 35def pack_uint4(uint8_data) -> torch.Tensor: 36 # converting to uint8 for operations 37 shape = uint8_data.shape 38 assert shape[-1] % 2 == 0 39 uint8_data = uint8_data.contiguous().view(-1) 40 return (uint8_data[1::2] << 4 | uint8_data[::2]).view(down_size(shape)) 41 42 43def unpack_uint4(uint8_data) -> torch.Tensor: 44 """Get the original weight from the normalized float weight format""" 45 # since we are using uint8 we will decode 2 entries per byte 46 # Shift elements down 4 and select out the bottom 4 bits 47 shape = uint8_data.shape 48 first_elements = (uint8_data & 0b1111).to(torch.uint8) 49 second_elements = (uint8_data >> 4).to(torch.uint8) 50 return torch.stack([first_elements, second_elements], dim=-1).view(up_size(shape)) 51 52 53def _pack_to_two_uint8(scale: torch.Tensor) -> torch.Tensor: 54 raw_bytes = scale.numpy().tobytes() 55 scale_uint8 = torch.frombuffer(raw_bytes, dtype=torch.uint8) 56 scale_uint8 = scale_uint8.view(-1, 2) 57 return scale_uint8 58 59 60def _unpack_two_uint8( 61 tensor: torch.Tensor, 62) -> torch.Tensor: 63 assert ( 64 tensor.dtype == torch.uint8 65 ), f"Expecting to have a uint8 tensor but get {tensor.dtype}" 66 raw_bytes = tensor.numpy().tobytes() 67 scale = torch.frombuffer(raw_bytes, dtype=torch.float16) 68 return scale 69 70 71def _interleave( 72 input: torch.Tensor, 73 group_size, 74) -> torch.Tensor: 75 half1 = input[:, : group_size // 2] 76 half2 = input[:, group_size // 2 :] 77 interleaved_tensor = torch.stack((half1, half2), dim=2) 78 return interleaved_tensor.view(input.size(0), -1) 79 80 81def from_float( 82 input: torch.Tensor, 83) -> torch.Tensor: 84 """ 85 Quantize similar to GGUF's Q4_0 quantization. Group into size of 86 32 and generate a uint8 tensor. One group will result into 18 uint8s 87 consisting of: 88 - 1 scale (float16 represented as 2 uint8 elements) 89 - 32 4-bit elements (represented as 16 uint8 elements) 90 """ 91 group_size = 32 92 zero_point = 8.5 93 # pyre-fixme[16]: Callable input has no attribute dtype. 94 assert input.dtype == torch.float16, f"Expecting float16 input, got {input.dtype}" 95 assert ( 96 input.numel() % group_size 97 == 0 98 # pyre-fixme[16]: Callable input has no attribute numel. 99 ), f"The number of input values has to be a multiple of {group_size} but got {input.numel()}" 100 input = input.reshape(-1, group_size) 101 abs_max_id = torch.argmax(torch.abs(input), dim=1) 102 scales = input[torch.arange(input.size(0)), abs_max_id] / -8 103 inv_scales = torch.div(1.0, scales.to(torch.float32)) 104 105 clamped = torch.clamp( 106 input=torch.floor(inv_scales.unsqueeze(1) * input + zero_point), 107 min=0, 108 max=15, 109 ).to(torch.uint8) 110 alternate = _interleave(clamped, group_size) 111 return torch.cat([_pack_to_two_uint8(scales), pack_uint4(alternate)], dim=1) 112 113 114def to_float( 115 input: torch.Tensor, 116) -> torch.Tensor: 117 """ 118 Dequantize GGUF's Q4_0 tensor. Expecting input to be a uint8 tensor 119 with a dimension of [num_group // 2, 18], the first 2 values of each 120 row represents the scale of that group. 121 """ 122 zero_point = 8 123 data_unint8 = input[:, 2:] 124 data = unpack_uint4(data_unint8) 125 assert data.dtype == torch.uint8 126 interleave = torch.cat([data[:, ::2], data[:, 1::2]], dim=1) 127 scale = _unpack_two_uint8(input[:, :2]) 128 a = interleave.to(torch.float16) - zero_point 129 return a * scale.unsqueeze(1) 130 131 132class GGMLInt4LinearWeight(QuantizedLinearWeightBase): 133 """ 134 A Tensor subclass that when applied to a weight used in a linear op/module, 135 changes that linear op to a weight-only int4 quantized linear op with groupwise 136 affine quantization on the weight. 137 """ 138 139 @staticmethod 140 def __new__( 141 cls, 142 int_data, 143 scales, 144 shape, 145 **kwargs, 146 ): 147 kwargs["dtype"] = kwargs.get("dtype", scales.dtype) 148 return super().__new__(cls, int_data, transposed=False, shape=shape, **kwargs) # type: ignore[attr-defined] 149 150 def __init__( 151 self, 152 int_data, 153 scales, 154 shape, 155 **kwargs, 156 ): 157 # the transposed flag tracks whether the tensor subclass has been transposed relative 158 # to how a weight is normally stored in a linear i.e. [out_features, in_features]. 159 # tracking both transposed and shape is slightly redundant but corner cases like 160 # square matrices can cause issues otherwise 161 self.scales = scales 162 self.groupsize = 32 163 self.zero_point = torch.tensor(8.5, dtype=torch.float) 164 super().__init__(int_data, transposed=False) 165 166 def int_repr(self): 167 return self.int_data 168 169 def q_params(self): 170 return {"q_scales": self.scales, "q_zero_points": self.zero_point} 171 172 def to(self, *args, **kwargs): 173 kwargs = self._get_to_kwargs(*args, **kwargs) 174 return self.__class__( 175 self.int_data.to(kwargs["device"]), 176 self.scales.to(kwargs["device"]), 177 self.shape, 178 **kwargs, 179 ) 180 181 def _apply_fn_to_data(self, fn): 182 return self.__class__( 183 fn(self.int_data), 184 fn(self.scales), 185 self.shape, 186 dtype=self.dtype, 187 ) 188 189 def __tensor_flatten__(self): 190 return ["int_data", "scales"], ( 191 self.dtype, 192 self.shape, 193 ) 194 195 @classmethod 196 def __tensor_unflatten__( 197 cls, tensor_data_dict, attributes, outer_size=None, outer_stride=None 198 ): 199 int_data, scales = ( 200 tensor_data_dict["int_data"], 201 tensor_data_dict["scales"], 202 ) 203 dtype, shape = attributes 204 return cls( 205 int_data, 206 scales, 207 shape if outer_size is None else outer_size, 208 dtype=dtype, 209 ) 210 211 @staticmethod 212 def _quantized_op(act_mat, w_qtensor, bias): 213 """ 214 This is the quantized linear op that is used to implement the weight-only 215 int4 quantized linear op. 216 """ 217 assert isinstance( 218 w_qtensor, GGMLInt4LinearWeight 219 ), f"Expect {w_qtensor} to be an instance of GGMLInt4LinearWeight but got {type(w_qtensor)}" 220 fp_weight = to_float(w_qtensor.int_data).view(w_qtensor.shape) 221 return torch.nn.functional.linear(act_mat, fp_weight, bias) 222 223 @classmethod 224 def from_float(cls, input_float): 225 """ 226 Method used to convert a linear weight tensor to an instance of the 227 GGMLInt4LinearWeight subclass. 228 229 Example usage:: 230 231 model.lin_mod.weight = ( 232 GGMLInt4LinearWeight.from_float(model.lin_mod.weight) 233 ) 234 """ 235 packed = from_float(input_float) 236 scale = torch.tensor(_unpack_two_uint8(packed[:, :2]), dtype=torch.float16) 237 return cls( 238 packed, 239 scale, 240 input_float.shape, 241 dtype=torch.float16, 242 ) 243