1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import json 10import os 11from typing import Dict, Tuple 12 13import torch 14from executorch.examples.models.checkpoint import ( 15 get_checkpoint_dtype, 16 get_default_model_resource_dir, 17) 18 19from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer 20 21try: 22 from .fairseq2 import convert_to_llama_checkpoint 23 24except ImportError: 25 26 def convert_to_llama_checkpoint(**kwargs): 27 raise NotImplementedError( 28 "Please install fairseq2 with `pip install fairseq2`." 29 ) 30 31 32from ..model_base import EagerModelBase 33 34 35class Llama2Model(EagerModelBase): 36 def __init__(self, **kwargs): 37 resource_dir = get_default_model_resource_dir(__file__) 38 39 # Use single checkpoint file. 40 checkpoint_path = kwargs.get( 41 "checkpoint", resource_dir / "demo_rand_params.pth" 42 ) 43 params_path = kwargs.get("params", resource_dir / "demo_config.json") 44 45 # Check if checkpoint_dir was provided for a sharded checkpoint. 46 checkpoint_dir = kwargs.get("checkpoint_dir", None) 47 48 self.use_kv_cache = kwargs.get("use_kv_cache", False) 49 self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) 50 self.generate_full_logits = kwargs.get("generate_full_logits", False) 51 self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) 52 self.input_prune_map_path = kwargs.get("input_prune_map_path", None) 53 self.output_prune_map_path = kwargs.get("output_prune_map_path", None) 54 self.max_seq_len = kwargs.get("max_seq_len", 128) 55 self.args = kwargs.get("args", None) 56 57 # The example is using a dummy small model with random weights for demo purpose only. 58 # Follow the instruction in https://github.com/facebookresearch/llama to download the model. 59 device = "cpu" 60 # flake8: noqa: TOR102 61 cps = [] 62 # Load sharded checkpoint. 63 if checkpoint_dir is not None: 64 # Load multiple checkpoint; ignore the single path. 65 checkpoint_path = None 66 for i in range(4): 67 cp_name = f"consolidated.{i}.pth" 68 print(f"Loading {cp_name}") 69 cps.append( 70 torch.load( 71 os.path.join(checkpoint_dir, cp_name), 72 map_location=device, 73 mmap=True, 74 ) 75 ) 76 checkpoint = {} 77 for key in cps[0].keys(): 78 if not torch.allclose(cps[0][key], cps[1][key]): 79 values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key]) 80 if "wo" in key or "w2" in key: 81 # Concat on dim=1 for "wo" and "w2". 82 checkpoint[key] = torch.cat(values, dim=1) 83 else: 84 # Concat on dim=0 for everything else. 85 checkpoint[key] = torch.cat(values, dim=0) 86 else: 87 # Do not duplicate layers shared between each checkpoint. 88 checkpoint[key] = cps[0][key] 89 # Load single checkpoint. 90 else: 91 checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True) 92 93 # If given checkpoint is fairseq, convert to llama checkpoint. 94 fairseq2_checkpoint = kwargs.get("fairseq2", False) 95 if fairseq2_checkpoint: 96 print("Using fairseq2 checkpoint") 97 checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint) 98 if "model" in checkpoint: 99 # NB: some checkpoint contains a "model" field, which is the actual weights dict 100 checkpoint = checkpoint["model"] 101 102 # Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2. 103 if (not fairseq2_checkpoint) and checkpoint.get( 104 "final_proj.weight", None 105 ) is not None: 106 raise ValueError( 107 """ 108************************************************************ 109This looks like a Fairseq2 checkpoint (based on the presence 110of `final_proj.weight`. 111 112You can import Fairseq2 checkpoints using the --fairseq2 113option, but --fairseq2 was not specified. Please verify 114the checkpoint format to avoid generating faulty models. 115************************************************************ 116""" 117 ) 118 119 # Get checkpoint dtype. 120 self.dtype = get_checkpoint_dtype(checkpoint) 121 122 with open(params_path, "r") as f: 123 params = json.loads(f.read()) 124 output_prune_map = None 125 if self.output_prune_map_path is not None: 126 with open(self.output_prune_map_path, "r") as f: 127 output_prune_map = json.load(f) 128 # Change keys from string to int (json only supports string keys). 129 output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} 130 input_prune_map = None 131 if self.input_prune_map_path is not None: 132 with open(self.input_prune_map_path, "r") as f: 133 input_prune_map = json.load(f) 134 # Change keys from string to int (json only supports string keys). 135 input_prune_map = {int(k): v for (k, v) in input_prune_map.items()} 136 137 model_args: ModelArgs = ModelArgs( 138 max_seq_len=self.max_seq_len, 139 max_batch_size=1, 140 use_kv_cache=self.use_kv_cache, 141 use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, 142 generate_full_logits=self.generate_full_logits, 143 input_prune_map=input_prune_map, 144 output_prune_map=output_prune_map, 145 enable_dynamic_shape=self.enable_dynamic_shape, 146 **params, 147 ) 148 if kwargs.get("verbose", False): 149 print("============= weights ================") 150 print("{key} : {weights.numel()} : {weights.size()}") 151 for key, weights in checkpoint.items(): 152 print(f"{key} : {weights.numel()} : {weights.size()}") 153 print("============= /weights ================") 154 155 # Within the device="meta" context, tensors that are created do not carry data. 156 # They possess all other metadata a tensor carries such as size, stride, requires_grad. 157 with torch.device("meta"): 158 self.model_ = Transformer(model_args) 159 160 if "int8" in str(checkpoint_path): 161 print("Using int8 weight-only quantization!") 162 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.examples.models.source_transformation.quantize` 163 from ..source_transformation.quantize import WeightOnlyInt8QuantHandler 164 165 simple_quantizer = WeightOnlyInt8QuantHandler(self.model_) 166 self.model_ = simple_quantizer.convert_for_runtime() 167 elif "8da4w" in str(checkpoint_path): 168 print("Using int4 weight and int8 dynamic activation quantization!") 169 from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer 170 171 self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime( 172 self.model_ 173 ) 174 elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant: 175 print("Using SPIN quantization.") 176 self._transform_for_pre_quantization(checkpoint, model_args) 177 178 from .source_transformation.pre_quantization import ( 179 sanitize_checkpoint_from_pre_quantization, 180 ) 181 182 sanitize_checkpoint_from_pre_quantization(checkpoint) 183 elif hasattr(self.args, "use_qat") and self.args.use_qat: 184 print("Using QAT quantization.") 185 self._transform_for_pre_quantization(checkpoint, model_args) 186 if hasattr(self.args, "use_lora") and self.args.use_lora: 187 assert model_args.lora_args["rank"] == self.args.use_lora 188 from .source_transformation.lora import ( 189 transform_linear_for_lora_after_quantization, 190 ) 191 192 self.model_ = transform_linear_for_lora_after_quantization( 193 self.model_, 194 checkpoint, 195 self.args.use_lora, 196 ) 197 198 from .source_transformation.pre_quantization import ( 199 sanitize_checkpoint_from_pre_quantization, 200 ) 201 202 sanitize_checkpoint_from_pre_quantization(checkpoint) 203 204 # assign=True: load params/buffers by assignment instead of performing an in-place copy. 205 # Because we are using device="meta", tensors do not have memory associated with them 206 # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. 207 missing, unexpected = self.model_.load_state_dict( 208 checkpoint, 209 strict=False, 210 assign=True, 211 ) # self.model_ = Transformer(gptconf) 212 if kwargs.get("verbose", False): 213 print("============= missing keys ================") 214 print(missing) 215 print("============= /missing ================") 216 print("============= unexpected keys ================") 217 print(unexpected) 218 print("============= /unexpected ================") 219 220 # Prune the input layer if input_prune_map is provided 221 if input_prune_map is not None: 222 from .source_transformation.prune_vocab import prune_input_vocab 223 224 self.model_ = prune_input_vocab(self.model_, input_prune_map) 225 226 # Prune the output layer if output_prune_map is provided 227 if output_prune_map is not None: 228 from .source_transformation.prune_vocab import prune_output_vocab 229 230 self.model_ = prune_output_vocab(self.model_, output_prune_map) 231 232 def get_eager_model(self) -> torch.nn.Module: 233 if self.dtype: 234 # convert to the type of the provided checkpoint 235 # input and output are torch.long, so signature unchanged 236 return self.model_.to(self.dtype) 237 else: 238 # int8 quantization code has some bf16, 239 # switch all to FP32 240 return self.model_.to(torch.float32) 241 242 def get_example_inputs(self): 243 if self.use_kv_cache: 244 return self.get_example_inputs_kvcache_sdpa() 245 else: 246 return ( 247 torch.tensor( 248 [[1, 2, 3]], dtype=torch.long 249 ), # tokens, with kv cache our input token length is always just 1 token. 250 ) 251 252 # assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working 253 def get_example_inputs_kvcache_sdpa(self): 254 if self.enable_dynamic_shape: 255 return ( 256 torch.tensor([[2, 3, 4]], dtype=torch.long), 257 torch.tensor([0], dtype=torch.long), 258 ) 259 else: 260 return ( 261 torch.tensor( 262 [[1]], dtype=torch.long 263 ), # tokens, with kv cache our input token length is always just 1 token. 264 torch.tensor( 265 [0], dtype=torch.long 266 ), # start_pos, what token of output are we on. 267 ) 268 269 def _transform_for_pre_quantization(self, checkpoint, model_args): 270 assert hasattr(self.args, "preq_mode"), "preq_mode must be specified" 271 assert self.args.preq_mode in [ 272 "8da4w", 273 "8da4w_output_8da8w", 274 ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant." 275 assert hasattr( 276 self.args, "preq_group_size" 277 ), "preq_group_size must be specified" 278 assert hasattr(self.args, "dtype_override"), "dtype_override must be specified" 279 from .source_transformation.pre_quantization import ( 280 transform_linear_for_pre_quantization, 281 ) 282 283 assert self.args.preq_group_size == model_args.quantization_args["group_size"] 284 285 mapping = { 286 "fp32": torch.float32, 287 "fp16": torch.float16, 288 "bf16": torch.bfloat16, 289 } 290 291 # Transform the output layer first if needed. 292 if self.args.preq_mode == "8da4w_output_8da8w": 293 from .source_transformation.pre_quantization import ( 294 transform_output_linear_for_pre_quantization, 295 ) 296 297 self.model_ = transform_output_linear_for_pre_quantization( 298 module=self.model_, 299 checkpoint=checkpoint, 300 dtype=mapping[self.args.dtype_override], 301 ) 302 303 self.model_ = transform_linear_for_pre_quantization( 304 self.model_, 305 checkpoint, 306 self.args.preq_group_size, 307 mapping[self.args.dtype_override], 308 ) 309 310 embedding_bit_width, embedding_group_size = None, None 311 if hasattr(self.args, "preq_embedding_quantize"): 312 embedding_bit_width, embedding_group_size = ( 313 self.args.preq_embedding_quantize.split(",") 314 ) 315 from .source_transformation.pre_quantization import ( 316 transform_embedding_for_pre_quantization, 317 ) 318 319 if ( 320 embedding_group_size == "none" 321 or embedding_group_size == "None" 322 or embedding_group_size == "0" 323 ): 324 embedding_group_size = None 325 else: 326 embedding_group_size = int(embedding_group_size) 327 328 self.model_ = transform_embedding_for_pre_quantization( 329 self.model_, 330 checkpoint, 331 mapping[self.args.dtype_override], 332 int(embedding_bit_width), 333 embedding_group_size, 334 ) 335