xref: /aosp_15_r20/external/pytorch/torch/nn/modules/fold.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch.nn.functional as F
2from torch import Tensor
3from torch.nn.common_types import _size_any_t
4
5from .module import Module
6
7
8__all__ = ["Fold", "Unfold"]
9
10
11class Fold(Module):
12    r"""Combines an array of sliding local blocks into a large containing tensor.
13
14    Consider a batched :attr:`input` tensor containing sliding local blocks,
15    e.g., patches of images, of shape :math:`(N, C \times  \prod(\text{kernel\_size}), L)`,
16    where :math:`N` is batch dimension, :math:`C \times \prod(\text{kernel\_size})`
17    is the number of values within a block (a block has :math:`\prod(\text{kernel\_size})`
18    spatial locations each containing a :math:`C`-channeled vector), and
19    :math:`L` is the total number of blocks. (This is exactly the
20    same specification as the output shape of :class:`~torch.nn.Unfold`.) This
21    operation combines these local blocks into the large :attr:`output` tensor
22    of shape :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)`
23    by summing the overlapping values. Similar to :class:`~torch.nn.Unfold`, the
24    arguments must satisfy
25
26    .. math::
27        L = \prod_d \left\lfloor\frac{\text{output\_size}[d] + 2 \times \text{padding}[d] %
28            - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor,
29
30    where :math:`d` is over all spatial dimensions.
31
32    * :attr:`output_size` describes the spatial shape of the large containing
33      tensor of the sliding local blocks. It is useful to resolve the ambiguity
34      when multiple input shapes map to same number of sliding blocks, e.g.,
35      with ``stride > 0``.
36
37    The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify
38    how the sliding blocks are retrieved.
39
40    * :attr:`stride` controls the stride for the sliding blocks.
41
42    * :attr:`padding` controls the amount of implicit zero-paddings on both
43      sides for :attr:`padding` number of points for each dimension before
44      reshaping.
45""" """
46    * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
47      It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
48""" r"""
49    Args:
50        output_size (int or tuple): the shape of the spatial dimensions of the
51                                    output (i.e., ``output.sizes()[2:]``)
52        kernel_size (int or tuple): the size of the sliding blocks
53        dilation (int or tuple, optional): a parameter that controls the
54                                           stride of elements within the
55                                           neighborhood. Default: 1
56        padding (int or tuple, optional): implicit zero padding to be added on
57                                          both sides of input. Default: 0
58        stride (int or tuple): the stride of the sliding blocks in the input
59                               spatial dimensions. Default: 1
60
61    * If :attr:`output_size`, :attr:`kernel_size`, :attr:`dilation`,
62      :attr:`padding` or :attr:`stride` is an int or a tuple of length 1 then
63      their values will be replicated across all spatial dimensions.
64
65    * For the case of two output spatial dimensions this operation is sometimes
66      called ``col2im``.
67
68    .. note::
69        :class:`~torch.nn.Fold` calculates each combined value in the resulting
70        large tensor by summing all values from all containing blocks.
71        :class:`~torch.nn.Unfold` extracts the values in the local blocks by
72        copying from the large tensor. So, if the blocks overlap, they are not
73        inverses of each other.
74
75        In general, folding and unfolding operations are related as
76        follows. Consider :class:`~torch.nn.Fold` and
77        :class:`~torch.nn.Unfold` instances created with the same
78        parameters:
79
80        >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...)
81        >>> fold = nn.Fold(output_size=..., **fold_params)
82        >>> unfold = nn.Unfold(**fold_params)
83
84        Then for any (supported) ``input`` tensor the following
85        equality holds:
86
87        ::
88
89            fold(unfold(input)) == divisor * input
90
91        where ``divisor`` is a tensor that depends only on the shape
92        and dtype of the ``input``:
93
94        >>> # xdoctest: +SKIP
95        >>> input_ones = torch.ones(input.shape, dtype=input.dtype)
96        >>> divisor = fold(unfold(input_ones))
97
98        When the ``divisor`` tensor contains no zero elements, then
99        ``fold`` and ``unfold`` operations are inverses of each
100        other (up to constant divisor).
101
102    .. warning::
103        Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported.
104
105    Shape:
106        - Input: :math:`(N, C \times \prod(\text{kernel\_size}), L)` or :math:`(C \times \prod(\text{kernel\_size}), L)`
107        - Output: :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)`
108          or :math:`(C, \text{output\_size}[0], \text{output\_size}[1], \dots)` as described above
109
110    Examples::
111
112        >>> fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2))
113        >>> input = torch.randn(1, 3 * 2 * 2, 12)
114        >>> output = fold(input)
115        >>> output.size()
116        torch.Size([1, 3, 4, 5])
117
118    .. _link:
119        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
120
121    """
122
123    __constants__ = ["output_size", "kernel_size", "dilation", "padding", "stride"]
124    output_size: _size_any_t
125    kernel_size: _size_any_t
126    dilation: _size_any_t
127    padding: _size_any_t
128    stride: _size_any_t
129
130    def __init__(
131        self,
132        output_size: _size_any_t,
133        kernel_size: _size_any_t,
134        dilation: _size_any_t = 1,
135        padding: _size_any_t = 0,
136        stride: _size_any_t = 1,
137    ) -> None:
138        super().__init__()
139        self.output_size = output_size
140        self.kernel_size = kernel_size
141        self.dilation = dilation
142        self.padding = padding
143        self.stride = stride
144
145    def forward(self, input: Tensor) -> Tensor:
146        return F.fold(
147            input,
148            self.output_size,
149            self.kernel_size,
150            self.dilation,
151            self.padding,
152            self.stride,
153        )
154
155    def extra_repr(self) -> str:
156        return (
157            "output_size={output_size}, kernel_size={kernel_size}, "
158            "dilation={dilation}, padding={padding}, stride={stride}".format(
159                **self.__dict__
160            )
161        )
162
163
164class Unfold(Module):
165    r"""Extracts sliding local blocks from a batched input tensor.
166
167    Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`,
168    where :math:`N` is the batch dimension, :math:`C` is the channel dimension,
169    and :math:`*` represent arbitrary spatial dimensions. This operation flattens
170    each sliding :attr:`kernel_size`-sized block within the spatial dimensions
171    of :attr:`input` into a column (i.e., last dimension) of a 3-D :attr:`output`
172    tensor of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, where
173    :math:`C \times \prod(\text{kernel\_size})` is the total number of values
174    within each block (a block has :math:`\prod(\text{kernel\_size})` spatial
175    locations each containing a :math:`C`-channeled vector), and :math:`L` is
176    the total number of such blocks:
177
178    .. math::
179        L = \prod_d \left\lfloor\frac{\text{spatial\_size}[d] + 2 \times \text{padding}[d] %
180            - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor,
181
182    where :math:`\text{spatial\_size}` is formed by the spatial dimensions
183    of :attr:`input` (:math:`*` above), and :math:`d` is over all spatial
184    dimensions.
185
186    Therefore, indexing :attr:`output` at the last dimension (column dimension)
187    gives all values within a certain block.
188
189    The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify
190    how the sliding blocks are retrieved.
191
192    * :attr:`stride` controls the stride for the sliding blocks.
193
194    * :attr:`padding` controls the amount of implicit zero-paddings on both
195      sides for :attr:`padding` number of points for each dimension before
196      reshaping.
197""" """
198    * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
199      It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
200""" r"""
201    Args:
202        kernel_size (int or tuple): the size of the sliding blocks
203        dilation (int or tuple, optional): a parameter that controls the
204                                           stride of elements within the
205                                           neighborhood. Default: 1
206        padding (int or tuple, optional): implicit zero padding to be added on
207                                          both sides of input. Default: 0
208        stride (int or tuple, optional): the stride of the sliding blocks in the input
209                                         spatial dimensions. Default: 1
210
211    * If :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or
212      :attr:`stride` is an int or a tuple of length 1, their values will be
213      replicated across all spatial dimensions.
214
215    * For the case of two input spatial dimensions this operation is sometimes
216      called ``im2col``.
217
218    .. note::
219        :class:`~torch.nn.Fold` calculates each combined value in the resulting
220        large tensor by summing all values from all containing blocks.
221        :class:`~torch.nn.Unfold` extracts the values in the local blocks by
222        copying from the large tensor. So, if the blocks overlap, they are not
223        inverses of each other.
224
225        In general, folding and unfolding operations are related as
226        follows. Consider :class:`~torch.nn.Fold` and
227        :class:`~torch.nn.Unfold` instances created with the same
228        parameters:
229
230        >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...)
231        >>> fold = nn.Fold(output_size=..., **fold_params)
232        >>> unfold = nn.Unfold(**fold_params)
233
234        Then for any (supported) ``input`` tensor the following
235        equality holds:
236
237        ::
238
239            fold(unfold(input)) == divisor * input
240
241        where ``divisor`` is a tensor that depends only on the shape
242        and dtype of the ``input``:
243
244        >>> # xdoctest: +SKIP
245        >>> input_ones = torch.ones(input.shape, dtype=input.dtype)
246        >>> divisor = fold(unfold(input_ones))
247
248        When the ``divisor`` tensor contains no zero elements, then
249        ``fold`` and ``unfold`` operations are inverses of each
250        other (up to constant divisor).
251
252    .. warning::
253        Currently, only 4-D input tensors (batched image-like tensors) are
254        supported.
255
256    Shape:
257        - Input: :math:`(N, C, *)`
258        - Output: :math:`(N, C \times \prod(\text{kernel\_size}), L)` as described above
259
260    Examples::
261
262        >>> unfold = nn.Unfold(kernel_size=(2, 3))
263        >>> input = torch.randn(2, 5, 3, 4)
264        >>> output = unfold(input)
265        >>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels)
266        >>> # 4 blocks (2x3 kernels) in total in the 3x4 input
267        >>> output.size()
268        torch.Size([2, 30, 4])
269
270        >>> # xdoctest: +IGNORE_WANT
271        >>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
272        >>> inp = torch.randn(1, 3, 10, 12)
273        >>> w = torch.randn(2, 3, 4, 5)
274        >>> inp_unf = torch.nn.functional.unfold(inp, (4, 5))
275        >>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
276        >>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
277        >>> # or equivalently (and avoiding a copy),
278        >>> # out = out_unf.view(1, 2, 7, 8)
279        >>> (torch.nn.functional.conv2d(inp, w) - out).abs().max()
280        tensor(1.9073e-06)
281
282    .. _link:
283        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
284
285    """
286
287    __constants__ = ["kernel_size", "dilation", "padding", "stride"]
288    kernel_size: _size_any_t
289    dilation: _size_any_t
290    padding: _size_any_t
291    stride: _size_any_t
292
293    def __init__(
294        self,
295        kernel_size: _size_any_t,
296        dilation: _size_any_t = 1,
297        padding: _size_any_t = 0,
298        stride: _size_any_t = 1,
299    ) -> None:
300        super().__init__()
301        self.kernel_size = kernel_size
302        self.dilation = dilation
303        self.padding = padding
304        self.stride = stride
305
306    def forward(self, input: Tensor) -> Tensor:
307        return F.unfold(
308            input, self.kernel_size, self.dilation, self.padding, self.stride
309        )
310
311    def extra_repr(self) -> str:
312        return (
313            "kernel_size={kernel_size}, dilation={dilation}, padding={padding},"
314            " stride={stride}".format(**self.__dict__)
315        )
316