xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/modules/batchnorm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.ao.nn.intrinsic as nni
4
5
6__all__ = ["BatchNorm2d", "BatchNorm3d"]
7
8
9class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
10    def __init__(
11        self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None
12    ) -> None:
13        factory_kwargs = {"device": device, "dtype": dtype}
14        super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
15        self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs))
16        self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs))
17
18    @staticmethod
19    def from_float(cls, mod, use_precomputed_fake_quant=False):
20        activation_post_process = mod.activation_post_process
21        if type(mod) == cls._NNI_BN_RELU_MODULE:
22            mod = mod[0]
23        scale, zero_point = activation_post_process.calculate_qparams()
24        new_mod = cls(mod.num_features, mod.eps)
25        new_mod.weight = mod.weight
26        new_mod.bias = mod.bias
27        new_mod.running_mean = mod.running_mean
28        new_mod.running_var = mod.running_var
29        new_mod.scale = scale
30        new_mod.zero_point = zero_point
31        return new_mod
32
33    @classmethod
34    def from_reference(cls, bn, output_scale, output_zero_point):
35        qbn = cls(
36            bn.num_features,
37            bn.eps,
38            bn.momentum,
39            device=bn.weight.device,
40            dtype=bn.weight.dtype,
41        )
42        qbn.weight = bn.weight
43        qbn.bias = bn.bias
44        qbn.running_mean = bn.running_mean
45        qbn.running_var = bn.running_var
46        qbn.scale = output_scale
47        qbn.zero_point = output_zero_point
48        return qbn
49
50
51class BatchNorm2d(_BatchNorm):
52    r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`."""
53
54    _NNI_BN_RELU_MODULE = nni.BNReLU2d
55
56    def __init__(
57        self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None
58    ) -> None:
59        factory_kwargs = {"device": device, "dtype": dtype}
60        super().__init__(num_features, eps, momentum, **factory_kwargs)
61
62    def _get_name(self):
63        return "QuantizedBatchNorm2d"
64
65    def _check_input_dim(self, input):
66        # Temporarily using len(shape) instead of ndim due to JIT issue
67        # https://github.com/pytorch/pytorch/issues/23890
68        if len(input.shape) != 4:
69            raise ValueError("Input shape must be `(N, C, H, W)`!")
70
71    def forward(self, input: torch.Tensor) -> torch.Tensor:
72        # disabling this since this is not symbolically traceable
73        # self._check_input_dim(input)
74        return torch.ops.quantized.batch_norm2d(
75            input,
76            self.weight,
77            self.bias,
78            self.running_mean,
79            self.running_var,
80            self.eps,
81            self.scale,
82            self.zero_point,
83        )
84
85    @classmethod
86    def from_float(cls, mod, use_precomputed_fake_quant=False):
87        return _BatchNorm.from_float(
88            cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
89        )
90
91
92class BatchNorm3d(_BatchNorm):
93    r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`."""
94
95    _NNI_BN_RELU_MODULE = nni.BNReLU3d
96
97    def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
98        factory_kwargs = {"device": device, "dtype": dtype}
99        super().__init__(num_features, eps, momentum, **factory_kwargs)
100
101    def _get_name(self):
102        return "QuantizedBatchNorm3d"
103
104    def _check_input_dim(self, input):
105        # Temporarily using len(shape) instead of ndim due to JIT issue
106        # https://github.com/pytorch/pytorch/issues/23890
107        if len(input.shape) != 5:
108            raise ValueError("Input shape must be `(N, C, H, W)`!")
109
110    def forward(self, input: torch.Tensor) -> torch.Tensor:
111        # disabling this since this is not symbolically traceable
112        # self._check_input_dim(input)
113        return torch.ops.quantized.batch_norm3d(
114            input,
115            self.weight,
116            self.bias,
117            self.running_mean,
118            self.running_var,
119            self.eps,
120            self.scale,
121            self.zero_point,
122        )
123
124    @classmethod
125    def from_float(cls, mod, use_precomputed_fake_quant=False):
126        return _BatchNorm.from_float(
127            cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
128        )
129