xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/PixelShuffle.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Tensor.h>
2 #include <c10/util/Exception.h>
3 
4 namespace at {
5 namespace native {
6 
check_pixel_shuffle_shapes(const Tensor & self,int64_t upscale_factor)7 inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
8   TORCH_CHECK(self.dim() >= 3,
9               "pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
10               self.dim(), " dimension(s)");
11   TORCH_CHECK(upscale_factor > 0,
12               "pixel_shuffle expects a positive upscale_factor, but got ",
13               upscale_factor);
14   int64_t c = self.size(-3);
15   int64_t upscale_factor_squared = upscale_factor * upscale_factor;
16   TORCH_CHECK(c % upscale_factor_squared == 0,
17               "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
18               "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
19 }
20 
check_pixel_unshuffle_shapes(const Tensor & self,int64_t downscale_factor)21 inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
22   TORCH_CHECK(
23       self.dim() >= 3,
24       "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
25       self.dim(),
26       " dimension(s)");
27   TORCH_CHECK(
28       downscale_factor > 0,
29       "pixel_unshuffle expects a positive downscale_factor, but got ",
30       downscale_factor);
31   int64_t h = self.size(-2);
32   int64_t w = self.size(-1);
33   TORCH_CHECK(
34       h % downscale_factor == 0,
35       "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=",
36       h,
37       " is not divisible by ",
38       downscale_factor);
39   TORCH_CHECK(
40       w % downscale_factor == 0,
41       "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=",
42       w,
43       " is not divisible by ",
44       downscale_factor);
45 }
46 
47 }} // namespace at::native
48