xref: /aosp_15_r20/external/executorch/examples/models/llama/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport json
10*523fa7a6SAndroid Build Coastguard Workerimport os
11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, Tuple
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerimport torch
14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.checkpoint import (
15*523fa7a6SAndroid Build Coastguard Worker    get_checkpoint_dtype,
16*523fa7a6SAndroid Build Coastguard Worker    get_default_model_resource_dir,
17*523fa7a6SAndroid Build Coastguard Worker)
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Workertry:
22*523fa7a6SAndroid Build Coastguard Worker    from .fairseq2 import convert_to_llama_checkpoint
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Workerexcept ImportError:
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker    def convert_to_llama_checkpoint(**kwargs):
27*523fa7a6SAndroid Build Coastguard Worker        raise NotImplementedError(
28*523fa7a6SAndroid Build Coastguard Worker            "Please install fairseq2 with `pip install fairseq2`."
29*523fa7a6SAndroid Build Coastguard Worker        )
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Worker
32*523fa7a6SAndroid Build Coastguard Workerfrom ..model_base import EagerModelBase
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Workerclass Llama2Model(EagerModelBase):
36*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, **kwargs):
37*523fa7a6SAndroid Build Coastguard Worker        resource_dir = get_default_model_resource_dir(__file__)
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker        # Use single checkpoint file.
40*523fa7a6SAndroid Build Coastguard Worker        checkpoint_path = kwargs.get(
41*523fa7a6SAndroid Build Coastguard Worker            "checkpoint", resource_dir / "demo_rand_params.pth"
42*523fa7a6SAndroid Build Coastguard Worker        )
43*523fa7a6SAndroid Build Coastguard Worker        params_path = kwargs.get("params", resource_dir / "demo_config.json")
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Worker        # Check if checkpoint_dir was provided for a sharded checkpoint.
46*523fa7a6SAndroid Build Coastguard Worker        checkpoint_dir = kwargs.get("checkpoint_dir", None)
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Worker        self.use_kv_cache = kwargs.get("use_kv_cache", False)
49*523fa7a6SAndroid Build Coastguard Worker        self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
50*523fa7a6SAndroid Build Coastguard Worker        self.generate_full_logits = kwargs.get("generate_full_logits", False)
51*523fa7a6SAndroid Build Coastguard Worker        self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
52*523fa7a6SAndroid Build Coastguard Worker        self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
53*523fa7a6SAndroid Build Coastguard Worker        self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
54*523fa7a6SAndroid Build Coastguard Worker        self.max_seq_len = kwargs.get("max_seq_len", 128)
55*523fa7a6SAndroid Build Coastguard Worker        self.args = kwargs.get("args", None)
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker        # The example is using a dummy small model with random weights for demo purpose only.
58*523fa7a6SAndroid Build Coastguard Worker        # Follow the instruction in https://github.com/facebookresearch/llama to download the model.
59*523fa7a6SAndroid Build Coastguard Worker        device = "cpu"
60*523fa7a6SAndroid Build Coastguard Worker        # flake8: noqa: TOR102
61*523fa7a6SAndroid Build Coastguard Worker        cps = []
62*523fa7a6SAndroid Build Coastguard Worker        # Load sharded checkpoint.
63*523fa7a6SAndroid Build Coastguard Worker        if checkpoint_dir is not None:
64*523fa7a6SAndroid Build Coastguard Worker            # Load multiple checkpoint; ignore the single path.
65*523fa7a6SAndroid Build Coastguard Worker            checkpoint_path = None
66*523fa7a6SAndroid Build Coastguard Worker            for i in range(4):
67*523fa7a6SAndroid Build Coastguard Worker                cp_name = f"consolidated.{i}.pth"
68*523fa7a6SAndroid Build Coastguard Worker                print(f"Loading {cp_name}")
69*523fa7a6SAndroid Build Coastguard Worker                cps.append(
70*523fa7a6SAndroid Build Coastguard Worker                    torch.load(
71*523fa7a6SAndroid Build Coastguard Worker                        os.path.join(checkpoint_dir, cp_name),
72*523fa7a6SAndroid Build Coastguard Worker                        map_location=device,
73*523fa7a6SAndroid Build Coastguard Worker                        mmap=True,
74*523fa7a6SAndroid Build Coastguard Worker                    )
75*523fa7a6SAndroid Build Coastguard Worker                )
76*523fa7a6SAndroid Build Coastguard Worker            checkpoint = {}
77*523fa7a6SAndroid Build Coastguard Worker            for key in cps[0].keys():
78*523fa7a6SAndroid Build Coastguard Worker                if not torch.allclose(cps[0][key], cps[1][key]):
79*523fa7a6SAndroid Build Coastguard Worker                    values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key])
80*523fa7a6SAndroid Build Coastguard Worker                    if "wo" in key or "w2" in key:
81*523fa7a6SAndroid Build Coastguard Worker                        # Concat on dim=1 for "wo" and "w2".
82*523fa7a6SAndroid Build Coastguard Worker                        checkpoint[key] = torch.cat(values, dim=1)
83*523fa7a6SAndroid Build Coastguard Worker                    else:
84*523fa7a6SAndroid Build Coastguard Worker                        # Concat on dim=0 for everything else.
85*523fa7a6SAndroid Build Coastguard Worker                        checkpoint[key] = torch.cat(values, dim=0)
86*523fa7a6SAndroid Build Coastguard Worker                else:
87*523fa7a6SAndroid Build Coastguard Worker                    # Do not duplicate layers shared between each checkpoint.
88*523fa7a6SAndroid Build Coastguard Worker                    checkpoint[key] = cps[0][key]
89*523fa7a6SAndroid Build Coastguard Worker        # Load single checkpoint.
90*523fa7a6SAndroid Build Coastguard Worker        else:
91*523fa7a6SAndroid Build Coastguard Worker            checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Worker        # If given checkpoint is fairseq, convert to llama checkpoint.
94*523fa7a6SAndroid Build Coastguard Worker        fairseq2_checkpoint = kwargs.get("fairseq2", False)
95*523fa7a6SAndroid Build Coastguard Worker        if fairseq2_checkpoint:
96*523fa7a6SAndroid Build Coastguard Worker            print("Using fairseq2 checkpoint")
97*523fa7a6SAndroid Build Coastguard Worker            checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint)
98*523fa7a6SAndroid Build Coastguard Worker        if "model" in checkpoint:
99*523fa7a6SAndroid Build Coastguard Worker            # NB: some checkpoint contains a "model" field, which is the actual weights dict
100*523fa7a6SAndroid Build Coastguard Worker            checkpoint = checkpoint["model"]
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker        # Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2.
103*523fa7a6SAndroid Build Coastguard Worker        if (not fairseq2_checkpoint) and checkpoint.get(
104*523fa7a6SAndroid Build Coastguard Worker            "final_proj.weight", None
105*523fa7a6SAndroid Build Coastguard Worker        ) is not None:
106*523fa7a6SAndroid Build Coastguard Worker            raise ValueError(
107*523fa7a6SAndroid Build Coastguard Worker                """
108*523fa7a6SAndroid Build Coastguard Worker************************************************************
109*523fa7a6SAndroid Build Coastguard WorkerThis looks like a Fairseq2 checkpoint (based on the presence
110*523fa7a6SAndroid Build Coastguard Workerof `final_proj.weight`.
111*523fa7a6SAndroid Build Coastguard Worker
112*523fa7a6SAndroid Build Coastguard WorkerYou can import Fairseq2 checkpoints using the --fairseq2
113*523fa7a6SAndroid Build Coastguard Workeroption, but --fairseq2 was not specified.  Please verify
114*523fa7a6SAndroid Build Coastguard Workerthe checkpoint format to avoid generating faulty models.
115*523fa7a6SAndroid Build Coastguard Worker************************************************************
116*523fa7a6SAndroid Build Coastguard Worker"""
117*523fa7a6SAndroid Build Coastguard Worker            )
118*523fa7a6SAndroid Build Coastguard Worker
119*523fa7a6SAndroid Build Coastguard Worker        # Get checkpoint dtype.
120*523fa7a6SAndroid Build Coastguard Worker        self.dtype = get_checkpoint_dtype(checkpoint)
121*523fa7a6SAndroid Build Coastguard Worker
122*523fa7a6SAndroid Build Coastguard Worker        with open(params_path, "r") as f:
123*523fa7a6SAndroid Build Coastguard Worker            params = json.loads(f.read())
124*523fa7a6SAndroid Build Coastguard Worker        output_prune_map = None
125*523fa7a6SAndroid Build Coastguard Worker        if self.output_prune_map_path is not None:
126*523fa7a6SAndroid Build Coastguard Worker            with open(self.output_prune_map_path, "r") as f:
127*523fa7a6SAndroid Build Coastguard Worker                output_prune_map = json.load(f)
128*523fa7a6SAndroid Build Coastguard Worker            # Change keys from string to int (json only supports string keys).
129*523fa7a6SAndroid Build Coastguard Worker            output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
130*523fa7a6SAndroid Build Coastguard Worker        input_prune_map = None
131*523fa7a6SAndroid Build Coastguard Worker        if self.input_prune_map_path is not None:
132*523fa7a6SAndroid Build Coastguard Worker            with open(self.input_prune_map_path, "r") as f:
133*523fa7a6SAndroid Build Coastguard Worker                input_prune_map = json.load(f)
134*523fa7a6SAndroid Build Coastguard Worker            # Change keys from string to int (json only supports string keys).
135*523fa7a6SAndroid Build Coastguard Worker            input_prune_map = {int(k): v for (k, v) in input_prune_map.items()}
136*523fa7a6SAndroid Build Coastguard Worker
137*523fa7a6SAndroid Build Coastguard Worker        model_args: ModelArgs = ModelArgs(
138*523fa7a6SAndroid Build Coastguard Worker            max_seq_len=self.max_seq_len,
139*523fa7a6SAndroid Build Coastguard Worker            max_batch_size=1,
140*523fa7a6SAndroid Build Coastguard Worker            use_kv_cache=self.use_kv_cache,
141*523fa7a6SAndroid Build Coastguard Worker            use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
142*523fa7a6SAndroid Build Coastguard Worker            generate_full_logits=self.generate_full_logits,
143*523fa7a6SAndroid Build Coastguard Worker            input_prune_map=input_prune_map,
144*523fa7a6SAndroid Build Coastguard Worker            output_prune_map=output_prune_map,
145*523fa7a6SAndroid Build Coastguard Worker            enable_dynamic_shape=self.enable_dynamic_shape,
146*523fa7a6SAndroid Build Coastguard Worker            **params,
147*523fa7a6SAndroid Build Coastguard Worker        )
148*523fa7a6SAndroid Build Coastguard Worker        if kwargs.get("verbose", False):
149*523fa7a6SAndroid Build Coastguard Worker            print("============= weights ================")
150*523fa7a6SAndroid Build Coastguard Worker            print("{key} : {weights.numel()} : {weights.size()}")
151*523fa7a6SAndroid Build Coastguard Worker            for key, weights in checkpoint.items():
152*523fa7a6SAndroid Build Coastguard Worker                print(f"{key} : {weights.numel()} : {weights.size()}")
153*523fa7a6SAndroid Build Coastguard Worker            print("============= /weights ================")
154*523fa7a6SAndroid Build Coastguard Worker
155*523fa7a6SAndroid Build Coastguard Worker        # Within the device="meta" context, tensors that are created do not carry data.
156*523fa7a6SAndroid Build Coastguard Worker        # They possess all other metadata a tensor carries such as size, stride, requires_grad.
157*523fa7a6SAndroid Build Coastguard Worker        with torch.device("meta"):
158*523fa7a6SAndroid Build Coastguard Worker            self.model_ = Transformer(model_args)
159*523fa7a6SAndroid Build Coastguard Worker
160*523fa7a6SAndroid Build Coastguard Worker        if "int8" in str(checkpoint_path):
161*523fa7a6SAndroid Build Coastguard Worker            print("Using int8 weight-only quantization!")
162*523fa7a6SAndroid Build Coastguard Worker            # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.examples.models.source_transformation.quantize`
163*523fa7a6SAndroid Build Coastguard Worker            from ..source_transformation.quantize import WeightOnlyInt8QuantHandler
164*523fa7a6SAndroid Build Coastguard Worker
165*523fa7a6SAndroid Build Coastguard Worker            simple_quantizer = WeightOnlyInt8QuantHandler(self.model_)
166*523fa7a6SAndroid Build Coastguard Worker            self.model_ = simple_quantizer.convert_for_runtime()
167*523fa7a6SAndroid Build Coastguard Worker        elif "8da4w" in str(checkpoint_path):
168*523fa7a6SAndroid Build Coastguard Worker            print("Using int4 weight and int8 dynamic activation quantization!")
169*523fa7a6SAndroid Build Coastguard Worker            from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
170*523fa7a6SAndroid Build Coastguard Worker
171*523fa7a6SAndroid Build Coastguard Worker            self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime(
172*523fa7a6SAndroid Build Coastguard Worker                self.model_
173*523fa7a6SAndroid Build Coastguard Worker            )
174*523fa7a6SAndroid Build Coastguard Worker        elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
175*523fa7a6SAndroid Build Coastguard Worker            print("Using SPIN quantization.")
176*523fa7a6SAndroid Build Coastguard Worker            self._transform_for_pre_quantization(checkpoint, model_args)
177*523fa7a6SAndroid Build Coastguard Worker
178*523fa7a6SAndroid Build Coastguard Worker            from .source_transformation.pre_quantization import (
179*523fa7a6SAndroid Build Coastguard Worker                sanitize_checkpoint_from_pre_quantization,
180*523fa7a6SAndroid Build Coastguard Worker            )
181*523fa7a6SAndroid Build Coastguard Worker
182*523fa7a6SAndroid Build Coastguard Worker            sanitize_checkpoint_from_pre_quantization(checkpoint)
183*523fa7a6SAndroid Build Coastguard Worker        elif hasattr(self.args, "use_qat") and self.args.use_qat:
184*523fa7a6SAndroid Build Coastguard Worker            print("Using QAT quantization.")
185*523fa7a6SAndroid Build Coastguard Worker            self._transform_for_pre_quantization(checkpoint, model_args)
186*523fa7a6SAndroid Build Coastguard Worker            if hasattr(self.args, "use_lora") and self.args.use_lora:
187*523fa7a6SAndroid Build Coastguard Worker                assert model_args.lora_args["rank"] == self.args.use_lora
188*523fa7a6SAndroid Build Coastguard Worker                from .source_transformation.lora import (
189*523fa7a6SAndroid Build Coastguard Worker                    transform_linear_for_lora_after_quantization,
190*523fa7a6SAndroid Build Coastguard Worker                )
191*523fa7a6SAndroid Build Coastguard Worker
192*523fa7a6SAndroid Build Coastguard Worker                self.model_ = transform_linear_for_lora_after_quantization(
193*523fa7a6SAndroid Build Coastguard Worker                    self.model_,
194*523fa7a6SAndroid Build Coastguard Worker                    checkpoint,
195*523fa7a6SAndroid Build Coastguard Worker                    self.args.use_lora,
196*523fa7a6SAndroid Build Coastguard Worker                )
197*523fa7a6SAndroid Build Coastguard Worker
198*523fa7a6SAndroid Build Coastguard Worker            from .source_transformation.pre_quantization import (
199*523fa7a6SAndroid Build Coastguard Worker                sanitize_checkpoint_from_pre_quantization,
200*523fa7a6SAndroid Build Coastguard Worker            )
201*523fa7a6SAndroid Build Coastguard Worker
202*523fa7a6SAndroid Build Coastguard Worker            sanitize_checkpoint_from_pre_quantization(checkpoint)
203*523fa7a6SAndroid Build Coastguard Worker
204*523fa7a6SAndroid Build Coastguard Worker        # assign=True: load params/buffers by assignment instead of performing an in-place copy.
205*523fa7a6SAndroid Build Coastguard Worker        # Because we are using device="meta", tensors do not have memory associated with them
206*523fa7a6SAndroid Build Coastguard Worker        # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
207*523fa7a6SAndroid Build Coastguard Worker        missing, unexpected = self.model_.load_state_dict(
208*523fa7a6SAndroid Build Coastguard Worker            checkpoint,
209*523fa7a6SAndroid Build Coastguard Worker            strict=False,
210*523fa7a6SAndroid Build Coastguard Worker            assign=True,
211*523fa7a6SAndroid Build Coastguard Worker        )  # self.model_ = Transformer(gptconf)
212*523fa7a6SAndroid Build Coastguard Worker        if kwargs.get("verbose", False):
213*523fa7a6SAndroid Build Coastguard Worker            print("============= missing keys ================")
214*523fa7a6SAndroid Build Coastguard Worker            print(missing)
215*523fa7a6SAndroid Build Coastguard Worker            print("============= /missing ================")
216*523fa7a6SAndroid Build Coastguard Worker            print("============= unexpected keys ================")
217*523fa7a6SAndroid Build Coastguard Worker            print(unexpected)
218*523fa7a6SAndroid Build Coastguard Worker            print("============= /unexpected ================")
219*523fa7a6SAndroid Build Coastguard Worker
220*523fa7a6SAndroid Build Coastguard Worker        # Prune the input layer if input_prune_map is provided
221*523fa7a6SAndroid Build Coastguard Worker        if input_prune_map is not None:
222*523fa7a6SAndroid Build Coastguard Worker            from .source_transformation.prune_vocab import prune_input_vocab
223*523fa7a6SAndroid Build Coastguard Worker
224*523fa7a6SAndroid Build Coastguard Worker            self.model_ = prune_input_vocab(self.model_, input_prune_map)
225*523fa7a6SAndroid Build Coastguard Worker
226*523fa7a6SAndroid Build Coastguard Worker        # Prune the output layer if output_prune_map is provided
227*523fa7a6SAndroid Build Coastguard Worker        if output_prune_map is not None:
228*523fa7a6SAndroid Build Coastguard Worker            from .source_transformation.prune_vocab import prune_output_vocab
229*523fa7a6SAndroid Build Coastguard Worker
230*523fa7a6SAndroid Build Coastguard Worker            self.model_ = prune_output_vocab(self.model_, output_prune_map)
231*523fa7a6SAndroid Build Coastguard Worker
232*523fa7a6SAndroid Build Coastguard Worker    def get_eager_model(self) -> torch.nn.Module:
233*523fa7a6SAndroid Build Coastguard Worker        if self.dtype:
234*523fa7a6SAndroid Build Coastguard Worker            # convert to the type of the provided checkpoint
235*523fa7a6SAndroid Build Coastguard Worker            # input and output are torch.long, so signature unchanged
236*523fa7a6SAndroid Build Coastguard Worker            return self.model_.to(self.dtype)
237*523fa7a6SAndroid Build Coastguard Worker        else:
238*523fa7a6SAndroid Build Coastguard Worker            # int8 quantization code has some bf16,
239*523fa7a6SAndroid Build Coastguard Worker            # switch all to FP32
240*523fa7a6SAndroid Build Coastguard Worker            return self.model_.to(torch.float32)
241*523fa7a6SAndroid Build Coastguard Worker
242*523fa7a6SAndroid Build Coastguard Worker    def get_example_inputs(self):
243*523fa7a6SAndroid Build Coastguard Worker        if self.use_kv_cache:
244*523fa7a6SAndroid Build Coastguard Worker            return self.get_example_inputs_kvcache_sdpa()
245*523fa7a6SAndroid Build Coastguard Worker        else:
246*523fa7a6SAndroid Build Coastguard Worker            return (
247*523fa7a6SAndroid Build Coastguard Worker                torch.tensor(
248*523fa7a6SAndroid Build Coastguard Worker                    [[1, 2, 3]], dtype=torch.long
249*523fa7a6SAndroid Build Coastguard Worker                ),  # tokens, with kv cache our input token length is always just 1 token.
250*523fa7a6SAndroid Build Coastguard Worker            )
251*523fa7a6SAndroid Build Coastguard Worker
252*523fa7a6SAndroid Build Coastguard Worker    # assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
253*523fa7a6SAndroid Build Coastguard Worker    def get_example_inputs_kvcache_sdpa(self):
254*523fa7a6SAndroid Build Coastguard Worker        if self.enable_dynamic_shape:
255*523fa7a6SAndroid Build Coastguard Worker            return (
256*523fa7a6SAndroid Build Coastguard Worker                torch.tensor([[2, 3, 4]], dtype=torch.long),
257*523fa7a6SAndroid Build Coastguard Worker                torch.tensor([0], dtype=torch.long),
258*523fa7a6SAndroid Build Coastguard Worker            )
259*523fa7a6SAndroid Build Coastguard Worker        else:
260*523fa7a6SAndroid Build Coastguard Worker            return (
261*523fa7a6SAndroid Build Coastguard Worker                torch.tensor(
262*523fa7a6SAndroid Build Coastguard Worker                    [[1]], dtype=torch.long
263*523fa7a6SAndroid Build Coastguard Worker                ),  # tokens, with kv cache our input token length is always just 1 token.
264*523fa7a6SAndroid Build Coastguard Worker                torch.tensor(
265*523fa7a6SAndroid Build Coastguard Worker                    [0], dtype=torch.long
266*523fa7a6SAndroid Build Coastguard Worker                ),  # start_pos, what token of output are we on.
267*523fa7a6SAndroid Build Coastguard Worker            )
268*523fa7a6SAndroid Build Coastguard Worker
269*523fa7a6SAndroid Build Coastguard Worker    def _transform_for_pre_quantization(self, checkpoint, model_args):
270*523fa7a6SAndroid Build Coastguard Worker        assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"
271*523fa7a6SAndroid Build Coastguard Worker        assert self.args.preq_mode in [
272*523fa7a6SAndroid Build Coastguard Worker            "8da4w",
273*523fa7a6SAndroid Build Coastguard Worker            "8da4w_output_8da8w",
274*523fa7a6SAndroid Build Coastguard Worker        ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant."
275*523fa7a6SAndroid Build Coastguard Worker        assert hasattr(
276*523fa7a6SAndroid Build Coastguard Worker            self.args, "preq_group_size"
277*523fa7a6SAndroid Build Coastguard Worker        ), "preq_group_size must be specified"
278*523fa7a6SAndroid Build Coastguard Worker        assert hasattr(self.args, "dtype_override"), "dtype_override must be specified"
279*523fa7a6SAndroid Build Coastguard Worker        from .source_transformation.pre_quantization import (
280*523fa7a6SAndroid Build Coastguard Worker            transform_linear_for_pre_quantization,
281*523fa7a6SAndroid Build Coastguard Worker        )
282*523fa7a6SAndroid Build Coastguard Worker
283*523fa7a6SAndroid Build Coastguard Worker        assert self.args.preq_group_size == model_args.quantization_args["group_size"]
284*523fa7a6SAndroid Build Coastguard Worker
285*523fa7a6SAndroid Build Coastguard Worker        mapping = {
286*523fa7a6SAndroid Build Coastguard Worker            "fp32": torch.float32,
287*523fa7a6SAndroid Build Coastguard Worker            "fp16": torch.float16,
288*523fa7a6SAndroid Build Coastguard Worker            "bf16": torch.bfloat16,
289*523fa7a6SAndroid Build Coastguard Worker        }
290*523fa7a6SAndroid Build Coastguard Worker
291*523fa7a6SAndroid Build Coastguard Worker        # Transform the output layer first if needed.
292*523fa7a6SAndroid Build Coastguard Worker        if self.args.preq_mode == "8da4w_output_8da8w":
293*523fa7a6SAndroid Build Coastguard Worker            from .source_transformation.pre_quantization import (
294*523fa7a6SAndroid Build Coastguard Worker                transform_output_linear_for_pre_quantization,
295*523fa7a6SAndroid Build Coastguard Worker            )
296*523fa7a6SAndroid Build Coastguard Worker
297*523fa7a6SAndroid Build Coastguard Worker            self.model_ = transform_output_linear_for_pre_quantization(
298*523fa7a6SAndroid Build Coastguard Worker                module=self.model_,
299*523fa7a6SAndroid Build Coastguard Worker                checkpoint=checkpoint,
300*523fa7a6SAndroid Build Coastguard Worker                dtype=mapping[self.args.dtype_override],
301*523fa7a6SAndroid Build Coastguard Worker            )
302*523fa7a6SAndroid Build Coastguard Worker
303*523fa7a6SAndroid Build Coastguard Worker        self.model_ = transform_linear_for_pre_quantization(
304*523fa7a6SAndroid Build Coastguard Worker            self.model_,
305*523fa7a6SAndroid Build Coastguard Worker            checkpoint,
306*523fa7a6SAndroid Build Coastguard Worker            self.args.preq_group_size,
307*523fa7a6SAndroid Build Coastguard Worker            mapping[self.args.dtype_override],
308*523fa7a6SAndroid Build Coastguard Worker        )
309*523fa7a6SAndroid Build Coastguard Worker
310*523fa7a6SAndroid Build Coastguard Worker        embedding_bit_width, embedding_group_size = None, None
311*523fa7a6SAndroid Build Coastguard Worker        if hasattr(self.args, "preq_embedding_quantize"):
312*523fa7a6SAndroid Build Coastguard Worker            embedding_bit_width, embedding_group_size = (
313*523fa7a6SAndroid Build Coastguard Worker                self.args.preq_embedding_quantize.split(",")
314*523fa7a6SAndroid Build Coastguard Worker            )
315*523fa7a6SAndroid Build Coastguard Worker            from .source_transformation.pre_quantization import (
316*523fa7a6SAndroid Build Coastguard Worker                transform_embedding_for_pre_quantization,
317*523fa7a6SAndroid Build Coastguard Worker            )
318*523fa7a6SAndroid Build Coastguard Worker
319*523fa7a6SAndroid Build Coastguard Worker            if (
320*523fa7a6SAndroid Build Coastguard Worker                embedding_group_size == "none"
321*523fa7a6SAndroid Build Coastguard Worker                or embedding_group_size == "None"
322*523fa7a6SAndroid Build Coastguard Worker                or embedding_group_size == "0"
323*523fa7a6SAndroid Build Coastguard Worker            ):
324*523fa7a6SAndroid Build Coastguard Worker                embedding_group_size = None
325*523fa7a6SAndroid Build Coastguard Worker            else:
326*523fa7a6SAndroid Build Coastguard Worker                embedding_group_size = int(embedding_group_size)
327*523fa7a6SAndroid Build Coastguard Worker
328*523fa7a6SAndroid Build Coastguard Worker            self.model_ = transform_embedding_for_pre_quantization(
329*523fa7a6SAndroid Build Coastguard Worker                self.model_,
330*523fa7a6SAndroid Build Coastguard Worker                checkpoint,
331*523fa7a6SAndroid Build Coastguard Worker                mapping[self.args.dtype_override],
332*523fa7a6SAndroid Build Coastguard Worker                int(embedding_bit_width),
333*523fa7a6SAndroid Build Coastguard Worker                embedding_group_size,
334*523fa7a6SAndroid Build Coastguard Worker            )
335