xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/AveragePool2d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ScalarOps.h>
4 #include <ATen/native/Pool.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/avg_pool2d_backward_native.h>
11 #include <ATen/ops/avg_pool2d_native.h>
12 #endif
13 
14 namespace at::meta {
15 using namespace ::at::native;
16 
TORCH_PRECOMPUTE_META_FUNC(avg_pool2d)17 TORCH_PRECOMPUTE_META_FUNC(avg_pool2d)
18 (const Tensor& input,
19  IntArrayRef kernel_size,
20  IntArrayRef stride,
21  IntArrayRef padding,
22  bool ceil_mode,
23  bool count_include_pad,
24  std::optional<int64_t> divisor_override) {
25   // #20866, #22032: Guarantee this for the official C++ API?
26   TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
27     "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints");
28   const int64_t kH = kernel_size[0];
29   const int64_t kW = kernel_size.size() == 1 ? kH : kernel_size[1];
30 
31   TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
32     "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints");
33   const int64_t dH = stride.empty() ? kH : stride[0];
34   const int64_t dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1];
35 
36   TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
37     "avg_pool2d: padding must either be a single int, or a tuple of two ints");
38   const int64_t padH = padding[0];
39   const int64_t padW = padding.size() == 1 ? padH : padding[1];
40 
41   TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
42     "divisor must be not zero");
43 
44   const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
45   const int64_t nInputPlane = input.size(-3);
46   const int64_t inputHeight = input.size(-2);
47   const int64_t inputWidth = input.size(-1);
48 
49   const int64_t outputHeight = pooling_output_shape<int64_t>(
50       inputHeight, kH, padH, dH, 1, ceil_mode);
51   const int64_t outputWidth =
52       pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
53 
54   auto memory_format = input.suggest_memory_format();
55   pool2d_shape_check(
56       input,
57       kH,
58       kW,
59       dH,
60       dW,
61       padH,
62       padW,
63       1,
64       1,
65       nInputPlane,
66       inputHeight,
67       inputWidth,
68       outputHeight,
69       outputWidth,
70       memory_format);
71 
72   /* resize output */
73   if (input.ndimension() == 3) {
74     set_output_raw_strided(
75         0,
76         {nInputPlane,
77          outputHeight,
78          outputWidth},
79         {},
80         input.options());
81   }
82   else {
83     set_output_raw_strided(
84         0,
85         {nbatch,
86          nInputPlane,
87          outputHeight,
88          outputWidth},
89         {},
90         input.options().memory_format(memory_format));
91   }
92 
93   return TORCH_PRECOMPUTE_STRUCT(avg_pool2d)().set_kH(kH).set_kW(kW).set_dH(dH).set_dW(dW).set_padH(padH).set_padW(padW);
94 }
95 
TORCH_META_FUNC(avg_pool2d_backward)96 TORCH_META_FUNC(avg_pool2d_backward) (
97   const Tensor& gradOutput_,
98   const Tensor& input,
99   IntArrayRef kernel_size,
100   IntArrayRef stride,
101   IntArrayRef padding,
102   bool ceil_mode,
103   bool count_include_pad,
104   std::optional<int64_t> divisor_override
105 ) {
106   // #20866, #22032: Guarantee this for the official C++ API?
107   TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
108     "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints");
109   const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
110   const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
111 
112   TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
113     "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints");
114   const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
115   const int dW = stride.empty() ? kW :
116                  stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
117 
118   TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
119     "avg_pool2d: padding must either be a single int, or a tuple of two ints");
120   const int padH = safe_downcast<int, int64_t>(padding[0]);
121   const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
122 
123   TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero");
124 
125   /* sizes */
126   const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
127   const int64_t nInputPlane = input.size(-3); // number of channels (or colors)
128   const int64_t inputHeight = input.size(-2);
129   const int64_t inputWidth = input.size(-1);
130   const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
131   const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
132 
133   auto memory_format = input.suggest_memory_format();
134   avg_pool2d_backward_shape_check(
135     input,
136     gradOutput_,
137     nbatch,
138     kH, kW, dH, dW, padH, padW,
139     nInputPlane,
140     inputHeight, inputWidth,
141     outputHeight, outputWidth,
142     memory_format);
143 
144   /* resize output */
145   set_output_raw_strided(0, input.sizes(), {}, input.options().memory_format(memory_format));
146 }
147 
148 } // namespace at::meta
149 
150 namespace at::native {
151 
TORCH_IMPL_FUNC(avg_pool2d_out_cpu)152 TORCH_IMPL_FUNC(avg_pool2d_out_cpu)
153 (const Tensor& input,
154  int64_t kH,
155  int64_t kW,
156  int64_t dH,
157  int64_t dW,
158  int64_t padH,
159  int64_t padW,
160  bool ceil_mode,
161  bool count_include_pad,
162  std::optional<int64_t> divisor_override,
163  const Tensor& output) {
164   avg_pool2d_kernel(
165       kCPU,
166       output,
167       input,
168       kW,
169       kH,
170       dW,
171       dH,
172       padW,
173       padH,
174       count_include_pad,
175       divisor_override);
176 }
177 
TORCH_IMPL_FUNC(avg_pool2d_backward_out_cpu)178 TORCH_IMPL_FUNC(avg_pool2d_backward_out_cpu) (
179   const Tensor& gradOutput,
180   const Tensor& input,
181   IntArrayRef kernel_size,
182   IntArrayRef stride,
183   IntArrayRef padding,
184   bool ceil_mode,
185   bool count_include_pad,
186   std::optional<int64_t> divisor_override,
187   const Tensor& gradInput
188 ) {
189   const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
190   const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
191 
192   const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
193   const int dW = stride.empty() ? kW :
194                  stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
195 
196   const int padH = safe_downcast<int, int64_t>(padding[0]);
197   const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
198 
199   TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero");
200 
201   TORCH_CHECK(input.dtype() == gradOutput.dtype(),
202     "expected dtype ", input.dtype(), " for `gradOutput` but got dtype ", gradOutput.dtype());
203 
204   /* zero the gradient */
205   gradInput.zero_();
206 
207   avg_pool2d_backward_kernel(
208       kCPU, gradInput, gradOutput,
209       kW, kH, dW, dH, padW, padH,
210       count_include_pad, divisor_override);
211 }
212 
213 DEFINE_DISPATCH(avg_pool2d_kernel);
214 DEFINE_DISPATCH(avg_pool2d_backward_kernel);
215 
216 } // namespace at::native
217