xref: /aosp_15_r20/external/pytorch/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import torch
4import torch.ao.nn.intrinsic
5import torch.ao.nn.intrinsic.qat
6import torch.ao.nn.quantized as nnq
7import torch.nn.functional as F
8from torch.nn.utils import fuse_conv_bn_weights
9
10
11__all__ = [
12    "ConvReLU1d",
13    "ConvReLU2d",
14    "ConvReLU3d",
15]
16
17_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
18
19
20# TODO: factor out the common parts to ConvNd
21class ConvReLU1d(nnq.Conv1d):
22    r"""
23    A ConvReLU1d module is a fused module of Conv1d and ReLU
24
25    We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
26
27    Attributes:
28        Same as torch.ao.nn.quantized.Conv1d
29
30    """
31    _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d  # type: ignore[assignment]
32
33    def __init__(
34        self,
35        in_channels,
36        out_channels,
37        kernel_size,
38        stride=1,
39        padding=0,
40        dilation=1,
41        groups=1,
42        bias=True,
43        padding_mode="zeros",
44        device=None,
45        dtype=None,
46    ):
47        super().__init__(
48            in_channels,
49            out_channels,
50            kernel_size,
51            stride=stride,
52            padding=padding,
53            dilation=dilation,
54            groups=groups,
55            bias=bias,
56            padding_mode=padding_mode,
57            device=device,
58            dtype=dtype,
59        )
60
61    def forward(self, input):
62        # Temporarily using len(shape) instead of ndim due to JIT issue
63        # https://github.com/pytorch/pytorch/issues/23890
64        if len(input.shape) != 3:
65            raise ValueError("Input shape must be `(N, C, L)`!")
66        if self.padding_mode != "zeros":
67            # Padding in Conv1d is stored as (p, p), need to get (p,)
68            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
69            input = F.pad(
70                input, _reversed_padding_repeated_twice, mode=self.padding_mode
71            )
72        return torch.ops.quantized.conv1d_relu(
73            input, self._packed_params, self.scale, self.zero_point
74        )
75
76    def _get_name(self):
77        return "QuantizedConvReLU1d"
78
79    @classmethod
80    def from_float(cls, mod, use_precomputed_fake_quant=False):
81        if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
82            assert mod.bn.running_var is not None and mod.bn.running_mean is not None
83            mod.weight, mod.bias = fuse_conv_bn_weights(
84                mod.weight,
85                mod.bias,
86                mod.bn.running_mean,
87                mod.bn.running_var,
88                mod.bn.eps,
89                mod.bn.weight,
90                mod.bn.bias,
91            )
92        return super().from_float(mod, use_precomputed_fake_quant)
93
94    @classmethod
95    def from_reference(cls, ref_qconv, output_scale, output_zero_point):
96        assert (
97            type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d
98        ), "BatchNorm1d should be fused into Conv1d before converting to reference module"
99        return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
100
101
102class ConvReLU2d(nnq.Conv2d):
103    r"""
104    A ConvReLU2d module is a fused module of Conv2d and ReLU
105
106    We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
107
108    Attributes:
109        Same as torch.ao.nn.quantized.Conv2d
110
111    """
112    _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d  # type: ignore[assignment]
113
114    def __init__(
115        self,
116        in_channels,
117        out_channels,
118        kernel_size,
119        stride=1,
120        padding=0,
121        dilation=1,
122        groups=1,
123        bias=True,
124        padding_mode="zeros",
125        device=None,
126        dtype=None,
127    ):
128        super().__init__(
129            in_channels,
130            out_channels,
131            kernel_size,
132            stride=stride,
133            padding=padding,
134            dilation=dilation,
135            groups=groups,
136            bias=bias,
137            padding_mode=padding_mode,
138            device=device,
139            dtype=dtype,
140        )
141
142    def forward(self, input):
143        # Temporarily using len(shape) instead of ndim due to JIT issue
144        # https://github.com/pytorch/pytorch/issues/23890
145        if len(input.shape) != 4:
146            raise ValueError("Input shape must be `(N, C, H, W)`!")
147        if self.padding_mode != "zeros":
148            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
149            input = F.pad(
150                input, _reversed_padding_repeated_twice, mode=self.padding_mode
151            )
152        return torch.ops.quantized.conv2d_relu(
153            input, self._packed_params, self.scale, self.zero_point
154        )
155
156    def _get_name(self):
157        return "QuantizedConvReLU2d"
158
159    @classmethod
160    def from_float(cls, mod, use_precomputed_fake_quant=False):
161        if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
162            assert mod.bn.running_var is not None and mod.bn.running_mean is not None
163            mod.weight, mod.bias = fuse_conv_bn_weights(
164                mod.weight,
165                mod.bias,
166                mod.bn.running_mean,
167                mod.bn.running_var,
168                mod.bn.eps,
169                mod.bn.weight,
170                mod.bn.bias,
171            )
172        return super().from_float(
173            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
174        )
175
176    @classmethod
177    def from_reference(cls, ref_qconv, output_scale, output_zero_point):
178        assert (
179            type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d
180        ), "BatchNorm2d should be fused into Conv2d before converting to reference module"
181        return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
182
183
184class ConvReLU3d(nnq.Conv3d):
185    r"""
186    A ConvReLU3d module is a fused module of Conv3d and ReLU
187
188    We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`.
189
190    Attributes: Same as torch.ao.nn.quantized.Conv3d
191
192    """
193    _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d  # type: ignore[assignment]
194
195    def __init__(
196        self,
197        in_channels,
198        out_channels,
199        kernel_size,
200        stride=1,
201        padding=0,
202        dilation=1,
203        groups=1,
204        bias=True,
205        padding_mode="zeros",
206        device=None,
207        dtype=None,
208    ):
209        assert padding_mode != "reflect", "Conv3d does not support reflection padding"
210        super().__init__(
211            in_channels,
212            out_channels,
213            kernel_size,
214            stride=stride,
215            padding=padding,
216            dilation=dilation,
217            groups=groups,
218            bias=bias,
219            padding_mode=padding_mode,
220            device=device,
221            dtype=dtype,
222        )
223
224    def forward(self, input):
225        # Temporarily using len(shape) instead of ndim due to JIT issue
226        # https://github.com/pytorch/pytorch/issues/23890
227        if len(input.shape) != 5:
228            raise ValueError("Input shape must be `(N, C, D, H, W)`!")
229        if self.padding_mode != "zeros":
230            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
231            input = F.pad(
232                input, _reversed_padding_repeated_twice, mode=self.padding_mode
233            )
234        return torch.ops.quantized.conv3d_relu(
235            input, self._packed_params, self.scale, self.zero_point
236        )
237
238    def _get_name(self):
239        return "QuantizedConvReLU3d"
240
241    @classmethod
242    def from_float(cls, mod, use_precomputed_fake_quant=False):
243        if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
244            assert mod.bn.running_var is not None and mod.bn.running_mean is not None
245            mod.weight, mod.bias = fuse_conv_bn_weights(
246                mod.weight,
247                mod.bias,
248                mod.bn.running_mean,
249                mod.bn.running_var,
250                mod.bn.eps,
251                mod.bn.weight,
252                mod.bn.bias,
253            )
254        return super().from_float(
255            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
256        )
257
258    @classmethod
259    def from_reference(cls, ref_qconv, output_scale, output_zero_point):
260        assert (
261            type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d
262        ), "BatchNorm3d should be fused into Conv3d before converting to reference module"
263        return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
264