xref: /aosp_15_r20/external/executorch/examples/models/llama/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import json
10import os
11from typing import Dict, Tuple
12
13import torch
14from executorch.examples.models.checkpoint import (
15    get_checkpoint_dtype,
16    get_default_model_resource_dir,
17)
18
19from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer
20
21try:
22    from .fairseq2 import convert_to_llama_checkpoint
23
24except ImportError:
25
26    def convert_to_llama_checkpoint(**kwargs):
27        raise NotImplementedError(
28            "Please install fairseq2 with `pip install fairseq2`."
29        )
30
31
32from ..model_base import EagerModelBase
33
34
35class Llama2Model(EagerModelBase):
36    def __init__(self, **kwargs):
37        resource_dir = get_default_model_resource_dir(__file__)
38
39        # Use single checkpoint file.
40        checkpoint_path = kwargs.get(
41            "checkpoint", resource_dir / "demo_rand_params.pth"
42        )
43        params_path = kwargs.get("params", resource_dir / "demo_config.json")
44
45        # Check if checkpoint_dir was provided for a sharded checkpoint.
46        checkpoint_dir = kwargs.get("checkpoint_dir", None)
47
48        self.use_kv_cache = kwargs.get("use_kv_cache", False)
49        self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
50        self.generate_full_logits = kwargs.get("generate_full_logits", False)
51        self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
52        self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
53        self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
54        self.max_seq_len = kwargs.get("max_seq_len", 128)
55        self.args = kwargs.get("args", None)
56
57        # The example is using a dummy small model with random weights for demo purpose only.
58        # Follow the instruction in https://github.com/facebookresearch/llama to download the model.
59        device = "cpu"
60        # flake8: noqa: TOR102
61        cps = []
62        # Load sharded checkpoint.
63        if checkpoint_dir is not None:
64            # Load multiple checkpoint; ignore the single path.
65            checkpoint_path = None
66            for i in range(4):
67                cp_name = f"consolidated.{i}.pth"
68                print(f"Loading {cp_name}")
69                cps.append(
70                    torch.load(
71                        os.path.join(checkpoint_dir, cp_name),
72                        map_location=device,
73                        mmap=True,
74                    )
75                )
76            checkpoint = {}
77            for key in cps[0].keys():
78                if not torch.allclose(cps[0][key], cps[1][key]):
79                    values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key])
80                    if "wo" in key or "w2" in key:
81                        # Concat on dim=1 for "wo" and "w2".
82                        checkpoint[key] = torch.cat(values, dim=1)
83                    else:
84                        # Concat on dim=0 for everything else.
85                        checkpoint[key] = torch.cat(values, dim=0)
86                else:
87                    # Do not duplicate layers shared between each checkpoint.
88                    checkpoint[key] = cps[0][key]
89        # Load single checkpoint.
90        else:
91            checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
92
93        # If given checkpoint is fairseq, convert to llama checkpoint.
94        fairseq2_checkpoint = kwargs.get("fairseq2", False)
95        if fairseq2_checkpoint:
96            print("Using fairseq2 checkpoint")
97            checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint)
98        if "model" in checkpoint:
99            # NB: some checkpoint contains a "model" field, which is the actual weights dict
100            checkpoint = checkpoint["model"]
101
102        # Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2.
103        if (not fairseq2_checkpoint) and checkpoint.get(
104            "final_proj.weight", None
105        ) is not None:
106            raise ValueError(
107                """
108************************************************************
109This looks like a Fairseq2 checkpoint (based on the presence
110of `final_proj.weight`.
111
112You can import Fairseq2 checkpoints using the --fairseq2
113option, but --fairseq2 was not specified.  Please verify
114the checkpoint format to avoid generating faulty models.
115************************************************************
116"""
117            )
118
119        # Get checkpoint dtype.
120        self.dtype = get_checkpoint_dtype(checkpoint)
121
122        with open(params_path, "r") as f:
123            params = json.loads(f.read())
124        output_prune_map = None
125        if self.output_prune_map_path is not None:
126            with open(self.output_prune_map_path, "r") as f:
127                output_prune_map = json.load(f)
128            # Change keys from string to int (json only supports string keys).
129            output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
130        input_prune_map = None
131        if self.input_prune_map_path is not None:
132            with open(self.input_prune_map_path, "r") as f:
133                input_prune_map = json.load(f)
134            # Change keys from string to int (json only supports string keys).
135            input_prune_map = {int(k): v for (k, v) in input_prune_map.items()}
136
137        model_args: ModelArgs = ModelArgs(
138            max_seq_len=self.max_seq_len,
139            max_batch_size=1,
140            use_kv_cache=self.use_kv_cache,
141            use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
142            generate_full_logits=self.generate_full_logits,
143            input_prune_map=input_prune_map,
144            output_prune_map=output_prune_map,
145            enable_dynamic_shape=self.enable_dynamic_shape,
146            **params,
147        )
148        if kwargs.get("verbose", False):
149            print("============= weights ================")
150            print("{key} : {weights.numel()} : {weights.size()}")
151            for key, weights in checkpoint.items():
152                print(f"{key} : {weights.numel()} : {weights.size()}")
153            print("============= /weights ================")
154
155        # Within the device="meta" context, tensors that are created do not carry data.
156        # They possess all other metadata a tensor carries such as size, stride, requires_grad.
157        with torch.device("meta"):
158            self.model_ = Transformer(model_args)
159
160        if "int8" in str(checkpoint_path):
161            print("Using int8 weight-only quantization!")
162            # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.examples.models.source_transformation.quantize`
163            from ..source_transformation.quantize import WeightOnlyInt8QuantHandler
164
165            simple_quantizer = WeightOnlyInt8QuantHandler(self.model_)
166            self.model_ = simple_quantizer.convert_for_runtime()
167        elif "8da4w" in str(checkpoint_path):
168            print("Using int4 weight and int8 dynamic activation quantization!")
169            from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
170
171            self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime(
172                self.model_
173            )
174        elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
175            print("Using SPIN quantization.")
176            self._transform_for_pre_quantization(checkpoint, model_args)
177
178            from .source_transformation.pre_quantization import (
179                sanitize_checkpoint_from_pre_quantization,
180            )
181
182            sanitize_checkpoint_from_pre_quantization(checkpoint)
183        elif hasattr(self.args, "use_qat") and self.args.use_qat:
184            print("Using QAT quantization.")
185            self._transform_for_pre_quantization(checkpoint, model_args)
186            if hasattr(self.args, "use_lora") and self.args.use_lora:
187                assert model_args.lora_args["rank"] == self.args.use_lora
188                from .source_transformation.lora import (
189                    transform_linear_for_lora_after_quantization,
190                )
191
192                self.model_ = transform_linear_for_lora_after_quantization(
193                    self.model_,
194                    checkpoint,
195                    self.args.use_lora,
196                )
197
198            from .source_transformation.pre_quantization import (
199                sanitize_checkpoint_from_pre_quantization,
200            )
201
202            sanitize_checkpoint_from_pre_quantization(checkpoint)
203
204        # assign=True: load params/buffers by assignment instead of performing an in-place copy.
205        # Because we are using device="meta", tensors do not have memory associated with them
206        # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
207        missing, unexpected = self.model_.load_state_dict(
208            checkpoint,
209            strict=False,
210            assign=True,
211        )  # self.model_ = Transformer(gptconf)
212        if kwargs.get("verbose", False):
213            print("============= missing keys ================")
214            print(missing)
215            print("============= /missing ================")
216            print("============= unexpected keys ================")
217            print(unexpected)
218            print("============= /unexpected ================")
219
220        # Prune the input layer if input_prune_map is provided
221        if input_prune_map is not None:
222            from .source_transformation.prune_vocab import prune_input_vocab
223
224            self.model_ = prune_input_vocab(self.model_, input_prune_map)
225
226        # Prune the output layer if output_prune_map is provided
227        if output_prune_map is not None:
228            from .source_transformation.prune_vocab import prune_output_vocab
229
230            self.model_ = prune_output_vocab(self.model_, output_prune_map)
231
232    def get_eager_model(self) -> torch.nn.Module:
233        if self.dtype:
234            # convert to the type of the provided checkpoint
235            # input and output are torch.long, so signature unchanged
236            return self.model_.to(self.dtype)
237        else:
238            # int8 quantization code has some bf16,
239            # switch all to FP32
240            return self.model_.to(torch.float32)
241
242    def get_example_inputs(self):
243        if self.use_kv_cache:
244            return self.get_example_inputs_kvcache_sdpa()
245        else:
246            return (
247                torch.tensor(
248                    [[1, 2, 3]], dtype=torch.long
249                ),  # tokens, with kv cache our input token length is always just 1 token.
250            )
251
252    # assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
253    def get_example_inputs_kvcache_sdpa(self):
254        if self.enable_dynamic_shape:
255            return (
256                torch.tensor([[2, 3, 4]], dtype=torch.long),
257                torch.tensor([0], dtype=torch.long),
258            )
259        else:
260            return (
261                torch.tensor(
262                    [[1]], dtype=torch.long
263                ),  # tokens, with kv cache our input token length is always just 1 token.
264                torch.tensor(
265                    [0], dtype=torch.long
266                ),  # start_pos, what token of output are we on.
267            )
268
269    def _transform_for_pre_quantization(self, checkpoint, model_args):
270        assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"
271        assert self.args.preq_mode in [
272            "8da4w",
273            "8da4w_output_8da8w",
274        ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant."
275        assert hasattr(
276            self.args, "preq_group_size"
277        ), "preq_group_size must be specified"
278        assert hasattr(self.args, "dtype_override"), "dtype_override must be specified"
279        from .source_transformation.pre_quantization import (
280            transform_linear_for_pre_quantization,
281        )
282
283        assert self.args.preq_group_size == model_args.quantization_args["group_size"]
284
285        mapping = {
286            "fp32": torch.float32,
287            "fp16": torch.float16,
288            "bf16": torch.bfloat16,
289        }
290
291        # Transform the output layer first if needed.
292        if self.args.preq_mode == "8da4w_output_8da8w":
293            from .source_transformation.pre_quantization import (
294                transform_output_linear_for_pre_quantization,
295            )
296
297            self.model_ = transform_output_linear_for_pre_quantization(
298                module=self.model_,
299                checkpoint=checkpoint,
300                dtype=mapping[self.args.dtype_override],
301            )
302
303        self.model_ = transform_linear_for_pre_quantization(
304            self.model_,
305            checkpoint,
306            self.args.preq_group_size,
307            mapping[self.args.dtype_override],
308        )
309
310        embedding_bit_width, embedding_group_size = None, None
311        if hasattr(self.args, "preq_embedding_quantize"):
312            embedding_bit_width, embedding_group_size = (
313                self.args.preq_embedding_quantize.split(",")
314            )
315            from .source_transformation.pre_quantization import (
316                transform_embedding_for_pre_quantization,
317            )
318
319            if (
320                embedding_group_size == "none"
321                or embedding_group_size == "None"
322                or embedding_group_size == "0"
323            ):
324                embedding_group_size = None
325            else:
326                embedding_group_size = int(embedding_group_size)
327
328            self.model_ = transform_embedding_for_pre_quantization(
329                self.model_,
330                checkpoint,
331                mapping[self.args.dtype_override],
332                int(embedding_bit_width),
333                embedding_group_size,
334            )
335