xref: /aosp_15_r20/external/pytorch/torch/ao/nn/qat/modules/conv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Tuple, TypeVar, Union
3
4import torch
5import torch.nn as nn
6from torch.ao.nn.intrinsic import _FusedModule
7from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
8from torch.nn.modules.utils import _pair, _single, _triple
9
10
11__all__ = ["Conv1d", "Conv2d", "Conv3d"]
12
13MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd)
14
15
16class _ConvNd(nn.modules.conv._ConvNd):
17    _FLOAT_MODULE = MOD
18
19    def __init__(
20        self,
21        in_channels: int,
22        out_channels: int,
23        kernel_size: Tuple[int, ...],
24        stride: Tuple[int, ...],
25        padding: Tuple[int, ...],
26        dilation: Tuple[int, ...],
27        transposed: bool,
28        output_padding: Tuple[int, ...],
29        groups: int,
30        bias: bool,
31        padding_mode: str,
32        qconfig=None,
33        device=None,
34        dtype=None,
35    ) -> None:
36        factory_kwargs = {"device": device, "dtype": dtype}
37        nn.modules.conv._ConvNd.__init__(
38            self,
39            in_channels,
40            out_channels,
41            kernel_size,
42            stride,
43            padding,
44            dilation,
45            transposed,
46            output_padding,
47            groups,
48            bias,
49            padding_mode,
50            **factory_kwargs,
51        )
52        assert qconfig, "qconfig must be provided for QAT module"
53        self.qconfig = qconfig
54        self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
55
56    def forward(self, input):
57        return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
58
59    @staticmethod
60    def from_float(cls, mod, use_precomputed_fake_quant=False):
61        r"""Create a qat module from a float module
62
63        Args:
64           `mod`: a float module, either produced by torch.ao.quantization utilities
65           or directly from user
66        """
67        assert type(mod) == cls._FLOAT_MODULE, (
68            "qat."
69            + cls.__name__
70            + ".from_float only works for "
71            + cls._FLOAT_MODULE.__name__  # type: ignore[attr-defined]
72        )
73        assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
74        assert mod.qconfig, "Input float module must have a valid qconfig"
75        if issubclass(type(mod), _FusedModule):
76            mod = mod[0]  # type: ignore[index]
77        qconfig = mod.qconfig
78        qat_conv = cls(
79            mod.in_channels,
80            mod.out_channels,
81            mod.kernel_size,
82            stride=mod.stride,
83            padding=mod.padding,
84            dilation=mod.dilation,
85            groups=mod.groups,
86            bias=mod.bias is not None,
87            padding_mode=mod.padding_mode,
88            qconfig=qconfig,
89        )
90        qat_conv.weight = mod.weight
91        qat_conv.bias = mod.bias
92        return qat_conv
93
94    def to_float(self):
95        """This works for both single qat conv, and the qat conv - relu modules
96        to convert the qat module to a floating point module
97        """
98        cls = type(self)
99        conv = cls._FLOAT_CONV_MODULE(  # type: ignore[attr-defined, operator]
100            self.in_channels,
101            self.out_channels,
102            self.kernel_size,  # type: ignore[arg-type]
103            self.stride,  # type: ignore[arg-type]
104            self.padding,  # type: ignore[arg-type]
105            self.dilation,  # type: ignore[arg-type]
106            self.groups,
107            self.bias is not None,
108            self.padding_mode,
109        )
110        conv.weight = torch.nn.Parameter(self.weight.detach())
111        if self.bias is not None:
112            conv.bias = torch.nn.Parameter(self.bias.detach())
113        # conv relu
114        if issubclass(cls, _FusedModule):
115            modules = [conv]
116            assert hasattr(cls, "_FLOAT_RELU_MODULE")
117            relu = cls._FLOAT_RELU_MODULE()  # type: ignore[attr-defined]
118            modules.append(relu)
119            fused = cls._FLOAT_MODULE(*modules)  # type: ignore[arg-type, attr-defined, operator]
120            fused.train(self.training)
121            return fused
122        else:
123            return conv
124
125
126class Conv1d(_ConvNd, nn.Conv1d):
127    r"""
128    A Conv1d module attached with FakeQuantize modules for weight,
129    used for quantization aware training.
130
131    We adopt the same interface as :class:`~torch.nn.Conv1d`
132
133    Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
134    default.
135
136    Attributes:
137        weight_fake_quant: fake quant module for weight
138    """
139    _FLOAT_MODULE = nn.Conv1d
140    _FLOAT_CONV_MODULE = nn.Conv1d
141
142    def __init__(
143        self,
144        in_channels: int,
145        out_channels: int,
146        kernel_size: _size_1_t,
147        stride: _size_1_t = 1,
148        padding: Union[str, _size_1_t] = 0,
149        dilation: _size_1_t = 1,
150        groups: int = 1,
151        bias: bool = True,
152        padding_mode: str = "zeros",
153        qconfig=None,
154        device=None,
155        dtype=None,
156    ) -> None:
157        kernel_size_ = _single(kernel_size)
158        stride_ = _single(stride)
159        padding_ = padding if isinstance(padding, str) else _single(padding)
160        dilation_ = _single(dilation)
161        super().__init__(
162            in_channels,
163            out_channels,
164            kernel_size_,
165            stride=stride_,
166            padding=padding_,
167            dilation=dilation_,
168            transposed=False,
169            output_padding=_single(0),
170            groups=groups,
171            bias=bias,
172            padding_mode=padding_mode,
173            qconfig=qconfig,
174            device=device,
175            dtype=dtype,
176        )
177
178    @classmethod
179    def from_float(cls, mod, use_precomputed_fake_quant=False):
180        return super().from_float(
181            cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
182        )
183
184
185class Conv2d(_ConvNd, nn.Conv2d):
186    r"""
187    A Conv2d module attached with FakeQuantize modules for weight,
188    used for quantization aware training.
189
190    We adopt the same interface as `torch.nn.Conv2d`, please see
191    https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d
192    for documentation.
193
194    Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
195    default.
196
197    Attributes:
198        weight_fake_quant: fake quant module for weight
199    """
200    _FLOAT_MODULE = nn.Conv2d
201    _FLOAT_CONV_MODULE = nn.Conv2d
202
203    def __init__(
204        self,
205        in_channels: int,
206        out_channels: int,
207        kernel_size: _size_2_t,
208        stride: _size_2_t = 1,
209        padding: Union[str, _size_2_t] = 0,
210        dilation: _size_2_t = 1,
211        groups: int = 1,
212        bias: bool = True,
213        padding_mode: str = "zeros",
214        qconfig=None,
215        device=None,
216        dtype=None,
217    ) -> None:
218        kernel_size_ = _pair(kernel_size)
219        stride_ = _pair(stride)
220        padding_ = padding if isinstance(padding, str) else _pair(padding)
221        dilation_ = _pair(dilation)
222        super().__init__(
223            in_channels,
224            out_channels,
225            kernel_size_,
226            stride=stride_,
227            padding=padding_,
228            dilation=dilation_,
229            transposed=False,
230            output_padding=_pair(0),
231            groups=groups,
232            bias=bias,
233            padding_mode=padding_mode,
234            qconfig=qconfig,
235            device=device,
236            dtype=dtype,
237        )
238
239    def forward(self, input):
240        return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
241
242    @classmethod
243    def from_float(cls, mod, use_precomputed_fake_quant=False):
244        return super().from_float(
245            cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
246        )
247
248
249class Conv3d(_ConvNd, nn.Conv3d):
250    r"""
251    A Conv3d module attached with FakeQuantize modules for weight,
252    used for quantization aware training.
253
254    We adopt the same interface as `torch.nn.Conv3d`, please see
255    https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d
256    for documentation.
257
258    Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
259    default.
260
261    Attributes:
262        weight_fake_quant: fake quant module for weight
263    """
264    _FLOAT_MODULE = nn.Conv3d
265    _FLOAT_CONV_MODULE = nn.Conv3d
266
267    def __init__(
268        self,
269        in_channels: int,
270        out_channels: int,
271        kernel_size: _size_3_t,
272        stride: _size_3_t = 1,
273        padding: Union[str, _size_3_t] = 0,
274        dilation: _size_3_t = 1,
275        groups: int = 1,
276        bias: bool = True,
277        padding_mode: str = "zeros",
278        qconfig=None,
279        device=None,
280        dtype=None,
281    ) -> None:
282        kernel_size_ = _triple(kernel_size)
283        stride_ = _triple(stride)
284        padding_ = padding if isinstance(padding, str) else _triple(padding)
285        dilation_ = _triple(dilation)
286        super().__init__(
287            in_channels,
288            out_channels,
289            kernel_size_,
290            stride=stride_,
291            padding=padding_,
292            dilation=dilation_,
293            transposed=False,
294            output_padding=_triple(0),
295            groups=groups,
296            bias=bias,
297            padding_mode=padding_mode,
298            qconfig=qconfig,
299            device=device,
300            dtype=dtype,
301        )
302
303    def forward(self, input):
304        return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
305
306    @classmethod
307    def from_float(cls, mod, use_precomputed_fake_quant=False):
308        return super().from_float(
309            cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
310        )
311