1# mypy: allow-untyped-defs 2import typing 3 4import torch 5 6 7__all__ = [ 8 "ReferenceQuantizedModule", 9] 10 11 12class ReferenceQuantizedModule(torch.nn.Module): 13 def _init_weight_qparams(self, weight_qparams, device): 14 if weight_qparams is None: 15 weight_qparams = { 16 "qscheme": torch.per_tensor_affine, 17 "dtype": torch.quint8, 18 "scale": 1.0, 19 "zero_point": 0, 20 } 21 self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"] 22 self.weight_dtype = weight_qparams["dtype"] 23 assert self.weight_qscheme in [ 24 None, 25 torch.per_tensor_affine, 26 torch.per_channel_affine, 27 torch.per_channel_affine_float_qparams, 28 ], f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}" 29 if self.weight_dtype in [ 30 torch.quint8, 31 torch.qint8, 32 torch.quint4x2, 33 torch.qint32, 34 ]: 35 zero_point_dtype = ( 36 weight_qparams["zero_point"].dtype 37 if isinstance(weight_qparams["zero_point"], torch.Tensor) 38 else torch.int 39 ) 40 w_scale = weight_qparams["scale"] 41 w_scale_tensor = ( 42 w_scale.clone().detach() 43 if isinstance(w_scale, torch.Tensor) 44 else torch.tensor(w_scale, dtype=torch.float, device=device) 45 ) 46 self.register_buffer("weight_scale", w_scale_tensor) 47 w_zp = weight_qparams["zero_point"] 48 w_zp_tensor = ( 49 w_zp.clone().detach() 50 if isinstance(w_zp, torch.Tensor) 51 else torch.tensor(w_zp, dtype=zero_point_dtype, device=device) 52 ) 53 self.register_buffer("weight_zero_point", w_zp_tensor) 54 if self.weight_qscheme in [ 55 torch.per_channel_affine, 56 torch.per_channel_affine_float_qparams, 57 ]: 58 w_axis = weight_qparams["axis"] 59 w_axis_tensor = ( 60 w_axis.clone().detach() 61 if isinstance(w_axis, torch.Tensor) 62 else torch.tensor(w_axis, dtype=torch.int, device=device) 63 ) 64 self.register_buffer("weight_axis", w_axis_tensor) 65 else: 66 # added for TorchScriptability, not used 67 self.register_buffer( 68 "weight_axis", torch.tensor(0, dtype=torch.int, device=device) 69 ) 70 else: 71 # added for TorchScriptability, and for torch.float 72 self.register_buffer( 73 "weight_scale", torch.tensor(1.0, dtype=torch.float, device=device) 74 ) 75 self.register_buffer( 76 "weight_zero_point", torch.tensor(0, dtype=torch.int, device=device) 77 ) 78 self.register_buffer( 79 "weight_axis", torch.tensor(0, dtype=torch.int, device=device) 80 ) 81 self.is_decomposed: bool = weight_qparams.get("is_decomposed", False) 82 # store weight_axis as weight_axis_int due to some constraints of torchdynamo.export 83 # for capturing `.item` operations 84 self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment] 85 self.weight_quant_min: typing.Optional[int] = weight_qparams.get( 86 "quant_min", None 87 ) 88 self.weight_quant_max: typing.Optional[int] = weight_qparams.get( 89 "quant_max", None 90 ) 91 92 def get_weight(self): 93 """ 94 Fake quantize (quantize and dequantize) the weight with 95 the quantization parameters for weight, this is used to 96 simulate the numerics for the quantized weight in a quantized 97 model 98 """ 99 # suppress mypy warning 100 assert isinstance(self.weight_scale, torch.Tensor) 101 assert isinstance(self.weight_zero_point, torch.Tensor) 102 if self.is_decomposed: 103 return _quantize_and_dequantize_weight_decomposed( 104 self.weight, # type: ignore[arg-type] 105 self.weight_qscheme, 106 self.weight_dtype, 107 self.weight_scale, 108 self.weight_zero_point, 109 self.weight_axis_int, 110 self.weight_quant_min, 111 self.weight_quant_max, 112 ) 113 else: 114 return _quantize_and_dequantize_weight( 115 self.weight, # type: ignore[arg-type] 116 self.weight_qscheme, 117 self.weight_dtype, 118 self.weight_scale, 119 self.weight_zero_point, 120 self.weight_axis_int, 121 ) 122 123 def get_quantized_weight(self): 124 # suppress mypy warning 125 assert isinstance(self.weight_scale, torch.Tensor) 126 assert isinstance(self.weight_zero_point, torch.Tensor) 127 # assert isinstance(self.weight_axis, torch.Tensor) 128 if self.is_decomposed: 129 return _quantize_weight_decomposed( 130 self.weight, # type: ignore[arg-type] 131 self.weight_qscheme, 132 self.weight_dtype, 133 self.weight_scale, 134 self.weight_zero_point, 135 self.weight_axis_int, 136 self.weight_quant_min, 137 self.weight_quant_max, 138 ) 139 else: 140 return _quantize_weight( 141 self.weight, # type: ignore[arg-type] 142 self.weight_qscheme, 143 self.weight_dtype, 144 self.weight_scale, 145 self.weight_zero_point, 146 self.weight_axis_int, 147 ) 148 149 def _save_to_state_dict(self, destination, prefix, keep_vars): 150 super()._save_to_state_dict(destination, prefix, keep_vars) 151 _save_weight_qparams( 152 destination, 153 prefix, 154 self.weight_qscheme, 155 self.weight_dtype, 156 self.weight_scale, 157 self.weight_zero_point, 158 self.weight_axis, 159 ) 160 161 def _load_from_state_dict( 162 self, 163 state_dict, 164 prefix, 165 local_metadata, 166 strict, 167 missing_keys, 168 unexpected_keys, 169 error_msgs, 170 ): 171 for key in _get_weight_qparam_keys(state_dict, prefix): 172 setattr(self, key, state_dict[prefix + key]) 173 state_dict.pop(prefix + key) 174 175 super()._load_from_state_dict( 176 state_dict, 177 prefix, 178 local_metadata, 179 False, 180 missing_keys, 181 unexpected_keys, 182 error_msgs, 183 ) 184 185 186def _quantize_weight_decomposed( 187 weight: torch.Tensor, 188 weight_qscheme: torch.qscheme, 189 weight_dtype: torch.dtype, 190 weight_scale: torch.Tensor, 191 weight_zero_point: torch.Tensor, 192 weight_axis: int, 193 weight_quant_min: typing.Optional[int], 194 weight_quant_max: typing.Optional[int], 195) -> torch.Tensor: 196 _DTYPE_TO_QVALUE_BOUNDS = { 197 torch.uint8: (0, 255), 198 torch.int8: (-128, 127), 199 torch.int32: (-(2**31), 2**31 - 1), 200 } 201 # TODO: add an util function for converting qdtype to dtype 202 _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = { 203 torch.quint8: torch.uint8, 204 torch.qint8: torch.int8, 205 torch.qint32: torch.int32, 206 } 207 if weight_qscheme == torch.per_tensor_affine: 208 if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: 209 weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] 210 if weight_quant_min is None or weight_quant_max is None: 211 weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[ 212 weight_dtype_ 213 ] 214 weight = torch.ops.quantized_decomposed.quantize_per_tensor( 215 weight, 216 weight_scale, 217 weight_zero_point, 218 weight_quant_min, 219 weight_quant_max, 220 weight_dtype_, 221 ) 222 return weight 223 elif weight_qscheme in [ 224 torch.per_channel_affine, 225 torch.per_channel_affine_float_qparams, 226 ]: 227 # TODO: torch.quint4x2 is not supported 228 if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: 229 weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] 230 if weight_quant_min is None or weight_quant_max is None: 231 weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[ 232 weight_dtype_ 233 ] 234 weight = torch.ops.quantized_decomposed.quantize_per_channel( 235 weight, 236 weight_scale, 237 weight_zero_point, 238 weight_axis, 239 weight_quant_min, 240 weight_quant_max, 241 weight_dtype_, 242 ) # type: ignore[arg-type] 243 return weight 244 raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") 245 246 247def _dequantize_weight_decomposed( 248 weight: torch.Tensor, 249 weight_qscheme: torch.qscheme, 250 weight_dtype: torch.dtype, 251 weight_scale: torch.Tensor, 252 weight_zero_point: torch.Tensor, 253 weight_axis: int, 254 weight_quant_min: typing.Optional[int], 255 weight_quant_max: typing.Optional[int], 256) -> torch.Tensor: 257 # TODO: get the quant_min and quant_max from activation_post_process 258 _DTYPE_TO_QVALUE_BOUNDS = { 259 torch.uint8: (0, 255), 260 torch.int8: (-128, 127), 261 torch.int32: (-(2**31), 2**31 - 1), 262 } 263 # TODO: add an util function for converting qdtype to dtype 264 _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = { 265 torch.quint8: torch.uint8, 266 torch.qint8: torch.int8, 267 torch.qint32: torch.int32, 268 } 269 weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] 270 if weight_quant_min is None or weight_quant_max is None: 271 weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_] 272 if weight_qscheme == torch.per_tensor_affine: 273 if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: 274 weight = torch.ops.quantized_decomposed.dequantize_per_tensor( 275 weight, 276 weight_scale, 277 weight_zero_point, 278 weight_quant_min, 279 weight_quant_max, 280 weight_dtype_, 281 ) 282 return weight 283 elif weight_qscheme in [ 284 torch.per_channel_affine, 285 torch.per_channel_affine_float_qparams, 286 ]: 287 # TODO: torch.quint4x2 is not supported 288 if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: 289 weight = torch.ops.quantized_decomposed.dequantize_per_channel( 290 weight, 291 weight_scale, 292 weight_zero_point, 293 weight_axis, 294 weight_quant_min, 295 weight_quant_max, 296 weight_dtype_, 297 ) # type: ignore[arg-type] 298 return weight 299 raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") 300 301 302def _quantize_weight( 303 weight: torch.Tensor, 304 weight_qscheme: torch.qscheme, 305 weight_dtype: torch.dtype, 306 weight_scale: torch.Tensor, 307 weight_zero_point: torch.Tensor, 308 weight_axis_int: int, 309) -> torch.Tensor: 310 if weight_dtype == torch.float16: 311 weight = weight.to(weight_dtype) 312 return weight 313 314 if weight_qscheme == torch.per_tensor_affine: 315 if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: 316 weight = torch.quantize_per_tensor( 317 weight, weight_scale, weight_zero_point, weight_dtype 318 ) 319 return weight 320 elif weight_qscheme in [ 321 torch.per_channel_affine, 322 torch.per_channel_affine_float_qparams, 323 ]: 324 if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]: 325 weight = torch.quantize_per_channel( 326 weight, weight_scale, weight_zero_point, weight_axis_int, weight_dtype 327 ) # type: ignore[arg-type] 328 return weight 329 raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") 330 331 332def _quantize_and_dequantize_weight_decomposed( 333 weight: torch.Tensor, 334 weight_qscheme: torch.qscheme, 335 weight_dtype: torch.dtype, 336 weight_scale: torch.Tensor, 337 weight_zero_point: torch.Tensor, 338 weight_axis_int: int, 339 weight_quant_min: typing.Optional[int], 340 weight_quant_max: typing.Optional[int], 341) -> torch.Tensor: 342 """Quantize and then dequantize the weight based on 343 the quantization parameters 344 """ 345 if weight_qscheme in [ 346 torch.per_tensor_affine, 347 torch.per_channel_affine, 348 torch.per_channel_affine_float_qparams, 349 ]: 350 weight_quant = _quantize_weight_decomposed( 351 weight, 352 weight_qscheme, 353 weight_dtype, 354 weight_scale, 355 weight_zero_point, 356 weight_axis_int, 357 weight_quant_min, 358 weight_quant_max, 359 ) 360 weight_dequant = _dequantize_weight_decomposed( 361 weight_quant, 362 weight_qscheme, 363 weight_dtype, 364 weight_scale, 365 weight_zero_point, 366 weight_axis_int, 367 weight_quant_min, 368 weight_quant_max, 369 ) 370 else: 371 weight_dequant = weight 372 return weight_dequant 373 374 375def _quantize_and_dequantize_weight( 376 weight: torch.Tensor, 377 weight_qscheme: torch.qscheme, 378 weight_dtype: torch.dtype, 379 weight_scale: torch.Tensor, 380 weight_zero_point: torch.Tensor, 381 weight_axis_int: int, 382) -> torch.Tensor: 383 """Quantize and then dequantize the weight based on 384 the quantization parameters 385 """ 386 if weight_qscheme in [ 387 torch.per_tensor_affine, 388 torch.per_channel_affine, 389 torch.per_channel_affine_float_qparams, 390 ]: 391 weight_quant = _quantize_weight( 392 weight, 393 weight_qscheme, 394 weight_dtype, 395 weight_scale, 396 weight_zero_point, 397 weight_axis_int, 398 ) 399 weight_dequant = weight_quant.dequantize() 400 else: 401 weight_dequant = weight 402 return weight_dequant 403 404 405def _save_weight_qparams( 406 destination, 407 prefix, 408 weight_qscheme, 409 weight_dtype, 410 weight_scale, 411 weight_zero_point, 412 weight_axis, 413): 414 destination[prefix + "weight_qscheme"] = weight_qscheme 415 destination[prefix + "weight_dtype"] = weight_dtype 416 if weight_qscheme is not None: 417 destination[prefix + "weight_scale"] = weight_scale 418 destination[prefix + "weight_zero_point"] = weight_zero_point 419 if weight_qscheme == torch.per_channel_affine: 420 destination[prefix + "weight_axis"] = weight_axis 421 422 423def _get_weight_qparam_keys(state_dict: typing.Dict[str, typing.Any], prefix: str): 424 keys = ["weight_qscheme", "weight_dtype"] 425 weight_qscheme = state_dict[prefix + "weight_qscheme"] 426 if weight_qscheme is not None: 427 keys.append("weight_scale") 428 keys.append("weight_zero_point") 429 if weight_qscheme == torch.quantize_per_channel: 430 keys.append("weight_axis") 431 return keys 432