1# mypy: allow-untyped-defs 2import torch 3import torch.ao.nn.intrinsic as nni 4import torch.ao.nn.quantized.dynamic as nnqd 5 6 7__all__ = ["LinearReLU"] 8 9 10class LinearReLU(nnqd.Linear): 11 r""" 12 A LinearReLU module fused from Linear and ReLU modules that can be used 13 for dynamic quantization. 14 Supports both, FP16 and INT8 quantization. 15 16 We adopt the same interface as :class:`torch.ao.nn.quantized.dynamic.Linear`. 17 18 Attributes: 19 Same as torch.ao.nn.quantized.dynamic.Linear 20 21 Examples:: 22 23 >>> # xdoctest: +SKIP 24 >>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30) 25 >>> input = torch.randn(128, 20) 26 >>> output = m(input) 27 >>> print(output.size()) 28 torch.Size([128, 30]) 29 """ 30 _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] 31 32 def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): 33 super().__init__(in_features, out_features, bias, dtype) 34 35 def forward(self, x: torch.Tensor) -> torch.Tensor: 36 if self._packed_params.dtype == torch.qint8: 37 # TODO check if we should set reduce_rage = True by default here 38 Y = torch.ops.quantized.linear_relu_dynamic( 39 x, self._packed_params._packed_params, reduce_range=True 40 ) 41 elif self._packed_params.dtype == torch.float16: 42 Y = torch.ops.quantized.linear_relu_dynamic_fp16( 43 x, self._packed_params._packed_params 44 ) 45 else: 46 raise RuntimeError("Unsupported dtype on dynamic quantized linear relu!") 47 return Y.to(x.dtype) 48 49 def _get_name(self): 50 return "DynamicQuantizedLinearReLU" 51 52 @classmethod 53 def from_float(cls, mod, use_precomputed_fake_quant=False): 54 return super().from_float( 55 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 56 ) 57 58 @classmethod 59 def from_reference(cls, ref_qlinear_relu): 60 return super().from_reference(ref_qlinear_relu[0]) 61