xref: /aosp_15_r20/external/pytorch/torch/nn/modules/conv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import math
3
4import torch
5from torch import Tensor
6from torch.nn.parameter import Parameter, UninitializedParameter
7from .. import functional as F
8from .. import init
9from .lazy import LazyModuleMixin
10from .module import Module
11from .utils import _single, _pair, _triple, _reverse_repeat_tuple
12from torch._torch_docs import reproducibility_notes
13
14from ..common_types import _size_1_t, _size_2_t, _size_3_t
15from typing import Optional, List, Tuple, Union
16from typing_extensions import deprecated
17
18__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d',
19           'LazyConv1d', 'LazyConv2d', 'LazyConv3d', 'LazyConvTranspose1d', 'LazyConvTranspose2d',
20           'LazyConvTranspose3d']
21
22convolution_notes = \
23    {"groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs.
24      :attr:`in_channels` and :attr:`out_channels` must both be divisible by
25      :attr:`groups`. For example,
26
27        * At groups=1, all inputs are convolved to all outputs.
28        * At groups=2, the operation becomes equivalent to having two conv
29          layers side by side, each seeing half the input channels
30          and producing half the output channels, and both subsequently
31          concatenated.
32        * At groups= :attr:`in_channels`, each input channel is convolved with
33          its own set of filters (of size
34          :math:`\frac{\text{out\_channels}}{\text{in\_channels}}`).""",
35
36        "depthwise_separable_note": r"""When `groups == in_channels` and `out_channels == K * in_channels`,
37        where `K` is a positive integer, this operation is also known as a "depthwise convolution".
38
39        In other words, for an input of size :math:`(N, C_{in}, L_{in})`,
40        a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments
41        :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`."""}  # noqa: B950
42
43
44class _ConvNd(Module):
45
46    __constants__ = ['stride', 'padding', 'dilation', 'groups',
47                     'padding_mode', 'output_padding', 'in_channels',
48                     'out_channels', 'kernel_size']
49    __annotations__ = {'bias': Optional[torch.Tensor]}
50
51    def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:  # type: ignore[empty-body]
52        ...
53
54    in_channels: int
55    _reversed_padding_repeated_twice: List[int]
56    out_channels: int
57    kernel_size: Tuple[int, ...]
58    stride: Tuple[int, ...]
59    padding: Union[str, Tuple[int, ...]]
60    dilation: Tuple[int, ...]
61    transposed: bool
62    output_padding: Tuple[int, ...]
63    groups: int
64    padding_mode: str
65    weight: Tensor
66    bias: Optional[Tensor]
67
68    def __init__(self,
69                 in_channels: int,
70                 out_channels: int,
71                 kernel_size: Tuple[int, ...],
72                 stride: Tuple[int, ...],
73                 padding: Tuple[int, ...],
74                 dilation: Tuple[int, ...],
75                 transposed: bool,
76                 output_padding: Tuple[int, ...],
77                 groups: int,
78                 bias: bool,
79                 padding_mode: str,
80                 device=None,
81                 dtype=None) -> None:
82        factory_kwargs = {'device': device, 'dtype': dtype}
83        super().__init__()
84        if groups <= 0:
85            raise ValueError('groups must be a positive integer')
86        if in_channels % groups != 0:
87            raise ValueError('in_channels must be divisible by groups')
88        if out_channels % groups != 0:
89            raise ValueError('out_channels must be divisible by groups')
90        valid_padding_strings = {'same', 'valid'}
91        if isinstance(padding, str):
92            if padding not in valid_padding_strings:
93                raise ValueError(
94                    f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}")
95            if padding == 'same' and any(s != 1 for s in stride):
96                raise ValueError("padding='same' is not supported for strided convolutions")
97
98        valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
99        if padding_mode not in valid_padding_modes:
100            raise ValueError(f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'")
101        self.in_channels = in_channels
102        self.out_channels = out_channels
103        self.kernel_size = kernel_size
104        self.stride = stride
105        self.padding = padding
106        self.dilation = dilation
107        self.transposed = transposed
108        self.output_padding = output_padding
109        self.groups = groups
110        self.padding_mode = padding_mode
111        # `_reversed_padding_repeated_twice` is the padding to be passed to
112        # `F.pad` if needed (e.g., for non-zero padding types that are
113        # implemented as two ops: padding + conv). `F.pad` accepts paddings in
114        # reverse order than the dimension.
115        if isinstance(self.padding, str):
116            self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size)
117            if padding == 'same':
118                for d, k, i in zip(dilation, kernel_size,
119                                   range(len(kernel_size) - 1, -1, -1)):
120                    total_padding = d * (k - 1)
121                    left_pad = total_padding // 2
122                    self._reversed_padding_repeated_twice[2 * i] = left_pad
123                    self._reversed_padding_repeated_twice[2 * i + 1] = (
124                        total_padding - left_pad)
125        else:
126            self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2)
127
128        if transposed:
129            self.weight = Parameter(torch.empty(
130                (in_channels, out_channels // groups, *kernel_size), **factory_kwargs))
131        else:
132            self.weight = Parameter(torch.empty(
133                (out_channels, in_channels // groups, *kernel_size), **factory_kwargs))
134        if bias:
135            self.bias = Parameter(torch.empty(out_channels, **factory_kwargs))
136        else:
137            self.register_parameter('bias', None)
138
139        self.reset_parameters()
140
141    def reset_parameters(self) -> None:
142        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
143        # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size)
144        # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573
145        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
146        if self.bias is not None:
147            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
148            if fan_in != 0:
149                bound = 1 / math.sqrt(fan_in)
150                init.uniform_(self.bias, -bound, bound)
151
152    def extra_repr(self):
153        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
154             ', stride={stride}')
155        if self.padding != (0,) * len(self.padding):
156            s += ', padding={padding}'
157        if self.dilation != (1,) * len(self.dilation):
158            s += ', dilation={dilation}'
159        if self.output_padding != (0,) * len(self.output_padding):
160            s += ', output_padding={output_padding}'
161        if self.groups != 1:
162            s += ', groups={groups}'
163        if self.bias is None:
164            s += ', bias=False'
165        if self.padding_mode != 'zeros':
166            s += ', padding_mode={padding_mode}'
167        return s.format(**self.__dict__)
168
169    def __setstate__(self, state):
170        super().__setstate__(state)
171        if not hasattr(self, 'padding_mode'):
172            self.padding_mode = 'zeros'
173
174
175class Conv1d(_ConvNd):
176    __doc__ = r"""Applies a 1D convolution over an input signal composed of several input
177    planes.
178
179    In the simplest case, the output value of the layer with input size
180    :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be
181    precisely described as:
182
183    .. math::
184        \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
185        \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k)
186        \star \text{input}(N_i, k)
187
188    where :math:`\star` is the valid `cross-correlation`_ operator,
189    :math:`N` is a batch size, :math:`C` denotes a number of channels,
190    :math:`L` is a length of signal sequence.
191    """ + r"""
192
193    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
194
195    On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
196
197    * :attr:`stride` controls the stride for the cross-correlation, a single
198      number or a one-element tuple.
199
200    * :attr:`padding` controls the amount of padding applied to the input. It
201      can be either a string {{'valid', 'same'}} or a tuple of ints giving the
202      amount of implicit padding applied on both sides.
203""" + """
204    * :attr:`dilation` controls the spacing between the kernel points; also
205      known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_
206      has a nice visualization of what :attr:`dilation` does.
207""" + r"""
208    {groups_note}
209
210    Note:
211        {depthwise_separable_note}
212    Note:
213        {cudnn_reproducibility_note}
214
215    Note:
216        ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
217        the input so the output has the shape as the input. However, this mode
218        doesn't support any stride values other than 1.
219
220    Note:
221        This module supports complex data types i.e. ``complex32, complex64, complex128``.
222
223    Args:
224        in_channels (int): Number of channels in the input image
225        out_channels (int): Number of channels produced by the convolution
226        kernel_size (int or tuple): Size of the convolving kernel
227        stride (int or tuple, optional): Stride of the convolution. Default: 1
228        padding (int, tuple or str, optional): Padding added to both sides of
229            the input. Default: 0
230        padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
231            ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
232        dilation (int or tuple, optional): Spacing between kernel
233            elements. Default: 1
234        groups (int, optional): Number of blocked connections from input
235            channels to output channels. Default: 1
236        bias (bool, optional): If ``True``, adds a learnable bias to the
237            output. Default: ``True``
238
239    """.format(**reproducibility_notes, **convolution_notes) + r"""
240
241    Shape:
242        - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
243        - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
244
245          .. math::
246              L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation}
247                        \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
248
249    Attributes:
250        weight (Tensor): the learnable weights of the module of shape
251            :math:`(\text{out\_channels},
252            \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`.
253            The values of these weights are sampled from
254            :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
255            :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
256        bias (Tensor):   the learnable bias of the module of shape
257            (out_channels). If :attr:`bias` is ``True``, then the values of these weights are
258            sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
259            :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
260
261    Examples::
262
263        >>> m = nn.Conv1d(16, 33, 3, stride=2)
264        >>> input = torch.randn(20, 16, 50)
265        >>> output = m(input)
266
267    .. _cross-correlation:
268        https://en.wikipedia.org/wiki/Cross-correlation
269
270    .. _link:
271        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
272    """
273
274    def __init__(
275        self,
276        in_channels: int,
277        out_channels: int,
278        kernel_size: _size_1_t,
279        stride: _size_1_t = 1,
280        padding: Union[str, _size_1_t] = 0,
281        dilation: _size_1_t = 1,
282        groups: int = 1,
283        bias: bool = True,
284        padding_mode: str = 'zeros',  # TODO: refine this type
285        device=None,
286        dtype=None
287    ) -> None:
288        factory_kwargs = {'device': device, 'dtype': dtype}
289        # we create new variables below to make mypy happy since kernel_size has
290        # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int]
291        kernel_size_ = _single(kernel_size)
292        stride_ = _single(stride)
293        padding_ = padding if isinstance(padding, str) else _single(padding)
294        dilation_ = _single(dilation)
295        super().__init__(
296            in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
297            False, _single(0), groups, bias, padding_mode, **factory_kwargs)
298
299    def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
300        if self.padding_mode != 'zeros':
301            return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
302                            weight, bias, self.stride,
303                            _single(0), self.dilation, self.groups)
304        return F.conv1d(input, weight, bias, self.stride,
305                        self.padding, self.dilation, self.groups)
306
307    def forward(self, input: Tensor) -> Tensor:
308        return self._conv_forward(input, self.weight, self.bias)
309
310
311class Conv2d(_ConvNd):
312    __doc__ = r"""Applies a 2D convolution over an input signal composed of several input
313    planes.
314
315    In the simplest case, the output value of the layer with input size
316    :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
317    can be precisely described as:
318
319    .. math::
320        \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
321        \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
322
323
324    where :math:`\star` is the valid 2D `cross-correlation`_ operator,
325    :math:`N` is a batch size, :math:`C` denotes a number of channels,
326    :math:`H` is a height of input planes in pixels, and :math:`W` is
327    width in pixels.
328    """ + r"""
329
330    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
331
332    On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
333
334    * :attr:`stride` controls the stride for the cross-correlation, a single
335      number or a tuple.
336
337    * :attr:`padding` controls the amount of padding applied to the input. It
338      can be either a string {{'valid', 'same'}} or an int / a tuple of ints giving the
339      amount of implicit padding applied on both sides.
340""" + """
341    * :attr:`dilation` controls the spacing between the kernel points; also
342      known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_
343      has a nice visualization of what :attr:`dilation` does.
344""" + r"""
345
346    {groups_note}
347
348    The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
349
350        - a single ``int`` -- in which case the same value is used for the height and width dimension
351        - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
352          and the second `int` for the width dimension
353
354    Note:
355        {depthwise_separable_note}
356
357    Note:
358        {cudnn_reproducibility_note}
359
360    Note:
361        ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
362        the input so the output has the shape as the input. However, this mode
363        doesn't support any stride values other than 1.
364
365    Note:
366        This module supports complex data types i.e. ``complex32, complex64, complex128``.
367
368    Args:
369        in_channels (int): Number of channels in the input image
370        out_channels (int): Number of channels produced by the convolution
371        kernel_size (int or tuple): Size of the convolving kernel
372        stride (int or tuple, optional): Stride of the convolution. Default: 1
373        padding (int, tuple or str, optional): Padding added to all four sides of
374            the input. Default: 0
375        padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
376            ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
377        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
378        groups (int, optional): Number of blocked connections from input
379            channels to output channels. Default: 1
380        bias (bool, optional): If ``True``, adds a learnable bias to the
381            output. Default: ``True``
382    """.format(**reproducibility_notes, **convolution_notes) + r"""
383
384    Shape:
385        - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
386        - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
387
388          .. math::
389              H_{out} = \left\lfloor\frac{H_{in}  + 2 \times \text{padding}[0] - \text{dilation}[0]
390                        \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
391
392          .. math::
393              W_{out} = \left\lfloor\frac{W_{in}  + 2 \times \text{padding}[1] - \text{dilation}[1]
394                        \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
395
396    Attributes:
397        weight (Tensor): the learnable weights of the module of shape
398            :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
399            :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
400            The values of these weights are sampled from
401            :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
402            :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
403        bias (Tensor):   the learnable bias of the module of shape
404            (out_channels). If :attr:`bias` is ``True``,
405            then the values of these weights are
406            sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
407            :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
408
409    Examples:
410
411        >>> # With square kernels and equal stride
412        >>> m = nn.Conv2d(16, 33, 3, stride=2)
413        >>> # non-square kernels and unequal stride and with padding
414        >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
415        >>> # non-square kernels and unequal stride and with padding and dilation
416        >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
417        >>> input = torch.randn(20, 16, 50, 100)
418        >>> output = m(input)
419
420    .. _cross-correlation:
421        https://en.wikipedia.org/wiki/Cross-correlation
422
423    .. _link:
424        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
425    """
426
427    def __init__(
428        self,
429        in_channels: int,
430        out_channels: int,
431        kernel_size: _size_2_t,
432        stride: _size_2_t = 1,
433        padding: Union[str, _size_2_t] = 0,
434        dilation: _size_2_t = 1,
435        groups: int = 1,
436        bias: bool = True,
437        padding_mode: str = 'zeros',  # TODO: refine this type
438        device=None,
439        dtype=None
440    ) -> None:
441        factory_kwargs = {'device': device, 'dtype': dtype}
442        kernel_size_ = _pair(kernel_size)
443        stride_ = _pair(stride)
444        padding_ = padding if isinstance(padding, str) else _pair(padding)
445        dilation_ = _pair(dilation)
446        super().__init__(
447            in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
448            False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
449
450    def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
451        if self.padding_mode != 'zeros':
452            return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
453                            weight, bias, self.stride,
454                            _pair(0), self.dilation, self.groups)
455        return F.conv2d(input, weight, bias, self.stride,
456                        self.padding, self.dilation, self.groups)
457
458    def forward(self, input: Tensor) -> Tensor:
459        return self._conv_forward(input, self.weight, self.bias)
460
461class Conv3d(_ConvNd):
462    __doc__ = r"""Applies a 3D convolution over an input signal composed of several input
463    planes.
464
465    In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)`
466    and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as:
467
468    .. math::
469        out(N_i, C_{out_j}) = bias(C_{out_j}) +
470                                \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k)
471
472    where :math:`\star` is the valid 3D `cross-correlation`_ operator
473    """ + r"""
474
475    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
476
477    On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
478
479    * :attr:`stride` controls the stride for the cross-correlation.
480
481    * :attr:`padding` controls the amount of padding applied to the input. It
482      can be either a string {{'valid', 'same'}} or a tuple of ints giving the
483      amount of implicit padding applied on both sides.
484""" + """
485    * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
486      It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
487""" + r"""
488
489    {groups_note}
490
491    The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
492
493        - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
494        - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
495          the second `int` for the height dimension and the third `int` for the width dimension
496
497    Note:
498        {depthwise_separable_note}
499
500    Note:
501        {cudnn_reproducibility_note}
502
503    Note:
504        ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
505        the input so the output has the shape as the input. However, this mode
506        doesn't support any stride values other than 1.
507
508    Note:
509        This module supports complex data types i.e. ``complex32, complex64, complex128``.
510
511    Args:
512        in_channels (int): Number of channels in the input image
513        out_channels (int): Number of channels produced by the convolution
514        kernel_size (int or tuple): Size of the convolving kernel
515        stride (int or tuple, optional): Stride of the convolution. Default: 1
516        padding (int, tuple or str, optional): Padding added to all six sides of
517            the input. Default: 0
518        padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
519        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
520        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
521        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
522    """.format(**reproducibility_notes, **convolution_notes) + r"""
523
524    Shape:
525        - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
526        - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or :math:`(C_{out}, D_{out}, H_{out}, W_{out})`,
527          where
528
529          .. math::
530              D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
531                    \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
532
533          .. math::
534              H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
535                    \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
536
537          .. math::
538              W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2]
539                    \times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
540
541    Attributes:
542        weight (Tensor): the learnable weights of the module of shape
543                         :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
544                         :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
545                         The values of these weights are sampled from
546                         :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
547                         :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
548        bias (Tensor):   the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
549                         then the values of these weights are
550                         sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
551                         :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
552
553    Examples::
554
555        >>> # With square kernels and equal stride
556        >>> m = nn.Conv3d(16, 33, 3, stride=2)
557        >>> # non-square kernels and unequal stride and with padding
558        >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
559        >>> input = torch.randn(20, 16, 10, 50, 100)
560        >>> output = m(input)
561
562    .. _cross-correlation:
563        https://en.wikipedia.org/wiki/Cross-correlation
564
565    .. _link:
566        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
567    """
568
569    def __init__(
570        self,
571        in_channels: int,
572        out_channels: int,
573        kernel_size: _size_3_t,
574        stride: _size_3_t = 1,
575        padding: Union[str, _size_3_t] = 0,
576        dilation: _size_3_t = 1,
577        groups: int = 1,
578        bias: bool = True,
579        padding_mode: str = 'zeros',
580        device=None,
581        dtype=None
582    ) -> None:
583        factory_kwargs = {'device': device, 'dtype': dtype}
584        kernel_size_ = _triple(kernel_size)
585        stride_ = _triple(stride)
586        padding_ = padding if isinstance(padding, str) else _triple(padding)
587        dilation_ = _triple(dilation)
588        super().__init__(
589            in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
590            False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
591
592    def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
593        if self.padding_mode != "zeros":
594            return F.conv3d(
595                F.pad(
596                    input, self._reversed_padding_repeated_twice, mode=self.padding_mode
597                ),
598                weight,
599                bias,
600                self.stride,
601                _triple(0),
602                self.dilation,
603                self.groups,
604            )
605        return F.conv3d(
606            input, weight, bias, self.stride, self.padding, self.dilation, self.groups
607        )
608
609    def forward(self, input: Tensor) -> Tensor:
610        return self._conv_forward(input, self.weight, self.bias)
611
612
613class _ConvTransposeNd(_ConvNd):
614    def __init__(self, in_channels, out_channels, kernel_size, stride,
615                 padding, dilation, transposed, output_padding,
616                 groups, bias, padding_mode, device=None, dtype=None) -> None:
617        if padding_mode != 'zeros':
618            raise ValueError(f'Only "zeros" padding mode is supported for {self.__class__.__name__}')
619
620        factory_kwargs = {'device': device, 'dtype': dtype}
621        super().__init__(
622            in_channels, out_channels, kernel_size, stride,
623            padding, dilation, transposed, output_padding,
624            groups, bias, padding_mode, **factory_kwargs)
625
626    # dilation being an optional parameter is for backwards
627    # compatibility
628    def _output_padding(self, input: Tensor, output_size: Optional[List[int]],
629                        stride: List[int], padding: List[int], kernel_size: List[int],
630                        num_spatial_dims: int, dilation: Optional[List[int]] = None) -> List[int]:
631        if output_size is None:
632            ret = _single(self.output_padding)  # converting to list if was not already
633        else:
634            has_batch_dim = input.dim() == num_spatial_dims + 2
635            num_non_spatial_dims = 2 if has_batch_dim else 1
636            if len(output_size) == num_non_spatial_dims + num_spatial_dims:
637                output_size = output_size[num_non_spatial_dims:]
638            if len(output_size) != num_spatial_dims:
639                raise ValueError(
640                    f"ConvTranspose{num_spatial_dims}D: for {input.dim()}D input, output_size must have {num_spatial_dims} "
641                    f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})")
642
643            min_sizes = torch.jit.annotate(List[int], [])
644            max_sizes = torch.jit.annotate(List[int], [])
645            for d in range(num_spatial_dims):
646                dim_size = ((input.size(d + num_non_spatial_dims) - 1) * stride[d] -
647                            2 * padding[d] +
648                            (dilation[d] if dilation is not None else 1) * (kernel_size[d] - 1) + 1)
649                min_sizes.append(dim_size)
650                max_sizes.append(min_sizes[d] + stride[d] - 1)
651
652            for i in range(len(output_size)):
653                size = output_size[i]
654                min_size = min_sizes[i]
655                max_size = max_sizes[i]
656                if size < min_size or size > max_size:
657                    raise ValueError(
658                        f"requested an output size of {output_size}, but valid sizes range "
659                        f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})")
660
661            res = torch.jit.annotate(List[int], [])
662            for d in range(num_spatial_dims):
663                res.append(output_size[d] - min_sizes[d])
664
665            ret = res
666        return ret
667
668
669class ConvTranspose1d(_ConvTransposeNd):
670    __doc__ = r"""Applies a 1D transposed convolution operator over an input image
671    composed of several input planes.
672
673    This module can be seen as the gradient of Conv1d with respect to its input.
674    It is also known as a fractionally-strided convolution or
675    a deconvolution (although it is not an actual deconvolution operation as it does
676    not compute a true inverse of convolution). For more information, see the visualizations
677    `here`_ and the `Deconvolutional Networks`_ paper.
678
679    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
680
681    On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
682
683    * :attr:`stride` controls the stride for the cross-correlation.
684
685    * :attr:`padding` controls the amount of implicit zero padding on both
686      sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
687      below for details.
688
689    * :attr:`output_padding` controls the additional size added to one side
690      of the output shape. See note below for details.
691""" + """
692    * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
693      It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
694""" + r"""
695    {groups_note}
696
697    Note:
698        The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
699        amount of zero padding to both sizes of the input. This is set so that
700        when a :class:`~torch.nn.Conv1d` and a :class:`~torch.nn.ConvTranspose1d`
701        are initialized with same parameters, they are inverses of each other in
702        regard to the input and output shapes. However, when ``stride > 1``,
703        :class:`~torch.nn.Conv1d` maps multiple input shapes to the same output
704        shape. :attr:`output_padding` is provided to resolve this ambiguity by
705        effectively increasing the calculated output shape on one side. Note
706        that :attr:`output_padding` is only used to find output shape, but does
707        not actually add zero-padding to output.
708
709    Note:
710        In some circumstances when using the CUDA backend with CuDNN, this operator
711        may select a nondeterministic algorithm to increase performance. If this is
712        undesirable, you can try to make the operation deterministic (potentially at
713        a performance cost) by setting ``torch.backends.cudnn.deterministic =
714        True``.
715        Please see the notes on :doc:`/notes/randomness` for background.
716
717
718    Args:
719        in_channels (int): Number of channels in the input image
720        out_channels (int): Number of channels produced by the convolution
721        kernel_size (int or tuple): Size of the convolving kernel
722        stride (int or tuple, optional): Stride of the convolution. Default: 1
723        padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
724            will be added to both sides of the input. Default: 0
725        output_padding (int or tuple, optional): Additional size added to one side
726            of the output shape. Default: 0
727        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
728        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
729        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
730    """.format(**reproducibility_notes, **convolution_notes) + r"""
731
732    Shape:
733        - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
734        - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
735
736          .. math::
737              L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation}
738                        \times (\text{kernel\_size} - 1) + \text{output\_padding} + 1
739
740    Attributes:
741        weight (Tensor): the learnable weights of the module of shape
742                         :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
743                         :math:`\text{kernel\_size})`.
744                         The values of these weights are sampled from
745                         :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
746                         :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}`
747        bias (Tensor):   the learnable bias of the module of shape (out_channels).
748                         If :attr:`bias` is ``True``, then the values of these weights are
749                         sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
750                         :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}`
751
752    .. _`here`:
753        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
754
755    .. _`Deconvolutional Networks`:
756        https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
757    """
758
759    def __init__(
760        self,
761        in_channels: int,
762        out_channels: int,
763        kernel_size: _size_1_t,
764        stride: _size_1_t = 1,
765        padding: _size_1_t = 0,
766        output_padding: _size_1_t = 0,
767        groups: int = 1,
768        bias: bool = True,
769        dilation: _size_1_t = 1,
770        padding_mode: str = 'zeros',
771        device=None,
772        dtype=None
773    ) -> None:
774        factory_kwargs = {'device': device, 'dtype': dtype}
775        kernel_size = _single(kernel_size)
776        stride = _single(stride)
777        padding = _single(padding)
778        dilation = _single(dilation)
779        output_padding = _single(output_padding)
780        super().__init__(
781            in_channels, out_channels, kernel_size, stride, padding, dilation,
782            True, output_padding, groups, bias, padding_mode, **factory_kwargs)
783
784    def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
785        if self.padding_mode != 'zeros':
786            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
787
788        assert isinstance(self.padding, tuple)
789        # One cannot replace List by Tuple or Sequence in "_output_padding" because
790        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
791        num_spatial_dims = 1
792        output_padding = self._output_padding(
793            input, output_size, self.stride, self.padding, self.kernel_size,  # type: ignore[arg-type]
794            num_spatial_dims, self.dilation)  # type: ignore[arg-type]
795        return F.conv_transpose1d(
796            input, self.weight, self.bias, self.stride, self.padding,
797            output_padding, self.groups, self.dilation)
798
799
800class ConvTranspose2d(_ConvTransposeNd):
801    __doc__ = r"""Applies a 2D transposed convolution operator over an input image
802    composed of several input planes.
803
804    This module can be seen as the gradient of Conv2d with respect to its input.
805    It is also known as a fractionally-strided convolution or
806    a deconvolution (although it is not an actual deconvolution operation as it does
807    not compute a true inverse of convolution). For more information, see the visualizations
808    `here`_ and the `Deconvolutional Networks`_ paper.
809
810    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
811
812    On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
813
814    * :attr:`stride` controls the stride for the cross-correlation.
815
816    * :attr:`padding` controls the amount of implicit zero padding on both
817      sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
818      below for details.
819
820    * :attr:`output_padding` controls the additional size added to one side
821      of the output shape. See note below for details.
822""" + """
823    * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
824      It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
825""" + r"""
826    {groups_note}
827
828    The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
829    can either be:
830
831        - a single ``int`` -- in which case the same value is used for the height and width dimensions
832        - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
833          and the second `int` for the width dimension
834
835    Note:
836        The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
837        amount of zero padding to both sizes of the input. This is set so that
838        when a :class:`~torch.nn.Conv2d` and a :class:`~torch.nn.ConvTranspose2d`
839        are initialized with same parameters, they are inverses of each other in
840        regard to the input and output shapes. However, when ``stride > 1``,
841        :class:`~torch.nn.Conv2d` maps multiple input shapes to the same output
842        shape. :attr:`output_padding` is provided to resolve this ambiguity by
843        effectively increasing the calculated output shape on one side. Note
844        that :attr:`output_padding` is only used to find output shape, but does
845        not actually add zero-padding to output.
846
847    Note:
848        {cudnn_reproducibility_note}
849
850    Args:
851        in_channels (int): Number of channels in the input image
852        out_channels (int): Number of channels produced by the convolution
853        kernel_size (int or tuple): Size of the convolving kernel
854        stride (int or tuple, optional): Stride of the convolution. Default: 1
855        padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
856            will be added to both sides of each dimension in the input. Default: 0
857        output_padding (int or tuple, optional): Additional size added to one side
858            of each dimension in the output shape. Default: 0
859        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
860        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
861        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
862    """.format(**reproducibility_notes, **convolution_notes) + r"""
863
864    Shape:
865        - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
866        - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
867
868        .. math::
869              H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
870                        \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
871        .. math::
872              W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
873                        \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
874
875    Attributes:
876        weight (Tensor): the learnable weights of the module of shape
877                         :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
878                         :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
879                         The values of these weights are sampled from
880                         :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
881                         :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
882        bias (Tensor):   the learnable bias of the module of shape (out_channels)
883                         If :attr:`bias` is ``True``, then the values of these weights are
884                         sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
885                         :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
886
887    Examples::
888
889        >>> # With square kernels and equal stride
890        >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2)
891        >>> # non-square kernels and unequal stride and with padding
892        >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
893        >>> input = torch.randn(20, 16, 50, 100)
894        >>> output = m(input)
895        >>> # exact output size can be also specified as an argument
896        >>> input = torch.randn(1, 16, 12, 12)
897        >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
898        >>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
899        >>> h = downsample(input)
900        >>> h.size()
901        torch.Size([1, 16, 6, 6])
902        >>> output = upsample(h, output_size=input.size())
903        >>> output.size()
904        torch.Size([1, 16, 12, 12])
905
906    .. _`here`:
907        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
908
909    .. _`Deconvolutional Networks`:
910        https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
911    """
912
913    def __init__(
914        self,
915        in_channels: int,
916        out_channels: int,
917        kernel_size: _size_2_t,
918        stride: _size_2_t = 1,
919        padding: _size_2_t = 0,
920        output_padding: _size_2_t = 0,
921        groups: int = 1,
922        bias: bool = True,
923        dilation: _size_2_t = 1,
924        padding_mode: str = 'zeros',
925        device=None,
926        dtype=None
927    ) -> None:
928        factory_kwargs = {'device': device, 'dtype': dtype}
929        kernel_size = _pair(kernel_size)
930        stride = _pair(stride)
931        padding = _pair(padding)
932        dilation = _pair(dilation)
933        output_padding = _pair(output_padding)
934        super().__init__(
935            in_channels, out_channels, kernel_size, stride, padding, dilation,
936            True, output_padding, groups, bias, padding_mode, **factory_kwargs)
937
938    def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
939        if self.padding_mode != 'zeros':
940            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
941
942        assert isinstance(self.padding, tuple)
943        # One cannot replace List by Tuple or Sequence in "_output_padding" because
944        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
945        num_spatial_dims = 2
946        output_padding = self._output_padding(
947            input, output_size, self.stride, self.padding, self.kernel_size,  # type: ignore[arg-type]
948            num_spatial_dims, self.dilation)  # type: ignore[arg-type]
949
950        return F.conv_transpose2d(
951            input, self.weight, self.bias, self.stride, self.padding,
952            output_padding, self.groups, self.dilation)
953
954
955class ConvTranspose3d(_ConvTransposeNd):
956    __doc__ = r"""Applies a 3D transposed convolution operator over an input image composed of several input
957    planes.
958    The transposed convolution operator multiplies each input value element-wise by a learnable kernel,
959    and sums over the outputs from all input feature planes.
960
961    This module can be seen as the gradient of Conv3d with respect to its input.
962    It is also known as a fractionally-strided convolution or
963    a deconvolution (although it is not an actual deconvolution operation as it does
964    not compute a true inverse of convolution). For more information, see the visualizations
965    `here`_ and the `Deconvolutional Networks`_ paper.
966
967    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
968
969    On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
970
971    * :attr:`stride` controls the stride for the cross-correlation.
972
973    * :attr:`padding` controls the amount of implicit zero padding on both
974      sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
975      below for details.
976
977    * :attr:`output_padding` controls the additional size added to one side
978      of the output shape. See note below for details.
979""" + """
980    * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
981      It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
982""" + r"""
983    {groups_note}
984
985    The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
986    can either be:
987
988        - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions
989        - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
990          the second `int` for the height dimension and the third `int` for the width dimension
991
992    Note:
993        The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
994        amount of zero padding to both sizes of the input. This is set so that
995        when a :class:`~torch.nn.Conv3d` and a :class:`~torch.nn.ConvTranspose3d`
996        are initialized with same parameters, they are inverses of each other in
997        regard to the input and output shapes. However, when ``stride > 1``,
998        :class:`~torch.nn.Conv3d` maps multiple input shapes to the same output
999        shape. :attr:`output_padding` is provided to resolve this ambiguity by
1000        effectively increasing the calculated output shape on one side. Note
1001        that :attr:`output_padding` is only used to find output shape, but does
1002        not actually add zero-padding to output.
1003
1004    Note:
1005        {cudnn_reproducibility_note}
1006
1007    Args:
1008        in_channels (int): Number of channels in the input image
1009        out_channels (int): Number of channels produced by the convolution
1010        kernel_size (int or tuple): Size of the convolving kernel
1011        stride (int or tuple, optional): Stride of the convolution. Default: 1
1012        padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1013            will be added to both sides of each dimension in the input. Default: 0
1014        output_padding (int or tuple, optional): Additional size added to one side
1015            of each dimension in the output shape. Default: 0
1016        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1017        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1018        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1019    """.format(**reproducibility_notes, **convolution_notes) + r"""
1020
1021    Shape:
1022        - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
1023        - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or
1024          :math:`(C_{out}, D_{out}, H_{out}, W_{out})`, where
1025
1026        .. math::
1027              D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
1028                        \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
1029        .. math::
1030              H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
1031                        \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
1032        .. math::
1033              W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{dilation}[2]
1034                        \times (\text{kernel\_size}[2] - 1) + \text{output\_padding}[2] + 1
1035
1036
1037    Attributes:
1038        weight (Tensor): the learnable weights of the module of shape
1039                         :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
1040                         :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
1041                         The values of these weights are sampled from
1042                         :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1043                         :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
1044        bias (Tensor):   the learnable bias of the module of shape (out_channels)
1045                         If :attr:`bias` is ``True``, then the values of these weights are
1046                         sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1047                         :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
1048
1049    Examples::
1050
1051        >>> # With square kernels and equal stride
1052        >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2)
1053        >>> # non-square kernels and unequal stride and with padding
1054        >>> m = nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2))
1055        >>> input = torch.randn(20, 16, 10, 50, 100)
1056        >>> output = m(input)
1057
1058    .. _`here`:
1059        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
1060
1061    .. _`Deconvolutional Networks`:
1062        https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
1063    """
1064
1065    def __init__(
1066        self,
1067        in_channels: int,
1068        out_channels: int,
1069        kernel_size: _size_3_t,
1070        stride: _size_3_t = 1,
1071        padding: _size_3_t = 0,
1072        output_padding: _size_3_t = 0,
1073        groups: int = 1,
1074        bias: bool = True,
1075        dilation: _size_3_t = 1,
1076        padding_mode: str = 'zeros',
1077        device=None,
1078        dtype=None
1079    ) -> None:
1080        factory_kwargs = {'device': device, 'dtype': dtype}
1081        kernel_size = _triple(kernel_size)
1082        stride = _triple(stride)
1083        padding = _triple(padding)
1084        dilation = _triple(dilation)
1085        output_padding = _triple(output_padding)
1086        super().__init__(
1087            in_channels, out_channels, kernel_size, stride, padding, dilation,
1088            True, output_padding, groups, bias, padding_mode, **factory_kwargs)
1089
1090    def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
1091        if self.padding_mode != 'zeros':
1092            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d')
1093
1094        assert isinstance(self.padding, tuple)
1095        # One cannot replace List by Tuple or Sequence in "_output_padding" because
1096        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
1097        num_spatial_dims = 3
1098        output_padding = self._output_padding(
1099            input, output_size, self.stride, self.padding, self.kernel_size,  # type: ignore[arg-type]
1100            num_spatial_dims, self.dilation)  # type: ignore[arg-type]
1101
1102        return F.conv_transpose3d(
1103            input, self.weight, self.bias, self.stride, self.padding,
1104            output_padding, self.groups, self.dilation)
1105
1106
1107# TODO: Deprecate and remove the following alias `_ConvTransposeMixin`.
1108#
1109# `_ConvTransposeMixin` was a mixin that was removed.  It is meant to be used
1110# with `_ConvNd` to construct actual module classes that implements conv
1111# transpose ops:
1112#
1113#   class MyConvTranspose(_ConvNd, _ConvTransposeMixin):
1114#       ...
1115#
1116# In PyTorch, it has been replaced by `_ConvTransposeNd`, which is a proper
1117# subclass of `_ConvNd`.  However, some user code in the wild still (incorrectly)
1118# use the internal class `_ConvTransposeMixin`.  Hence, we provide this alias
1119# for BC, because it is cheap and easy for us to do so, even though that
1120# `_ConvTransposeNd` is really not a mixin anymore (but multiple inheritance as
1121# above would still work).
1122class _ConvTransposeMixin(_ConvTransposeNd):
1123
1124    @deprecated(
1125        "`_ConvTransposeMixin` is a deprecated internal class. "
1126        "Please consider using public APIs.",
1127        category=FutureWarning,
1128    )
1129    def __init__(self, *args, **kwargs):
1130        super().__init__(*args, **kwargs)
1131
1132
1133# TODO: Conv2dLocal
1134# TODO: Conv2dMap
1135# TODO: ConvTranspose2dMap
1136
1137
1138class _LazyConvXdMixin(LazyModuleMixin):
1139    groups: int
1140    transposed: bool
1141    in_channels: int
1142    out_channels: int
1143    kernel_size: Tuple[int, ...]
1144    weight: UninitializedParameter
1145    bias: UninitializedParameter
1146
1147    def reset_parameters(self) -> None:
1148        # has_uninitialized_params is defined in parent class and it is using a protocol on self
1149        if not self.has_uninitialized_params() and self.in_channels != 0:  # type: ignore[misc]
1150            # "type:ignore[..]" is required because mypy thinks that "reset_parameters" is undefined
1151            # in super class. Turns out that it is defined in _ConvND which is inherited by any class
1152            # that also inherits _LazyConvXdMixin
1153            super().reset_parameters()  # type: ignore[misc]
1154
1155    # Signature of "initialize_parameters" is incompatible with the definition in supertype LazyModuleMixin
1156    def initialize_parameters(self, input: Tensor, *args, **kwargs) -> None:  # type: ignore[override]
1157        # defined by parent class but using a protocol
1158        if self.has_uninitialized_params():  # type: ignore[misc]
1159            self.in_channels = self._get_in_channels(input)
1160            if self.in_channels % self.groups != 0:
1161                raise ValueError('in_channels must be divisible by groups')
1162            assert isinstance(self.weight, UninitializedParameter)
1163            if self.transposed:
1164                self.weight.materialize((
1165                    self.in_channels, self.out_channels // self.groups, *self.kernel_size))
1166            else:
1167                self.weight.materialize((
1168                    self.out_channels, self.in_channels // self.groups, *self.kernel_size))
1169            if self.bias is not None:
1170                assert isinstance(self.bias, UninitializedParameter)
1171                self.bias.materialize((self.out_channels,))
1172            self.reset_parameters()
1173
1174    # Function to extract in_channels from first input.
1175    def _get_in_channels(self, input: Tensor) -> int:
1176        num_spatial_dims = self._get_num_spatial_dims()
1177        num_dims_no_batch = num_spatial_dims + 1  # +1 for channels dim
1178        num_dims_batch = num_dims_no_batch + 1
1179        if input.dim() not in (num_dims_no_batch, num_dims_batch):
1180            raise RuntimeError(f"Expected {num_dims_no_batch}D (unbatched) or {num_dims_batch}D (batched) input "
1181                               f"to {self.__class__.__name__}, but "
1182                               f"got input of size: {input.shape}")
1183        return input.shape[1] if input.dim() == num_dims_batch else input.shape[0]
1184
1185    # Function to return the number of spatial dims expected for inputs to the module.
1186    # This is expected to be implemented by subclasses.
1187    def _get_num_spatial_dims(self) -> int:
1188        raise NotImplementedError
1189
1190
1191# LazyConv1d defines weight as a Tensor but derived class defines it as UnitializeParameter
1192class LazyConv1d(_LazyConvXdMixin, Conv1d):  # type: ignore[misc]
1193    r"""A :class:`torch.nn.Conv1d` module with lazy initialization of the ``in_channels`` argument.
1194
1195    The ``in_channels`` argument of the :class:`Conv1d` is inferred from the ``input.size(1)``.
1196    The attributes that will be lazily initialized are `weight` and `bias`.
1197
1198    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1199    on lazy modules and their limitations.
1200
1201    Args:
1202        out_channels (int): Number of channels produced by the convolution
1203        kernel_size (int or tuple): Size of the convolving kernel
1204        stride (int or tuple, optional): Stride of the convolution. Default: 1
1205        padding (int or tuple, optional): Zero-padding added to both sides of
1206            the input. Default: 0
1207        padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
1208            ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
1209        dilation (int or tuple, optional): Spacing between kernel
1210            elements. Default: 1
1211        groups (int, optional): Number of blocked connections from input
1212            channels to output channels. Default: 1
1213        bias (bool, optional): If ``True``, adds a learnable bias to the
1214            output. Default: ``True``
1215
1216    .. seealso:: :class:`torch.nn.Conv1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1217    """
1218
1219    # super class define this variable as None. "type: ignore[..] is required
1220    # since we are redefining the variable.
1221    cls_to_become = Conv1d  # type: ignore[assignment]
1222
1223    def __init__(
1224        self,
1225        out_channels: int,
1226        kernel_size: _size_1_t,
1227        stride: _size_1_t = 1,
1228        padding: _size_1_t = 0,
1229        dilation: _size_1_t = 1,
1230        groups: int = 1,
1231        bias: bool = True,
1232        padding_mode: str = 'zeros',
1233        device=None,
1234        dtype=None
1235    ) -> None:
1236        factory_kwargs = {'device': device, 'dtype': dtype}
1237        super().__init__(
1238            0,
1239            0,
1240            kernel_size,
1241            stride,
1242            padding,
1243            dilation,
1244            groups,
1245            # bias is hardcoded to False to avoid creating tensor
1246            # that will soon be overwritten.
1247            False,
1248            padding_mode,
1249            **factory_kwargs
1250        )
1251        self.weight = UninitializedParameter(**factory_kwargs)
1252        self.out_channels = out_channels
1253        if bias:
1254            self.bias = UninitializedParameter(**factory_kwargs)
1255
1256    def _get_num_spatial_dims(self) -> int:
1257        return 1
1258
1259
1260# LazyConv2d defines weight as a Tensor but derived class defines it as UnitializeParameter
1261class LazyConv2d(_LazyConvXdMixin, Conv2d):  # type: ignore[misc]
1262    r"""A :class:`torch.nn.Conv2d` module with lazy initialization of the ``in_channels`` argument.
1263
1264    The ``in_channels`` argument of the :class:`Conv2d` that is inferred from the ``input.size(1)``.
1265    The attributes that will be lazily initialized are `weight` and `bias`.
1266
1267    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1268    on lazy modules and their limitations.
1269
1270    Args:
1271        out_channels (int): Number of channels produced by the convolution
1272        kernel_size (int or tuple): Size of the convolving kernel
1273        stride (int or tuple, optional): Stride of the convolution. Default: 1
1274        padding (int or tuple, optional): Zero-padding added to both sides of
1275            the input. Default: 0
1276        padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
1277            ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
1278        dilation (int or tuple, optional): Spacing between kernel
1279            elements. Default: 1
1280        groups (int, optional): Number of blocked connections from input
1281            channels to output channels. Default: 1
1282        bias (bool, optional): If ``True``, adds a learnable bias to the
1283            output. Default: ``True``
1284
1285    .. seealso:: :class:`torch.nn.Conv2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1286    """
1287
1288    # super class define this variable as None. "type: ignore[..] is required
1289    # since we are redefining the variable.
1290    cls_to_become = Conv2d  # type: ignore[assignment]
1291
1292    def __init__(
1293        self,
1294        out_channels: int,
1295        kernel_size: _size_2_t,
1296        stride: _size_2_t = 1,
1297        padding: _size_2_t = 0,
1298        dilation: _size_2_t = 1,
1299        groups: int = 1,
1300        bias: bool = True,
1301        padding_mode: str = 'zeros',  # TODO: refine this type
1302        device=None,
1303        dtype=None
1304    ) -> None:
1305        factory_kwargs = {'device': device, 'dtype': dtype}
1306        super().__init__(
1307            0,
1308            0,
1309            kernel_size,
1310            stride,
1311            padding,
1312            dilation,
1313            groups,
1314            # bias is hardcoded to False to avoid creating tensor
1315            # that will soon be overwritten.
1316            False,
1317            padding_mode,
1318            **factory_kwargs
1319        )
1320        self.weight = UninitializedParameter(**factory_kwargs)
1321        self.out_channels = out_channels
1322        if bias:
1323            self.bias = UninitializedParameter(**factory_kwargs)
1324
1325    def _get_num_spatial_dims(self) -> int:
1326        return 2
1327
1328
1329# LazyConv3d defines weight as a Tensor but derived class defines it as UnitializeParameter
1330class LazyConv3d(_LazyConvXdMixin, Conv3d):  # type: ignore[misc]
1331    r"""A :class:`torch.nn.Conv3d` module with lazy initialization of the ``in_channels`` argument.
1332
1333    The ``in_channels`` argument of the :class:`Conv3d` that is inferred from
1334    the ``input.size(1)``.
1335    The attributes that will be lazily initialized are `weight` and `bias`.
1336
1337    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1338    on lazy modules and their limitations.
1339
1340    Args:
1341        out_channels (int): Number of channels produced by the convolution
1342        kernel_size (int or tuple): Size of the convolving kernel
1343        stride (int or tuple, optional): Stride of the convolution. Default: 1
1344        padding (int or tuple, optional): Zero-padding added to both sides of
1345            the input. Default: 0
1346        padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
1347            ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
1348        dilation (int or tuple, optional): Spacing between kernel
1349            elements. Default: 1
1350        groups (int, optional): Number of blocked connections from input
1351            channels to output channels. Default: 1
1352        bias (bool, optional): If ``True``, adds a learnable bias to the
1353            output. Default: ``True``
1354
1355    .. seealso:: :class:`torch.nn.Conv3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1356    """
1357
1358    # super class define this variable as None. "type: ignore[..] is required
1359    # since we are redefining the variable.
1360    cls_to_become = Conv3d  # type: ignore[assignment]
1361
1362    def __init__(
1363        self,
1364        out_channels: int,
1365        kernel_size: _size_3_t,
1366        stride: _size_3_t = 1,
1367        padding: _size_3_t = 0,
1368        dilation: _size_3_t = 1,
1369        groups: int = 1,
1370        bias: bool = True,
1371        padding_mode: str = 'zeros',
1372        device=None,
1373        dtype=None
1374    ) -> None:
1375        factory_kwargs = {'device': device, 'dtype': dtype}
1376        super().__init__(
1377            0,
1378            0,
1379            kernel_size,
1380            stride,
1381            padding,
1382            dilation,
1383            groups,
1384            # bias is hardcoded to False to avoid creating tensor
1385            # that will soon be overwritten.
1386            False,
1387            padding_mode,
1388            **factory_kwargs
1389        )
1390        self.weight = UninitializedParameter(**factory_kwargs)
1391        self.out_channels = out_channels
1392        if bias:
1393            self.bias = UninitializedParameter(**factory_kwargs)
1394
1395    def _get_num_spatial_dims(self) -> int:
1396        return 3
1397
1398
1399# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UnitializeParameter
1400class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d):  # type: ignore[misc]
1401    r"""A :class:`torch.nn.ConvTranspose1d` module with lazy initialization of the ``in_channels`` argument.
1402
1403    The ``in_channels`` argument of the :class:`ConvTranspose1d` that is inferred from
1404    the ``input.size(1)``.
1405    The attributes that will be lazily initialized are `weight` and `bias`.
1406
1407    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1408    on lazy modules and their limitations.
1409
1410    Args:
1411        out_channels (int): Number of channels produced by the convolution
1412        kernel_size (int or tuple): Size of the convolving kernel
1413        stride (int or tuple, optional): Stride of the convolution. Default: 1
1414        padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1415            will be added to both sides of the input. Default: 0
1416        output_padding (int or tuple, optional): Additional size added to one side
1417            of the output shape. Default: 0
1418        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1419        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1420        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1421
1422    .. seealso:: :class:`torch.nn.ConvTranspose1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1423    """
1424
1425    # super class define this variable as None. "type: ignore[..] is required
1426    # since we are redefining the variable.
1427    cls_to_become = ConvTranspose1d  # type: ignore[assignment]
1428
1429    def __init__(
1430        self,
1431        out_channels: int,
1432        kernel_size: _size_1_t,
1433        stride: _size_1_t = 1,
1434        padding: _size_1_t = 0,
1435        output_padding: _size_1_t = 0,
1436        groups: int = 1,
1437        bias: bool = True,
1438        dilation: _size_1_t = 1,
1439        padding_mode: str = 'zeros',
1440        device=None,
1441        dtype=None
1442    ) -> None:
1443        factory_kwargs = {'device': device, 'dtype': dtype}
1444        super().__init__(
1445            0,
1446            0,
1447            kernel_size,
1448            stride,
1449            padding,
1450            output_padding,
1451            groups,
1452            # bias is hardcoded to False to avoid creating tensor
1453            # that will soon be overwritten.
1454            False,
1455            dilation,
1456            padding_mode,
1457            **factory_kwargs
1458        )
1459        self.weight = UninitializedParameter(**factory_kwargs)
1460        self.out_channels = out_channels
1461        if bias:
1462            self.bias = UninitializedParameter(**factory_kwargs)
1463
1464    def _get_num_spatial_dims(self) -> int:
1465        return 1
1466
1467
1468# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UnitializeParameter
1469class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d):  # type: ignore[misc]
1470    r"""A :class:`torch.nn.ConvTranspose2d` module with lazy initialization of the ``in_channels`` argument.
1471
1472    The ``in_channels`` argument of the :class:`ConvTranspose2d` is inferred from
1473    the ``input.size(1)``.
1474    The attributes that will be lazily initialized are `weight` and `bias`.
1475
1476    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1477    on lazy modules and their limitations.
1478
1479    Args:
1480        out_channels (int): Number of channels produced by the convolution
1481        kernel_size (int or tuple): Size of the convolving kernel
1482        stride (int or tuple, optional): Stride of the convolution. Default: 1
1483        padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1484            will be added to both sides of each dimension in the input. Default: 0
1485        output_padding (int or tuple, optional): Additional size added to one side
1486            of each dimension in the output shape. Default: 0
1487        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1488        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1489        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1490
1491    .. seealso:: :class:`torch.nn.ConvTranspose2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1492    """
1493
1494    # super class define this variable as None. "type: ignore[..] is required
1495    # since we are redefining the variable.
1496    cls_to_become = ConvTranspose2d  # type: ignore[assignment]
1497
1498    def __init__(
1499        self,
1500        out_channels: int,
1501        kernel_size: _size_2_t,
1502        stride: _size_2_t = 1,
1503        padding: _size_2_t = 0,
1504        output_padding: _size_2_t = 0,
1505        groups: int = 1,
1506        bias: bool = True,
1507        dilation: int = 1,
1508        padding_mode: str = 'zeros',
1509        device=None,
1510        dtype=None
1511    ) -> None:
1512        factory_kwargs = {'device': device, 'dtype': dtype}
1513        super().__init__(
1514            0,
1515            0,
1516            kernel_size,
1517            stride,
1518            padding,
1519            output_padding,
1520            groups,
1521            # bias is hardcoded to False to avoid creating tensor
1522            # that will soon be overwritten.
1523            False,
1524            dilation,
1525            padding_mode,
1526            **factory_kwargs
1527        )
1528        self.weight = UninitializedParameter(**factory_kwargs)
1529        self.out_channels = out_channels
1530        if bias:
1531            self.bias = UninitializedParameter(**factory_kwargs)
1532
1533    def _get_num_spatial_dims(self) -> int:
1534        return 2
1535
1536
1537# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UnitializeParameter
1538class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d):  # type: ignore[misc]
1539    r"""A :class:`torch.nn.ConvTranspose3d` module with lazy initialization of the ``in_channels`` argument.
1540
1541    The ``in_channels`` argument of the :class:`ConvTranspose3d` is inferred from
1542    the ``input.size(1)``.
1543    The attributes that will be lazily initialized are `weight` and `bias`.
1544
1545    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1546    on lazy modules and their limitations.
1547
1548    Args:
1549        out_channels (int): Number of channels produced by the convolution
1550        kernel_size (int or tuple): Size of the convolving kernel
1551        stride (int or tuple, optional): Stride of the convolution. Default: 1
1552        padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1553            will be added to both sides of each dimension in the input. Default: 0
1554        output_padding (int or tuple, optional): Additional size added to one side
1555            of each dimension in the output shape. Default: 0
1556        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1557        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1558        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1559
1560    .. seealso:: :class:`torch.nn.ConvTranspose3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1561    """
1562
1563    # super class define this variable as None. "type: ignore[..] is required
1564    # since we are redefining the variable.
1565    cls_to_become = ConvTranspose3d  # type: ignore[assignment]
1566
1567    def __init__(
1568        self,
1569        out_channels: int,
1570        kernel_size: _size_3_t,
1571        stride: _size_3_t = 1,
1572        padding: _size_3_t = 0,
1573        output_padding: _size_3_t = 0,
1574        groups: int = 1,
1575        bias: bool = True,
1576        dilation: _size_3_t = 1,
1577        padding_mode: str = 'zeros',
1578        device=None,
1579        dtype=None
1580    ) -> None:
1581        factory_kwargs = {'device': device, 'dtype': dtype}
1582        super().__init__(
1583            0,
1584            0,
1585            kernel_size,
1586            stride,
1587            padding,
1588            output_padding,
1589            groups,
1590            # bias is hardcoded to False to avoid creating tensor
1591            # that will soon be overwritten.
1592            False,
1593            dilation,
1594            padding_mode,
1595            **factory_kwargs
1596        )
1597        self.weight = UninitializedParameter(**factory_kwargs)
1598        self.out_channels = out_channels
1599        if bias:
1600            self.bias = UninitializedParameter(**factory_kwargs)
1601
1602    def _get_num_spatial_dims(self) -> int:
1603        return 3
1604