1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport json 10*523fa7a6SAndroid Build Coastguard Workerimport os 11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, Tuple 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerimport torch 14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.checkpoint import ( 15*523fa7a6SAndroid Build Coastguard Worker get_checkpoint_dtype, 16*523fa7a6SAndroid Build Coastguard Worker get_default_model_resource_dir, 17*523fa7a6SAndroid Build Coastguard Worker) 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Workertry: 22*523fa7a6SAndroid Build Coastguard Worker from .fairseq2 import convert_to_llama_checkpoint 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Workerexcept ImportError: 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Worker def convert_to_llama_checkpoint(**kwargs): 27*523fa7a6SAndroid Build Coastguard Worker raise NotImplementedError( 28*523fa7a6SAndroid Build Coastguard Worker "Please install fairseq2 with `pip install fairseq2`." 29*523fa7a6SAndroid Build Coastguard Worker ) 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Worker 32*523fa7a6SAndroid Build Coastguard Workerfrom ..model_base import EagerModelBase 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Workerclass Llama2Model(EagerModelBase): 36*523fa7a6SAndroid Build Coastguard Worker def __init__(self, **kwargs): 37*523fa7a6SAndroid Build Coastguard Worker resource_dir = get_default_model_resource_dir(__file__) 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker # Use single checkpoint file. 40*523fa7a6SAndroid Build Coastguard Worker checkpoint_path = kwargs.get( 41*523fa7a6SAndroid Build Coastguard Worker "checkpoint", resource_dir / "demo_rand_params.pth" 42*523fa7a6SAndroid Build Coastguard Worker ) 43*523fa7a6SAndroid Build Coastguard Worker params_path = kwargs.get("params", resource_dir / "demo_config.json") 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Worker # Check if checkpoint_dir was provided for a sharded checkpoint. 46*523fa7a6SAndroid Build Coastguard Worker checkpoint_dir = kwargs.get("checkpoint_dir", None) 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Worker self.use_kv_cache = kwargs.get("use_kv_cache", False) 49*523fa7a6SAndroid Build Coastguard Worker self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) 50*523fa7a6SAndroid Build Coastguard Worker self.generate_full_logits = kwargs.get("generate_full_logits", False) 51*523fa7a6SAndroid Build Coastguard Worker self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) 52*523fa7a6SAndroid Build Coastguard Worker self.input_prune_map_path = kwargs.get("input_prune_map_path", None) 53*523fa7a6SAndroid Build Coastguard Worker self.output_prune_map_path = kwargs.get("output_prune_map_path", None) 54*523fa7a6SAndroid Build Coastguard Worker self.max_seq_len = kwargs.get("max_seq_len", 128) 55*523fa7a6SAndroid Build Coastguard Worker self.args = kwargs.get("args", None) 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker # The example is using a dummy small model with random weights for demo purpose only. 58*523fa7a6SAndroid Build Coastguard Worker # Follow the instruction in https://github.com/facebookresearch/llama to download the model. 59*523fa7a6SAndroid Build Coastguard Worker device = "cpu" 60*523fa7a6SAndroid Build Coastguard Worker # flake8: noqa: TOR102 61*523fa7a6SAndroid Build Coastguard Worker cps = [] 62*523fa7a6SAndroid Build Coastguard Worker # Load sharded checkpoint. 63*523fa7a6SAndroid Build Coastguard Worker if checkpoint_dir is not None: 64*523fa7a6SAndroid Build Coastguard Worker # Load multiple checkpoint; ignore the single path. 65*523fa7a6SAndroid Build Coastguard Worker checkpoint_path = None 66*523fa7a6SAndroid Build Coastguard Worker for i in range(4): 67*523fa7a6SAndroid Build Coastguard Worker cp_name = f"consolidated.{i}.pth" 68*523fa7a6SAndroid Build Coastguard Worker print(f"Loading {cp_name}") 69*523fa7a6SAndroid Build Coastguard Worker cps.append( 70*523fa7a6SAndroid Build Coastguard Worker torch.load( 71*523fa7a6SAndroid Build Coastguard Worker os.path.join(checkpoint_dir, cp_name), 72*523fa7a6SAndroid Build Coastguard Worker map_location=device, 73*523fa7a6SAndroid Build Coastguard Worker mmap=True, 74*523fa7a6SAndroid Build Coastguard Worker ) 75*523fa7a6SAndroid Build Coastguard Worker ) 76*523fa7a6SAndroid Build Coastguard Worker checkpoint = {} 77*523fa7a6SAndroid Build Coastguard Worker for key in cps[0].keys(): 78*523fa7a6SAndroid Build Coastguard Worker if not torch.allclose(cps[0][key], cps[1][key]): 79*523fa7a6SAndroid Build Coastguard Worker values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key]) 80*523fa7a6SAndroid Build Coastguard Worker if "wo" in key or "w2" in key: 81*523fa7a6SAndroid Build Coastguard Worker # Concat on dim=1 for "wo" and "w2". 82*523fa7a6SAndroid Build Coastguard Worker checkpoint[key] = torch.cat(values, dim=1) 83*523fa7a6SAndroid Build Coastguard Worker else: 84*523fa7a6SAndroid Build Coastguard Worker # Concat on dim=0 for everything else. 85*523fa7a6SAndroid Build Coastguard Worker checkpoint[key] = torch.cat(values, dim=0) 86*523fa7a6SAndroid Build Coastguard Worker else: 87*523fa7a6SAndroid Build Coastguard Worker # Do not duplicate layers shared between each checkpoint. 88*523fa7a6SAndroid Build Coastguard Worker checkpoint[key] = cps[0][key] 89*523fa7a6SAndroid Build Coastguard Worker # Load single checkpoint. 90*523fa7a6SAndroid Build Coastguard Worker else: 91*523fa7a6SAndroid Build Coastguard Worker checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True) 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Worker # If given checkpoint is fairseq, convert to llama checkpoint. 94*523fa7a6SAndroid Build Coastguard Worker fairseq2_checkpoint = kwargs.get("fairseq2", False) 95*523fa7a6SAndroid Build Coastguard Worker if fairseq2_checkpoint: 96*523fa7a6SAndroid Build Coastguard Worker print("Using fairseq2 checkpoint") 97*523fa7a6SAndroid Build Coastguard Worker checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint) 98*523fa7a6SAndroid Build Coastguard Worker if "model" in checkpoint: 99*523fa7a6SAndroid Build Coastguard Worker # NB: some checkpoint contains a "model" field, which is the actual weights dict 100*523fa7a6SAndroid Build Coastguard Worker checkpoint = checkpoint["model"] 101*523fa7a6SAndroid Build Coastguard Worker 102*523fa7a6SAndroid Build Coastguard Worker # Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2. 103*523fa7a6SAndroid Build Coastguard Worker if (not fairseq2_checkpoint) and checkpoint.get( 104*523fa7a6SAndroid Build Coastguard Worker "final_proj.weight", None 105*523fa7a6SAndroid Build Coastguard Worker ) is not None: 106*523fa7a6SAndroid Build Coastguard Worker raise ValueError( 107*523fa7a6SAndroid Build Coastguard Worker """ 108*523fa7a6SAndroid Build Coastguard Worker************************************************************ 109*523fa7a6SAndroid Build Coastguard WorkerThis looks like a Fairseq2 checkpoint (based on the presence 110*523fa7a6SAndroid Build Coastguard Workerof `final_proj.weight`. 111*523fa7a6SAndroid Build Coastguard Worker 112*523fa7a6SAndroid Build Coastguard WorkerYou can import Fairseq2 checkpoints using the --fairseq2 113*523fa7a6SAndroid Build Coastguard Workeroption, but --fairseq2 was not specified. Please verify 114*523fa7a6SAndroid Build Coastguard Workerthe checkpoint format to avoid generating faulty models. 115*523fa7a6SAndroid Build Coastguard Worker************************************************************ 116*523fa7a6SAndroid Build Coastguard Worker""" 117*523fa7a6SAndroid Build Coastguard Worker ) 118*523fa7a6SAndroid Build Coastguard Worker 119*523fa7a6SAndroid Build Coastguard Worker # Get checkpoint dtype. 120*523fa7a6SAndroid Build Coastguard Worker self.dtype = get_checkpoint_dtype(checkpoint) 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Worker with open(params_path, "r") as f: 123*523fa7a6SAndroid Build Coastguard Worker params = json.loads(f.read()) 124*523fa7a6SAndroid Build Coastguard Worker output_prune_map = None 125*523fa7a6SAndroid Build Coastguard Worker if self.output_prune_map_path is not None: 126*523fa7a6SAndroid Build Coastguard Worker with open(self.output_prune_map_path, "r") as f: 127*523fa7a6SAndroid Build Coastguard Worker output_prune_map = json.load(f) 128*523fa7a6SAndroid Build Coastguard Worker # Change keys from string to int (json only supports string keys). 129*523fa7a6SAndroid Build Coastguard Worker output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} 130*523fa7a6SAndroid Build Coastguard Worker input_prune_map = None 131*523fa7a6SAndroid Build Coastguard Worker if self.input_prune_map_path is not None: 132*523fa7a6SAndroid Build Coastguard Worker with open(self.input_prune_map_path, "r") as f: 133*523fa7a6SAndroid Build Coastguard Worker input_prune_map = json.load(f) 134*523fa7a6SAndroid Build Coastguard Worker # Change keys from string to int (json only supports string keys). 135*523fa7a6SAndroid Build Coastguard Worker input_prune_map = {int(k): v for (k, v) in input_prune_map.items()} 136*523fa7a6SAndroid Build Coastguard Worker 137*523fa7a6SAndroid Build Coastguard Worker model_args: ModelArgs = ModelArgs( 138*523fa7a6SAndroid Build Coastguard Worker max_seq_len=self.max_seq_len, 139*523fa7a6SAndroid Build Coastguard Worker max_batch_size=1, 140*523fa7a6SAndroid Build Coastguard Worker use_kv_cache=self.use_kv_cache, 141*523fa7a6SAndroid Build Coastguard Worker use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, 142*523fa7a6SAndroid Build Coastguard Worker generate_full_logits=self.generate_full_logits, 143*523fa7a6SAndroid Build Coastguard Worker input_prune_map=input_prune_map, 144*523fa7a6SAndroid Build Coastguard Worker output_prune_map=output_prune_map, 145*523fa7a6SAndroid Build Coastguard Worker enable_dynamic_shape=self.enable_dynamic_shape, 146*523fa7a6SAndroid Build Coastguard Worker **params, 147*523fa7a6SAndroid Build Coastguard Worker ) 148*523fa7a6SAndroid Build Coastguard Worker if kwargs.get("verbose", False): 149*523fa7a6SAndroid Build Coastguard Worker print("============= weights ================") 150*523fa7a6SAndroid Build Coastguard Worker print("{key} : {weights.numel()} : {weights.size()}") 151*523fa7a6SAndroid Build Coastguard Worker for key, weights in checkpoint.items(): 152*523fa7a6SAndroid Build Coastguard Worker print(f"{key} : {weights.numel()} : {weights.size()}") 153*523fa7a6SAndroid Build Coastguard Worker print("============= /weights ================") 154*523fa7a6SAndroid Build Coastguard Worker 155*523fa7a6SAndroid Build Coastguard Worker # Within the device="meta" context, tensors that are created do not carry data. 156*523fa7a6SAndroid Build Coastguard Worker # They possess all other metadata a tensor carries such as size, stride, requires_grad. 157*523fa7a6SAndroid Build Coastguard Worker with torch.device("meta"): 158*523fa7a6SAndroid Build Coastguard Worker self.model_ = Transformer(model_args) 159*523fa7a6SAndroid Build Coastguard Worker 160*523fa7a6SAndroid Build Coastguard Worker if "int8" in str(checkpoint_path): 161*523fa7a6SAndroid Build Coastguard Worker print("Using int8 weight-only quantization!") 162*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.examples.models.source_transformation.quantize` 163*523fa7a6SAndroid Build Coastguard Worker from ..source_transformation.quantize import WeightOnlyInt8QuantHandler 164*523fa7a6SAndroid Build Coastguard Worker 165*523fa7a6SAndroid Build Coastguard Worker simple_quantizer = WeightOnlyInt8QuantHandler(self.model_) 166*523fa7a6SAndroid Build Coastguard Worker self.model_ = simple_quantizer.convert_for_runtime() 167*523fa7a6SAndroid Build Coastguard Worker elif "8da4w" in str(checkpoint_path): 168*523fa7a6SAndroid Build Coastguard Worker print("Using int4 weight and int8 dynamic activation quantization!") 169*523fa7a6SAndroid Build Coastguard Worker from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer 170*523fa7a6SAndroid Build Coastguard Worker 171*523fa7a6SAndroid Build Coastguard Worker self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime( 172*523fa7a6SAndroid Build Coastguard Worker self.model_ 173*523fa7a6SAndroid Build Coastguard Worker ) 174*523fa7a6SAndroid Build Coastguard Worker elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant: 175*523fa7a6SAndroid Build Coastguard Worker print("Using SPIN quantization.") 176*523fa7a6SAndroid Build Coastguard Worker self._transform_for_pre_quantization(checkpoint, model_args) 177*523fa7a6SAndroid Build Coastguard Worker 178*523fa7a6SAndroid Build Coastguard Worker from .source_transformation.pre_quantization import ( 179*523fa7a6SAndroid Build Coastguard Worker sanitize_checkpoint_from_pre_quantization, 180*523fa7a6SAndroid Build Coastguard Worker ) 181*523fa7a6SAndroid Build Coastguard Worker 182*523fa7a6SAndroid Build Coastguard Worker sanitize_checkpoint_from_pre_quantization(checkpoint) 183*523fa7a6SAndroid Build Coastguard Worker elif hasattr(self.args, "use_qat") and self.args.use_qat: 184*523fa7a6SAndroid Build Coastguard Worker print("Using QAT quantization.") 185*523fa7a6SAndroid Build Coastguard Worker self._transform_for_pre_quantization(checkpoint, model_args) 186*523fa7a6SAndroid Build Coastguard Worker if hasattr(self.args, "use_lora") and self.args.use_lora: 187*523fa7a6SAndroid Build Coastguard Worker assert model_args.lora_args["rank"] == self.args.use_lora 188*523fa7a6SAndroid Build Coastguard Worker from .source_transformation.lora import ( 189*523fa7a6SAndroid Build Coastguard Worker transform_linear_for_lora_after_quantization, 190*523fa7a6SAndroid Build Coastguard Worker ) 191*523fa7a6SAndroid Build Coastguard Worker 192*523fa7a6SAndroid Build Coastguard Worker self.model_ = transform_linear_for_lora_after_quantization( 193*523fa7a6SAndroid Build Coastguard Worker self.model_, 194*523fa7a6SAndroid Build Coastguard Worker checkpoint, 195*523fa7a6SAndroid Build Coastguard Worker self.args.use_lora, 196*523fa7a6SAndroid Build Coastguard Worker ) 197*523fa7a6SAndroid Build Coastguard Worker 198*523fa7a6SAndroid Build Coastguard Worker from .source_transformation.pre_quantization import ( 199*523fa7a6SAndroid Build Coastguard Worker sanitize_checkpoint_from_pre_quantization, 200*523fa7a6SAndroid Build Coastguard Worker ) 201*523fa7a6SAndroid Build Coastguard Worker 202*523fa7a6SAndroid Build Coastguard Worker sanitize_checkpoint_from_pre_quantization(checkpoint) 203*523fa7a6SAndroid Build Coastguard Worker 204*523fa7a6SAndroid Build Coastguard Worker # assign=True: load params/buffers by assignment instead of performing an in-place copy. 205*523fa7a6SAndroid Build Coastguard Worker # Because we are using device="meta", tensors do not have memory associated with them 206*523fa7a6SAndroid Build Coastguard Worker # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. 207*523fa7a6SAndroid Build Coastguard Worker missing, unexpected = self.model_.load_state_dict( 208*523fa7a6SAndroid Build Coastguard Worker checkpoint, 209*523fa7a6SAndroid Build Coastguard Worker strict=False, 210*523fa7a6SAndroid Build Coastguard Worker assign=True, 211*523fa7a6SAndroid Build Coastguard Worker ) # self.model_ = Transformer(gptconf) 212*523fa7a6SAndroid Build Coastguard Worker if kwargs.get("verbose", False): 213*523fa7a6SAndroid Build Coastguard Worker print("============= missing keys ================") 214*523fa7a6SAndroid Build Coastguard Worker print(missing) 215*523fa7a6SAndroid Build Coastguard Worker print("============= /missing ================") 216*523fa7a6SAndroid Build Coastguard Worker print("============= unexpected keys ================") 217*523fa7a6SAndroid Build Coastguard Worker print(unexpected) 218*523fa7a6SAndroid Build Coastguard Worker print("============= /unexpected ================") 219*523fa7a6SAndroid Build Coastguard Worker 220*523fa7a6SAndroid Build Coastguard Worker # Prune the input layer if input_prune_map is provided 221*523fa7a6SAndroid Build Coastguard Worker if input_prune_map is not None: 222*523fa7a6SAndroid Build Coastguard Worker from .source_transformation.prune_vocab import prune_input_vocab 223*523fa7a6SAndroid Build Coastguard Worker 224*523fa7a6SAndroid Build Coastguard Worker self.model_ = prune_input_vocab(self.model_, input_prune_map) 225*523fa7a6SAndroid Build Coastguard Worker 226*523fa7a6SAndroid Build Coastguard Worker # Prune the output layer if output_prune_map is provided 227*523fa7a6SAndroid Build Coastguard Worker if output_prune_map is not None: 228*523fa7a6SAndroid Build Coastguard Worker from .source_transformation.prune_vocab import prune_output_vocab 229*523fa7a6SAndroid Build Coastguard Worker 230*523fa7a6SAndroid Build Coastguard Worker self.model_ = prune_output_vocab(self.model_, output_prune_map) 231*523fa7a6SAndroid Build Coastguard Worker 232*523fa7a6SAndroid Build Coastguard Worker def get_eager_model(self) -> torch.nn.Module: 233*523fa7a6SAndroid Build Coastguard Worker if self.dtype: 234*523fa7a6SAndroid Build Coastguard Worker # convert to the type of the provided checkpoint 235*523fa7a6SAndroid Build Coastguard Worker # input and output are torch.long, so signature unchanged 236*523fa7a6SAndroid Build Coastguard Worker return self.model_.to(self.dtype) 237*523fa7a6SAndroid Build Coastguard Worker else: 238*523fa7a6SAndroid Build Coastguard Worker # int8 quantization code has some bf16, 239*523fa7a6SAndroid Build Coastguard Worker # switch all to FP32 240*523fa7a6SAndroid Build Coastguard Worker return self.model_.to(torch.float32) 241*523fa7a6SAndroid Build Coastguard Worker 242*523fa7a6SAndroid Build Coastguard Worker def get_example_inputs(self): 243*523fa7a6SAndroid Build Coastguard Worker if self.use_kv_cache: 244*523fa7a6SAndroid Build Coastguard Worker return self.get_example_inputs_kvcache_sdpa() 245*523fa7a6SAndroid Build Coastguard Worker else: 246*523fa7a6SAndroid Build Coastguard Worker return ( 247*523fa7a6SAndroid Build Coastguard Worker torch.tensor( 248*523fa7a6SAndroid Build Coastguard Worker [[1, 2, 3]], dtype=torch.long 249*523fa7a6SAndroid Build Coastguard Worker ), # tokens, with kv cache our input token length is always just 1 token. 250*523fa7a6SAndroid Build Coastguard Worker ) 251*523fa7a6SAndroid Build Coastguard Worker 252*523fa7a6SAndroid Build Coastguard Worker # assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working 253*523fa7a6SAndroid Build Coastguard Worker def get_example_inputs_kvcache_sdpa(self): 254*523fa7a6SAndroid Build Coastguard Worker if self.enable_dynamic_shape: 255*523fa7a6SAndroid Build Coastguard Worker return ( 256*523fa7a6SAndroid Build Coastguard Worker torch.tensor([[2, 3, 4]], dtype=torch.long), 257*523fa7a6SAndroid Build Coastguard Worker torch.tensor([0], dtype=torch.long), 258*523fa7a6SAndroid Build Coastguard Worker ) 259*523fa7a6SAndroid Build Coastguard Worker else: 260*523fa7a6SAndroid Build Coastguard Worker return ( 261*523fa7a6SAndroid Build Coastguard Worker torch.tensor( 262*523fa7a6SAndroid Build Coastguard Worker [[1]], dtype=torch.long 263*523fa7a6SAndroid Build Coastguard Worker ), # tokens, with kv cache our input token length is always just 1 token. 264*523fa7a6SAndroid Build Coastguard Worker torch.tensor( 265*523fa7a6SAndroid Build Coastguard Worker [0], dtype=torch.long 266*523fa7a6SAndroid Build Coastguard Worker ), # start_pos, what token of output are we on. 267*523fa7a6SAndroid Build Coastguard Worker ) 268*523fa7a6SAndroid Build Coastguard Worker 269*523fa7a6SAndroid Build Coastguard Worker def _transform_for_pre_quantization(self, checkpoint, model_args): 270*523fa7a6SAndroid Build Coastguard Worker assert hasattr(self.args, "preq_mode"), "preq_mode must be specified" 271*523fa7a6SAndroid Build Coastguard Worker assert self.args.preq_mode in [ 272*523fa7a6SAndroid Build Coastguard Worker "8da4w", 273*523fa7a6SAndroid Build Coastguard Worker "8da4w_output_8da8w", 274*523fa7a6SAndroid Build Coastguard Worker ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant." 275*523fa7a6SAndroid Build Coastguard Worker assert hasattr( 276*523fa7a6SAndroid Build Coastguard Worker self.args, "preq_group_size" 277*523fa7a6SAndroid Build Coastguard Worker ), "preq_group_size must be specified" 278*523fa7a6SAndroid Build Coastguard Worker assert hasattr(self.args, "dtype_override"), "dtype_override must be specified" 279*523fa7a6SAndroid Build Coastguard Worker from .source_transformation.pre_quantization import ( 280*523fa7a6SAndroid Build Coastguard Worker transform_linear_for_pre_quantization, 281*523fa7a6SAndroid Build Coastguard Worker ) 282*523fa7a6SAndroid Build Coastguard Worker 283*523fa7a6SAndroid Build Coastguard Worker assert self.args.preq_group_size == model_args.quantization_args["group_size"] 284*523fa7a6SAndroid Build Coastguard Worker 285*523fa7a6SAndroid Build Coastguard Worker mapping = { 286*523fa7a6SAndroid Build Coastguard Worker "fp32": torch.float32, 287*523fa7a6SAndroid Build Coastguard Worker "fp16": torch.float16, 288*523fa7a6SAndroid Build Coastguard Worker "bf16": torch.bfloat16, 289*523fa7a6SAndroid Build Coastguard Worker } 290*523fa7a6SAndroid Build Coastguard Worker 291*523fa7a6SAndroid Build Coastguard Worker # Transform the output layer first if needed. 292*523fa7a6SAndroid Build Coastguard Worker if self.args.preq_mode == "8da4w_output_8da8w": 293*523fa7a6SAndroid Build Coastguard Worker from .source_transformation.pre_quantization import ( 294*523fa7a6SAndroid Build Coastguard Worker transform_output_linear_for_pre_quantization, 295*523fa7a6SAndroid Build Coastguard Worker ) 296*523fa7a6SAndroid Build Coastguard Worker 297*523fa7a6SAndroid Build Coastguard Worker self.model_ = transform_output_linear_for_pre_quantization( 298*523fa7a6SAndroid Build Coastguard Worker module=self.model_, 299*523fa7a6SAndroid Build Coastguard Worker checkpoint=checkpoint, 300*523fa7a6SAndroid Build Coastguard Worker dtype=mapping[self.args.dtype_override], 301*523fa7a6SAndroid Build Coastguard Worker ) 302*523fa7a6SAndroid Build Coastguard Worker 303*523fa7a6SAndroid Build Coastguard Worker self.model_ = transform_linear_for_pre_quantization( 304*523fa7a6SAndroid Build Coastguard Worker self.model_, 305*523fa7a6SAndroid Build Coastguard Worker checkpoint, 306*523fa7a6SAndroid Build Coastguard Worker self.args.preq_group_size, 307*523fa7a6SAndroid Build Coastguard Worker mapping[self.args.dtype_override], 308*523fa7a6SAndroid Build Coastguard Worker ) 309*523fa7a6SAndroid Build Coastguard Worker 310*523fa7a6SAndroid Build Coastguard Worker embedding_bit_width, embedding_group_size = None, None 311*523fa7a6SAndroid Build Coastguard Worker if hasattr(self.args, "preq_embedding_quantize"): 312*523fa7a6SAndroid Build Coastguard Worker embedding_bit_width, embedding_group_size = ( 313*523fa7a6SAndroid Build Coastguard Worker self.args.preq_embedding_quantize.split(",") 314*523fa7a6SAndroid Build Coastguard Worker ) 315*523fa7a6SAndroid Build Coastguard Worker from .source_transformation.pre_quantization import ( 316*523fa7a6SAndroid Build Coastguard Worker transform_embedding_for_pre_quantization, 317*523fa7a6SAndroid Build Coastguard Worker ) 318*523fa7a6SAndroid Build Coastguard Worker 319*523fa7a6SAndroid Build Coastguard Worker if ( 320*523fa7a6SAndroid Build Coastguard Worker embedding_group_size == "none" 321*523fa7a6SAndroid Build Coastguard Worker or embedding_group_size == "None" 322*523fa7a6SAndroid Build Coastguard Worker or embedding_group_size == "0" 323*523fa7a6SAndroid Build Coastguard Worker ): 324*523fa7a6SAndroid Build Coastguard Worker embedding_group_size = None 325*523fa7a6SAndroid Build Coastguard Worker else: 326*523fa7a6SAndroid Build Coastguard Worker embedding_group_size = int(embedding_group_size) 327*523fa7a6SAndroid Build Coastguard Worker 328*523fa7a6SAndroid Build Coastguard Worker self.model_ = transform_embedding_for_pre_quantization( 329*523fa7a6SAndroid Build Coastguard Worker self.model_, 330*523fa7a6SAndroid Build Coastguard Worker checkpoint, 331*523fa7a6SAndroid Build Coastguard Worker mapping[self.args.dtype_override], 332*523fa7a6SAndroid Build Coastguard Worker int(embedding_bit_width), 333*523fa7a6SAndroid Build Coastguard Worker embedding_group_size, 334*523fa7a6SAndroid Build Coastguard Worker ) 335