1# mypy: allow-untyped-defs 2from typing import List 3 4import torch 5from torch.nn.parameter import Parameter 6 7 8__all__: List[str] = [] 9 10 11class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase): 12 r"""Generalized extension of the FakeQuantize module in fake_quantize.py. 13 14 This is an extension of the FakeQuantize module in fake_quantize.py, which 15 supports more generalized lower-bit quantization and supports learning of the scale 16 and zero point parameters through backpropagation. 17 18 In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize 19 module also includes the following attributes to support quantization parameter learning. 20 21 * :attr:`channel_len` defines the length of the channel when initializing scale and zero point 22 for the per channel case. 23 24 * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are 25 normalized by the constant, which is proportional to the square root of the number of 26 elements in the tensor. The related literature justifying the use of this particular constant 27 can be found here: https://openreview.net/pdf?id=rkgO66VKDS. 28 29 * :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output. 30 31 * :attr:`static_enabled` defines the flag for using observer's static estimation for 32 scale and zero point. 33 34 * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point. 35 """ 36 37 def __init__( 38 self, 39 observer, 40 quant_min=0, 41 quant_max=255, 42 scale=1.0, 43 zero_point=0.0, 44 channel_len=-1, 45 use_grad_scaling=False, 46 **observer_kwargs, 47 ): 48 super().__init__() 49 assert quant_min < quant_max, "quant_min must be strictly less than quant_max." 50 self.quant_min = quant_min 51 self.quant_max = quant_max 52 # also pass quant_min and quant_max to observer 53 observer_kwargs["quant_min"] = quant_min 54 observer_kwargs["quant_max"] = quant_max 55 self.use_grad_scaling = use_grad_scaling 56 if channel_len == -1: 57 self.scale = Parameter(torch.tensor([scale])) 58 self.zero_point = Parameter(torch.tensor([zero_point])) 59 else: 60 assert ( 61 isinstance(channel_len, int) and channel_len > 0 62 ), "Channel size must be a positive integer." 63 self.scale = Parameter(torch.tensor([scale] * channel_len)) 64 self.zero_point = Parameter(torch.tensor([zero_point] * channel_len)) 65 66 self.activation_post_process = observer(**observer_kwargs) 67 assert ( 68 torch.iinfo(self.activation_post_process.dtype).min <= quant_min 69 ), "quant_min out of bound" 70 assert ( 71 quant_max <= torch.iinfo(self.activation_post_process.dtype).max 72 ), "quant_max out of bound" 73 self.dtype = self.activation_post_process.dtype 74 self.qscheme = self.activation_post_process.qscheme 75 self.ch_axis = ( 76 self.activation_post_process.ch_axis 77 if hasattr(self.activation_post_process, "ch_axis") 78 else -1 79 ) 80 self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8)) 81 self.register_buffer("static_enabled", torch.tensor([1], dtype=torch.uint8)) 82 self.register_buffer("learning_enabled", torch.tensor([0], dtype=torch.uint8)) 83 84 bitrange = torch.tensor(quant_max - quant_min + 1).double() 85 self.bitwidth = int(torch.log2(bitrange).item()) 86 self.register_buffer("eps", torch.tensor([torch.finfo(torch.float32).eps])) 87 88 @torch.jit.export 89 def enable_param_learning(self): 90 r"""Enable parameter learning over static observer estimates. 91 92 Enables learning of quantization parameters and 93 disables static observer estimates. Forward path returns fake quantized X. 94 """ 95 self.toggle_qparam_learning(enabled=True).toggle_fake_quant( 96 enabled=True 97 ).toggle_observer_update(enabled=False) 98 return self 99 100 @torch.jit.export 101 def enable_static_estimate(self): 102 """Enable static estimates of quantization parameters. 103 104 Enables static observer estimates and disables learning of 105 quantization parameters. Forward path returns fake quantized X. 106 """ 107 self.toggle_qparam_learning(enabled=False).toggle_fake_quant( 108 enabled=True 109 ).toggle_observer_update(enabled=True) 110 111 @torch.jit.export 112 def enable_static_observation(self): 113 """Enable accumulation of data without updating quantization parameters. 114 115 Enables static observer accumulating data from input but doesn't 116 update the quantization parameters. Forward path returns the original X. 117 """ 118 self.toggle_qparam_learning(enabled=False).toggle_fake_quant( 119 enabled=False 120 ).toggle_observer_update(enabled=True) 121 122 @torch.jit.export 123 def toggle_observer_update(self, enabled=True): 124 self.static_enabled[0] = int(enabled) # type: ignore[operator] 125 return self 126 127 @torch.jit.export 128 def enable_observer(self, enabled=True): 129 self.toggle_observer_update(enabled) 130 131 @torch.jit.export 132 def toggle_qparam_learning(self, enabled=True): 133 self.learning_enabled[0] = int(enabled) # type: ignore[operator] 134 self.scale.requires_grad = enabled 135 self.zero_point.requires_grad = enabled 136 return self 137 138 @torch.jit.export 139 def toggle_fake_quant(self, enabled=True): 140 self.fake_quant_enabled[0] = int(enabled) 141 return self 142 143 @torch.jit.export 144 def observe_quant_params(self): 145 print(f"_LearnableFakeQuantize Scale: {self.scale.detach()}") 146 print(f"_LearnableFakeQuantize Zero Point: {self.zero_point.detach()}") 147 148 @torch.jit.export 149 def calculate_qparams(self): 150 self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] 151 scale = self.scale.detach() 152 zero_point = ( 153 self.zero_point.detach() 154 .round() 155 .clamp(self.quant_min, self.quant_max) 156 .long() 157 ) 158 return scale, zero_point 159 160 def forward(self, X): 161 if self.static_enabled[0] == 1: # type: ignore[index] 162 self.activation_post_process(X.detach()) 163 _scale, _zero_point = self.activation_post_process.calculate_qparams() 164 _scale = _scale.to(self.scale.device) 165 _zero_point = _zero_point.to(self.zero_point.device) 166 self.scale.data.copy_(_scale) 167 self.zero_point.data.copy_(_zero_point) 168 else: 169 self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] 170 171 if self.fake_quant_enabled[0] == 1: 172 if self.qscheme in ( 173 torch.per_channel_symmetric, 174 torch.per_tensor_symmetric, 175 ): 176 self.zero_point.data.zero_() 177 178 if self.use_grad_scaling: 179 grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5 180 else: 181 grad_factor = 1.0 182 if self.qscheme in (torch.per_channel_symmetric, torch.per_channel_affine): 183 X = torch._fake_quantize_learnable_per_channel_affine( 184 X, 185 self.scale, 186 self.zero_point, 187 self.ch_axis, 188 self.quant_min, 189 self.quant_max, 190 grad_factor, 191 ) 192 else: 193 X = torch._fake_quantize_learnable_per_tensor_affine( 194 X, 195 self.scale, 196 self.zero_point, 197 self.quant_min, 198 self.quant_max, 199 grad_factor, 200 ) 201 202 return X 203