1# mypy: allow-untyped-defs 2from warnings import warn 3 4import torch 5 6 7__all__ = [ 8 "ReLU6", 9 "Hardswish", 10 "ELU", 11 "LeakyReLU", 12 "Sigmoid", 13 "Softmax", 14 "MultiheadAttention", 15 "PReLU", 16] 17 18 19class ReLU6(torch.nn.ReLU): 20 r"""Applies the element-wise function: 21 22 :math:`\text{ReLU6}(x) = \min(\max(x_0, x), q(6))`, where :math:`x_0` is the 23 zero_point, and :math:`q(6)` is the quantized representation of number 6. 24 25 Args: 26 inplace: can optionally do the operation in-place. Default: ``False`` 27 28 Shape: 29 - Input: :math:`(N, *)` where `*` means, any number of additional 30 dimensions 31 - Output: :math:`(N, *)`, same shape as the input 32 33 .. image:: ../scripts/activation_images/ReLU6.png 34 35 Examples:: 36 37 >>> m = nn.quantized.ReLU6() 38 >>> input = torch.randn(2) 39 >>> # xdoctest: +SKIP 40 >>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32) 41 >>> output = m(input) 42 """ 43 44 def __init__(self, inplace=False): 45 super().__init__(inplace) 46 self.inplace = inplace 47 48 def forward(self, input): 49 return torch.ops.quantized.relu6(input, self.inplace) 50 51 def _get_name(self): 52 return "QuantizedReLU6" 53 54 @staticmethod 55 def from_float(mod, use_precomputed_fake_quant=False): 56 return ReLU6(mod.inplace) 57 58 59class Hardswish(torch.nn.Hardswish): 60 r"""This is the quantized version of :class:`~torch.nn.Hardswish`. 61 62 Args: 63 scale: quantization scale of the output tensor 64 zero_point: quantization zero point of the output tensor 65 """ 66 67 def __init__(self, scale, zero_point, device=None, dtype=None): 68 factory_kwargs = {"device": device, "dtype": dtype} 69 super().__init__() 70 self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) 71 self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) 72 73 def forward(self, input): 74 return torch.ops.quantized.hardswish(input, self.scale, self.zero_point) 75 76 def _get_name(self): 77 return "QuantizedHardswish" 78 79 @staticmethod 80 def from_float(mod, use_precomputed_fake_quant=False): 81 scale, zero_point = mod.activation_post_process.calculate_qparams() 82 return Hardswish(float(scale), int(zero_point)) 83 84 @classmethod 85 def from_reference(cls, mod, scale, zero_point): 86 return cls(float(scale), int(zero_point)) 87 88 89class ELU(torch.nn.ELU): 90 r"""This is the quantized equivalent of :class:`~torch.nn.ELU`. 91 92 Args: 93 scale: quantization scale of the output tensor 94 zero_point: quantization zero point of the output tensor 95 alpha: the alpha constant 96 """ 97 98 def __init__(self, scale, zero_point, alpha=1.0): 99 super().__init__(alpha) 100 self.scale = scale 101 self.zero_point = zero_point 102 103 def forward(self, input): 104 return torch.ao.nn.quantized.functional.elu( 105 input, self.scale, self.zero_point, self.alpha 106 ) 107 108 def _get_name(self): 109 return "QuantizedELU" 110 111 @staticmethod 112 def from_float(mod, use_precomputed_fake_quant=False): 113 scale, zero_point = mod.activation_post_process.calculate_qparams() 114 return ELU(float(scale), int(zero_point), mod.alpha) 115 116 @classmethod 117 def from_reference(cls, mod, scale, zero_point): 118 return cls(float(scale), int(zero_point), mod.alpha) 119 120 121class LeakyReLU(torch.nn.LeakyReLU): 122 r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`. 123 124 Args: 125 scale: quantization scale of the output tensor 126 zero_point: quantization zero point of the output tensor 127 negative_slope: Controls the angle of the negative slope. Default: 1e-2 128 """ 129 130 def __init__( 131 self, 132 scale: float, 133 zero_point: int, 134 negative_slope: float = 1e-2, 135 inplace: bool = False, 136 device=None, 137 dtype=None, 138 ) -> None: 139 factory_kwargs = {"device": device, "dtype": dtype} 140 super().__init__(negative_slope, inplace) 141 self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) 142 self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) 143 144 def forward(self, input): 145 return torch.ops.quantized.leaky_relu( 146 input, self.negative_slope, self.inplace, self.scale, self.zero_point 147 ) 148 149 def _get_name(self): 150 return "QuantizedLeakyReLU" 151 152 @classmethod 153 def from_float(cls, mod, use_precomputed_fake_quant=False): 154 scale, zero_point = mod.activation_post_process.calculate_qparams() 155 return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) 156 157 @classmethod 158 def from_reference(cls, mod, scale, zero_point): 159 return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) 160 161 162class Sigmoid(torch.nn.Sigmoid): 163 r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`. 164 165 Args: 166 scale: quantization scale of the output tensor 167 zero_point: quantization zero point of the output tensor 168 """ 169 170 def __init__(self, output_scale: float, output_zero_point: int): 171 super().__init__() 172 self.output_scale = output_scale 173 self.output_zero_point = output_zero_point 174 175 def forward(self, input): 176 return torch.ops.quantized.sigmoid( 177 input, self.output_scale, self.output_zero_point 178 ) 179 180 @classmethod 181 def from_float(cls, mod, use_precomputed_fake_quant=False): 182 ( 183 output_scale, 184 output_zero_point, 185 ) = mod.activation_post_process.calculate_qparams() 186 return cls(float(output_scale), int(output_zero_point)) 187 188 189class Softmax(torch.nn.Softmax): 190 r"""This is the quantized version of :class:`~torch.nn.Softmax`. 191 192 Args: 193 dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1). 194 scale: quantization scale of the output tensor 195 zero_point: quantization zero point of the output tensor 196 """ 197 198 def __init__(self, dim=None, scale=1.0, zero_point=0): 199 super().__init__() 200 self.dim = dim 201 self.scale = scale 202 self.zero_point = zero_point 203 204 def forward(self, input): 205 dim = self.dim 206 if dim is None: 207 stacklevel = 3 208 # Note: adding the mypy ignore on _get_softmax_dim seems less bad 209 # than making `_get_softmax_dim` an official API. 210 dim = torch.nn.functional._get_softmax_dim( # type: ignore[attr-defined] 211 "softmax", input.dim(), stacklevel 212 ) 213 return torch.ops.quantized.softmax(input, dim, self.scale, self.zero_point) 214 215 def _get_name(self): 216 return "QuantizedSoftmax" 217 218 @staticmethod 219 def from_float(mod, use_precomputed_fake_quant=False): 220 scale, zero_point = mod.activation_post_process.calculate_qparams() 221 return Softmax(mod.dim, float(scale), int(zero_point)) 222 223 @classmethod 224 def from_reference(cls, mod, scale, zero_point): 225 return cls(mod.dim, float(scale), int(zero_point)) 226 227 228class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention): 229 _FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention 230 231 def _get_name(self): 232 return "QuantizedMultiheadAttention" 233 234 @classmethod 235 def from_float(cls, other): 236 # The whole flow is float -> observed -> quantized 237 # This class does observed -> quantized only 238 raise NotImplementedError( 239 "It looks like you are trying to convert a " 240 "non-observed MHA module. Please, see " 241 "the examples on quantizable MHAs." 242 ) 243 244 @classmethod 245 def from_observed(cls, other): 246 converted = torch.ao.quantization.convert( 247 other, 248 mapping=None, 249 inplace=False, 250 remove_qconfig=True, 251 convert_custom_config_dict=None, 252 ) 253 converted.__class__ = cls 254 # Remove the parameters for the bias_k and bias_v to quantize them 255 # TODO: This is a potential source of accuracy drop. 256 # quantized cat takes the scale and zp of the first 257 # element, which might lose the precision in the bias_k 258 # and the bias_v (which are cat'ed with k/v being first). 259 if converted.bias_k is not None: 260 bias_k = converted._parameters.pop("bias_k") 261 sc, zp = torch._choose_qparams_per_tensor(bias_k, reduce_range=False) 262 bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8) 263 setattr(converted, "bias_k", bias_k) # noqa: B010 264 265 if converted.bias_v is not None: 266 bias_v = converted._parameters.pop("bias_v") 267 sc, zp = torch._choose_qparams_per_tensor( 268 bias_k, reduce_range=False # type: ignore[possibly-undefined] 269 ) 270 bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8) 271 setattr(converted, "bias_v", bias_v) # noqa: B010 272 273 del converted.in_proj_weight 274 del converted.in_proj_bias 275 276 return converted 277 278 279class PReLU(torch.nn.Module): 280 r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`. 281 282 Args: 283 scale: quantization scale of the output tensor 284 zero_point: quantization zero point of the output tensor 285 num_parameters: number of parameters: 1, or the number of channels at input. Default: 1 286 """ 287 288 def __init__( 289 self, output_scale: float, output_zero_point: int, num_parameters: int = 1 290 ) -> None: 291 super().__init__() 292 self.num_parameters = num_parameters 293 self.scale = output_scale 294 self.zero_point = output_zero_point 295 w = torch.randn(num_parameters, dtype=torch.float) 296 qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8) 297 self.set_weight(qw) 298 299 def set_weight(self, w: torch.Tensor) -> None: 300 self.weight = w 301 302 def forward(self, input: torch.Tensor) -> torch.Tensor: 303 return torch.ops.quantized.prelu( 304 input, self.weight, self.scale, self.zero_point 305 ) 306 307 def _get_name(self): 308 return "QuantizedPReLU" 309 310 @classmethod 311 def from_float(cls, mod, use_precomputed_fake_quant=False): 312 scale, zero_point = mod.activation_post_process.calculate_qparams() 313 qprelu = cls(float(scale), int(zero_point), mod.num_parameters) 314 float_wt = mod.weight.float() 315 observer = mod.qconfig.weight() 316 observer(float_wt) 317 if observer.dtype != torch.quint8: 318 warn( 319 f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}" 320 ) 321 wt_scale, wt_zp = observer.calculate_qparams() 322 qweight = torch.quantize_per_tensor( 323 float_wt, float(wt_scale), int(wt_zp), torch.quint8 324 ) 325 qprelu.set_weight(qweight) 326 return qprelu 327 328 @classmethod 329 def from_reference(cls, mod, scale, zero_point): 330 qprelu = cls(float(scale), int(zero_point), mod.num_parameters) 331 float_wt = mod.weight.float() 332 observer = mod.qconfig.weight() 333 observer(float_wt) 334 if observer.dtype != torch.quint8: 335 warn( 336 f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}" 337 ) 338 wt_scale, wt_zp = observer.calculate_qparams() 339 qweight = torch.quantize_per_tensor( 340 float_wt, float(wt_scale), int(wt_zp), torch.quint8 341 ) 342 qprelu.set_weight(qweight) 343 return qprelu 344