xref: /aosp_15_r20/external/pytorch/torch/ao/nn/intrinsic/qat/modules/conv_fused.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import math
3from typing import TypeVar
4
5import torch
6import torch.ao.nn.intrinsic as nni
7import torch.ao.nn.qat as nnqat
8import torch.nn as nn
9import torch.nn.functional as F
10from torch.nn import init
11from torch.nn.modules.utils import _pair, _single, _triple
12from torch.nn.parameter import Parameter
13from torch.nn.utils import fuse_conv_bn_weights
14
15
16__all__ = [
17    "ConvBn1d",
18    "ConvBnReLU1d",
19    "ConvReLU1d",
20    "ConvBn2d",
21    "ConvBnReLU2d",
22    "ConvReLU2d",
23    "ConvBn3d",
24    "ConvBnReLU3d",
25    "ConvReLU3d",
26    "update_bn_stats",
27    "freeze_bn_stats",
28]
29_BN_CLASS_MAP = {
30    1: nn.BatchNorm1d,
31    2: nn.BatchNorm2d,
32    3: nn.BatchNorm3d,
33}
34
35
36MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd)
37
38
39class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
40    _version = 2
41    _FLOAT_MODULE = MOD
42
43    def __init__(
44        self,
45        # ConvNd args
46        in_channels,
47        out_channels,
48        kernel_size,
49        stride,
50        padding,
51        dilation,
52        transposed,
53        output_padding,
54        groups,
55        bias,
56        padding_mode,
57        # BatchNormNd args
58        # num_features: out_channels
59        eps=1e-05,
60        momentum=0.1,
61        # affine: True
62        # track_running_stats: True
63        # Args for this module
64        freeze_bn=False,
65        qconfig=None,
66        dim=2,
67    ):
68        nn.modules.conv._ConvNd.__init__(
69            self,
70            in_channels,
71            out_channels,
72            kernel_size,
73            stride,
74            padding,
75            dilation,
76            transposed,
77            output_padding,
78            groups,
79            False,
80            padding_mode,
81        )
82        assert qconfig, "qconfig must be provided for QAT module"
83        self.qconfig = qconfig
84        self.freeze_bn = freeze_bn if self.training else True
85        self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
86        self.weight_fake_quant = self.qconfig.weight()
87        if bias:
88            self.bias = Parameter(torch.empty(out_channels))
89        else:
90            self.register_parameter("bias", None)
91        self.reset_bn_parameters()
92
93        # this needs to be called after reset_bn_parameters,
94        # as they modify the same state
95        if self.training:
96            if freeze_bn:
97                self.freeze_bn_stats()
98            else:
99                self.update_bn_stats()
100        else:
101            self.freeze_bn_stats()
102
103        self._enable_slow_path_for_better_numerical_stability = False
104
105    def reset_running_stats(self):
106        self.bn.reset_running_stats()
107
108    def reset_bn_parameters(self):
109        self.bn.reset_running_stats()
110        init.uniform_(self.bn.weight)
111        init.zeros_(self.bn.bias)
112        # note: below is actually for conv, not BN
113        if self.bias is not None:
114            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
115            bound = 1 / math.sqrt(fan_in)
116            init.uniform_(self.bias, -bound, bound)
117
118    def reset_parameters(self):
119        super().reset_parameters()
120
121    def update_bn_stats(self):
122        self.freeze_bn = False
123        self.bn.training = True
124        return self
125
126    def freeze_bn_stats(self):
127        self.freeze_bn = True
128        self.bn.training = False
129        return self
130
131    def _forward(self, input):
132        if self._enable_slow_path_for_better_numerical_stability:
133            return self._forward_slow(input)
134        return self._forward_approximate(input)
135
136    def _forward_approximate(self, input):
137        """Approximated method to fuse conv and bn. It requires only one forward pass.
138        conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std
139        """
140        assert self.bn.running_var is not None
141        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
142        scale_factor = self.bn.weight / running_std
143        weight_shape = [1] * len(self.weight.shape)
144        weight_shape[0] = -1
145        bias_shape = [1] * len(self.weight.shape)
146        bias_shape[1] = -1
147        scaled_weight = self.weight_fake_quant(
148            self.weight * scale_factor.reshape(weight_shape)
149        )
150        # using zero bias here since the bias for original conv
151        # will be added later
152        if self.bias is not None:
153            zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
154        else:
155            zero_bias = torch.zeros(
156                self.out_channels, device=scaled_weight.device, dtype=input.dtype
157            )
158        conv = self._conv_forward(input, scaled_weight, zero_bias)
159        conv_orig = conv / scale_factor.reshape(bias_shape)
160        if self.bias is not None:
161            conv_orig = conv_orig + self.bias.reshape(bias_shape)
162        conv = self.bn(conv_orig)
163        return conv
164
165    def _forward_slow(self, input):
166        """
167        A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf
168        It requires two forward passes but handles the case bn.weight == 0
169
170        Conv: Y = WX + B_c
171        Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c
172
173        Batch statistics:
174          mean_Y = Y.mean()
175                 = Y0.mean() + B_c
176          var_Y = (Y - mean_Y)^2.mean()
177                = (Y0 - Y0.mean())^2.mean()
178        BN (r: bn.weight, beta: bn.bias):
179          Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta
180            = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta
181
182        Fused Conv BN training (std_Y = sqrt(var_Y + eps)):
183          Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta
184            = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta
185
186        Fused Conv BN inference (running_std = sqrt(running_var + eps)):
187          Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta
188
189        QAT with fused conv bn:
190          Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
191                  = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
192          Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta
193        """
194
195        assert self.bn.running_var is not None
196        assert self.bn.running_mean is not None
197
198        # using zero bias here since the bias for original conv
199        # will be added later
200        zero_bias = torch.zeros(
201            self.out_channels, device=self.weight.device, dtype=input.dtype
202        )
203
204        weight_shape = [1] * len(self.weight.shape)
205        weight_shape[0] = -1
206        bias_shape = [1] * len(self.weight.shape)
207        bias_shape[1] = -1
208
209        if self.bn.training:
210            # needed to compute batch mean/std
211            conv_out = self._conv_forward(input, self.weight, zero_bias)
212            # update bn statistics
213            with torch.no_grad():
214                conv_out_bias = (
215                    conv_out
216                    if self.bias is None
217                    else conv_out + self.bias.reshape(bias_shape)
218                )
219                self.bn(conv_out_bias)
220
221        # fused conv + bn without bias using bn running statistics
222        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
223        scale_factor = self.bn.weight / running_std
224        scaled_weight = self.weight_fake_quant(
225            self.weight * scale_factor.reshape(weight_shape)
226        )
227        # fused conv without bias for inference: (r * W / running_std) * X
228        conv_bn = self._conv_forward(input, scaled_weight, zero_bias)
229
230        if self.bn.training:
231            avg_dims = [0] + list(range(2, len(self.weight.shape)))
232            batch_mean = conv_out.mean(avg_dims)  # type: ignore[possibly-undefined]
233            batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
234                avg_dims
235            )
236            batch_std = torch.sqrt(batch_var + self.bn.eps)
237
238            # scale to use batch std in training mode
239            # conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)
240            unscale_factor = running_std / batch_std
241            conv_bn *= unscale_factor.reshape(bias_shape)
242
243            fused_mean = batch_mean
244            fused_std = batch_std
245        else:
246            fused_mean = self.bn.running_mean - (
247                self.bias if self.bias is not None else 0
248            )
249            fused_std = running_std
250
251        # fused bias = beta - r * mean / std
252        fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std
253        conv_bn += fused_bias.reshape(bias_shape)
254
255        # HACK to let conv bias participate in loss to avoid DDP error (parameters
256        #   were not used in producing loss)
257        if self.bias is not None:
258            conv_bn += (self.bias - self.bias).reshape(bias_shape)
259
260        return conv_bn
261
262    def extra_repr(self):
263        # TODO(jerryzh): extend
264        return super().extra_repr()
265
266    def forward(self, input):
267        return self._forward(input)
268
269    def train(self, mode=True):
270        """
271        Batchnorm's training behavior is using the self.training flag. Prevent
272        changing it if BN is frozen. This makes sure that calling `model.train()`
273        on a model with a frozen BN will behave properly.
274        """
275        self.training = mode
276        if not self.freeze_bn:
277            for module in self.children():
278                module.train(mode)
279        return self
280
281    # ===== Serialization version history =====
282    #
283    # Version 1/None
284    #   self
285    #   |--- weight : Tensor
286    #   |--- bias : Tensor
287    #   |--- gamma : Tensor
288    #   |--- beta : Tensor
289    #   |--- running_mean : Tensor
290    #   |--- running_var : Tensor
291    #   |--- num_batches_tracked : Tensor
292    #
293    # Version 2
294    #   self
295    #   |--- weight : Tensor
296    #   |--- bias : Tensor
297    #   |--- bn : Module
298    #        |--- weight : Tensor (moved from v1.self.gamma)
299    #        |--- bias : Tensor (moved from v1.self.beta)
300    #        |--- running_mean : Tensor (moved from v1.self.running_mean)
301    #        |--- running_var : Tensor (moved from v1.self.running_var)
302    #        |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
303    def _load_from_state_dict(
304        self,
305        state_dict,
306        prefix,
307        local_metadata,
308        strict,
309        missing_keys,
310        unexpected_keys,
311        error_msgs,
312    ):
313        version = local_metadata.get("version", None)
314        if version is None or version == 1:
315            # BN related parameters and buffers were moved into the BN module for v2
316            v2_to_v1_names = {
317                "bn.weight": "gamma",
318                "bn.bias": "beta",
319                "bn.running_mean": "running_mean",
320                "bn.running_var": "running_var",
321                "bn.num_batches_tracked": "num_batches_tracked",
322            }
323            for v2_name, v1_name in v2_to_v1_names.items():
324                if prefix + v1_name in state_dict:
325                    state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
326                    state_dict.pop(prefix + v1_name)
327                elif prefix + v2_name in state_dict:
328                    # there was a brief period where forward compatibility
329                    # for this module was broken (between
330                    # https://github.com/pytorch/pytorch/pull/38478
331                    # and https://github.com/pytorch/pytorch/pull/38820)
332                    # and modules emitted the v2 state_dict format while
333                    # specifying that version == 1. This patches the forward
334                    # compatibility issue by allowing the v2 style entries to
335                    # be used.
336                    pass
337                elif strict:
338                    missing_keys.append(prefix + v2_name)
339
340        super()._load_from_state_dict(
341            state_dict,
342            prefix,
343            local_metadata,
344            strict,
345            missing_keys,
346            unexpected_keys,
347            error_msgs,
348        )
349
350    @classmethod
351    def from_float(cls, mod, use_precomputed_fake_quant=False):
352        r"""Create a qat module from a float module or qparams_dict
353
354        Args: `mod` a float module, either produced by torch.ao.quantization utilities
355        or directly from user
356        """
357        # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
358        # has no __name__ (code is fine though)
359        assert type(mod) == cls._FLOAT_MODULE, (
360            "qat."
361            + cls.__name__
362            + ".from_float only works for "
363            + cls._FLOAT_MODULE.__name__  # type: ignore[attr-defined]
364        )
365        assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
366        assert mod.qconfig, "Input float module must have a valid qconfig"
367        qconfig = mod.qconfig
368        conv, bn = mod[0], mod[1]
369        qat_convbn = cls(
370            conv.in_channels,
371            conv.out_channels,
372            conv.kernel_size,
373            conv.stride,
374            conv.padding,
375            conv.dilation,
376            conv.groups,
377            conv.bias is not None,
378            conv.padding_mode,
379            bn.eps,
380            bn.momentum,
381            False,
382            qconfig,
383        )
384        qat_convbn.weight = conv.weight
385        qat_convbn.bias = conv.bias
386        qat_convbn.bn.weight = bn.weight
387        qat_convbn.bn.bias = bn.bias
388        qat_convbn.bn.running_mean = bn.running_mean
389        qat_convbn.bn.running_var = bn.running_var
390        # mypy error: Cannot determine type of 'num_batches_tracked'
391        qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked  # type: ignore[has-type]
392        return qat_convbn
393
394    def to_float(self):
395        cls = type(self)
396        conv = cls._FLOAT_CONV_MODULE(  # type: ignore[attr-defined]
397            self.in_channels,
398            self.out_channels,
399            self.kernel_size,
400            self.stride,
401            self.padding,
402            self.dilation,
403            self.groups,
404            self.bias is not None,
405            self.padding_mode,
406        )
407        conv.weight = torch.nn.Parameter(self.weight.detach())
408        if self.bias is not None:
409            conv.bias = torch.nn.Parameter(self.bias.detach())
410
411        if cls._FLOAT_BN_MODULE:  # type: ignore[attr-defined]
412            # fuse bn into conv
413            assert self.bn.running_var is not None and self.bn.running_mean is not None
414            conv.weight, conv.bias = fuse_conv_bn_weights(
415                conv.weight,
416                conv.bias,
417                self.bn.running_mean,
418                self.bn.running_var,
419                self.bn.eps,
420                self.bn.weight,
421                self.bn.bias,
422            )
423
424        if cls._FLOAT_RELU_MODULE:  # type: ignore[attr-defined]
425            modules = []
426            modules.append(conv)
427            relu = cls._FLOAT_RELU_MODULE()  # type: ignore[attr-defined]
428            modules.append(relu)
429            conv_relu = cls._FUSED_FLOAT_MODULE(*modules)  # type: ignore[attr-defined]
430            conv_relu.train(self.training)
431            return conv_relu
432        else:
433            conv.train(self.training)
434            return conv
435
436
437class ConvBn1d(_ConvBnNd, nn.Conv1d):
438    r"""
439    A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
440    attached with FakeQuantize modules for weight,
441    used in quantization aware training.
442
443    We combined the interface of :class:`torch.nn.Conv1d` and
444    :class:`torch.nn.BatchNorm1d`.
445
446    Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
447    to default.
448
449    Attributes:
450        freeze_bn:
451        weight_fake_quant: fake quant module for weight
452
453    """
454    _FLOAT_BN_MODULE = nn.BatchNorm1d
455    _FLOAT_RELU_MODULE: None = None
456    _FLOAT_MODULE = nni.ConvBn1d
457    _FLOAT_CONV_MODULE = nn.Conv1d
458
459    def __init__(
460        self,
461        # Conv1d args
462        in_channels,
463        out_channels,
464        kernel_size,
465        stride=1,
466        padding=0,
467        dilation=1,
468        groups=1,
469        bias=None,
470        padding_mode="zeros",
471        # BatchNorm1d args
472        # num_features: out_channels
473        eps=1e-05,
474        momentum=0.1,
475        # affine: True
476        # track_running_stats: True
477        # Args for this module
478        freeze_bn=False,
479        qconfig=None,
480    ):
481        kernel_size = _single(kernel_size)
482        stride = _single(stride)
483        padding = _single(padding)
484        dilation = _single(dilation)
485        _ConvBnNd.__init__(
486            self,
487            in_channels,
488            out_channels,
489            kernel_size,
490            stride,
491            padding,
492            dilation,
493            False,
494            _single(0),
495            groups,
496            bias,
497            padding_mode,
498            eps,
499            momentum,
500            freeze_bn,
501            qconfig,
502            dim=1,
503        )
504
505
506class ConvBnReLU1d(ConvBn1d):
507    r"""
508    A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
509    attached with FakeQuantize modules for weight,
510    used in quantization aware training.
511
512    We combined the interface of :class:`torch.nn.Conv1d` and
513    :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
514
515    Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
516    default.
517
518    Attributes:
519        weight_fake_quant: fake quant module for weight
520
521    """
522    # base class defines _FLOAT_MODULE as "ConvBn1d"
523    _FLOAT_MODULE = nni.ConvBnReLU1d  # type: ignore[assignment]
524    _FLOAT_CONV_MODULE = nn.Conv1d
525    _FLOAT_BN_MODULE = nn.BatchNorm1d
526    _FLOAT_RELU_MODULE = nn.ReLU  # type: ignore[assignment]
527    # module class after fusing bn into conv
528    _FUSED_FLOAT_MODULE = nni.ConvReLU1d
529
530    def __init__(
531        self,
532        # Conv1d args
533        in_channels,
534        out_channels,
535        kernel_size,
536        stride=1,
537        padding=0,
538        dilation=1,
539        groups=1,
540        bias=None,
541        padding_mode="zeros",
542        # BatchNorm1d args
543        # num_features: out_channels
544        eps=1e-05,
545        momentum=0.1,
546        # affine: True
547        # track_running_stats: True
548        # Args for this module
549        freeze_bn=False,
550        qconfig=None,
551    ):
552        super().__init__(
553            in_channels,
554            out_channels,
555            kernel_size,
556            stride,
557            padding,
558            dilation,
559            groups,
560            bias,
561            padding_mode,
562            eps,
563            momentum,
564            freeze_bn,
565            qconfig,
566        )
567
568    def forward(self, input):
569        return F.relu(ConvBn1d._forward(self, input))
570
571    @classmethod
572    def from_float(cls, mod, use_precomputed_fake_quant=False):
573        return super().from_float(mod, use_precomputed_fake_quant)
574
575
576class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
577    r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
578    FakeQuantize modules for weight for
579    quantization aware training.
580
581    We combined the interface of :class:`~torch.nn.Conv1d` and
582    :class:`~torch.nn.BatchNorm1d`.
583
584    Attributes:
585        weight_fake_quant: fake quant module for weight
586
587    """
588    _FLOAT_MODULE = nni.ConvReLU1d  # type: ignore[assignment]
589    _FLOAT_CONV_MODULE = nn.Conv1d
590    _FLOAT_BN_MODULE: None = None
591    _FLOAT_RELU_MODULE = nn.ReLU
592
593    def __init__(
594        self,
595        in_channels,
596        out_channels,
597        kernel_size,
598        stride=1,
599        padding=0,
600        dilation=1,
601        groups=1,
602        bias=True,
603        padding_mode="zeros",
604        qconfig=None,
605    ):
606        super().__init__(
607            in_channels,
608            out_channels,
609            kernel_size,
610            stride=stride,
611            padding=padding,
612            dilation=dilation,
613            groups=groups,
614            bias=bias,
615            padding_mode=padding_mode,
616            qconfig=qconfig,
617        )
618        assert qconfig, "qconfig must be provided for QAT module"
619        self.qconfig = qconfig
620        self.weight_fake_quant = self.qconfig.weight()
621
622    def forward(self, input):
623        return F.relu(
624            self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
625        )
626
627    @classmethod
628    def from_float(cls, mod, use_precomputed_fake_quant=False):
629        return super().from_float(
630            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
631        )
632
633
634class ConvBn2d(_ConvBnNd, nn.Conv2d):
635    r"""
636    A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
637    attached with FakeQuantize modules for weight,
638    used in quantization aware training.
639
640    We combined the interface of :class:`torch.nn.Conv2d` and
641    :class:`torch.nn.BatchNorm2d`.
642
643    Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
644    to default.
645
646    Attributes:
647        freeze_bn:
648        weight_fake_quant: fake quant module for weight
649
650    """
651    _FLOAT_MODULE = nni.ConvBn2d
652    _FLOAT_CONV_MODULE = nn.Conv2d
653    _FLOAT_BN_MODULE = nn.BatchNorm2d
654    _FLOAT_RELU_MODULE: None = None
655
656    def __init__(
657        self,
658        # ConvNd args
659        in_channels,
660        out_channels,
661        kernel_size,
662        stride=1,
663        padding=0,
664        dilation=1,
665        groups=1,
666        bias=None,
667        padding_mode="zeros",
668        # BatchNorm2d args
669        # num_features: out_channels
670        eps=1e-05,
671        momentum=0.1,
672        # affine: True
673        # track_running_stats: True
674        # Args for this module
675        freeze_bn=False,
676        qconfig=None,
677    ):
678        kernel_size = _pair(kernel_size)
679        stride = _pair(stride)
680        padding = _pair(padding)
681        dilation = _pair(dilation)
682        _ConvBnNd.__init__(
683            self,
684            in_channels,
685            out_channels,
686            kernel_size,
687            stride,
688            padding,
689            dilation,
690            False,
691            _pair(0),
692            groups,
693            bias,
694            padding_mode,
695            eps,
696            momentum,
697            freeze_bn,
698            qconfig,
699            dim=2,
700        )
701
702
703class ConvBnReLU2d(ConvBn2d):
704    r"""
705    A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU,
706    attached with FakeQuantize modules for weight,
707    used in quantization aware training.
708
709    We combined the interface of :class:`torch.nn.Conv2d` and
710    :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`.
711
712    Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
713    default.
714
715    Attributes:
716        weight_fake_quant: fake quant module for weight
717
718    """
719    # base class defines _FLOAT_MODULE as "ConvBn2d"
720    _FLOAT_MODULE = nni.ConvBnReLU2d  # type: ignore[assignment]
721    _FLOAT_CONV_MODULE = nn.Conv2d
722    _FLOAT_BN_MODULE = nn.BatchNorm2d
723    _FLOAT_RELU_MODULE = nn.ReLU  # type: ignore[assignment]
724    # module class after fusing bn into conv
725    _FUSED_FLOAT_MODULE = nni.ConvReLU2d
726
727    def __init__(
728        self,
729        # Conv2d args
730        in_channels,
731        out_channels,
732        kernel_size,
733        stride=1,
734        padding=0,
735        dilation=1,
736        groups=1,
737        bias=None,
738        padding_mode="zeros",
739        # BatchNorm2d args
740        # num_features: out_channels
741        eps=1e-05,
742        momentum=0.1,
743        # affine: True
744        # track_running_stats: True
745        # Args for this module
746        freeze_bn=False,
747        qconfig=None,
748    ):
749        super().__init__(
750            in_channels,
751            out_channels,
752            kernel_size,
753            stride,
754            padding,
755            dilation,
756            groups,
757            bias,
758            padding_mode,
759            eps,
760            momentum,
761            freeze_bn,
762            qconfig,
763        )
764
765    def forward(self, input):
766        return F.relu(ConvBn2d._forward(self, input))
767
768    @classmethod
769    def from_float(cls, mod, use_precomputed_fake_quant=False):
770        return super().from_float(mod, use_precomputed_fake_quant)
771
772
773class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
774    r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
775    FakeQuantize modules for weight for
776    quantization aware training.
777
778    We combined the interface of :class:`~torch.nn.Conv2d` and
779    :class:`~torch.nn.BatchNorm2d`.
780
781    Attributes:
782        weight_fake_quant: fake quant module for weight
783
784    """
785    _FLOAT_MODULE = nni.ConvReLU2d  # type: ignore[assignment]
786    _FLOAT_CONV_MODULE = nn.Conv2d
787    _FLOAT_BN_MODULE: None = None
788    _FLOAT_RELU_MODULE = nn.ReLU
789
790    def __init__(
791        self,
792        in_channels,
793        out_channels,
794        kernel_size,
795        stride=1,
796        padding=0,
797        dilation=1,
798        groups=1,
799        bias=True,
800        padding_mode="zeros",
801        qconfig=None,
802    ):
803        super().__init__(
804            in_channels,
805            out_channels,
806            kernel_size,
807            stride=stride,
808            padding=padding,
809            dilation=dilation,
810            groups=groups,
811            bias=bias,
812            padding_mode=padding_mode,
813            qconfig=qconfig,
814        )
815        assert qconfig, "qconfig must be provided for QAT module"
816        self.qconfig = qconfig
817        self.weight_fake_quant = self.qconfig.weight()
818
819    def forward(self, input):
820        return F.relu(
821            self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
822        )
823
824    @classmethod
825    def from_float(cls, mod, use_precomputed_fake_quant=False):
826        return super().from_float(
827            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
828        )
829
830
831class ConvBn3d(_ConvBnNd, nn.Conv3d):
832    r"""
833    A ConvBn3d module is a module fused from Conv3d and BatchNorm3d,
834    attached with FakeQuantize modules for weight,
835    used in quantization aware training.
836
837    We combined the interface of :class:`torch.nn.Conv3d` and
838    :class:`torch.nn.BatchNorm3d`.
839
840    Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized
841    to default.
842
843    Attributes:
844        freeze_bn:
845        weight_fake_quant: fake quant module for weight
846
847    """
848    _FLOAT_MODULE = nni.ConvBn3d
849    _FLOAT_CONV_MODULE = nn.Conv3d
850    _FLOAT_BN_MODULE = nn.BatchNorm3d
851    _FLOAT_RELU_MODULE: None = None
852
853    def __init__(
854        self,
855        # ConvNd args
856        in_channels,
857        out_channels,
858        kernel_size,
859        stride=1,
860        padding=0,
861        dilation=1,
862        groups=1,
863        bias=None,
864        padding_mode="zeros",
865        # BatchNorm3d args
866        # num_features: out_channels
867        eps=1e-05,
868        momentum=0.1,
869        # affine: True
870        # track_running_stats: True
871        # Args for this module
872        freeze_bn=False,
873        qconfig=None,
874    ):
875        kernel_size = _triple(kernel_size)
876        stride = _triple(stride)
877        padding = _triple(padding)
878        dilation = _triple(dilation)
879        _ConvBnNd.__init__(
880            self,
881            in_channels,
882            out_channels,
883            kernel_size,
884            stride,
885            padding,
886            dilation,
887            False,
888            _triple(0),
889            groups,
890            bias,
891            padding_mode,
892            eps,
893            momentum,
894            freeze_bn,
895            qconfig,
896            dim=3,
897        )
898
899
900class ConvBnReLU3d(ConvBn3d):
901    r"""
902    A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU,
903    attached with FakeQuantize modules for weight,
904    used in quantization aware training.
905
906    We combined the interface of :class:`torch.nn.Conv3d` and
907    :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`.
908
909    Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
910    default.
911
912    Attributes:
913        weight_fake_quant: fake quant module for weight
914
915    """
916    _FLOAT_MODULE = nni.ConvBnReLU3d  # type: ignore[assignment]
917    _FLOAT_CONV_MODULE = nn.Conv3d
918    _FLOAT_BN_MODULE = nn.BatchNorm3d
919    _FLOAT_RELU_MODULE = nn.ReLU  # type: ignore[assignment]
920    # module class after fusing bn into conv
921    _FUSED_FLOAT_MODULE = nni.ConvReLU3d
922
923    def __init__(
924        self,
925        # Conv3d args
926        in_channels,
927        out_channels,
928        kernel_size,
929        stride=1,
930        padding=0,
931        dilation=1,
932        groups=1,
933        bias=None,
934        padding_mode="zeros",
935        # BatchNorm3d args
936        # num_features: out_channels
937        eps=1e-05,
938        momentum=0.1,
939        # affine: True
940        # track_running_stats: True
941        # Args for this module
942        freeze_bn=False,
943        qconfig=None,
944    ):
945        super().__init__(
946            in_channels,
947            out_channels,
948            kernel_size,
949            stride,
950            padding,
951            dilation,
952            groups,
953            bias,
954            padding_mode,
955            eps,
956            momentum,
957            freeze_bn,
958            qconfig,
959        )
960
961    def forward(self, input):
962        return F.relu(ConvBn3d._forward(self, input))
963
964    @classmethod
965    def from_float(cls, mod, use_precomputed_fake_quant=False):
966        return super().from_float(
967            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
968        )
969
970
971class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
972    r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
973    FakeQuantize modules for weight for
974    quantization aware training.
975
976    We combined the interface of :class:`~torch.nn.Conv3d` and
977    :class:`~torch.nn.BatchNorm3d`.
978
979    Attributes:
980        weight_fake_quant: fake quant module for weight
981
982    """
983    _FLOAT_MODULE = nni.ConvReLU3d  # type: ignore[assignment]
984    _FLOAT_CONV_MODULE = nn.Conv3d
985    _FLOAT_BN_MODULE: None = None
986    _FLOAT_RELU_MODULE = nn.ReLU
987
988    def __init__(
989        self,
990        in_channels,
991        out_channels,
992        kernel_size,
993        stride=1,
994        padding=0,
995        dilation=1,
996        groups=1,
997        bias=True,
998        padding_mode="zeros",
999        qconfig=None,
1000    ):
1001        super().__init__(
1002            in_channels,
1003            out_channels,
1004            kernel_size,
1005            stride=stride,
1006            padding=padding,
1007            dilation=dilation,
1008            groups=groups,
1009            bias=bias,
1010            padding_mode=padding_mode,
1011            qconfig=qconfig,
1012        )
1013        assert qconfig, "qconfig must be provided for QAT module"
1014        self.qconfig = qconfig
1015        self.weight_fake_quant = self.qconfig.weight()
1016
1017    def forward(self, input):
1018        return F.relu(
1019            self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
1020        )
1021
1022    @classmethod
1023    def from_float(cls, mod, use_precomputed_fake_quant=False):
1024        return super().from_float(
1025            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
1026        )
1027
1028
1029def update_bn_stats(mod):
1030    if type(mod) in {
1031        ConvBnReLU1d,
1032        ConvBnReLU2d,
1033        ConvBnReLU3d,
1034        ConvBn1d,
1035        ConvBn2d,
1036        ConvBn3d,
1037    }:
1038        mod.update_bn_stats()
1039
1040
1041def freeze_bn_stats(mod):
1042    if type(mod) in {
1043        ConvBnReLU1d,
1044        ConvBnReLU2d,
1045        ConvBnReLU3d,
1046        ConvBn1d,
1047        ConvBn2d,
1048        ConvBn3d,
1049    }:
1050        mod.freeze_bn_stats()
1051