1# mypy: allow-untyped-defs 2import math 3from typing import TypeVar 4 5import torch 6import torch.ao.nn.intrinsic as nni 7import torch.ao.nn.qat as nnqat 8import torch.nn as nn 9import torch.nn.functional as F 10from torch.nn import init 11from torch.nn.modules.utils import _pair, _single, _triple 12from torch.nn.parameter import Parameter 13from torch.nn.utils import fuse_conv_bn_weights 14 15 16__all__ = [ 17 "ConvBn1d", 18 "ConvBnReLU1d", 19 "ConvReLU1d", 20 "ConvBn2d", 21 "ConvBnReLU2d", 22 "ConvReLU2d", 23 "ConvBn3d", 24 "ConvBnReLU3d", 25 "ConvReLU3d", 26 "update_bn_stats", 27 "freeze_bn_stats", 28] 29_BN_CLASS_MAP = { 30 1: nn.BatchNorm1d, 31 2: nn.BatchNorm2d, 32 3: nn.BatchNorm3d, 33} 34 35 36MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) 37 38 39class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule): 40 _version = 2 41 _FLOAT_MODULE = MOD 42 43 def __init__( 44 self, 45 # ConvNd args 46 in_channels, 47 out_channels, 48 kernel_size, 49 stride, 50 padding, 51 dilation, 52 transposed, 53 output_padding, 54 groups, 55 bias, 56 padding_mode, 57 # BatchNormNd args 58 # num_features: out_channels 59 eps=1e-05, 60 momentum=0.1, 61 # affine: True 62 # track_running_stats: True 63 # Args for this module 64 freeze_bn=False, 65 qconfig=None, 66 dim=2, 67 ): 68 nn.modules.conv._ConvNd.__init__( 69 self, 70 in_channels, 71 out_channels, 72 kernel_size, 73 stride, 74 padding, 75 dilation, 76 transposed, 77 output_padding, 78 groups, 79 False, 80 padding_mode, 81 ) 82 assert qconfig, "qconfig must be provided for QAT module" 83 self.qconfig = qconfig 84 self.freeze_bn = freeze_bn if self.training else True 85 self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True) 86 self.weight_fake_quant = self.qconfig.weight() 87 if bias: 88 self.bias = Parameter(torch.empty(out_channels)) 89 else: 90 self.register_parameter("bias", None) 91 self.reset_bn_parameters() 92 93 # this needs to be called after reset_bn_parameters, 94 # as they modify the same state 95 if self.training: 96 if freeze_bn: 97 self.freeze_bn_stats() 98 else: 99 self.update_bn_stats() 100 else: 101 self.freeze_bn_stats() 102 103 self._enable_slow_path_for_better_numerical_stability = False 104 105 def reset_running_stats(self): 106 self.bn.reset_running_stats() 107 108 def reset_bn_parameters(self): 109 self.bn.reset_running_stats() 110 init.uniform_(self.bn.weight) 111 init.zeros_(self.bn.bias) 112 # note: below is actually for conv, not BN 113 if self.bias is not None: 114 fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 115 bound = 1 / math.sqrt(fan_in) 116 init.uniform_(self.bias, -bound, bound) 117 118 def reset_parameters(self): 119 super().reset_parameters() 120 121 def update_bn_stats(self): 122 self.freeze_bn = False 123 self.bn.training = True 124 return self 125 126 def freeze_bn_stats(self): 127 self.freeze_bn = True 128 self.bn.training = False 129 return self 130 131 def _forward(self, input): 132 if self._enable_slow_path_for_better_numerical_stability: 133 return self._forward_slow(input) 134 return self._forward_approximate(input) 135 136 def _forward_approximate(self, input): 137 """Approximated method to fuse conv and bn. It requires only one forward pass. 138 conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std 139 """ 140 assert self.bn.running_var is not None 141 running_std = torch.sqrt(self.bn.running_var + self.bn.eps) 142 scale_factor = self.bn.weight / running_std 143 weight_shape = [1] * len(self.weight.shape) 144 weight_shape[0] = -1 145 bias_shape = [1] * len(self.weight.shape) 146 bias_shape[1] = -1 147 scaled_weight = self.weight_fake_quant( 148 self.weight * scale_factor.reshape(weight_shape) 149 ) 150 # using zero bias here since the bias for original conv 151 # will be added later 152 if self.bias is not None: 153 zero_bias = torch.zeros_like(self.bias, dtype=input.dtype) 154 else: 155 zero_bias = torch.zeros( 156 self.out_channels, device=scaled_weight.device, dtype=input.dtype 157 ) 158 conv = self._conv_forward(input, scaled_weight, zero_bias) 159 conv_orig = conv / scale_factor.reshape(bias_shape) 160 if self.bias is not None: 161 conv_orig = conv_orig + self.bias.reshape(bias_shape) 162 conv = self.bn(conv_orig) 163 return conv 164 165 def _forward_slow(self, input): 166 """ 167 A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf 168 It requires two forward passes but handles the case bn.weight == 0 169 170 Conv: Y = WX + B_c 171 Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c 172 173 Batch statistics: 174 mean_Y = Y.mean() 175 = Y0.mean() + B_c 176 var_Y = (Y - mean_Y)^2.mean() 177 = (Y0 - Y0.mean())^2.mean() 178 BN (r: bn.weight, beta: bn.bias): 179 Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta 180 = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta 181 182 Fused Conv BN training (std_Y = sqrt(var_Y + eps)): 183 Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta 184 = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta 185 186 Fused Conv BN inference (running_std = sqrt(running_var + eps)): 187 Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta 188 189 QAT with fused conv bn: 190 Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta 191 = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta 192 Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta 193 """ 194 195 assert self.bn.running_var is not None 196 assert self.bn.running_mean is not None 197 198 # using zero bias here since the bias for original conv 199 # will be added later 200 zero_bias = torch.zeros( 201 self.out_channels, device=self.weight.device, dtype=input.dtype 202 ) 203 204 weight_shape = [1] * len(self.weight.shape) 205 weight_shape[0] = -1 206 bias_shape = [1] * len(self.weight.shape) 207 bias_shape[1] = -1 208 209 if self.bn.training: 210 # needed to compute batch mean/std 211 conv_out = self._conv_forward(input, self.weight, zero_bias) 212 # update bn statistics 213 with torch.no_grad(): 214 conv_out_bias = ( 215 conv_out 216 if self.bias is None 217 else conv_out + self.bias.reshape(bias_shape) 218 ) 219 self.bn(conv_out_bias) 220 221 # fused conv + bn without bias using bn running statistics 222 running_std = torch.sqrt(self.bn.running_var + self.bn.eps) 223 scale_factor = self.bn.weight / running_std 224 scaled_weight = self.weight_fake_quant( 225 self.weight * scale_factor.reshape(weight_shape) 226 ) 227 # fused conv without bias for inference: (r * W / running_std) * X 228 conv_bn = self._conv_forward(input, scaled_weight, zero_bias) 229 230 if self.bn.training: 231 avg_dims = [0] + list(range(2, len(self.weight.shape))) 232 batch_mean = conv_out.mean(avg_dims) # type: ignore[possibly-undefined] 233 batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean( 234 avg_dims 235 ) 236 batch_std = torch.sqrt(batch_var + self.bn.eps) 237 238 # scale to use batch std in training mode 239 # conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y) 240 unscale_factor = running_std / batch_std 241 conv_bn *= unscale_factor.reshape(bias_shape) 242 243 fused_mean = batch_mean 244 fused_std = batch_std 245 else: 246 fused_mean = self.bn.running_mean - ( 247 self.bias if self.bias is not None else 0 248 ) 249 fused_std = running_std 250 251 # fused bias = beta - r * mean / std 252 fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std 253 conv_bn += fused_bias.reshape(bias_shape) 254 255 # HACK to let conv bias participate in loss to avoid DDP error (parameters 256 # were not used in producing loss) 257 if self.bias is not None: 258 conv_bn += (self.bias - self.bias).reshape(bias_shape) 259 260 return conv_bn 261 262 def extra_repr(self): 263 # TODO(jerryzh): extend 264 return super().extra_repr() 265 266 def forward(self, input): 267 return self._forward(input) 268 269 def train(self, mode=True): 270 """ 271 Batchnorm's training behavior is using the self.training flag. Prevent 272 changing it if BN is frozen. This makes sure that calling `model.train()` 273 on a model with a frozen BN will behave properly. 274 """ 275 self.training = mode 276 if not self.freeze_bn: 277 for module in self.children(): 278 module.train(mode) 279 return self 280 281 # ===== Serialization version history ===== 282 # 283 # Version 1/None 284 # self 285 # |--- weight : Tensor 286 # |--- bias : Tensor 287 # |--- gamma : Tensor 288 # |--- beta : Tensor 289 # |--- running_mean : Tensor 290 # |--- running_var : Tensor 291 # |--- num_batches_tracked : Tensor 292 # 293 # Version 2 294 # self 295 # |--- weight : Tensor 296 # |--- bias : Tensor 297 # |--- bn : Module 298 # |--- weight : Tensor (moved from v1.self.gamma) 299 # |--- bias : Tensor (moved from v1.self.beta) 300 # |--- running_mean : Tensor (moved from v1.self.running_mean) 301 # |--- running_var : Tensor (moved from v1.self.running_var) 302 # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked) 303 def _load_from_state_dict( 304 self, 305 state_dict, 306 prefix, 307 local_metadata, 308 strict, 309 missing_keys, 310 unexpected_keys, 311 error_msgs, 312 ): 313 version = local_metadata.get("version", None) 314 if version is None or version == 1: 315 # BN related parameters and buffers were moved into the BN module for v2 316 v2_to_v1_names = { 317 "bn.weight": "gamma", 318 "bn.bias": "beta", 319 "bn.running_mean": "running_mean", 320 "bn.running_var": "running_var", 321 "bn.num_batches_tracked": "num_batches_tracked", 322 } 323 for v2_name, v1_name in v2_to_v1_names.items(): 324 if prefix + v1_name in state_dict: 325 state_dict[prefix + v2_name] = state_dict[prefix + v1_name] 326 state_dict.pop(prefix + v1_name) 327 elif prefix + v2_name in state_dict: 328 # there was a brief period where forward compatibility 329 # for this module was broken (between 330 # https://github.com/pytorch/pytorch/pull/38478 331 # and https://github.com/pytorch/pytorch/pull/38820) 332 # and modules emitted the v2 state_dict format while 333 # specifying that version == 1. This patches the forward 334 # compatibility issue by allowing the v2 style entries to 335 # be used. 336 pass 337 elif strict: 338 missing_keys.append(prefix + v2_name) 339 340 super()._load_from_state_dict( 341 state_dict, 342 prefix, 343 local_metadata, 344 strict, 345 missing_keys, 346 unexpected_keys, 347 error_msgs, 348 ) 349 350 @classmethod 351 def from_float(cls, mod, use_precomputed_fake_quant=False): 352 r"""Create a qat module from a float module or qparams_dict 353 354 Args: `mod` a float module, either produced by torch.ao.quantization utilities 355 or directly from user 356 """ 357 # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound 358 # has no __name__ (code is fine though) 359 assert type(mod) == cls._FLOAT_MODULE, ( 360 "qat." 361 + cls.__name__ 362 + ".from_float only works for " 363 + cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined] 364 ) 365 assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" 366 assert mod.qconfig, "Input float module must have a valid qconfig" 367 qconfig = mod.qconfig 368 conv, bn = mod[0], mod[1] 369 qat_convbn = cls( 370 conv.in_channels, 371 conv.out_channels, 372 conv.kernel_size, 373 conv.stride, 374 conv.padding, 375 conv.dilation, 376 conv.groups, 377 conv.bias is not None, 378 conv.padding_mode, 379 bn.eps, 380 bn.momentum, 381 False, 382 qconfig, 383 ) 384 qat_convbn.weight = conv.weight 385 qat_convbn.bias = conv.bias 386 qat_convbn.bn.weight = bn.weight 387 qat_convbn.bn.bias = bn.bias 388 qat_convbn.bn.running_mean = bn.running_mean 389 qat_convbn.bn.running_var = bn.running_var 390 # mypy error: Cannot determine type of 'num_batches_tracked' 391 qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[has-type] 392 return qat_convbn 393 394 def to_float(self): 395 cls = type(self) 396 conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined] 397 self.in_channels, 398 self.out_channels, 399 self.kernel_size, 400 self.stride, 401 self.padding, 402 self.dilation, 403 self.groups, 404 self.bias is not None, 405 self.padding_mode, 406 ) 407 conv.weight = torch.nn.Parameter(self.weight.detach()) 408 if self.bias is not None: 409 conv.bias = torch.nn.Parameter(self.bias.detach()) 410 411 if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined] 412 # fuse bn into conv 413 assert self.bn.running_var is not None and self.bn.running_mean is not None 414 conv.weight, conv.bias = fuse_conv_bn_weights( 415 conv.weight, 416 conv.bias, 417 self.bn.running_mean, 418 self.bn.running_var, 419 self.bn.eps, 420 self.bn.weight, 421 self.bn.bias, 422 ) 423 424 if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined] 425 modules = [] 426 modules.append(conv) 427 relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined] 428 modules.append(relu) 429 conv_relu = cls._FUSED_FLOAT_MODULE(*modules) # type: ignore[attr-defined] 430 conv_relu.train(self.training) 431 return conv_relu 432 else: 433 conv.train(self.training) 434 return conv 435 436 437class ConvBn1d(_ConvBnNd, nn.Conv1d): 438 r""" 439 A ConvBn1d module is a module fused from Conv1d and BatchNorm1d, 440 attached with FakeQuantize modules for weight, 441 used in quantization aware training. 442 443 We combined the interface of :class:`torch.nn.Conv1d` and 444 :class:`torch.nn.BatchNorm1d`. 445 446 Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized 447 to default. 448 449 Attributes: 450 freeze_bn: 451 weight_fake_quant: fake quant module for weight 452 453 """ 454 _FLOAT_BN_MODULE = nn.BatchNorm1d 455 _FLOAT_RELU_MODULE: None = None 456 _FLOAT_MODULE = nni.ConvBn1d 457 _FLOAT_CONV_MODULE = nn.Conv1d 458 459 def __init__( 460 self, 461 # Conv1d args 462 in_channels, 463 out_channels, 464 kernel_size, 465 stride=1, 466 padding=0, 467 dilation=1, 468 groups=1, 469 bias=None, 470 padding_mode="zeros", 471 # BatchNorm1d args 472 # num_features: out_channels 473 eps=1e-05, 474 momentum=0.1, 475 # affine: True 476 # track_running_stats: True 477 # Args for this module 478 freeze_bn=False, 479 qconfig=None, 480 ): 481 kernel_size = _single(kernel_size) 482 stride = _single(stride) 483 padding = _single(padding) 484 dilation = _single(dilation) 485 _ConvBnNd.__init__( 486 self, 487 in_channels, 488 out_channels, 489 kernel_size, 490 stride, 491 padding, 492 dilation, 493 False, 494 _single(0), 495 groups, 496 bias, 497 padding_mode, 498 eps, 499 momentum, 500 freeze_bn, 501 qconfig, 502 dim=1, 503 ) 504 505 506class ConvBnReLU1d(ConvBn1d): 507 r""" 508 A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU, 509 attached with FakeQuantize modules for weight, 510 used in quantization aware training. 511 512 We combined the interface of :class:`torch.nn.Conv1d` and 513 :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`. 514 515 Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to 516 default. 517 518 Attributes: 519 weight_fake_quant: fake quant module for weight 520 521 """ 522 # base class defines _FLOAT_MODULE as "ConvBn1d" 523 _FLOAT_MODULE = nni.ConvBnReLU1d # type: ignore[assignment] 524 _FLOAT_CONV_MODULE = nn.Conv1d 525 _FLOAT_BN_MODULE = nn.BatchNorm1d 526 _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] 527 # module class after fusing bn into conv 528 _FUSED_FLOAT_MODULE = nni.ConvReLU1d 529 530 def __init__( 531 self, 532 # Conv1d args 533 in_channels, 534 out_channels, 535 kernel_size, 536 stride=1, 537 padding=0, 538 dilation=1, 539 groups=1, 540 bias=None, 541 padding_mode="zeros", 542 # BatchNorm1d args 543 # num_features: out_channels 544 eps=1e-05, 545 momentum=0.1, 546 # affine: True 547 # track_running_stats: True 548 # Args for this module 549 freeze_bn=False, 550 qconfig=None, 551 ): 552 super().__init__( 553 in_channels, 554 out_channels, 555 kernel_size, 556 stride, 557 padding, 558 dilation, 559 groups, 560 bias, 561 padding_mode, 562 eps, 563 momentum, 564 freeze_bn, 565 qconfig, 566 ) 567 568 def forward(self, input): 569 return F.relu(ConvBn1d._forward(self, input)) 570 571 @classmethod 572 def from_float(cls, mod, use_precomputed_fake_quant=False): 573 return super().from_float(mod, use_precomputed_fake_quant) 574 575 576class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): 577 r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with 578 FakeQuantize modules for weight for 579 quantization aware training. 580 581 We combined the interface of :class:`~torch.nn.Conv1d` and 582 :class:`~torch.nn.BatchNorm1d`. 583 584 Attributes: 585 weight_fake_quant: fake quant module for weight 586 587 """ 588 _FLOAT_MODULE = nni.ConvReLU1d # type: ignore[assignment] 589 _FLOAT_CONV_MODULE = nn.Conv1d 590 _FLOAT_BN_MODULE: None = None 591 _FLOAT_RELU_MODULE = nn.ReLU 592 593 def __init__( 594 self, 595 in_channels, 596 out_channels, 597 kernel_size, 598 stride=1, 599 padding=0, 600 dilation=1, 601 groups=1, 602 bias=True, 603 padding_mode="zeros", 604 qconfig=None, 605 ): 606 super().__init__( 607 in_channels, 608 out_channels, 609 kernel_size, 610 stride=stride, 611 padding=padding, 612 dilation=dilation, 613 groups=groups, 614 bias=bias, 615 padding_mode=padding_mode, 616 qconfig=qconfig, 617 ) 618 assert qconfig, "qconfig must be provided for QAT module" 619 self.qconfig = qconfig 620 self.weight_fake_quant = self.qconfig.weight() 621 622 def forward(self, input): 623 return F.relu( 624 self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) 625 ) 626 627 @classmethod 628 def from_float(cls, mod, use_precomputed_fake_quant=False): 629 return super().from_float( 630 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 631 ) 632 633 634class ConvBn2d(_ConvBnNd, nn.Conv2d): 635 r""" 636 A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, 637 attached with FakeQuantize modules for weight, 638 used in quantization aware training. 639 640 We combined the interface of :class:`torch.nn.Conv2d` and 641 :class:`torch.nn.BatchNorm2d`. 642 643 Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized 644 to default. 645 646 Attributes: 647 freeze_bn: 648 weight_fake_quant: fake quant module for weight 649 650 """ 651 _FLOAT_MODULE = nni.ConvBn2d 652 _FLOAT_CONV_MODULE = nn.Conv2d 653 _FLOAT_BN_MODULE = nn.BatchNorm2d 654 _FLOAT_RELU_MODULE: None = None 655 656 def __init__( 657 self, 658 # ConvNd args 659 in_channels, 660 out_channels, 661 kernel_size, 662 stride=1, 663 padding=0, 664 dilation=1, 665 groups=1, 666 bias=None, 667 padding_mode="zeros", 668 # BatchNorm2d args 669 # num_features: out_channels 670 eps=1e-05, 671 momentum=0.1, 672 # affine: True 673 # track_running_stats: True 674 # Args for this module 675 freeze_bn=False, 676 qconfig=None, 677 ): 678 kernel_size = _pair(kernel_size) 679 stride = _pair(stride) 680 padding = _pair(padding) 681 dilation = _pair(dilation) 682 _ConvBnNd.__init__( 683 self, 684 in_channels, 685 out_channels, 686 kernel_size, 687 stride, 688 padding, 689 dilation, 690 False, 691 _pair(0), 692 groups, 693 bias, 694 padding_mode, 695 eps, 696 momentum, 697 freeze_bn, 698 qconfig, 699 dim=2, 700 ) 701 702 703class ConvBnReLU2d(ConvBn2d): 704 r""" 705 A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU, 706 attached with FakeQuantize modules for weight, 707 used in quantization aware training. 708 709 We combined the interface of :class:`torch.nn.Conv2d` and 710 :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`. 711 712 Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to 713 default. 714 715 Attributes: 716 weight_fake_quant: fake quant module for weight 717 718 """ 719 # base class defines _FLOAT_MODULE as "ConvBn2d" 720 _FLOAT_MODULE = nni.ConvBnReLU2d # type: ignore[assignment] 721 _FLOAT_CONV_MODULE = nn.Conv2d 722 _FLOAT_BN_MODULE = nn.BatchNorm2d 723 _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] 724 # module class after fusing bn into conv 725 _FUSED_FLOAT_MODULE = nni.ConvReLU2d 726 727 def __init__( 728 self, 729 # Conv2d args 730 in_channels, 731 out_channels, 732 kernel_size, 733 stride=1, 734 padding=0, 735 dilation=1, 736 groups=1, 737 bias=None, 738 padding_mode="zeros", 739 # BatchNorm2d args 740 # num_features: out_channels 741 eps=1e-05, 742 momentum=0.1, 743 # affine: True 744 # track_running_stats: True 745 # Args for this module 746 freeze_bn=False, 747 qconfig=None, 748 ): 749 super().__init__( 750 in_channels, 751 out_channels, 752 kernel_size, 753 stride, 754 padding, 755 dilation, 756 groups, 757 bias, 758 padding_mode, 759 eps, 760 momentum, 761 freeze_bn, 762 qconfig, 763 ) 764 765 def forward(self, input): 766 return F.relu(ConvBn2d._forward(self, input)) 767 768 @classmethod 769 def from_float(cls, mod, use_precomputed_fake_quant=False): 770 return super().from_float(mod, use_precomputed_fake_quant) 771 772 773class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): 774 r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with 775 FakeQuantize modules for weight for 776 quantization aware training. 777 778 We combined the interface of :class:`~torch.nn.Conv2d` and 779 :class:`~torch.nn.BatchNorm2d`. 780 781 Attributes: 782 weight_fake_quant: fake quant module for weight 783 784 """ 785 _FLOAT_MODULE = nni.ConvReLU2d # type: ignore[assignment] 786 _FLOAT_CONV_MODULE = nn.Conv2d 787 _FLOAT_BN_MODULE: None = None 788 _FLOAT_RELU_MODULE = nn.ReLU 789 790 def __init__( 791 self, 792 in_channels, 793 out_channels, 794 kernel_size, 795 stride=1, 796 padding=0, 797 dilation=1, 798 groups=1, 799 bias=True, 800 padding_mode="zeros", 801 qconfig=None, 802 ): 803 super().__init__( 804 in_channels, 805 out_channels, 806 kernel_size, 807 stride=stride, 808 padding=padding, 809 dilation=dilation, 810 groups=groups, 811 bias=bias, 812 padding_mode=padding_mode, 813 qconfig=qconfig, 814 ) 815 assert qconfig, "qconfig must be provided for QAT module" 816 self.qconfig = qconfig 817 self.weight_fake_quant = self.qconfig.weight() 818 819 def forward(self, input): 820 return F.relu( 821 self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) 822 ) 823 824 @classmethod 825 def from_float(cls, mod, use_precomputed_fake_quant=False): 826 return super().from_float( 827 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 828 ) 829 830 831class ConvBn3d(_ConvBnNd, nn.Conv3d): 832 r""" 833 A ConvBn3d module is a module fused from Conv3d and BatchNorm3d, 834 attached with FakeQuantize modules for weight, 835 used in quantization aware training. 836 837 We combined the interface of :class:`torch.nn.Conv3d` and 838 :class:`torch.nn.BatchNorm3d`. 839 840 Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized 841 to default. 842 843 Attributes: 844 freeze_bn: 845 weight_fake_quant: fake quant module for weight 846 847 """ 848 _FLOAT_MODULE = nni.ConvBn3d 849 _FLOAT_CONV_MODULE = nn.Conv3d 850 _FLOAT_BN_MODULE = nn.BatchNorm3d 851 _FLOAT_RELU_MODULE: None = None 852 853 def __init__( 854 self, 855 # ConvNd args 856 in_channels, 857 out_channels, 858 kernel_size, 859 stride=1, 860 padding=0, 861 dilation=1, 862 groups=1, 863 bias=None, 864 padding_mode="zeros", 865 # BatchNorm3d args 866 # num_features: out_channels 867 eps=1e-05, 868 momentum=0.1, 869 # affine: True 870 # track_running_stats: True 871 # Args for this module 872 freeze_bn=False, 873 qconfig=None, 874 ): 875 kernel_size = _triple(kernel_size) 876 stride = _triple(stride) 877 padding = _triple(padding) 878 dilation = _triple(dilation) 879 _ConvBnNd.__init__( 880 self, 881 in_channels, 882 out_channels, 883 kernel_size, 884 stride, 885 padding, 886 dilation, 887 False, 888 _triple(0), 889 groups, 890 bias, 891 padding_mode, 892 eps, 893 momentum, 894 freeze_bn, 895 qconfig, 896 dim=3, 897 ) 898 899 900class ConvBnReLU3d(ConvBn3d): 901 r""" 902 A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU, 903 attached with FakeQuantize modules for weight, 904 used in quantization aware training. 905 906 We combined the interface of :class:`torch.nn.Conv3d` and 907 :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`. 908 909 Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to 910 default. 911 912 Attributes: 913 weight_fake_quant: fake quant module for weight 914 915 """ 916 _FLOAT_MODULE = nni.ConvBnReLU3d # type: ignore[assignment] 917 _FLOAT_CONV_MODULE = nn.Conv3d 918 _FLOAT_BN_MODULE = nn.BatchNorm3d 919 _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] 920 # module class after fusing bn into conv 921 _FUSED_FLOAT_MODULE = nni.ConvReLU3d 922 923 def __init__( 924 self, 925 # Conv3d args 926 in_channels, 927 out_channels, 928 kernel_size, 929 stride=1, 930 padding=0, 931 dilation=1, 932 groups=1, 933 bias=None, 934 padding_mode="zeros", 935 # BatchNorm3d args 936 # num_features: out_channels 937 eps=1e-05, 938 momentum=0.1, 939 # affine: True 940 # track_running_stats: True 941 # Args for this module 942 freeze_bn=False, 943 qconfig=None, 944 ): 945 super().__init__( 946 in_channels, 947 out_channels, 948 kernel_size, 949 stride, 950 padding, 951 dilation, 952 groups, 953 bias, 954 padding_mode, 955 eps, 956 momentum, 957 freeze_bn, 958 qconfig, 959 ) 960 961 def forward(self, input): 962 return F.relu(ConvBn3d._forward(self, input)) 963 964 @classmethod 965 def from_float(cls, mod, use_precomputed_fake_quant=False): 966 return super().from_float( 967 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 968 ) 969 970 971class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): 972 r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with 973 FakeQuantize modules for weight for 974 quantization aware training. 975 976 We combined the interface of :class:`~torch.nn.Conv3d` and 977 :class:`~torch.nn.BatchNorm3d`. 978 979 Attributes: 980 weight_fake_quant: fake quant module for weight 981 982 """ 983 _FLOAT_MODULE = nni.ConvReLU3d # type: ignore[assignment] 984 _FLOAT_CONV_MODULE = nn.Conv3d 985 _FLOAT_BN_MODULE: None = None 986 _FLOAT_RELU_MODULE = nn.ReLU 987 988 def __init__( 989 self, 990 in_channels, 991 out_channels, 992 kernel_size, 993 stride=1, 994 padding=0, 995 dilation=1, 996 groups=1, 997 bias=True, 998 padding_mode="zeros", 999 qconfig=None, 1000 ): 1001 super().__init__( 1002 in_channels, 1003 out_channels, 1004 kernel_size, 1005 stride=stride, 1006 padding=padding, 1007 dilation=dilation, 1008 groups=groups, 1009 bias=bias, 1010 padding_mode=padding_mode, 1011 qconfig=qconfig, 1012 ) 1013 assert qconfig, "qconfig must be provided for QAT module" 1014 self.qconfig = qconfig 1015 self.weight_fake_quant = self.qconfig.weight() 1016 1017 def forward(self, input): 1018 return F.relu( 1019 self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) 1020 ) 1021 1022 @classmethod 1023 def from_float(cls, mod, use_precomputed_fake_quant=False): 1024 return super().from_float( 1025 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 1026 ) 1027 1028 1029def update_bn_stats(mod): 1030 if type(mod) in { 1031 ConvBnReLU1d, 1032 ConvBnReLU2d, 1033 ConvBnReLU3d, 1034 ConvBn1d, 1035 ConvBn2d, 1036 ConvBn3d, 1037 }: 1038 mod.update_bn_stats() 1039 1040 1041def freeze_bn_stats(mod): 1042 if type(mod) in { 1043 ConvBnReLU1d, 1044 ConvBnReLU2d, 1045 ConvBnReLU3d, 1046 ConvBn1d, 1047 ConvBn2d, 1048 ConvBn3d, 1049 }: 1050 mod.freeze_bn_stats() 1051