1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ScalarOps.h>
4 #include <ATen/native/Pool.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/avg_pool2d_backward_native.h>
11 #include <ATen/ops/avg_pool2d_native.h>
12 #endif
13
14 namespace at::meta {
15 using namespace ::at::native;
16
TORCH_PRECOMPUTE_META_FUNC(avg_pool2d)17 TORCH_PRECOMPUTE_META_FUNC(avg_pool2d)
18 (const Tensor& input,
19 IntArrayRef kernel_size,
20 IntArrayRef stride,
21 IntArrayRef padding,
22 bool ceil_mode,
23 bool count_include_pad,
24 std::optional<int64_t> divisor_override) {
25 // #20866, #22032: Guarantee this for the official C++ API?
26 TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
27 "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints");
28 const int64_t kH = kernel_size[0];
29 const int64_t kW = kernel_size.size() == 1 ? kH : kernel_size[1];
30
31 TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
32 "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints");
33 const int64_t dH = stride.empty() ? kH : stride[0];
34 const int64_t dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1];
35
36 TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
37 "avg_pool2d: padding must either be a single int, or a tuple of two ints");
38 const int64_t padH = padding[0];
39 const int64_t padW = padding.size() == 1 ? padH : padding[1];
40
41 TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
42 "divisor must be not zero");
43
44 const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
45 const int64_t nInputPlane = input.size(-3);
46 const int64_t inputHeight = input.size(-2);
47 const int64_t inputWidth = input.size(-1);
48
49 const int64_t outputHeight = pooling_output_shape<int64_t>(
50 inputHeight, kH, padH, dH, 1, ceil_mode);
51 const int64_t outputWidth =
52 pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
53
54 auto memory_format = input.suggest_memory_format();
55 pool2d_shape_check(
56 input,
57 kH,
58 kW,
59 dH,
60 dW,
61 padH,
62 padW,
63 1,
64 1,
65 nInputPlane,
66 inputHeight,
67 inputWidth,
68 outputHeight,
69 outputWidth,
70 memory_format);
71
72 /* resize output */
73 if (input.ndimension() == 3) {
74 set_output_raw_strided(
75 0,
76 {nInputPlane,
77 outputHeight,
78 outputWidth},
79 {},
80 input.options());
81 }
82 else {
83 set_output_raw_strided(
84 0,
85 {nbatch,
86 nInputPlane,
87 outputHeight,
88 outputWidth},
89 {},
90 input.options().memory_format(memory_format));
91 }
92
93 return TORCH_PRECOMPUTE_STRUCT(avg_pool2d)().set_kH(kH).set_kW(kW).set_dH(dH).set_dW(dW).set_padH(padH).set_padW(padW);
94 }
95
TORCH_META_FUNC(avg_pool2d_backward)96 TORCH_META_FUNC(avg_pool2d_backward) (
97 const Tensor& gradOutput_,
98 const Tensor& input,
99 IntArrayRef kernel_size,
100 IntArrayRef stride,
101 IntArrayRef padding,
102 bool ceil_mode,
103 bool count_include_pad,
104 std::optional<int64_t> divisor_override
105 ) {
106 // #20866, #22032: Guarantee this for the official C++ API?
107 TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
108 "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints");
109 const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
110 const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
111
112 TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
113 "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints");
114 const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
115 const int dW = stride.empty() ? kW :
116 stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
117
118 TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
119 "avg_pool2d: padding must either be a single int, or a tuple of two ints");
120 const int padH = safe_downcast<int, int64_t>(padding[0]);
121 const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
122
123 TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero");
124
125 /* sizes */
126 const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
127 const int64_t nInputPlane = input.size(-3); // number of channels (or colors)
128 const int64_t inputHeight = input.size(-2);
129 const int64_t inputWidth = input.size(-1);
130 const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
131 const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
132
133 auto memory_format = input.suggest_memory_format();
134 avg_pool2d_backward_shape_check(
135 input,
136 gradOutput_,
137 nbatch,
138 kH, kW, dH, dW, padH, padW,
139 nInputPlane,
140 inputHeight, inputWidth,
141 outputHeight, outputWidth,
142 memory_format);
143
144 /* resize output */
145 set_output_raw_strided(0, input.sizes(), {}, input.options().memory_format(memory_format));
146 }
147
148 } // namespace at::meta
149
150 namespace at::native {
151
TORCH_IMPL_FUNC(avg_pool2d_out_cpu)152 TORCH_IMPL_FUNC(avg_pool2d_out_cpu)
153 (const Tensor& input,
154 int64_t kH,
155 int64_t kW,
156 int64_t dH,
157 int64_t dW,
158 int64_t padH,
159 int64_t padW,
160 bool ceil_mode,
161 bool count_include_pad,
162 std::optional<int64_t> divisor_override,
163 const Tensor& output) {
164 avg_pool2d_kernel(
165 kCPU,
166 output,
167 input,
168 kW,
169 kH,
170 dW,
171 dH,
172 padW,
173 padH,
174 count_include_pad,
175 divisor_override);
176 }
177
TORCH_IMPL_FUNC(avg_pool2d_backward_out_cpu)178 TORCH_IMPL_FUNC(avg_pool2d_backward_out_cpu) (
179 const Tensor& gradOutput,
180 const Tensor& input,
181 IntArrayRef kernel_size,
182 IntArrayRef stride,
183 IntArrayRef padding,
184 bool ceil_mode,
185 bool count_include_pad,
186 std::optional<int64_t> divisor_override,
187 const Tensor& gradInput
188 ) {
189 const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
190 const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
191
192 const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
193 const int dW = stride.empty() ? kW :
194 stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
195
196 const int padH = safe_downcast<int, int64_t>(padding[0]);
197 const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
198
199 TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero");
200
201 TORCH_CHECK(input.dtype() == gradOutput.dtype(),
202 "expected dtype ", input.dtype(), " for `gradOutput` but got dtype ", gradOutput.dtype());
203
204 /* zero the gradient */
205 gradInput.zero_();
206
207 avg_pool2d_backward_kernel(
208 kCPU, gradInput, gradOutput,
209 kW, kH, dW, dH, padW, padH,
210 count_include_pad, divisor_override);
211 }
212
213 DEFINE_DISPATCH(avg_pool2d_kernel);
214 DEFINE_DISPATCH(avg_pool2d_backward_kernel);
215
216 } // namespace at::native
217