xref: /aosp_15_r20/external/pytorch/torch/nn/modules/upsampling.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Optional
3
4import torch.nn.functional as F
5from torch import Tensor
6from torch.nn.common_types import _ratio_2_t, _ratio_any_t, _size_2_t, _size_any_t
7
8from .module import Module
9
10
11__all__ = ["Upsample", "UpsamplingNearest2d", "UpsamplingBilinear2d"]
12
13
14class Upsample(Module):
15    r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
16
17    The input data is assumed to be of the form
18    `minibatch x channels x [optional depth] x [optional height] x width`.
19    Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor.
20
21    The algorithms available for upsampling are nearest neighbor and linear,
22    bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor,
23    respectively.
24
25    One can either give a :attr:`scale_factor` or the target output :attr:`size` to
26    calculate the output size. (You cannot give both, as it is ambiguous)
27
28    Args:
29        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional):
30            output spatial sizes
31        scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional):
32            multiplier for spatial size. Has to match input size if it is a tuple.
33        mode (str, optional): the upsampling algorithm: one of ``'nearest'``,
34            ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``.
35            Default: ``'nearest'``
36        align_corners (bool, optional): if ``True``, the corner pixels of the input
37            and output tensors are aligned, and thus preserving the values at
38            those pixels. This only has effect when :attr:`mode` is
39            ``'linear'``, ``'bilinear'``, ``'bicubic'``, or ``'trilinear'``.
40            Default: ``False``
41        recompute_scale_factor (bool, optional): recompute the scale_factor for use in the
42            interpolation calculation. If `recompute_scale_factor` is ``True``, then
43            `scale_factor` must be passed in and `scale_factor` is used to compute the
44            output `size`. The computed output `size` will be used to infer new scales for
45            the interpolation. Note that when `scale_factor` is floating-point, it may differ
46            from the recomputed `scale_factor` due to rounding and precision issues.
47            If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will
48            be used directly for interpolation.
49
50    Shape:
51        - Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})`
52        - Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})`
53          or :math:`(N, C, D_{out}, H_{out}, W_{out})`, where
54
55    .. math::
56        D_{out} = \left\lfloor D_{in} \times \text{scale\_factor} \right\rfloor
57
58    .. math::
59        H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
60
61    .. math::
62        W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
63
64    .. warning::
65        With ``align_corners = True``, the linearly interpolating modes
66        (`linear`, `bilinear`, `bicubic`, and `trilinear`) don't proportionally
67        align the output and input pixels, and thus the output values can depend
68        on the input size. This was the default behavior for these modes up to
69        version 0.3.1. Since then, the default behavior is
70        ``align_corners = False``. See below for concrete examples on how this
71        affects the outputs.
72
73    .. note::
74        If you want downsampling/general resizing, you should use :func:`~nn.functional.interpolate`.
75
76    Examples::
77
78        >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
79        >>> input
80        tensor([[[[1., 2.],
81                  [3., 4.]]]])
82
83        >>> m = nn.Upsample(scale_factor=2, mode='nearest')
84        >>> m(input)
85        tensor([[[[1., 1., 2., 2.],
86                  [1., 1., 2., 2.],
87                  [3., 3., 4., 4.],
88                  [3., 3., 4., 4.]]]])
89
90        >>> # xdoctest: +IGNORE_WANT("other tests seem to modify printing styles")
91        >>> m = nn.Upsample(scale_factor=2, mode='bilinear')  # align_corners=False
92        >>> m(input)
93        tensor([[[[1.0000, 1.2500, 1.7500, 2.0000],
94                  [1.5000, 1.7500, 2.2500, 2.5000],
95                  [2.5000, 2.7500, 3.2500, 3.5000],
96                  [3.0000, 3.2500, 3.7500, 4.0000]]]])
97
98        >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
99        >>> m(input)
100        tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
101                  [1.6667, 2.0000, 2.3333, 2.6667],
102                  [2.3333, 2.6667, 3.0000, 3.3333],
103                  [3.0000, 3.3333, 3.6667, 4.0000]]]])
104
105        >>> # Try scaling the same data in a larger tensor
106        >>> input_3x3 = torch.zeros(3, 3).view(1, 1, 3, 3)
107        >>> input_3x3[:, :, :2, :2].copy_(input)
108        tensor([[[[1., 2.],
109                  [3., 4.]]]])
110        >>> input_3x3
111        tensor([[[[1., 2., 0.],
112                  [3., 4., 0.],
113                  [0., 0., 0.]]]])
114
115        >>> # xdoctest: +IGNORE_WANT("seems to fail when other tests are run in the same session")
116        >>> m = nn.Upsample(scale_factor=2, mode='bilinear')  # align_corners=False
117        >>> # Notice that values in top left corner are the same with the small input (except at boundary)
118        >>> m(input_3x3)
119        tensor([[[[1.0000, 1.2500, 1.7500, 1.5000, 0.5000, 0.0000],
120                  [1.5000, 1.7500, 2.2500, 1.8750, 0.6250, 0.0000],
121                  [2.5000, 2.7500, 3.2500, 2.6250, 0.8750, 0.0000],
122                  [2.2500, 2.4375, 2.8125, 2.2500, 0.7500, 0.0000],
123                  [0.7500, 0.8125, 0.9375, 0.7500, 0.2500, 0.0000],
124                  [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
125
126        >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
127        >>> # Notice that values in top left corner are now changed
128        >>> m(input_3x3)
129        tensor([[[[1.0000, 1.4000, 1.8000, 1.6000, 0.8000, 0.0000],
130                  [1.8000, 2.2000, 2.6000, 2.2400, 1.1200, 0.0000],
131                  [2.6000, 3.0000, 3.4000, 2.8800, 1.4400, 0.0000],
132                  [2.4000, 2.7200, 3.0400, 2.5600, 1.2800, 0.0000],
133                  [1.2000, 1.3600, 1.5200, 1.2800, 0.6400, 0.0000],
134                  [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
135    """
136
137    __constants__ = [
138        "size",
139        "scale_factor",
140        "mode",
141        "align_corners",
142        "name",
143        "recompute_scale_factor",
144    ]
145    name: str
146    size: Optional[_size_any_t]
147    scale_factor: Optional[_ratio_any_t]
148    mode: str
149    align_corners: Optional[bool]
150    recompute_scale_factor: Optional[bool]
151
152    def __init__(
153        self,
154        size: Optional[_size_any_t] = None,
155        scale_factor: Optional[_ratio_any_t] = None,
156        mode: str = "nearest",
157        align_corners: Optional[bool] = None,
158        recompute_scale_factor: Optional[bool] = None,
159    ) -> None:
160        super().__init__()
161        self.name = type(self).__name__
162        self.size = size
163        if isinstance(scale_factor, tuple):
164            self.scale_factor = tuple(float(factor) for factor in scale_factor)
165        else:
166            self.scale_factor = float(scale_factor) if scale_factor else None
167        self.mode = mode
168        self.align_corners = align_corners
169        self.recompute_scale_factor = recompute_scale_factor
170
171    def forward(self, input: Tensor) -> Tensor:
172        return F.interpolate(
173            input,
174            self.size,
175            self.scale_factor,
176            self.mode,
177            self.align_corners,
178            recompute_scale_factor=self.recompute_scale_factor,
179        )
180
181    def __setstate__(self, state):
182        if "recompute_scale_factor" not in state:
183            state["recompute_scale_factor"] = True
184
185        super().__setstate__(state)
186
187    def extra_repr(self) -> str:
188        if self.scale_factor is not None:
189            info = "scale_factor=" + repr(self.scale_factor)
190        else:
191            info = "size=" + repr(self.size)
192        info += ", mode=" + repr(self.mode)
193        return info
194
195
196class UpsamplingNearest2d(Upsample):
197    r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input channels.
198
199    To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor`
200    as it's constructor argument.
201
202    When :attr:`size` is given, it is the output size of the image `(h, w)`.
203
204    Args:
205        size (int or Tuple[int, int], optional): output spatial sizes
206        scale_factor (float or Tuple[float, float], optional): multiplier for
207            spatial size.
208
209    .. warning::
210        This class is deprecated in favor of :func:`~nn.functional.interpolate`.
211
212    Shape:
213        - Input: :math:`(N, C, H_{in}, W_{in})`
214        - Output: :math:`(N, C, H_{out}, W_{out})` where
215
216    .. math::
217          H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
218
219    .. math::
220          W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
221
222    Examples::
223
224        >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
225        >>> input
226        tensor([[[[1., 2.],
227                  [3., 4.]]]])
228
229        >>> m = nn.UpsamplingNearest2d(scale_factor=2)
230        >>> m(input)
231        tensor([[[[1., 1., 2., 2.],
232                  [1., 1., 2., 2.],
233                  [3., 3., 4., 4.],
234                  [3., 3., 4., 4.]]]])
235    """
236
237    def __init__(
238        self,
239        size: Optional[_size_2_t] = None,
240        scale_factor: Optional[_ratio_2_t] = None,
241    ) -> None:
242        super().__init__(size, scale_factor, mode="nearest")
243
244
245class UpsamplingBilinear2d(Upsample):
246    r"""Applies a 2D bilinear upsampling to an input signal composed of several input channels.
247
248    To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor`
249    as it's constructor argument.
250
251    When :attr:`size` is given, it is the output size of the image `(h, w)`.
252
253    Args:
254        size (int or Tuple[int, int], optional): output spatial sizes
255        scale_factor (float or Tuple[float, float], optional): multiplier for
256            spatial size.
257
258    .. warning::
259        This class is deprecated in favor of :func:`~nn.functional.interpolate`. It is
260        equivalent to ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``.
261
262    Shape:
263        - Input: :math:`(N, C, H_{in}, W_{in})`
264        - Output: :math:`(N, C, H_{out}, W_{out})` where
265
266    .. math::
267        H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
268
269    .. math::
270        W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
271
272    Examples::
273
274        >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
275        >>> input
276        tensor([[[[1., 2.],
277                  [3., 4.]]]])
278
279        >>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?")
280        >>> m = nn.UpsamplingBilinear2d(scale_factor=2)
281        >>> m(input)
282        tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
283                  [1.6667, 2.0000, 2.3333, 2.6667],
284                  [2.3333, 2.6667, 3.0000, 3.3333],
285                  [3.0000, 3.3333, 3.6667, 4.0000]]]])
286    """
287
288    def __init__(
289        self,
290        size: Optional[_size_2_t] = None,
291        scale_factor: Optional[_ratio_2_t] = None,
292    ) -> None:
293        super().__init__(size, scale_factor, mode="bilinear", align_corners=True)
294