1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <c10/util/irange.h>
6 #include <tuple>
7
8 #include <ATen/native/AdaptivePooling.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/adaptive_max_pool3d_backward_native.h>
14 #include <ATen/ops/adaptive_max_pool3d_native.h>
15 #endif
16
17 namespace at::meta {
TORCH_META_FUNC(adaptive_max_pool3d)18 TORCH_META_FUNC(adaptive_max_pool3d) (const Tensor& input, IntArrayRef output_size) {
19 auto ndim = input.ndimension();
20 TORCH_CHECK(
21 ndim == 4 || ndim == 5,
22 "adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: ", input.sizes());
23 for (const auto i : c10::irange(1, ndim)) {
24 TORCH_CHECK(
25 input.size(i) > 0,
26 "adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
27 "but input has sizes ",
28 input.sizes(),
29 " with dimension ",
30 i,
31 " being "
32 "empty");
33 }
34
35 TORCH_CHECK(
36 output_size.size() == 3,
37 "adaptive_max_pool3d(): internal error: output_size.size() must be 3");
38
39 int dimD = 0;
40 int64_t sizeB = 1;
41 int64_t sizeD = 0;
42
43 if (ndim == 5) {
44 sizeB = input.size(0);
45 dimD++;
46 }
47
48 /* sizes */
49 sizeD = input.size(dimD);
50
51 int64_t osizeT = output_size[0];
52 int64_t osizeH = output_size[1];
53 int64_t osizeW = output_size[2];
54
55 /* resize output */
56 if (ndim == 4) {
57 set_output_raw_strided(0, {sizeD, osizeT, osizeH, osizeW}, {}, input.options());
58 /* indices will contain max input locations for each output point */
59 set_output_raw_strided(1, {sizeD, osizeT, osizeH, osizeW}, {}, input.options().dtype(kLong));
60 } else {
61 set_output_raw_strided(0, {sizeB, sizeD, osizeT, osizeH, osizeW}, {}, input.options());
62 /* indices will contain max input locations for each output point */
63 set_output_raw_strided(1, {sizeB, sizeD, osizeT, osizeH, osizeW}, {}, input.options().dtype(kLong));
64 }
65 }
66
TORCH_META_FUNC(adaptive_max_pool3d_backward)67 TORCH_META_FUNC(adaptive_max_pool3d_backward)
68 (const Tensor& gradOutput, const Tensor& input, const Tensor& indices) {
69 at::native::adaptive_pool_empty_output_check(gradOutput, "adaptive_max_pool3d_backward");
70 set_output_raw_strided(0, input.sizes(), {}, input.options());
71 }
72 } // namespace meta
73
74 namespace at::native {
75
76 namespace {
77
78 // #define START_IND(a,b,c) a * c / b
79 // #define END_IND(a,b,c) (a + 1) * c / b + ((a + 1) * c % b > 0)?1:0
80
81 // 5d tensor B x D x T x H x W
82
83 template <typename scalar_t>
adaptive_max_pool3d_single_out_frame(const scalar_t * input_p,scalar_t * output_p,int64_t * ind_p,int64_t sizeD,int64_t isizeT,int64_t isizeH,int64_t isizeW,int64_t osizeT,int64_t osizeH,int64_t osizeW,int64_t istrideD,int64_t istrideT,int64_t istrideH,int64_t istrideW)84 static void adaptive_max_pool3d_single_out_frame(
85 const scalar_t *input_p,
86 scalar_t *output_p,
87 int64_t *ind_p,
88 int64_t sizeD,
89 int64_t isizeT,
90 int64_t isizeH,
91 int64_t isizeW,
92 int64_t osizeT,
93 int64_t osizeH,
94 int64_t osizeW,
95 int64_t istrideD,
96 int64_t istrideT,
97 int64_t istrideH,
98 int64_t istrideW)
99 {
100 at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) {
101 for (const auto d : c10::irange(start, end)) {
102 /* loop over output */
103 int64_t ot = 0, oh = 0, ow = 0;
104 for(ot = 0; ot < osizeT; ot++)
105 {
106 int64_t istartT = start_index(ot, osizeT, isizeT);
107 int64_t iendT = end_index(ot, osizeT, isizeT);
108 int64_t kT = iendT - istartT;
109
110 for(oh = 0; oh < osizeH; oh++)
111 {
112 int64_t istartH = start_index(oh, osizeH, isizeH);
113 int64_t iendH = end_index(oh, osizeH, isizeH);
114 int64_t kH = iendH - istartH;
115
116 for(ow = 0; ow < osizeW; ow++)
117 {
118
119 int64_t istartW = start_index(ow, osizeW, isizeW);
120 int64_t iendW = end_index(ow, osizeW, isizeW);
121 int64_t kW = iendW - istartW;
122
123 /* local pointers */
124 const scalar_t *ip = input_p + d*istrideD + istartT *istrideT + istartH*istrideH + istartW*istrideW;
125 scalar_t *op = output_p + d*osizeT*osizeH*osizeW + ot*osizeH*osizeW + oh*osizeW + ow;
126 int64_t *indp = ind_p + d*osizeT*osizeH*osizeW + ot*osizeH*osizeW + oh*osizeW + ow;
127
128 /* compute local max: */
129 int64_t it = 0, ih = 0, iw = 0;
130 int64_t maxindex = (it+istartT)*isizeH*isizeW + (ih+istartH)*isizeW + (iw+istartW);
131 scalar_t maxval = -std::numeric_limits<scalar_t>::infinity();
132 for(it = 0; it < kT; it++)
133 {
134 for(ih = 0; ih < kH; ih++)
135 {
136 for(iw = 0; iw < kW; iw++)
137 {
138 scalar_t val = *(ip + it*istrideT + ih*istrideH + iw*istrideW);
139 if ((val > maxval) || std::isnan(val))
140 {
141 maxval = val;
142 maxindex = (it+istartT)*isizeH*isizeW + (ih+istartH)*isizeW + (iw+istartW);
143 }
144 }
145 }
146 }
147
148 /* set output to local max */
149 *op = maxval;
150
151 /* store location of max */
152 *indp = maxindex;
153 }
154 }
155 }
156 }
157 });
158 }
159
160 template <typename scalar_t>
adaptive_max_pool3d_out_frame(const scalar_t * input_data,scalar_t * output_data,int64_t * indices_data,int64_t sizeB,int64_t sizeD,int64_t isizeT,int64_t isizeH,int64_t isizeW,int64_t osizeT,int64_t osizeH,int64_t osizeW,int64_t istrideB,int64_t istrideD,int64_t istrideT,int64_t istrideH,int64_t istrideW)161 static void adaptive_max_pool3d_out_frame(
162 const scalar_t *input_data,
163 scalar_t *output_data,
164 int64_t *indices_data,
165 int64_t sizeB,
166 int64_t sizeD,
167 int64_t isizeT,
168 int64_t isizeH,
169 int64_t isizeW,
170 int64_t osizeT,
171 int64_t osizeH,
172 int64_t osizeW,
173 int64_t istrideB,
174 int64_t istrideD,
175 int64_t istrideT,
176 int64_t istrideH,
177 int64_t istrideW)
178 {
179 at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) {
180 for (const auto b : c10::irange(start, end)) {
181 adaptive_max_pool3d_single_out_frame<scalar_t>(input_data+b*istrideB, output_data+b*sizeD*osizeT*osizeH*osizeW,
182 indices_data+b*sizeD*osizeT*osizeH*osizeW,
183 sizeD,
184 isizeT, isizeH, isizeW,
185 osizeT, osizeH, osizeW,
186 istrideD, istrideT,
187 istrideH, istrideW);
188 }
189 });
190 }
191
192 template <typename scalar_t>
adaptive_max_pool3d_backward_single_out_frame(scalar_t * gradInput_p,const scalar_t * gradOutput_p,const int64_t * ind_p,int64_t sizeD,int64_t isizeT,int64_t isizeH,int64_t isizeW,int64_t osizeT,int64_t osizeH,int64_t osizeW)193 static void adaptive_max_pool3d_backward_single_out_frame(
194 scalar_t *gradInput_p,
195 const scalar_t *gradOutput_p,
196 const int64_t *ind_p,
197 int64_t sizeD,
198 int64_t isizeT,
199 int64_t isizeH,
200 int64_t isizeW,
201 int64_t osizeT,
202 int64_t osizeH,
203 int64_t osizeW)
204 {
205 at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) {
206 for (const auto d : c10::irange(start, end)) {
207 scalar_t *gradInput_p_d = gradInput_p + d*isizeT*isizeH*isizeW;
208 const scalar_t *gradOutput_p_d = gradOutput_p + d*osizeT*osizeH*osizeW;
209 const int64_t *ind_p_d = ind_p + d*osizeT*osizeH*osizeW;
210
211 /* calculate max points */
212 int64_t ot = 0, oh = 0, ow = 0;
213 for(ot = 0; ot < osizeT; ot++)
214 {
215 for(oh = 0; oh < osizeH; oh++)
216 {
217 for(ow = 0; ow < osizeW; ow++)
218 {
219 /* retrieve position of max */
220 int64_t maxp = ind_p_d[ot*osizeH*osizeW + oh*osizeW + ow];
221
222 /* update gradient */
223 gradInput_p_d[maxp] += gradOutput_p_d[ot*osizeH*osizeW + oh*osizeW + ow];
224 }
225 }
226 }
227 }
228 });
229 }
230
231 template <typename scalar_t>
adaptive_max_pool3d_backward_out_frame(scalar_t * gradInput_data,const scalar_t * gradOutput_data,const int64_t * indices_data,int64_t sizeB,int64_t sizeD,int64_t isizeT,int64_t isizeH,int64_t isizeW,int64_t osizeT,int64_t osizeH,int64_t osizeW)232 static void adaptive_max_pool3d_backward_out_frame(
233 scalar_t *gradInput_data,
234 const scalar_t *gradOutput_data,
235 const int64_t *indices_data,
236 int64_t sizeB,
237 int64_t sizeD,
238 int64_t isizeT,
239 int64_t isizeH,
240 int64_t isizeW,
241 int64_t osizeT,
242 int64_t osizeH,
243 int64_t osizeW)
244 {
245 at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) {
246 for (const auto b : c10::irange(start, end)) {
247 adaptive_max_pool3d_backward_single_out_frame<scalar_t>(gradInput_data+b*sizeD*isizeT*isizeH*isizeW, gradOutput_data+b*sizeD*osizeT*osizeH*osizeW,
248 indices_data+b*sizeD*osizeT*osizeH*osizeW,
249 sizeD,
250 isizeT, isizeH, isizeW,
251 osizeT, osizeH, osizeW);
252 }
253 });
254 }
255 } // namespace
256
TORCH_IMPL_FUNC(adaptive_max_pool3d_out_cpu)257 TORCH_IMPL_FUNC(adaptive_max_pool3d_out_cpu)
258 (const Tensor& input, IntArrayRef output_size, const Tensor& output, const Tensor& indices) {
259 int dimD = 0;
260 int dimT = 1;
261 int dimH = 2;
262 int dimW = 3;
263 int64_t sizeB = 1;
264 int64_t sizeD = 0;
265 int64_t isizeT = 0;
266 int64_t isizeH = 0;
267 int64_t isizeW = 0;
268
269 int64_t istrideB = 0;
270 int64_t istrideD = 0;
271 int64_t istrideT = 0;
272 int64_t istrideH = 0;
273 int64_t istrideW = 0;
274
275 if (input.ndimension() == 5) {
276 istrideB = input.stride(0);
277 sizeB = input.size(0);
278 dimD++;
279 dimT++;
280 dimH++;
281 dimW++;
282 }
283
284 /* sizes */
285 sizeD = input.size(dimD);
286 isizeT = input.size(dimT);
287 isizeH = input.size(dimH);
288 isizeW = input.size(dimW);
289 /* strides */
290 istrideD = input.stride(dimD);
291 istrideT = input.stride(dimT);
292 istrideH = input.stride(dimH);
293 istrideW = input.stride(dimW);
294
295 int64_t osizeT = output_size[0];
296 int64_t osizeH = output_size[1];
297 int64_t osizeW = output_size[2];
298
299 if (input.ndimension() == 4) {
300 AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16,
301 input.scalar_type(), "adaptive_max_pool3d_cpu", [&] {
302 auto input_data = input.const_data_ptr<scalar_t>();
303 auto output_data = output.data_ptr<scalar_t>();
304 auto indices_data = indices.data_ptr<int64_t>();
305
306 adaptive_max_pool3d_single_out_frame<scalar_t>(
307 input_data,
308 output_data,
309 indices_data,
310 sizeD,
311 isizeT,
312 isizeH,
313 isizeW,
314 osizeT,
315 osizeH,
316 osizeW,
317 istrideD,
318 istrideT,
319 istrideH,
320 istrideW);
321 });
322 } else {
323 AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16,
324 input.scalar_type(), "adaptive_max_pool3d_cpu", [&] {
325 auto input_data = input.const_data_ptr<scalar_t>();
326 auto output_data = output.data_ptr<scalar_t>();
327 auto indices_data = indices.data_ptr<int64_t>();
328
329 adaptive_max_pool3d_out_frame<scalar_t>(
330 input_data,
331 output_data,
332 indices_data,
333 sizeB,
334 sizeD,
335 isizeT,
336 isizeH,
337 isizeW,
338 osizeT,
339 osizeH,
340 osizeW,
341 istrideB,
342 istrideD,
343 istrideT,
344 istrideH,
345 istrideW);
346 });
347 }
348 }
349
TORCH_IMPL_FUNC(adaptive_max_pool3d_backward_out_cpu)350 TORCH_IMPL_FUNC(adaptive_max_pool3d_backward_out_cpu)
351 (const Tensor& gradOutput,
352 const Tensor& input,
353 const Tensor& indices,
354 const Tensor& gradInput) {
355 int dimD = 0;
356 int dimT = 1;
357 int dimH = 2;
358 int dimW = 3;
359 int64_t sizeB = 1;
360 int64_t sizeD = 0;
361 int64_t isizeT = 0;
362 int64_t isizeH = 0;
363 int64_t isizeW = 0;
364 int64_t osizeT = 0;
365 int64_t osizeH = 0;
366 int64_t osizeW = 0;
367
368 /* get contiguous gradOutput */
369 auto gradOutput_ = gradOutput.contiguous();
370
371 /* resize */
372 gradInput.zero_();
373
374 if (input.ndimension() == 5) {
375 sizeB = input.size(0);
376 dimD++;
377 dimT++;
378 dimH++;
379 dimW++;
380 }
381
382 /* sizes */
383 sizeD = input.size(dimD);
384 isizeT = input.size(dimT);
385 isizeH = input.size(dimH);
386 isizeW = input.size(dimW);
387 osizeT = gradOutput_.size(dimT);
388 osizeH = gradOutput_.size(dimH);
389 osizeW = gradOutput_.size(dimW);
390
391 /* backprop */
392 if (input.ndimension() == 4) {
393 AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16,
394 input.scalar_type(), "adaptive_max_pool3d_backward", [&] {
395 /* get raw pointers */
396 scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
397 const scalar_t* gradOutput_data = gradOutput_.const_data_ptr<scalar_t>();
398 const int64_t* indices_data = indices.const_data_ptr<int64_t>();
399
400 adaptive_max_pool3d_backward_single_out_frame<scalar_t>(
401 gradInput_data,
402 gradOutput_data,
403 indices_data,
404 sizeD,
405 isizeT,
406 isizeH,
407 isizeW,
408 osizeT,
409 osizeH,
410 osizeW);
411 });
412 } else {
413 AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16,
414 input.scalar_type(), "adaptive_max_pool3d_backward", [&] {
415 /* get raw pointers */
416 scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
417 const scalar_t* gradOutput_data = gradOutput_.const_data_ptr<scalar_t>();
418 const int64_t* indices_data = indices.const_data_ptr<int64_t>();
419
420 adaptive_max_pool3d_backward_out_frame<scalar_t>(
421 gradInput_data,
422 gradOutput_data,
423 indices_data,
424 sizeB,
425 sizeD,
426 isizeT,
427 isizeH,
428 isizeW,
429 osizeT,
430 osizeH,
431 osizeW);
432 });
433 }
434 }
435 } // namespace at::native
436