xref: /aosp_15_r20/external/executorch/examples/models/llama/experimental/subclass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# Define a Tensor subclass to wrap around ggml q4_0 tensor layout.
8# The layout is the following:
9# ┌─────────────────────┬───────────────────────────┐
10# │                     │                           │
11# │                     │                           │
12# │  2 bytes (1xfp16)   │    16 bytes (32xint4)     │
13# │  group-wise scale   │    group-wise weights     │
14# │                     │                           │
15# │                     │                           │
16# └─────────────────────┴───────────────────────────┘
17#
18# Notice that the 16 bytes (32 int4) are interleved:
19# [0th value, 16th value, 1st value, 17th value, ..., 15th, 31st]
20#
21# This layout is handled internally in the tensor subclass.
22import torch
23from torchao.quantization.subclass import QuantizedLinearWeightBase
24
25
26def down_size(size):
27    assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
28    return (*size[:-1], size[-1] // 2)
29
30
31def up_size(size):
32    return (*size[:-1], size[-1] * 2)
33
34
35def pack_uint4(uint8_data) -> torch.Tensor:
36    # converting to uint8 for operations
37    shape = uint8_data.shape
38    assert shape[-1] % 2 == 0
39    uint8_data = uint8_data.contiguous().view(-1)
40    return (uint8_data[1::2] << 4 | uint8_data[::2]).view(down_size(shape))
41
42
43def unpack_uint4(uint8_data) -> torch.Tensor:
44    """Get the original weight from the normalized float weight format"""
45    # since we are using uint8 we will decode 2 entries per byte
46    # Shift elements down 4 and select out the bottom 4 bits
47    shape = uint8_data.shape
48    first_elements = (uint8_data & 0b1111).to(torch.uint8)
49    second_elements = (uint8_data >> 4).to(torch.uint8)
50    return torch.stack([first_elements, second_elements], dim=-1).view(up_size(shape))
51
52
53def _pack_to_two_uint8(scale: torch.Tensor) -> torch.Tensor:
54    raw_bytes = scale.numpy().tobytes()
55    scale_uint8 = torch.frombuffer(raw_bytes, dtype=torch.uint8)
56    scale_uint8 = scale_uint8.view(-1, 2)
57    return scale_uint8
58
59
60def _unpack_two_uint8(
61    tensor: torch.Tensor,
62) -> torch.Tensor:
63    assert (
64        tensor.dtype == torch.uint8
65    ), f"Expecting to have a uint8 tensor but get {tensor.dtype}"
66    raw_bytes = tensor.numpy().tobytes()
67    scale = torch.frombuffer(raw_bytes, dtype=torch.float16)
68    return scale
69
70
71def _interleave(
72    input: torch.Tensor,
73    group_size,
74) -> torch.Tensor:
75    half1 = input[:, : group_size // 2]
76    half2 = input[:, group_size // 2 :]
77    interleaved_tensor = torch.stack((half1, half2), dim=2)
78    return interleaved_tensor.view(input.size(0), -1)
79
80
81def from_float(
82    input: torch.Tensor,
83) -> torch.Tensor:
84    """
85    Quantize similar to GGUF's Q4_0 quantization. Group into size of
86    32 and generate a uint8 tensor. One group will result into 18 uint8s
87    consisting of:
88      - 1 scale (float16 represented as 2 uint8 elements)
89      - 32 4-bit elements (represented as 16 uint8 elements)
90    """
91    group_size = 32
92    zero_point = 8.5
93    # pyre-fixme[16]: Callable input has no attribute dtype.
94    assert input.dtype == torch.float16, f"Expecting float16 input, got {input.dtype}"
95    assert (
96        input.numel() % group_size
97        == 0
98        # pyre-fixme[16]: Callable input has no attribute numel.
99    ), f"The number of input values has to be a multiple of {group_size} but got {input.numel()}"
100    input = input.reshape(-1, group_size)
101    abs_max_id = torch.argmax(torch.abs(input), dim=1)
102    scales = input[torch.arange(input.size(0)), abs_max_id] / -8
103    inv_scales = torch.div(1.0, scales.to(torch.float32))
104
105    clamped = torch.clamp(
106        input=torch.floor(inv_scales.unsqueeze(1) * input + zero_point),
107        min=0,
108        max=15,
109    ).to(torch.uint8)
110    alternate = _interleave(clamped, group_size)
111    return torch.cat([_pack_to_two_uint8(scales), pack_uint4(alternate)], dim=1)
112
113
114def to_float(
115    input: torch.Tensor,
116) -> torch.Tensor:
117    """
118    Dequantize GGUF's Q4_0 tensor. Expecting input to be a uint8 tensor
119    with a dimension of [num_group // 2, 18], the first 2 values of each
120    row represents the scale of that group.
121    """
122    zero_point = 8
123    data_unint8 = input[:, 2:]
124    data = unpack_uint4(data_unint8)
125    assert data.dtype == torch.uint8
126    interleave = torch.cat([data[:, ::2], data[:, 1::2]], dim=1)
127    scale = _unpack_two_uint8(input[:, :2])
128    a = interleave.to(torch.float16) - zero_point
129    return a * scale.unsqueeze(1)
130
131
132class GGMLInt4LinearWeight(QuantizedLinearWeightBase):
133    """
134    A Tensor subclass that when applied to a weight used in a linear op/module,
135    changes that linear op to a weight-only int4 quantized linear op with groupwise
136    affine quantization on the weight.
137    """
138
139    @staticmethod
140    def __new__(
141        cls,
142        int_data,
143        scales,
144        shape,
145        **kwargs,
146    ):
147        kwargs["dtype"] = kwargs.get("dtype", scales.dtype)
148        return super().__new__(cls, int_data, transposed=False, shape=shape, **kwargs)  # type: ignore[attr-defined]
149
150    def __init__(
151        self,
152        int_data,
153        scales,
154        shape,
155        **kwargs,
156    ):
157        # the transposed flag tracks whether the tensor subclass has been transposed relative
158        # to how a weight is normally stored in a linear i.e. [out_features, in_features].
159        # tracking both transposed and shape is slightly redundant but corner cases like
160        # square matrices can cause issues otherwise
161        self.scales = scales
162        self.groupsize = 32
163        self.zero_point = torch.tensor(8.5, dtype=torch.float)
164        super().__init__(int_data, transposed=False)
165
166    def int_repr(self):
167        return self.int_data
168
169    def q_params(self):
170        return {"q_scales": self.scales, "q_zero_points": self.zero_point}
171
172    def to(self, *args, **kwargs):
173        kwargs = self._get_to_kwargs(*args, **kwargs)
174        return self.__class__(
175            self.int_data.to(kwargs["device"]),
176            self.scales.to(kwargs["device"]),
177            self.shape,
178            **kwargs,
179        )
180
181    def _apply_fn_to_data(self, fn):
182        return self.__class__(
183            fn(self.int_data),
184            fn(self.scales),
185            self.shape,
186            dtype=self.dtype,
187        )
188
189    def __tensor_flatten__(self):
190        return ["int_data", "scales"], (
191            self.dtype,
192            self.shape,
193        )
194
195    @classmethod
196    def __tensor_unflatten__(
197        cls, tensor_data_dict, attributes, outer_size=None, outer_stride=None
198    ):
199        int_data, scales = (
200            tensor_data_dict["int_data"],
201            tensor_data_dict["scales"],
202        )
203        dtype, shape = attributes
204        return cls(
205            int_data,
206            scales,
207            shape if outer_size is None else outer_size,
208            dtype=dtype,
209        )
210
211    @staticmethod
212    def _quantized_op(act_mat, w_qtensor, bias):
213        """
214        This is the quantized linear op that is used to implement the weight-only
215        int4 quantized linear op.
216        """
217        assert isinstance(
218            w_qtensor, GGMLInt4LinearWeight
219        ), f"Expect {w_qtensor} to be an instance of GGMLInt4LinearWeight but got {type(w_qtensor)}"
220        fp_weight = to_float(w_qtensor.int_data).view(w_qtensor.shape)
221        return torch.nn.functional.linear(act_mat, fp_weight, bias)
222
223    @classmethod
224    def from_float(cls, input_float):
225        """
226        Method used to convert a linear weight tensor to an instance of the
227        GGMLInt4LinearWeight subclass.
228
229        Example usage::
230
231            model.lin_mod.weight = (
232                GGMLInt4LinearWeight.from_float(model.lin_mod.weight)
233            )
234        """
235        packed = from_float(input_float)
236        scale = torch.tensor(_unpack_two_uint8(packed[:, :2]), dtype=torch.float16)
237        return cls(
238            packed,
239            scale,
240            input_float.shape,
241            dtype=torch.float16,
242        )
243