1# mypy: allow-untyped-defs 2import torch 3import torch.ao.nn.intrinsic as nni 4import torch.nn as nn 5import torch.nn.functional as F 6from torch.nn import init 7from torch.nn.parameter import Parameter 8from torch.nn.utils.fusion import fuse_linear_bn_weights 9 10 11__all__ = [ 12 "LinearBn1d", 13] 14 15 16class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule): 17 r""" 18 A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached 19 with FakeQuantize modules for weight, used in quantization aware training. 20 21 We combined the interface of :class:`torch.nn.Linear` and 22 :class:torch.nn.BatchNorm1d`. 23 24 Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized 25 to default. 26 27 Attributes: 28 freeze_bn: 29 weight_fake_quant: fake quant module for weight 30 31 """ 32 33 def __init__( 34 self, 35 # Linear args 36 in_features, 37 out_features, 38 bias=True, 39 # BatchNorm1d args 40 # num_features: out_features 41 eps=1e-05, 42 momentum=0.1, 43 # affine: True 44 # track_running_stats: True 45 # Args for this module 46 freeze_bn=False, 47 qconfig=None, 48 ): 49 nn.modules.linear.Linear.__init__(self, in_features, out_features, bias) 50 assert qconfig, "qconfig must be provided for QAT module" 51 self.qconfig = qconfig 52 self.freeze_bn = freeze_bn if self.training else True 53 self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True) 54 self.weight_fake_quant = self.qconfig.weight() 55 if bias: 56 self.bias = Parameter(torch.empty(out_features)) 57 else: 58 self.register_parameter("bias", None) 59 self.reset_bn_parameters() 60 61 # this needs to be called after reset_bn_parameters, 62 # as they modify the same state 63 if self.training: 64 if freeze_bn: 65 self.freeze_bn_stats() 66 else: 67 self.update_bn_stats() 68 else: 69 self.freeze_bn_stats() 70 71 def reset_running_stats(self): 72 self.bn.reset_running_stats() 73 74 def reset_bn_parameters(self): 75 self.bn.reset_running_stats() 76 init.uniform_(self.bn.weight) 77 init.zeros_(self.bn.bias) 78 79 def reset_parameters(self): 80 super().reset_parameters() 81 82 def update_bn_stats(self): 83 self.freeze_bn = False 84 self.bn.training = True 85 return self 86 87 def freeze_bn_stats(self): 88 self.freeze_bn = True 89 self.bn.training = False 90 return self 91 92 def forward(self, input): 93 assert self.bn.running_var is not None 94 95 # Scale the linear weights by BN's running statistics to reduce 96 # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18 97 # for motivation. 98 # 99 # Instead of 100 # 101 # x1 = F.linear(x0, fq(w), b) 102 # x2 = self.bn(x1) 103 # 104 # We have 105 # 106 # # scale the weight by previous batch's running statistics 107 # scale_factor = bn.w / bn.running_std_from_prev_batch 108 # # do the linear transformation without bias 109 # x1_scaled = F.linear(x0, fq(w * scale_factor), 0) 110 # # reverse the scaling and add original bias 111 # x1_orig = x1_scaled / scale_factor + b 112 # x2 = self.bn(x1_orig) 113 114 running_std = torch.sqrt(self.bn.running_var + self.bn.eps) 115 scale_factor = self.bn.weight / running_std 116 weight_shape = [1] * len(self.weight.shape) 117 weight_shape[0] = -1 118 bias_shape = [1] * len(self.weight.shape) 119 bias_shape[1] = -1 120 scaled_weight = self.weight_fake_quant( 121 self.weight * scale_factor.reshape(weight_shape) 122 ) 123 if self.bias is not None: 124 zero_bias = torch.zeros_like(self.bias) 125 else: 126 zero_bias = torch.zeros(self.out_features, device=scaled_weight.device) 127 linear_out = F.linear(input, scaled_weight, zero_bias) 128 linear_out_orig = linear_out / scale_factor.reshape(bias_shape) 129 if self.bias is not None: 130 linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape) 131 bn_out = self.bn(linear_out_orig) 132 return bn_out 133 134 def train(self, mode=True): 135 """ 136 Batchnorm's training behavior is using the self.training flag. Prevent 137 changing it if BN is frozen. This makes sure that calling `model.train()` 138 on a model with a frozen BN will behave properly. 139 """ 140 self.training = mode 141 if not self.freeze_bn: 142 for module in self.children(): 143 module.train(mode) 144 return self 145 146 @classmethod 147 def from_float(cls, mod, use_precomputed_fake_quant=False): 148 r"""Create a qat module from a float module or qparams_dict 149 150 Args: `mod' a float module, either produced by torch.ao.quantization 151 utilities or directly from user 152 """ 153 assert type(mod) == nni.LinearBn1d, ( 154 "qat." 155 + cls.__name__ 156 + ".from_float only works for " 157 + nni.LinearBn1d.__name__ 158 ) 159 assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" 160 assert mod.qconfig, "Input float module must have a valid config" 161 qconfig = mod.qconfig 162 linear, bn = mod[0], mod[1] 163 qat_linearbn = cls( 164 linear.in_features, 165 linear.out_features, 166 linear.bias is not None, 167 bn.eps, 168 bn.momentum, 169 False, 170 qconfig, 171 ) 172 qat_linearbn.weight = linear.weight 173 qat_linearbn.bias = linear.bias 174 qat_linearbn.bn.weight = bn.weight 175 qat_linearbn.bn.bias = bn.bias 176 qat_linearbn.bn.running_mean = bn.running_mean 177 qat_linearbn.bn.running_var = bn.running_var 178 qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked 179 return qat_linearbn 180 181 def to_float(self): 182 linear = torch.nn.Linear(self.in_features, self.out_features) 183 assert self.bn.running_var is not None and self.bn.running_mean is not None 184 linear.weight, linear.bias = fuse_linear_bn_weights( 185 self.weight, 186 self.bias, 187 self.bn.running_mean, 188 self.bn.running_var, 189 self.bn.eps, 190 self.bn.weight, 191 self.bn.bias, 192 ) 193 return linear 194