xref: /aosp_15_r20/external/pytorch/torch/ao/nn/intrinsic/qat/modules/linear_fused.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.ao.nn.intrinsic as nni
4import torch.nn as nn
5import torch.nn.functional as F
6from torch.nn import init
7from torch.nn.parameter import Parameter
8from torch.nn.utils.fusion import fuse_linear_bn_weights
9
10
11__all__ = [
12    "LinearBn1d",
13]
14
15
16class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
17    r"""
18    A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
19    with FakeQuantize modules for weight, used in quantization aware training.
20
21    We combined the interface of :class:`torch.nn.Linear` and
22    :class:torch.nn.BatchNorm1d`.
23
24    Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
25    to default.
26
27    Attributes:
28        freeze_bn:
29        weight_fake_quant: fake quant module for weight
30
31    """
32
33    def __init__(
34        self,
35        # Linear args
36        in_features,
37        out_features,
38        bias=True,
39        # BatchNorm1d args
40        # num_features: out_features
41        eps=1e-05,
42        momentum=0.1,
43        # affine: True
44        # track_running_stats: True
45        # Args for this module
46        freeze_bn=False,
47        qconfig=None,
48    ):
49        nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
50        assert qconfig, "qconfig must be provided for QAT module"
51        self.qconfig = qconfig
52        self.freeze_bn = freeze_bn if self.training else True
53        self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
54        self.weight_fake_quant = self.qconfig.weight()
55        if bias:
56            self.bias = Parameter(torch.empty(out_features))
57        else:
58            self.register_parameter("bias", None)
59        self.reset_bn_parameters()
60
61        # this needs to be called after reset_bn_parameters,
62        # as they modify the same state
63        if self.training:
64            if freeze_bn:
65                self.freeze_bn_stats()
66            else:
67                self.update_bn_stats()
68        else:
69            self.freeze_bn_stats()
70
71    def reset_running_stats(self):
72        self.bn.reset_running_stats()
73
74    def reset_bn_parameters(self):
75        self.bn.reset_running_stats()
76        init.uniform_(self.bn.weight)
77        init.zeros_(self.bn.bias)
78
79    def reset_parameters(self):
80        super().reset_parameters()
81
82    def update_bn_stats(self):
83        self.freeze_bn = False
84        self.bn.training = True
85        return self
86
87    def freeze_bn_stats(self):
88        self.freeze_bn = True
89        self.bn.training = False
90        return self
91
92    def forward(self, input):
93        assert self.bn.running_var is not None
94
95        # Scale the linear weights by BN's running statistics to reduce
96        # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
97        # for motivation.
98        #
99        # Instead of
100        #
101        #   x1 = F.linear(x0, fq(w), b)
102        #   x2 = self.bn(x1)
103        #
104        # We have
105        #
106        #   # scale the weight by previous batch's running statistics
107        #   scale_factor = bn.w / bn.running_std_from_prev_batch
108        #   # do the linear transformation without bias
109        #   x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
110        #   # reverse the scaling and add original bias
111        #   x1_orig = x1_scaled / scale_factor + b
112        #   x2 = self.bn(x1_orig)
113
114        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
115        scale_factor = self.bn.weight / running_std
116        weight_shape = [1] * len(self.weight.shape)
117        weight_shape[0] = -1
118        bias_shape = [1] * len(self.weight.shape)
119        bias_shape[1] = -1
120        scaled_weight = self.weight_fake_quant(
121            self.weight * scale_factor.reshape(weight_shape)
122        )
123        if self.bias is not None:
124            zero_bias = torch.zeros_like(self.bias)
125        else:
126            zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
127        linear_out = F.linear(input, scaled_weight, zero_bias)
128        linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
129        if self.bias is not None:
130            linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
131        bn_out = self.bn(linear_out_orig)
132        return bn_out
133
134    def train(self, mode=True):
135        """
136        Batchnorm's training behavior is using the self.training flag. Prevent
137        changing it if BN is frozen. This makes sure that calling `model.train()`
138        on a model with a frozen BN will behave properly.
139        """
140        self.training = mode
141        if not self.freeze_bn:
142            for module in self.children():
143                module.train(mode)
144        return self
145
146    @classmethod
147    def from_float(cls, mod, use_precomputed_fake_quant=False):
148        r"""Create a qat module from a float module or qparams_dict
149
150        Args: `mod' a float module, either produced by torch.ao.quantization
151        utilities or directly from user
152        """
153        assert type(mod) == nni.LinearBn1d, (
154            "qat."
155            + cls.__name__
156            + ".from_float only works for "
157            + nni.LinearBn1d.__name__
158        )
159        assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
160        assert mod.qconfig, "Input float module must have a valid config"
161        qconfig = mod.qconfig
162        linear, bn = mod[0], mod[1]
163        qat_linearbn = cls(
164            linear.in_features,
165            linear.out_features,
166            linear.bias is not None,
167            bn.eps,
168            bn.momentum,
169            False,
170            qconfig,
171        )
172        qat_linearbn.weight = linear.weight
173        qat_linearbn.bias = linear.bias
174        qat_linearbn.bn.weight = bn.weight
175        qat_linearbn.bn.bias = bn.bias
176        qat_linearbn.bn.running_mean = bn.running_mean
177        qat_linearbn.bn.running_var = bn.running_var
178        qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
179        return qat_linearbn
180
181    def to_float(self):
182        linear = torch.nn.Linear(self.in_features, self.out_features)
183        assert self.bn.running_var is not None and self.bn.running_mean is not None
184        linear.weight, linear.bias = fuse_linear_bn_weights(
185            self.weight,
186            self.bias,
187            self.bn.running_mean,
188            self.bn.running_var,
189            self.bn.eps,
190            self.bn.weight,
191            self.bn.bias,
192        )
193        return linear
194