xref: /aosp_15_r20/external/pytorch/torch/ao/nn/intrinsic/qat/modules/linear_relu.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.ao.nn.intrinsic as nni
4import torch.ao.nn.qat as nnqat
5import torch.nn.functional as F
6
7
8class LinearReLU(nnqat.Linear, nni._FusedModule):
9    r"""
10    A LinearReLU module fused from Linear and ReLU modules, attached with
11    FakeQuantize modules for weight, used in
12    quantization aware training.
13
14    We adopt the same interface as :class:`torch.nn.Linear`.
15
16    Similar to `torch.ao.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to
17    default.
18
19    Attributes:
20        weight: fake quant module for weight
21
22    Examples::
23
24        >>> # xdoctest: +SKIP
25        >>> m = nn.qat.LinearReLU(20, 30)
26        >>> input = torch.randn(128, 20)
27        >>> output = m(input)
28        >>> print(output.size())
29        torch.Size([128, 30])
30    """
31    _FLOAT_MODULE = nni.LinearReLU  # type: ignore[assignment]
32
33    def __init__(self, in_features, out_features, bias=True, qconfig=None):
34        super().__init__(in_features, out_features, bias, qconfig)
35
36    def forward(self, input):
37        return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
38
39    @classmethod
40    def from_float(cls, mod, use_precomputed_fake_quant=False):
41        return super().from_float(mod, use_precomputed_fake_quant)
42
43    def to_float(self):
44        linear = torch.nn.Linear(
45            self.in_features, self.out_features, self.bias is not None
46        )
47        linear.weight = torch.nn.Parameter(self.weight.detach())
48        if self.bias is not None:
49            linear.bias = torch.nn.Parameter(self.bias.detach())
50        relu = torch.nn.ReLU()
51        return torch.ao.nn.intrinsic.LinearReLU(linear, relu)
52