1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9# Helper functions for tranforming the model to be able to load pre-quantized checkpoints. 10 11from typing import Any, Optional 12 13import torch 14from torch import nn 15 16from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear 17from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter 18 19from .quantize import Int8DynActInt8WeightLinear, QuantizedGroupEmbedding 20 21 22def _replace_linear_with_linear_8da4w_for_pre_quantization( 23 module: torch.nn.Module, 24 checkpoint: Any, 25 group_size: int, 26 precision: torch.dtype, 27 scales_precision: torch.dtype, 28): 29 def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: 30 # Only replace linear layers where the checkpoint contains explicit scales 31 scales_key = f"{cur_fqn}.scales" 32 if isinstance(child, nn.Linear) and scales_key in checkpoint: 33 assert _check_linear_int4_k(child.in_features, group_size) 34 assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 35 assert checkpoint[scales_key].dtype == scales_precision 36 return True 37 return False 38 39 def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: 40 new_linear = Int8DynActInt4WeightLinear( 41 # pyre-fixme[6]: For 1st argument expected `int` but got `Union[Module, 42 # Tensor]`. 43 child.in_features, 44 # pyre-fixme[6]: For 2nd argument expected `int` but got `Union[Module, 45 # Tensor]`. 46 child.out_features, 47 bias=False, 48 device=child.weight.device, 49 groupsize=group_size, 50 precision=precision, 51 scales_precision=scales_precision, 52 ) 53 # TODO(lunwenh): Remove this once TorchAO's commit pin in ExecuTorch is updated to include this PR 54 new_linear.zeros = torch.zeros_like(new_linear.zeros) 55 return new_linear 56 57 _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) 58 59 60def transform_linear_for_pre_quantization( 61 module: torch.nn.Module, 62 checkpoint: Any, 63 group_size: int, 64 dtype: torch.dtype, 65) -> torch.nn.Module: 66 """ 67 Transform the model to be able to load pre-quantized checkpoints that 68 are quantized with the given group size and quantization mode for 69 linear layers. 70 """ 71 72 if group_size not in [32, 64, 128, 256]: 73 raise ValueError( 74 f"Group size {group_size} is not supported for pre-quantized checkpoint." 75 ) 76 _replace_linear_with_linear_8da4w_for_pre_quantization( 77 module, 78 checkpoint, 79 group_size, 80 dtype, 81 dtype, 82 ) 83 return module 84 85 86def _replace_output_linear_with_linear_int8_for_pre_quantization( 87 module: torch.nn.Module, 88 checkpoint: Any, 89 dtype: torch.dtype, 90): 91 def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: 92 scales_key = f"{cur_fqn}.scales" 93 if ( 94 isinstance(child, nn.Linear) 95 and scales_key in checkpoint 96 and "output" in cur_fqn 97 ): 98 assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 99 assert checkpoint[scales_key].dtype == dtype 100 return True 101 return False 102 103 def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: 104 new_linear = Int8DynActInt8WeightLinear( 105 device=child.weight.device, 106 # pyre-fixme[6]: For 2nd argument expected `int` but got `Union[Module, 107 # Tensor]`. 108 in_features=child.in_features, 109 # pyre-fixme[6]: For 3rd argument expected `int` but got `Union[Module, 110 # Tensor]`. 111 out_features=child.out_features, 112 precision=dtype, 113 bias=False, 114 ) 115 return new_linear 116 117 _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) 118 119 120def transform_output_linear_for_pre_quantization( 121 module: torch.nn.Module, 122 checkpoint: Any, 123 dtype: torch.dtype, 124) -> torch.nn.Module: 125 """ 126 Transform the model to be able to load pre-quantized checkpoints that 127 has the output layer quantized per-channel. 128 """ 129 _replace_output_linear_with_linear_int8_for_pre_quantization( 130 module, 131 checkpoint, 132 dtype, 133 ) 134 return module 135 136 137def _replace_embedding_with_quantized_group_embedding_for_pre_quantization( 138 module: torch.nn.Module, 139 checkpoint: Any, 140 dtype: torch.dtype, 141 bit_width: int, 142 group_size: Optional[int] = None, 143): 144 def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: 145 # Only replace embedding layers where the checkpoint contains explicit scales 146 scales_key = f"{cur_fqn}.scales" 147 if isinstance(child, nn.Embedding) and scales_key in checkpoint: 148 assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 149 assert checkpoint[scales_key].dtype == torch.float32 150 return True 151 return False 152 153 def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: 154 new_embedding = QuantizedGroupEmbedding( 155 device=child.weight.device, 156 vocab_size=child.weight.shape[0], 157 embedding_dim=child.weight.shape[1], 158 group_size=group_size, 159 dtype=dtype, 160 packed=False, # TODO(lunwenh): support packed embedding for pre-quantized 161 ) 162 return new_embedding 163 164 _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) 165 166 167def transform_embedding_for_pre_quantization( 168 module: torch.nn.Module, 169 checkpoint: Any, 170 dtype: torch.dtype, 171 bit_width: int, 172 group_size: Optional[int] = None, 173) -> torch.nn.Module: 174 """ 175 Transform the model to be able to load pre-quantized checkpoints that 176 are quantized with the given bit_width and group size for embedding. 177 """ 178 if group_size is not None and group_size not in [0, 32, 64, 128, 256]: 179 raise ValueError( 180 f"Group size {group_size} is not supported for pre-quantized checkpoint." 181 ) 182 _replace_embedding_with_quantized_group_embedding_for_pre_quantization( 183 module, 184 checkpoint, 185 dtype, 186 bit_width, 187 group_size, 188 ) 189 return module 190 191 192def sanitize_checkpoint_from_pre_quantization( 193 checkpoint: Any, 194): 195 """ 196 Sanitize the pre-quantized checkpoint. 197 - Converts all tensors to contiguous format 198 - Squeeze all tensors 199 """ 200 for k, v in checkpoint.items(): 201 checkpoint[k] = torch.squeeze(v.contiguous()) 202