1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3"""Implements modules used to perform fake quantization.""" 4 5import re 6from abc import ABC, abstractmethod 7from typing import Any, Tuple 8 9import torch 10from torch.ao.quantization.observer import ( 11 _with_args, 12 default_fixed_qparams_range_0to1_observer, 13 default_fixed_qparams_range_neg1to1_observer, 14 FixedQParamsObserver, 15 HistogramObserver, 16 MovingAverageMinMaxObserver, 17 MovingAveragePerChannelMinMaxObserver, 18) 19from torch.nn import Module 20 21 22__all__ = [ 23 "FakeQuantizeBase", 24 "FakeQuantize", 25 "FixedQParamsFakeQuantize", 26 "FusedMovingAvgObsFakeQuantize", 27 "disable_fake_quant", 28 "disable_observer", 29 "enable_fake_quant", 30 "enable_observer", 31 "default_fake_quant", 32 "default_weight_fake_quant", 33 "default_dynamic_fake_quant", 34 "default_fixed_qparams_range_neg1to1_fake_quant", 35 "default_fixed_qparams_range_0to1_fake_quant", 36 "default_symmetric_fixed_qparams_fake_quant", 37 "default_affine_fixed_qparams_fake_quant", 38 "default_per_channel_weight_fake_quant", 39 "default_embedding_fake_quant", 40 "default_embedding_fake_quant_4bit", 41 "default_histogram_fake_quant", 42 "default_fused_act_fake_quant", 43 "default_fused_wt_fake_quant", 44 "default_fused_per_channel_wt_fake_quant", 45 "fused_wt_fake_quant_range_neg_127_to_127", 46 "fused_per_channel_wt_fake_quant_range_neg_127_to_127", 47] 48 49 50def _is_per_channel(qscheme: "torch.qscheme") -> bool: 51 return qscheme in [ 52 torch.per_channel_symmetric, 53 torch.per_channel_affine, 54 torch.per_channel_affine_float_qparams, 55 ] 56 57 58def _is_per_tensor(qscheme: "torch.qscheme") -> bool: 59 return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] 60 61 62def _is_symmetric_quant(qscheme: "torch.qscheme") -> bool: 63 return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric] 64 65 66def _is_float_qparams(qscheme: "torch.qscheme") -> bool: 67 return qscheme in [ 68 torch.per_channel_affine_float_qparams, 69 ] 70 71 72class FakeQuantizeBase(ABC, Module): 73 r"""Base fake quantize module. 74 75 Base fake quantize module 76 Any fake quantize implementation should derive from this class. 77 78 Concrete fake quantize module should follow the same API. In forward, they will update 79 the statistics of the observed Tensor and fake quantize the input. They should also provide a 80 `calculate_qparams` function that computes the quantization parameters given 81 the collected statistics. 82 83 """ 84 85 fake_quant_enabled: torch.Tensor 86 observer_enabled: torch.Tensor 87 88 def __init__(self) -> None: 89 """Set fake_quant_enabled and observer_enabled.""" 90 super().__init__() 91 # fake_quant_enabled and observer_enabled are buffers to support their 92 # replication in DDP. Data type is uint8 because NCCL does not support 93 # bool tensors. 94 self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8)) 95 self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.uint8)) 96 97 @abstractmethod 98 def forward(self, x): 99 pass 100 101 @abstractmethod 102 def calculate_qparams(self, **kwargs): 103 pass 104 105 @torch.jit.export 106 def enable_fake_quant(self, enabled: bool = True) -> None: 107 self.fake_quant_enabled[0] = 1 if enabled else 0 108 109 @torch.jit.export 110 def disable_fake_quant(self): 111 self.enable_fake_quant(False) 112 113 @torch.jit.export 114 def enable_observer(self, enabled: bool = True) -> None: 115 self.observer_enabled[0] = 1 if enabled else 0 116 117 @torch.jit.export 118 def disable_observer(self): 119 self.enable_observer(False) 120 121 @classmethod 122 def with_args(cls, **kwargs): 123 fake_quant_constructor = _with_args(cls, **kwargs) 124 # need to assign the correct module to fake_quantize 125 # constructors to satisfy public v private requirements 126 fake_quant_constructor.__module__ = "torch.ao.quantization.fake_quantize" 127 return fake_quant_constructor 128 129 130class FakeQuantize(FakeQuantizeBase): 131 r"""Simulate the quantize and dequantize operations in training time. 132 133 The output of this module is given by:: 134 135 x_out = ( 136 clamp(round(x/scale + zero_point), quant_min, quant_max) - zero_point 137 ) * scale 138 139 * :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization 140 operators (choose_qparams -> q -> dq) or static quantization operators (q -> dq) 141 142 * :attr:`scale` defines the scale factor used for quantization. 143 144 * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to 145 146 * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that 147 statistics can still be updated. 148 149 * :attr:`observer_enabled` controls statistics collection on tensors 150 151 * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization, 152 allowable values are torch.qint8 and torch.quint8. 153 154 Args: 155 156 observer (module): Module for observing statistics on input tensors and calculating scale 157 and zero-point. 158 observer_kwargs (optional): Arguments for the observer module 159 160 Attributes: 161 activation_post_process (Module): User provided module that collects statistics on the input tensor and 162 provides a method to calculate scale and zero-point. 163 164 """ 165 166 scale: torch.Tensor 167 zero_point: torch.Tensor 168 169 def __init__( 170 self, 171 observer=MovingAverageMinMaxObserver, 172 quant_min=None, 173 quant_max=None, 174 is_dynamic=False, 175 **observer_kwargs, 176 ): 177 super().__init__() 178 # Populate quant_min/quant_max to observer_kwargs if valid 179 if quant_min is not None and quant_max is not None: 180 assert ( 181 quant_min <= quant_max 182 ), "quant_min must be less than or equal to quant_max" 183 dtype = observer_kwargs.get("dtype", torch.quint8) 184 if hasattr(observer, "p"): 185 # In case observer is _PartialWrapper, dtype can be stored in 186 # observer.p.keywords["dtype"] 187 dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get( 188 "dtype", dtype 189 ) 190 assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound" 191 assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound" 192 observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max}) 193 observer_kwargs["is_dynamic"] = is_dynamic 194 self.activation_post_process = observer(**observer_kwargs) 195 # TODO: keeping self.quant_min/max for BC; remove after a couple releases 196 # Users should use self.activation_post_process.quant_min 197 self.quant_min = self.activation_post_process.quant_min 198 self.quant_max = self.activation_post_process.quant_max 199 self.is_dynamic = self.activation_post_process.is_dynamic 200 if _is_float_qparams(self.activation_post_process.qscheme): 201 zero_point_dtype = torch.float 202 else: 203 zero_point_dtype = torch.int 204 self.register_buffer("scale", torch.tensor([1.0], dtype=torch.float)) 205 self.register_buffer("zero_point", torch.tensor([0], dtype=zero_point_dtype)) 206 self.dtype = self.activation_post_process.dtype 207 self.qscheme = self.activation_post_process.qscheme 208 self.ch_axis = ( 209 self.activation_post_process.ch_axis 210 if hasattr(self.activation_post_process, "ch_axis") 211 else -1 212 ) 213 assert _is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme), ( 214 "Only per channel and per tensor quantization are supported in fake quantize" 215 + " got qscheme: " 216 + str(self.qscheme) 217 ) 218 self.is_per_channel = _is_per_channel(self.qscheme) 219 220 @torch.jit.export 221 def calculate_qparams(self): 222 return self.activation_post_process.calculate_qparams() 223 224 def forward(self, X): 225 if self.observer_enabled[0] == 1: 226 self.activation_post_process(X.detach()) 227 _scale, _zero_point = self.calculate_qparams() 228 _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( 229 self.zero_point.device 230 ) 231 if self.scale.shape != _scale.shape: 232 self.scale.resize_(_scale.shape) 233 self.zero_point.resize_(_zero_point.shape) 234 self.scale.copy_(_scale) 235 self.zero_point.copy_(_zero_point) 236 237 if self.fake_quant_enabled[0] == 1: 238 if self.is_per_channel: 239 X = torch.fake_quantize_per_channel_affine( 240 X, 241 self.scale, 242 self.zero_point, 243 self.ch_axis, 244 self.activation_post_process.quant_min, 245 self.activation_post_process.quant_max, 246 ) 247 else: 248 X = torch.fake_quantize_per_tensor_affine( 249 X, 250 self.scale, 251 self.zero_point, 252 self.activation_post_process.quant_min, 253 self.activation_post_process.quant_max, 254 ) 255 return X 256 257 @torch.jit.export 258 def extra_repr(self): 259 return ( 260 f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " 261 f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, " 262 f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, " 263 f"scale={self.scale}, zero_point={self.zero_point}" 264 ) 265 266 def _save_to_state_dict(self, destination, prefix, keep_vars): 267 # We cannot currently register scalar values as buffers, so need to manually 268 # specify serialization here. 269 super()._save_to_state_dict(destination, prefix, keep_vars) 270 destination[prefix + "scale"] = self.scale 271 destination[prefix + "zero_point"] = self.zero_point 272 273 def _load_from_state_dict( 274 self, 275 state_dict, 276 prefix, 277 local_metadata, 278 strict, 279 missing_keys, 280 unexpected_keys, 281 error_msgs, 282 ): 283 # Removing this function throws an error that the size of the loaded tensor does not match the original size 284 # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass. 285 local_state = ["scale", "zero_point"] 286 for name in local_state: 287 key = prefix + name 288 if key in state_dict: 289 val = state_dict[key] 290 # Custom handling to allow loading scale and zero_point 291 # of size N into uninitialized buffers of size 0. The 292 # buffers are resized here, and the values are copied in 293 # the default state_dict loading code of the parent. 294 if name == "scale": 295 self.scale.resize_(val.shape) 296 else: 297 assert name == "zero_point" 298 self.zero_point.resize_(val.shape) 299 # For torchscript module we need to update the attributes here since we do not 300 # call the `_load_from_state_dict` function defined module.py 301 if torch.jit.is_scripting(): 302 if name == "scale": 303 self.scale.copy_(val) 304 else: 305 assert name == "zero_point" 306 self.zero_point.copy_(val) 307 elif strict: 308 missing_keys.append(key) 309 super()._load_from_state_dict( 310 state_dict, 311 prefix, 312 local_metadata, 313 strict, 314 missing_keys, 315 unexpected_keys, 316 error_msgs, 317 ) 318 319 320class FixedQParamsFakeQuantize(FakeQuantize): 321 """Simulate quantize and dequantize in training time. 322 323 Simulate quantize and dequantize with fixed quantization 324 parameters in training time. Only per tensor quantization 325 is supported. 326 """ 327 328 # TODO: rename observer to observer_ctr 329 def __init__(self, observer): 330 super().__init__(observer=observer) 331 assert ( 332 type(self.activation_post_process) == FixedQParamsObserver 333 ), f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}" 334 self._observer_ctr = observer 335 self.scale = self.activation_post_process.scale 336 self.zero_point = self.activation_post_process.zero_point 337 assert _is_per_tensor(self.qscheme), ( 338 "Only per tensor quantization is supported" 339 + " FixedQParamsFakeQuantize module, got qscheme:" 340 + str(self.qscheme) 341 ) 342 343 @torch.jit.export 344 def calculate_qparams(self): 345 return self.scale, self.zero_point 346 347 @torch.jit.export 348 def extra_repr(self): 349 """Define a string representation of the object's attributes.""" 350 return ( 351 f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " 352 f"scale={self.scale}, zero_point={self.zero_point}, " 353 f"dtype={self.dtype}, quant_min={self.activation_post_process.quant_min}, " 354 f"quant_max={self.activation_post_process.quant_max}, qscheme={self.qscheme}" 355 ) 356 357 358class FusedMovingAvgObsFakeQuantize(FakeQuantize): 359 r"""Define a fused module to observe the tensor. 360 361 Fused module that is used to observe the input tensor (compute min/max), compute 362 scale/zero_point and fake_quantize the tensor. 363 This module uses calculation similar MovingAverageMinMaxObserver for the inputs, 364 to compute the min/max values in order to compute the scale/zero_point. 365 The qscheme input in the observer is used to differentiate between symmetric/affine 366 quantization scheme. 367 368 The output of this module is given by 369 x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale 370 371 Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the 372 base class. 373 374 """ 375 376 def __init__( 377 self, 378 observer: Any = MovingAverageMinMaxObserver, 379 quant_min: int = 0, 380 quant_max: int = 255, 381 **observer_kwargs: Any, 382 ) -> None: 383 super().__init__(observer, quant_min, quant_max, **observer_kwargs) 384 assert isinstance( 385 self.activation_post_process, 386 (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver), 387 ), "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver" 388 self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long)) 389 self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long)) 390 self.is_symmetric_quant = _is_symmetric_quant( 391 self.activation_post_process.qscheme 392 ) 393 394 @torch.jit.export 395 def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: 396 return self.activation_post_process.calculate_qparams() 397 398 @torch.jit.export 399 def extra_repr(self) -> str: 400 return ( 401 f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " 402 f"scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}, " 403 f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, " 404 f"qscheme={self.qscheme}, reduce_range={self.activation_post_process.reduce_range}" 405 ) 406 407 def forward(self, X: torch.Tensor) -> torch.Tensor: 408 return torch.fused_moving_avg_obs_fake_quant( 409 X, 410 self.observer_enabled, 411 self.fake_quant_enabled, 412 self.activation_post_process.min_val, 413 self.activation_post_process.max_val, 414 self.scale, 415 self.zero_point, 416 self.activation_post_process.averaging_constant, 417 self.activation_post_process.quant_min, 418 self.activation_post_process.quant_max, 419 self.ch_axis, 420 self.is_per_channel, 421 self.is_symmetric_quant, 422 ) 423 424 425default_fake_quant = FakeQuantize.with_args( 426 observer=MovingAverageMinMaxObserver, 427 quant_min=0, 428 quant_max=255, 429 dtype=torch.quint8, 430 qscheme=torch.per_tensor_affine, 431 reduce_range=True, 432) 433""" 434Default fake_quant for activations. 435""" 436 437default_weight_fake_quant = FakeQuantize.with_args( 438 observer=MovingAverageMinMaxObserver, 439 quant_min=-128, 440 quant_max=127, 441 dtype=torch.qint8, 442 qscheme=torch.per_tensor_symmetric, 443 reduce_range=False, 444) 445""" 446Default fake_quant for weights. 447Observer is memoryless since averaging_constant is 1. 448""" 449 450default_dynamic_fake_quant = FakeQuantize.with_args( 451 observer=MovingAverageMinMaxObserver, 452 quant_min=0, 453 quant_max=255, 454 is_dynamic=True, 455 dtype=torch.quint8, 456 averaging_constant=1, 457) 458""" 459Default dynamic fake_quant for activations. 460""" 461 462default_fixed_qparams_range_neg1to1_fake_quant = FixedQParamsFakeQuantize.with_args( 463 observer=default_fixed_qparams_range_neg1to1_observer 464) 465default_fixed_qparams_range_0to1_fake_quant = FixedQParamsFakeQuantize.with_args( 466 observer=default_fixed_qparams_range_0to1_observer 467) 468# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases 469default_symmetric_fixed_qparams_fake_quant = ( 470 default_fixed_qparams_range_neg1to1_fake_quant 471) 472default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant 473 474default_per_channel_weight_fake_quant = FakeQuantize.with_args( 475 observer=MovingAveragePerChannelMinMaxObserver, 476 quant_min=-128, 477 quant_max=127, 478 dtype=torch.qint8, 479 qscheme=torch.per_channel_symmetric, 480 reduce_range=False, 481 ch_axis=0, 482) 483""" 484Default fake_quant for per-channel weights. 485Observer is memoryless since averaging_constant is 1. 486""" 487default_embedding_fake_quant = FakeQuantize.with_args( 488 observer=MovingAveragePerChannelMinMaxObserver, 489 qscheme=torch.per_channel_affine_float_qparams, 490 dtype=torch.quint8, 491 quant_min=0, 492 quant_max=255, 493 ch_axis=0, 494 averaging_constant=1, 495) 496""" 497Default fake_quant for embeddings. 498Observer is memoryless since averaging_constant is 1. 499""" 500 501default_embedding_fake_quant_4bit = FakeQuantize.with_args( 502 observer=MovingAveragePerChannelMinMaxObserver, 503 qscheme=torch.per_channel_affine_float_qparams, 504 ch_axis=0, 505 dtype=torch.quint4x2, 506 averaging_constant=1, 507) 508 509default_histogram_fake_quant = FakeQuantize.with_args( 510 observer=HistogramObserver, 511 quant_min=0, 512 quant_max=255, 513 dtype=torch.quint8, 514 qscheme=torch.per_tensor_affine, 515 reduce_range=True, 516) 517""" 518Fake_quant for activations using a histogram.. 519""" 520 521 522default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( 523 observer=MovingAverageMinMaxObserver, 524 quant_min=0, 525 quant_max=255, 526 dtype=torch.quint8, 527) 528 529""" 530Fused version of `default_fake_quant`, with improved performance. 531""" 532 533 534default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( 535 observer=MovingAverageMinMaxObserver, 536 quant_min=-128, 537 quant_max=127, 538 dtype=torch.qint8, 539 qscheme=torch.per_tensor_symmetric, 540) 541""" 542Fused version of `default_weight_fake_quant`, with improved performance. 543""" 544 545default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( 546 observer=MovingAveragePerChannelMinMaxObserver, 547 quant_min=-128, 548 quant_max=127, 549 dtype=torch.qint8, 550 qscheme=torch.per_channel_symmetric, 551) 552""" 553Fused version of `default_per_channel_weight_fake_quant`, with improved performance. 554""" 555 556fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args( 557 observer=MovingAverageMinMaxObserver, 558 quant_min=-127, 559 quant_max=127, 560 dtype=torch.qint8, 561 qscheme=torch.per_tensor_symmetric, 562 eps=2**-12, 563) 564""" 565Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128. 566""" 567 568fused_per_channel_wt_fake_quant_range_neg_127_to_127 = ( 569 FusedMovingAvgObsFakeQuantize.with_args( 570 observer=MovingAveragePerChannelMinMaxObserver, 571 quant_min=-127, 572 quant_max=127, 573 dtype=torch.qint8, 574 qscheme=torch.per_channel_symmetric, 575 eps=2**-12, 576 ) 577) 578 579""" 580Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128. 581""" 582 583 584def _is_fake_quant_script_module(mod): 585 """Return true if given mod is an instance of FakeQuantize script module.""" 586 if isinstance(mod, torch.jit.RecursiveScriptModule): 587 # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize' 588 suffix = mod._c.qualified_name.split(".", 1)[1] 589 name = re.sub(r"\.___torch_mangle_\d+", "", suffix) 590 return ( 591 name == "torch.ao.quantization.fake_quantize.FakeQuantize" 592 or name 593 == "torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize" 594 ) 595 return False 596 597 598def disable_fake_quant(mod): 599 """Disable fake quantization for the module. 600 601 Disable fake quantization for this module, if applicable. Example usage:: 602 603 # model is any PyTorch model 604 model.apply(torch.ao.quantization.disable_fake_quant) 605 606 """ 607 if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): 608 mod.disable_fake_quant() 609 610 611def enable_fake_quant(mod): 612 """Enable fake quantization for the module. 613 614 Enable fake quantization for this module, if applicable. Example usage:: 615 616 # model is any PyTorch model 617 model.apply(torch.ao.quantization.enable_fake_quant) 618 619 """ 620 if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): 621 mod.enable_fake_quant() 622 623 624def disable_observer(mod): 625 """Disable observation for this module. 626 627 Disable observation for this module, if applicable. Example usage:: 628 629 # model is any PyTorch model 630 model.apply(torch.ao.quantization.disable_observer) 631 632 """ 633 if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): 634 mod.disable_observer() 635 636 637def enable_observer(mod): 638 """Enable observation for this module. 639 640 Enable observation for this module, if applicable. Example usage:: 641 642 # model is any PyTorch model 643 model.apply(torch.ao.quantization.enable_observer) 644 645 """ 646 if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): 647 mod.enable_observer() 648