xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/dynamic/modules/conv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""Dynamically quantized convolution modules."""
3
4import warnings
5
6import torch
7import torch.ao.nn.quantized as nnq
8import torch.nn as nn
9import torch.nn.functional as F
10from torch import Tensor
11from torch._ops import ops
12from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding
13from torch.nn.common_types import _size_1_t
14from torch.nn.modules.utils import _pair, _single, _triple
15
16
17__all__ = [
18    "Conv1d",
19    "Conv2d",
20    "Conv3d",
21    "ConvTranspose1d",
22    "ConvTranspose2d",
23    "ConvTranspose3d",
24]
25
26
27class Conv1d(nnq.Conv1d):
28    r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
29
30    For details on input arguments, parameters, and implementation see
31    :class:`~torch.nn.Conv1d` and :class:`~torch.ao.nn.quantized.dynamic.Conv1d` and
32
33    Attributes:
34        weight (Tensor):     packed tensor derived from the learnable weight
35                             parameter.
36        scale (Tensor):      scalar for the output scale
37        zero_point (Tensor): scalar for the output zero point
38
39    See :class:`~torch.nn.Conv1d` for other attributes.
40
41    Examples::
42
43        >>> # xdoctest: +SKIP
44        >>> m = nn.quantized.dynamic.Conv1d(16, 33, 3, stride=2)
45        >>> input = torch.randn(20, 16, 100)
46        >>> output = m(input)
47
48    """
49
50    _FLOAT_MODULE = nn.Conv1d
51    _NNIQAT_CONV_BN_MODULE = None  # type: ignore[assignment]
52    _NNI_CONV_RELU_MODULE = None  # type: ignore[assignment]
53
54    def __init__(
55        self,
56        in_channels: int,
57        out_channels: int,
58        kernel_size: _size_1_t,
59        stride: _size_1_t = 1,
60        padding: _size_1_t = 0,
61        dilation: _size_1_t = 1,
62        groups: int = 1,
63        bias: bool = True,
64        padding_mode: str = "zeros",
65        device=None,
66        dtype=None,
67        reduce_range=True,
68    ):
69        warnings.warn(
70            f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended"  # noqa: B950
71        )
72        factory_kwargs = {"device": device, "dtype": dtype}
73        kernel_size = _single(kernel_size)
74        stride = _single(stride)
75        padding = padding if isinstance(padding, str) else _single(padding)
76        dilation = _single(dilation)
77
78        super().__init__(
79            in_channels,
80            out_channels,
81            kernel_size,
82            stride,
83            padding,
84            dilation,
85            groups,
86            bias,
87            padding_mode,
88            **factory_kwargs,
89        )
90
91    def _get_name(self):
92        return "DynamicQuantizedConv1d"
93
94    def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
95        # Temporarily using len(shape) instead of ndim due to JIT issue
96        # https://github.com/pytorch/pytorch/issues/23890
97        if len(input.shape) != 3:
98            raise ValueError("Input shape must be `(N, C, L)`!")
99        if self.padding_mode != "zeros":
100            # Padding in Conv1d is stored as (p, p), need to get (p,)
101            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
102            input = F.pad(
103                input, _reversed_padding_repeated_twice, mode=self.padding_mode
104            )
105        return ops.quantized.conv1d_dynamic(input, self._packed_params, reduce_range)
106
107
108class Conv2d(nnq.Conv2d):
109    r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
110
111    For details on input arguments, parameters, and implementation see
112    :class:`~torch.nn.Conv2d` and :class:`~torch.ao.nn.quantized.dynamic.Conv2d` and
113
114    Attributes:
115        weight (Tensor):     packed tensor derived from the learnable weight
116                             parameter.
117        scale (Tensor):      scalar for the output scale
118        zero_point (Tensor): scalar for the output zero point
119
120    See :class:`~torch.nn.Conv2d` for other attributes.
121
122    Examples::
123
124        >>> # xdoctest: +SKIP
125        >>> # With square kernels and equal stride
126        >>> m = nn.quantized.dynamic.Conv2d(16, 33, 3, stride=2)
127        >>> # non-square kernels and unequal stride and with padding
128        >>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
129        >>> # non-square kernels and unequal stride and with padding and dilation
130        >>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
131        >>> input = torch.randn(20, 16, 50, 100)
132        >>> output = m(input)
133
134    """
135    _FLOAT_MODULE = nn.Conv2d
136    _NNIQAT_CONV_BN_MODULE = None  # type: ignore[assignment]
137    _NNI_CONV_RELU_MODULE = None  # type: ignore[assignment]
138
139    def __init__(
140        self,
141        in_channels,
142        out_channels,
143        kernel_size,
144        stride=1,
145        padding=0,
146        dilation=1,
147        groups=1,
148        bias=True,
149        padding_mode="zeros",
150        device=None,
151        dtype=None,
152    ):
153        warnings.warn(
154            f"The current implementation of the {self._get_name()} module "
155            "has poor numerical accuracy and its use is not recommended"
156        )
157        factory_kwargs = {"device": device, "dtype": dtype}
158        kernel_size = _pair(kernel_size)
159        stride = _pair(stride)
160        padding = _pair(padding)
161        dilation = _pair(dilation)
162
163        super().__init__(
164            in_channels,
165            out_channels,
166            kernel_size,
167            stride,
168            padding,
169            dilation,
170            groups,
171            bias,
172            padding_mode,
173            **factory_kwargs,
174        )
175
176    def _get_name(self):
177        return "DynamicQuantizedConv2d"
178
179    def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
180        # Temporarily using len(shape) instead of ndim due to JIT issue
181        # https://github.com/pytorch/pytorch/issues/23890
182        if len(input.shape) != 4:
183            raise ValueError("Input shape must be `(N, C, H, W)`!")
184        if self.padding_mode != "zeros":
185            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
186            input = F.pad(
187                input, _reversed_padding_repeated_twice, mode=self.padding_mode
188            )
189        return ops.quantized.conv2d_dynamic(input, self._packed_params, reduce_range)
190
191
192class Conv3d(nnq.Conv3d):
193    r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
194
195    For details on input arguments, parameters, and implementation see
196    :class:`~torch.nn.Conv3d` and :class:`~torch.ao.nn.quantized.dynamic.Conv3d` and
197
198    Attributes:
199        weight (Tensor):     packed tensor derived from the learnable weight
200                             parameter.
201        scale (Tensor):      scalar for the output scale
202        zero_point (Tensor): scalar for the output zero point
203
204    See :class:`~torch.nn.Conv3d` for other attributes.
205
206    Examples::
207
208        >>> # xdoctest: +SKIP
209        >>> # With square kernels and equal stride
210        >>> m = nn.quantized.dynamic.Conv3d(16, 33, 3, stride=2)
211        >>> # non-square kernels and unequal stride and with padding
212        >>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
213        >>> # non-square kernels and unequal stride and with padding and dilation
214        >>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
215        >>> input = torch.randn(20, 16, 56, 56, 56)
216        >>> output = m(input)
217
218    """
219    _FLOAT_MODULE = nn.Conv3d
220    _NNIQAT_CONV_BN_MODULE = None  # type: ignore[assignment]
221    _NNI_CONV_RELU_MODULE = None  # type: ignore[assignment]
222
223    def __init__(
224        self,
225        in_channels,
226        out_channels,
227        kernel_size,
228        stride=1,
229        padding=0,
230        dilation=1,
231        groups=1,
232        bias=True,
233        padding_mode="zeros",
234        device=None,
235        dtype=None,
236    ):
237        warnings.warn(
238            f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended"  # noqa: B950
239        )
240        assert padding_mode != "reflect", "Conv3d does not support reflection padding"
241        factory_kwargs = {"device": device, "dtype": dtype}
242        kernel_size = _triple(kernel_size)
243        stride = _triple(stride)
244        padding = _triple(padding)
245        dilation = _triple(dilation)
246        super()._init(
247            in_channels,
248            out_channels,
249            kernel_size,
250            stride,
251            padding,
252            dilation,
253            False,
254            _triple(0),
255            groups,
256            bias,
257            padding_mode,
258            **factory_kwargs,
259        )
260
261    def _get_name(self):
262        return "DynamicQuantizedConv3d"
263
264    def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
265        # Temporarily using len(shape) instead of ndim due to JIT issue
266        # https://github.com/pytorch/pytorch/issues/23890
267        if len(input.shape) != 5:
268            raise ValueError("Input shape must be `(N, C, D, H, W)`!")
269        if self.padding_mode != "zeros":
270            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
271            input = F.pad(
272                input, _reversed_padding_repeated_twice, mode=self.padding_mode
273            )
274        return ops.quantized.conv3d_dynamic(input, self._packed_params, reduce_range)
275
276
277class ConvTranspose1d(nnq.ConvTranspose1d):
278    r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
279
280    For details on input arguments, parameters, and implementation see
281    :class:`~torch.nn.ConvTranspose1d`.
282
283    For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv1d`
284
285    Attributes:
286        weight (Tensor):     packed tensor derived from the learnable weight
287                             parameter.
288        scale (Tensor):      scalar for the output scale
289        zero_point (Tensor): scalar for the output zero point
290    See :class:`~torch.nn.ConvTranspose1d` for other attributes.
291
292    Examples::
293
294        >>> # xdoctest: +SKIP
295        >>> # With square kernels and equal stride
296        >>> m = nndq.ConvTranspose1d(16, 33, 3, stride=2)
297        >>> # non-square kernels and unequal stride and with padding
298        >>> m = nndq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
299        >>> output = m(input)
300        >>> # exact output size can be also specified as an argument
301        >>> downsample = nndq.Conv1d(16, 16, 3, stride=2, padding=1)
302        >>> upsample = nndq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
303        >>> h = downsample(input)
304        >>> h.size()
305        torch.Size([1, 16, 6])
306        >>> output = upsample(h, output_size=input.size())
307        >>> output.size()
308        torch.Size([1, 16, 12])
309    """
310
311    _FLOAT_MODULE = nn.ConvTranspose1d
312
313    def __init__(
314        self,
315        in_channels,
316        out_channels,
317        kernel_size,
318        stride=1,
319        padding=0,
320        output_padding=0,
321        groups=1,
322        bias=True,
323        dilation=1,
324        padding_mode="zeros",
325        device=None,
326        dtype=None,
327    ):
328        warnings.warn(
329            f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended"  # noqa: B950
330        )
331        factory_kwargs = {"device": device, "dtype": dtype}
332        super().__init__(
333            in_channels,
334            out_channels,
335            kernel_size,
336            stride,
337            padding,
338            output_padding,
339            groups,
340            bias,
341            dilation,
342            padding_mode,
343            **factory_kwargs,
344        )
345
346    def _get_name(self):
347        return "DynamicQuantizedConvTranspose1d"
348
349    def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
350        # Temporarily using len(shape) instead of ndim due to JIT issue
351        # https://github.com/pytorch/pytorch/issues/23890
352        if len(input.shape) != 3:
353            raise ValueError("Input shape must be `(N, C, L)`!")
354        return torch.ops.quantized.conv_transpose1d_dynamic(
355            input, self._packed_params, reduce_range
356        )
357
358
359class ConvTranspose2d(nnq.ConvTranspose2d):
360    r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
361
362    For details on input arguments, parameters, and implementation see
363    :class:`~torch.nn.ConvTranspose2d`.
364
365    For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv2d`
366
367    Attributes:
368        weight (Tensor):     packed tensor derived from the learnable weight
369                             parameter.
370        scale (Tensor):      scalar for the output scale
371        zero_point (Tensor): scalar for the output zero point
372    See :class:`~torch.nn.ConvTranspose2d` for other attributes.
373
374    Examples::
375
376        >>> # xdoctest: +SKIP
377        >>> # With square kernels and equal stride
378        >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
379        >>> # non-square kernels and unequal stride and with padding
380        >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
381        >>> output = m(input)
382        >>> # exact output size can be also specified as an argument
383        >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
384        >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
385        >>> h = downsample(input)
386        >>> h.size()
387        torch.Size([1, 16, 6, 6])
388        >>> output = upsample(h, output_size=input.size())
389        >>> output.size()
390        torch.Size([1, 16, 12, 12])
391    """
392
393    _FLOAT_MODULE = nn.ConvTranspose2d
394
395    def __init__(
396        self,
397        in_channels,
398        out_channels,
399        kernel_size,
400        stride=1,
401        padding=0,
402        output_padding=0,
403        groups=1,
404        bias=True,
405        dilation=1,
406        padding_mode="zeros",
407        device=None,
408        dtype=None,
409    ):
410        warnings.warn(
411            f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended"  # noqa: B950
412        )
413        factory_kwargs = {"device": device, "dtype": dtype}
414        super().__init__(
415            in_channels,
416            out_channels,
417            kernel_size,
418            stride,
419            padding,
420            output_padding,
421            groups,
422            bias,
423            dilation,
424            padding_mode,
425            **factory_kwargs,
426        )
427
428    def _get_name(self):
429        return "DynamicQuantizedConvTranspose2d"
430
431    def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
432        # Temporarily using len(shape) instead of ndim due to JIT issue
433        # https://github.com/pytorch/pytorch/issues/23890
434        if len(input.shape) != 4:
435            raise ValueError("Input shape must be `(N, C, H, W)`!")
436        return ops.quantized.conv_transpose2d_dynamic(
437            input, self._packed_params, reduce_range
438        )
439
440
441class ConvTranspose3d(nnq.ConvTranspose3d):
442    r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
443
444    For details on input arguments, parameters, and implementation see
445    :class:`~torch.nn.ConvTranspose3d`.
446
447    For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv3d`
448
449    Attributes:
450        weight (Tensor):     packed tensor derived from the learnable weight
451                             parameter.
452        scale (Tensor):      scalar for the output scale
453        zero_point (Tensor): scalar for the output zero point
454    See :class:`~torch.nn.ConvTranspose3d` for other attributes.
455
456    Examples::
457
458        >>> # xdoctest: +SKIP
459        >>> # With cubic kernels and equal stride
460        >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
461        >>> # non-cubic kernels and unequal stride and with padding
462        >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
463        >>> output = m(input)
464        >>> # exact output size can be also specified as an argument
465        >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
466        >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
467        >>> h = downsample(input)
468        >>> h.size()
469        torch.Size([1, 16, 6, 6, 6])
470        >>> output = upsample(h, output_size=input.size())
471        >>> output.size()
472        torch.Size([1, 16, 12, 12, 12])
473    """
474
475    _FLOAT_MODULE = nn.ConvTranspose3d
476
477    def __init__(
478        self,
479        in_channels,
480        out_channels,
481        kernel_size,
482        stride=1,
483        padding=0,
484        output_padding=0,
485        groups=1,
486        bias=True,
487        dilation=1,
488        padding_mode="zeros",
489        device=None,
490        dtype=None,
491    ):
492        warnings.warn(
493            f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended"  # noqa: B950
494        )
495        factory_kwargs = {"device": device, "dtype": dtype}
496        super().__init__(
497            in_channels,
498            out_channels,
499            kernel_size,
500            stride,
501            padding,
502            output_padding,
503            groups,
504            bias,
505            dilation,
506            padding_mode,
507            **factory_kwargs,
508        )
509
510    def _get_name(self):
511        return "DynamicQuantizedConvTranspose3d"
512
513    def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
514        # Temporarily using len(shape) instead of ndim due to JIT issue
515        # https://github.com/pytorch/pytorch/issues/23890
516        if len(input.shape) != 5:
517            raise ValueError("Input shape must be `(N, C, T, H, W)`!")
518        return ops.quantized.conv_transpose3d_dynamic(
519            input, self._packed_params, reduce_range
520        )
521