xref: /aosp_15_r20/external/executorch/examples/models/llama/source_transformation/pre_quantization.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9# Helper functions for tranforming the model to be able to load pre-quantized checkpoints.
10
11from typing import Any, Optional
12
13import torch
14from torch import nn
15
16from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear
17from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
18
19from .quantize import Int8DynActInt8WeightLinear, QuantizedGroupEmbedding
20
21
22def _replace_linear_with_linear_8da4w_for_pre_quantization(
23    module: torch.nn.Module,
24    checkpoint: Any,
25    group_size: int,
26    precision: torch.dtype,
27    scales_precision: torch.dtype,
28):
29    def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
30        # Only replace linear layers where the checkpoint contains explicit scales
31        scales_key = f"{cur_fqn}.scales"
32        if isinstance(child, nn.Linear) and scales_key in checkpoint:
33            assert _check_linear_int4_k(child.in_features, group_size)
34            assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
35            assert checkpoint[scales_key].dtype == scales_precision
36            return True
37        return False
38
39    def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
40        new_linear = Int8DynActInt4WeightLinear(
41            # pyre-fixme[6]: For 1st argument expected `int` but got `Union[Module,
42            #  Tensor]`.
43            child.in_features,
44            # pyre-fixme[6]: For 2nd argument expected `int` but got `Union[Module,
45            #  Tensor]`.
46            child.out_features,
47            bias=False,
48            device=child.weight.device,
49            groupsize=group_size,
50            precision=precision,
51            scales_precision=scales_precision,
52        )
53        # TODO(lunwenh): Remove this once TorchAO's commit pin in ExecuTorch is updated to include this PR
54        new_linear.zeros = torch.zeros_like(new_linear.zeros)
55        return new_linear
56
57    _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
58
59
60def transform_linear_for_pre_quantization(
61    module: torch.nn.Module,
62    checkpoint: Any,
63    group_size: int,
64    dtype: torch.dtype,
65) -> torch.nn.Module:
66    """
67    Transform the model to be able to load pre-quantized checkpoints that
68    are quantized with the given group size and quantization mode for
69    linear layers.
70    """
71
72    if group_size not in [32, 64, 128, 256]:
73        raise ValueError(
74            f"Group size {group_size} is not supported for pre-quantized checkpoint."
75        )
76    _replace_linear_with_linear_8da4w_for_pre_quantization(
77        module,
78        checkpoint,
79        group_size,
80        dtype,
81        dtype,
82    )
83    return module
84
85
86def _replace_output_linear_with_linear_int8_for_pre_quantization(
87    module: torch.nn.Module,
88    checkpoint: Any,
89    dtype: torch.dtype,
90):
91    def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
92        scales_key = f"{cur_fqn}.scales"
93        if (
94            isinstance(child, nn.Linear)
95            and scales_key in checkpoint
96            and "output" in cur_fqn
97        ):
98            assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
99            assert checkpoint[scales_key].dtype == dtype
100            return True
101        return False
102
103    def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
104        new_linear = Int8DynActInt8WeightLinear(
105            device=child.weight.device,
106            # pyre-fixme[6]: For 2nd argument expected `int` but got `Union[Module,
107            #  Tensor]`.
108            in_features=child.in_features,
109            # pyre-fixme[6]: For 3rd argument expected `int` but got `Union[Module,
110            #  Tensor]`.
111            out_features=child.out_features,
112            precision=dtype,
113            bias=False,
114        )
115        return new_linear
116
117    _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
118
119
120def transform_output_linear_for_pre_quantization(
121    module: torch.nn.Module,
122    checkpoint: Any,
123    dtype: torch.dtype,
124) -> torch.nn.Module:
125    """
126    Transform the model to be able to load pre-quantized checkpoints that
127    has the output layer quantized per-channel.
128    """
129    _replace_output_linear_with_linear_int8_for_pre_quantization(
130        module,
131        checkpoint,
132        dtype,
133    )
134    return module
135
136
137def _replace_embedding_with_quantized_group_embedding_for_pre_quantization(
138    module: torch.nn.Module,
139    checkpoint: Any,
140    dtype: torch.dtype,
141    bit_width: int,
142    group_size: Optional[int] = None,
143):
144    def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
145        # Only replace embedding layers where the checkpoint contains explicit scales
146        scales_key = f"{cur_fqn}.scales"
147        if isinstance(child, nn.Embedding) and scales_key in checkpoint:
148            assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
149            assert checkpoint[scales_key].dtype == torch.float32
150            return True
151        return False
152
153    def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
154        new_embedding = QuantizedGroupEmbedding(
155            device=child.weight.device,
156            vocab_size=child.weight.shape[0],
157            embedding_dim=child.weight.shape[1],
158            group_size=group_size,
159            dtype=dtype,
160            packed=False,  # TODO(lunwenh): support packed embedding for pre-quantized
161        )
162        return new_embedding
163
164    _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
165
166
167def transform_embedding_for_pre_quantization(
168    module: torch.nn.Module,
169    checkpoint: Any,
170    dtype: torch.dtype,
171    bit_width: int,
172    group_size: Optional[int] = None,
173) -> torch.nn.Module:
174    """
175    Transform the model to be able to load pre-quantized checkpoints that
176    are quantized with the given bit_width and group size for embedding.
177    """
178    if group_size is not None and group_size not in [0, 32, 64, 128, 256]:
179        raise ValueError(
180            f"Group size {group_size} is not supported for pre-quantized checkpoint."
181        )
182    _replace_embedding_with_quantized_group_embedding_for_pre_quantization(
183        module,
184        checkpoint,
185        dtype,
186        bit_width,
187        group_size,
188    )
189    return module
190
191
192def sanitize_checkpoint_from_pre_quantization(
193    checkpoint: Any,
194):
195    """
196    Sanitize the pre-quantized checkpoint.
197        - Converts all tensors to contiguous format
198        - Squeeze all tensors
199    """
200    for k, v in checkpoint.items():
201        checkpoint[k] = torch.squeeze(v.contiguous())
202