1# mypy: allow-untyped-defs 2import torch 3import torch.ao.nn.intrinsic as nni 4import torch.ao.nn.qat as nnqat 5import torch.nn.functional as F 6 7 8class LinearReLU(nnqat.Linear, nni._FusedModule): 9 r""" 10 A LinearReLU module fused from Linear and ReLU modules, attached with 11 FakeQuantize modules for weight, used in 12 quantization aware training. 13 14 We adopt the same interface as :class:`torch.nn.Linear`. 15 16 Similar to `torch.ao.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to 17 default. 18 19 Attributes: 20 weight: fake quant module for weight 21 22 Examples:: 23 24 >>> # xdoctest: +SKIP 25 >>> m = nn.qat.LinearReLU(20, 30) 26 >>> input = torch.randn(128, 20) 27 >>> output = m(input) 28 >>> print(output.size()) 29 torch.Size([128, 30]) 30 """ 31 _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] 32 33 def __init__(self, in_features, out_features, bias=True, qconfig=None): 34 super().__init__(in_features, out_features, bias, qconfig) 35 36 def forward(self, input): 37 return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias)) 38 39 @classmethod 40 def from_float(cls, mod, use_precomputed_fake_quant=False): 41 return super().from_float(mod, use_precomputed_fake_quant) 42 43 def to_float(self): 44 linear = torch.nn.Linear( 45 self.in_features, self.out_features, self.bias is not None 46 ) 47 linear.weight = torch.nn.Parameter(self.weight.detach()) 48 if self.bias is not None: 49 linear.bias = torch.nn.Parameter(self.bias.detach()) 50 relu = torch.nn.ReLU() 51 return torch.ao.nn.intrinsic.LinearReLU(linear, relu) 52