xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <c10/util/irange.h>
6 #include <tuple>
7 
8 #include <ATen/native/AdaptivePooling.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/adaptive_max_pool3d_backward_native.h>
14 #include <ATen/ops/adaptive_max_pool3d_native.h>
15 #endif
16 
17 namespace at::meta {
TORCH_META_FUNC(adaptive_max_pool3d)18 TORCH_META_FUNC(adaptive_max_pool3d) (const Tensor& input, IntArrayRef output_size) {
19   auto ndim = input.ndimension();
20   TORCH_CHECK(
21     ndim == 4 || ndim == 5,
22     "adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: ", input.sizes());
23   for (const auto i : c10::irange(1, ndim)) {
24     TORCH_CHECK(
25         input.size(i) > 0,
26         "adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
27         "but input has sizes ",
28         input.sizes(),
29         " with dimension ",
30         i,
31         " being "
32         "empty");
33   }
34 
35   TORCH_CHECK(
36       output_size.size() == 3,
37       "adaptive_max_pool3d(): internal error: output_size.size() must be 3");
38 
39   int dimD = 0;
40   int64_t sizeB = 1;
41   int64_t sizeD = 0;
42 
43   if (ndim == 5) {
44     sizeB = input.size(0);
45     dimD++;
46   }
47 
48   /* sizes */
49   sizeD = input.size(dimD);
50 
51   int64_t osizeT = output_size[0];
52   int64_t osizeH = output_size[1];
53   int64_t osizeW = output_size[2];
54 
55   /* resize output */
56   if (ndim == 4) {
57     set_output_raw_strided(0, {sizeD, osizeT, osizeH, osizeW}, {}, input.options());
58     /* indices will contain max input locations for each output point */
59     set_output_raw_strided(1, {sizeD, osizeT, osizeH, osizeW}, {}, input.options().dtype(kLong));
60   } else {
61     set_output_raw_strided(0, {sizeB, sizeD, osizeT, osizeH, osizeW}, {}, input.options());
62     /* indices will contain max input locations for each output point */
63     set_output_raw_strided(1, {sizeB, sizeD, osizeT, osizeH, osizeW}, {}, input.options().dtype(kLong));
64   }
65 }
66 
TORCH_META_FUNC(adaptive_max_pool3d_backward)67 TORCH_META_FUNC(adaptive_max_pool3d_backward)
68 (const Tensor& gradOutput, const Tensor& input, const Tensor& indices) {
69     at::native::adaptive_pool_empty_output_check(gradOutput, "adaptive_max_pool3d_backward");
70     set_output_raw_strided(0, input.sizes(), {}, input.options());
71 }
72 } // namespace meta
73 
74 namespace at::native {
75 
76 namespace {
77 
78 // #define START_IND(a,b,c) a * c / b
79 // #define END_IND(a,b,c)  (a + 1) * c / b + ((a + 1) * c % b > 0)?1:0
80 
81 // 5d tensor B x D x T x H x W
82 
83 template <typename scalar_t>
adaptive_max_pool3d_single_out_frame(const scalar_t * input_p,scalar_t * output_p,int64_t * ind_p,int64_t sizeD,int64_t isizeT,int64_t isizeH,int64_t isizeW,int64_t osizeT,int64_t osizeH,int64_t osizeW,int64_t istrideD,int64_t istrideT,int64_t istrideH,int64_t istrideW)84 static void adaptive_max_pool3d_single_out_frame(
85           const scalar_t *input_p,
86           scalar_t *output_p,
87           int64_t *ind_p,
88           int64_t sizeD,
89           int64_t isizeT,
90           int64_t isizeH,
91           int64_t isizeW,
92           int64_t osizeT,
93           int64_t osizeH,
94           int64_t osizeW,
95           int64_t istrideD,
96           int64_t istrideT,
97           int64_t istrideH,
98           int64_t istrideW)
99 {
100   at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) {
101     for (const auto d : c10::irange(start, end)) {
102       /* loop over output */
103       int64_t ot = 0, oh = 0, ow = 0;
104       for(ot = 0; ot < osizeT; ot++)
105       {
106         int64_t istartT = start_index(ot, osizeT, isizeT);
107         int64_t iendT   = end_index(ot, osizeT, isizeT);
108         int64_t kT = iendT - istartT;
109 
110         for(oh = 0; oh < osizeH; oh++)
111         {
112           int64_t istartH = start_index(oh, osizeH, isizeH);
113           int64_t iendH   = end_index(oh, osizeH, isizeH);
114           int64_t kH = iendH - istartH;
115 
116           for(ow = 0; ow < osizeW; ow++)
117           {
118 
119             int64_t istartW = start_index(ow, osizeW, isizeW);
120             int64_t iendW   = end_index(ow, osizeW, isizeW);
121             int64_t kW = iendW - istartW;
122 
123             /* local pointers */
124             const scalar_t *ip = input_p   + d*istrideD + istartT *istrideT + istartH*istrideH + istartW*istrideW;
125             scalar_t *op = output_p  + d*osizeT*osizeH*osizeW + ot*osizeH*osizeW + oh*osizeW + ow;
126             int64_t *indp = ind_p   + d*osizeT*osizeH*osizeW + ot*osizeH*osizeW + oh*osizeW + ow;
127 
128             /* compute local max: */
129             int64_t it = 0, ih = 0, iw = 0;
130             int64_t maxindex = (it+istartT)*isizeH*isizeW + (ih+istartH)*isizeW + (iw+istartW);
131             scalar_t maxval = -std::numeric_limits<scalar_t>::infinity();
132             for(it = 0; it < kT; it++)
133             {
134               for(ih = 0; ih < kH; ih++)
135               {
136                 for(iw = 0; iw < kW; iw++)
137                 {
138                   scalar_t val = *(ip + it*istrideT + ih*istrideH + iw*istrideW);
139                   if ((val > maxval) || std::isnan(val))
140                   {
141                     maxval = val;
142                     maxindex = (it+istartT)*isizeH*isizeW + (ih+istartH)*isizeW + (iw+istartW);
143                   }
144                 }
145               }
146             }
147 
148             /* set output to local max */
149             *op = maxval;
150 
151             /* store location of max */
152             *indp = maxindex;
153           }
154         }
155       }
156     }
157   });
158 }
159 
160 template <typename scalar_t>
adaptive_max_pool3d_out_frame(const scalar_t * input_data,scalar_t * output_data,int64_t * indices_data,int64_t sizeB,int64_t sizeD,int64_t isizeT,int64_t isizeH,int64_t isizeW,int64_t osizeT,int64_t osizeH,int64_t osizeW,int64_t istrideB,int64_t istrideD,int64_t istrideT,int64_t istrideH,int64_t istrideW)161 static void adaptive_max_pool3d_out_frame(
162           const scalar_t *input_data,
163           scalar_t *output_data,
164           int64_t *indices_data,
165           int64_t sizeB,
166           int64_t sizeD,
167           int64_t isizeT,
168           int64_t isizeH,
169           int64_t isizeW,
170           int64_t osizeT,
171           int64_t osizeH,
172           int64_t osizeW,
173           int64_t istrideB,
174           int64_t istrideD,
175           int64_t istrideT,
176           int64_t istrideH,
177           int64_t istrideW)
178 {
179   at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) {
180     for (const auto b : c10::irange(start, end)) {
181       adaptive_max_pool3d_single_out_frame<scalar_t>(input_data+b*istrideB, output_data+b*sizeD*osizeT*osizeH*osizeW,
182                                                      indices_data+b*sizeD*osizeT*osizeH*osizeW,
183                                                      sizeD,
184                                                      isizeT, isizeH, isizeW,
185                                                      osizeT, osizeH, osizeW,
186                                                      istrideD, istrideT,
187                                                      istrideH, istrideW);
188     }
189   });
190 }
191 
192 template <typename scalar_t>
adaptive_max_pool3d_backward_single_out_frame(scalar_t * gradInput_p,const scalar_t * gradOutput_p,const int64_t * ind_p,int64_t sizeD,int64_t isizeT,int64_t isizeH,int64_t isizeW,int64_t osizeT,int64_t osizeH,int64_t osizeW)193 static void adaptive_max_pool3d_backward_single_out_frame(
194           scalar_t *gradInput_p,
195           const scalar_t *gradOutput_p,
196           const int64_t *ind_p,
197           int64_t sizeD,
198           int64_t isizeT,
199           int64_t isizeH,
200           int64_t isizeW,
201           int64_t osizeT,
202           int64_t osizeH,
203           int64_t osizeW)
204 {
205   at::parallel_for(0, sizeD, 0, [&](int64_t start, int64_t end) {
206     for (const auto d : c10::irange(start, end)) {
207       scalar_t *gradInput_p_d = gradInput_p + d*isizeT*isizeH*isizeW;
208       const scalar_t *gradOutput_p_d = gradOutput_p + d*osizeT*osizeH*osizeW;
209       const int64_t *ind_p_d = ind_p + d*osizeT*osizeH*osizeW;
210 
211       /* calculate max points */
212       int64_t ot = 0, oh = 0, ow = 0;
213       for(ot = 0; ot < osizeT; ot++)
214       {
215         for(oh = 0; oh < osizeH; oh++)
216         {
217           for(ow = 0; ow < osizeW; ow++)
218           {
219             /* retrieve position of max */
220             int64_t maxp = ind_p_d[ot*osizeH*osizeW + oh*osizeW + ow];
221 
222             /* update gradient */
223             gradInput_p_d[maxp] += gradOutput_p_d[ot*osizeH*osizeW + oh*osizeW + ow];
224           }
225         }
226       }
227     }
228   });
229 }
230 
231 template <typename scalar_t>
adaptive_max_pool3d_backward_out_frame(scalar_t * gradInput_data,const scalar_t * gradOutput_data,const int64_t * indices_data,int64_t sizeB,int64_t sizeD,int64_t isizeT,int64_t isizeH,int64_t isizeW,int64_t osizeT,int64_t osizeH,int64_t osizeW)232 static void adaptive_max_pool3d_backward_out_frame(
233           scalar_t *gradInput_data,
234           const scalar_t *gradOutput_data,
235           const int64_t *indices_data,
236           int64_t sizeB,
237           int64_t sizeD,
238           int64_t isizeT,
239           int64_t isizeH,
240           int64_t isizeW,
241           int64_t osizeT,
242           int64_t osizeH,
243           int64_t osizeW)
244 {
245   at::parallel_for(0, sizeB, 0, [&](int64_t start, int64_t end) {
246     for (const auto b : c10::irange(start, end)) {
247       adaptive_max_pool3d_backward_single_out_frame<scalar_t>(gradInput_data+b*sizeD*isizeT*isizeH*isizeW, gradOutput_data+b*sizeD*osizeT*osizeH*osizeW,
248                                                               indices_data+b*sizeD*osizeT*osizeH*osizeW,
249                                                               sizeD,
250                                                               isizeT, isizeH, isizeW,
251                                                               osizeT, osizeH, osizeW);
252     }
253   });
254 }
255 } // namespace
256 
TORCH_IMPL_FUNC(adaptive_max_pool3d_out_cpu)257 TORCH_IMPL_FUNC(adaptive_max_pool3d_out_cpu)
258 (const Tensor& input, IntArrayRef output_size, const Tensor& output, const Tensor& indices) {
259   int dimD = 0;
260   int dimT = 1;
261   int dimH = 2;
262   int dimW = 3;
263   int64_t sizeB = 1;
264   int64_t sizeD = 0;
265   int64_t isizeT = 0;
266   int64_t isizeH = 0;
267   int64_t isizeW = 0;
268 
269   int64_t istrideB = 0;
270   int64_t istrideD = 0;
271   int64_t istrideT = 0;
272   int64_t istrideH = 0;
273   int64_t istrideW = 0;
274 
275   if (input.ndimension() == 5) {
276     istrideB = input.stride(0);
277     sizeB = input.size(0);
278     dimD++;
279     dimT++;
280     dimH++;
281     dimW++;
282   }
283 
284   /* sizes */
285   sizeD = input.size(dimD);
286   isizeT = input.size(dimT);
287   isizeH = input.size(dimH);
288   isizeW = input.size(dimW);
289   /* strides */
290   istrideD = input.stride(dimD);
291   istrideT = input.stride(dimT);
292   istrideH = input.stride(dimH);
293   istrideW = input.stride(dimW);
294 
295   int64_t osizeT = output_size[0];
296   int64_t osizeH = output_size[1];
297   int64_t osizeW = output_size[2];
298 
299   if (input.ndimension() == 4) {
300     AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16,
301         input.scalar_type(), "adaptive_max_pool3d_cpu", [&] {
302           auto input_data = input.const_data_ptr<scalar_t>();
303           auto output_data = output.data_ptr<scalar_t>();
304           auto indices_data = indices.data_ptr<int64_t>();
305 
306           adaptive_max_pool3d_single_out_frame<scalar_t>(
307               input_data,
308               output_data,
309               indices_data,
310               sizeD,
311               isizeT,
312               isizeH,
313               isizeW,
314               osizeT,
315               osizeH,
316               osizeW,
317               istrideD,
318               istrideT,
319               istrideH,
320               istrideW);
321         });
322   } else {
323     AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16,
324         input.scalar_type(), "adaptive_max_pool3d_cpu", [&] {
325           auto input_data = input.const_data_ptr<scalar_t>();
326           auto output_data = output.data_ptr<scalar_t>();
327           auto indices_data = indices.data_ptr<int64_t>();
328 
329           adaptive_max_pool3d_out_frame<scalar_t>(
330               input_data,
331               output_data,
332               indices_data,
333               sizeB,
334               sizeD,
335               isizeT,
336               isizeH,
337               isizeW,
338               osizeT,
339               osizeH,
340               osizeW,
341               istrideB,
342               istrideD,
343               istrideT,
344               istrideH,
345               istrideW);
346         });
347   }
348 }
349 
TORCH_IMPL_FUNC(adaptive_max_pool3d_backward_out_cpu)350 TORCH_IMPL_FUNC(adaptive_max_pool3d_backward_out_cpu)
351 (const Tensor& gradOutput,
352  const Tensor& input,
353  const Tensor& indices,
354  const Tensor& gradInput) {
355   int dimD = 0;
356   int dimT = 1;
357   int dimH = 2;
358   int dimW = 3;
359   int64_t sizeB = 1;
360   int64_t sizeD = 0;
361   int64_t isizeT = 0;
362   int64_t isizeH = 0;
363   int64_t isizeW = 0;
364   int64_t osizeT = 0;
365   int64_t osizeH = 0;
366   int64_t osizeW = 0;
367 
368   /* get contiguous gradOutput */
369   auto gradOutput_ = gradOutput.contiguous();
370 
371   /* resize */
372   gradInput.zero_();
373 
374   if (input.ndimension() == 5) {
375     sizeB = input.size(0);
376     dimD++;
377     dimT++;
378     dimH++;
379     dimW++;
380   }
381 
382   /* sizes */
383   sizeD = input.size(dimD);
384   isizeT = input.size(dimT);
385   isizeH = input.size(dimH);
386   isizeW = input.size(dimW);
387   osizeT = gradOutput_.size(dimT);
388   osizeH = gradOutput_.size(dimH);
389   osizeW = gradOutput_.size(dimW);
390 
391   /* backprop */
392   if (input.ndimension() == 4) {
393     AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16,
394         input.scalar_type(), "adaptive_max_pool3d_backward", [&] {
395           /* get raw pointers */
396           scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
397           const scalar_t* gradOutput_data = gradOutput_.const_data_ptr<scalar_t>();
398           const int64_t* indices_data = indices.const_data_ptr<int64_t>();
399 
400           adaptive_max_pool3d_backward_single_out_frame<scalar_t>(
401               gradInput_data,
402               gradOutput_data,
403               indices_data,
404               sizeD,
405               isizeT,
406               isizeH,
407               isizeW,
408               osizeT,
409               osizeH,
410               osizeW);
411         });
412   } else {
413     AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16,
414         input.scalar_type(), "adaptive_max_pool3d_backward", [&] {
415           /* get raw pointers */
416           scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
417           const scalar_t* gradOutput_data = gradOutput_.const_data_ptr<scalar_t>();
418           const int64_t* indices_data = indices.const_data_ptr<int64_t>();
419 
420           adaptive_max_pool3d_backward_out_frame<scalar_t>(
421               gradInput_data,
422               gradOutput_data,
423               indices_data,
424               sizeB,
425               sizeD,
426               isizeT,
427               isizeH,
428               isizeW,
429               osizeT,
430               osizeH,
431               osizeW);
432         });
433   }
434 }
435 } // namespace at::native
436