xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/Utils.h>
7 #include <ATen/cuda/Atomic.cuh>
8 #include <ATen/cuda/CUDAContext.h>
9 #include <c10/util/Exception.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/adaptive_avg_pool3d_backward_native.h>
16 #include <ATen/ops/adaptive_avg_pool3d_native.h>
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/empty_like.h>
19 #endif
20 
21 #include <ATen/native/AdaptivePooling.h>
22 
23 #include <algorithm>
24 #include <cfloat>
25 #include <cmath>
26 
27 
28 namespace at::native {
29 
30 namespace {
31 
start_index(int64_t a,int64_t b,int64_t c)32 __device__ inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
33   return (a / b) * c + ((a % b) * c) / b;
34 }
35 
end_index(int64_t a,int64_t b,int64_t c)36 __device__ inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
37   return 1 + ((a + 1) * c - 1) / b;
38 }
39 
40 // 5d tensor B x D x T x H x W
41 // All kernels view batch dim B and dim D as collapsed.
42 
43 /*
44  * Description:
45  *    this function adaptively average pools an input 5D tensor along dimensions
46  * 2, 3, and 4 5D input, 5D output
47  *
48  *    gridDim.y blocks work together on a single 2D output plane specified by
49  *    (blockIdx.x + offsetZ).
50  */
51 template <typename scalar_t, typename accscalar_t>
adaptiveaveragepool(const scalar_t * input,scalar_t * output,int isizeT,int isizeH,int isizeW,int osizeT,int osizeH,int osizeW,int64_t istrideD,int64_t istrideT,int64_t istrideH,int64_t istrideW,int64_t offsetZ)52 __global__ void adaptiveaveragepool(
53     const scalar_t *input, scalar_t *output,
54     int isizeT, int isizeH, int isizeW,
55     int osizeT, int osizeH, int osizeW,
56     int64_t istrideD,
57     int64_t istrideT, int64_t istrideH, int64_t istrideW,
58     int64_t offsetZ) {
59   // iterates on output pixels
60   int ot, oh, ow;
61 
62   // compute offsets based on thread/block ID
63   int ostartH = blockIdx.y * blockDim.y + threadIdx.y;
64   int oendH = osizeH;
65   int ostepH = gridDim.y * blockDim.y;
66   int ostartW = threadIdx.x;
67   int oendW = osizeW;
68   int ostepW = blockDim.x;
69 
70   // select output plane
71   int64_t o_plane = blockIdx.x + offsetZ;
72   ot = o_plane % osizeT; // output frame/time
73   int d = o_plane / osizeT; // slice/feature
74 
75   // input frame/time range is fixed.
76   int istartT = start_index(ot, osizeT, isizeT);
77   int iendT = end_index(ot, osizeT, isizeT);
78   int kT = iendT - istartT;
79 
80   // input offset by slice/feature and earliest relevant frame/time
81   const scalar_t *input_dt = input + d*istrideD + istartT*istrideT;
82   // output offset by slice/feature and frame/time
83   scalar_t *output_dt = output + o_plane*osizeH*osizeW;
84 
85   // For all output pixels...
86   for (oh = ostartH; oh < oendH; oh += ostepH) {
87     int istartH = start_index(oh, osizeH, isizeH);
88     int iendH = end_index(oh, osizeH, isizeH);
89     int kH = iendH - istartH;
90 
91     for (ow = ostartW; ow < oendW; ow += ostepW) {
92       int istartW = start_index(ow, osizeW, isizeW);
93       int iendW = end_index(ow, osizeW, isizeW);
94       int kW = iendW - istartW;
95 
96       // Compute the average pooling from corresponding input pixels
97       const scalar_t *ptr_input = input_dt + istartH*istrideH + istartW*istrideW;
98       scalar_t *ptr_output = output_dt + oh*osizeW + ow;
99       accscalar_t sum = static_cast<accscalar_t>(0);
100 
101       int it, ih, iw;
102       for (it = 0; it < kT; ++it) {
103         for (ih = 0; ih < kH; ++ih) {
104           for (iw = 0; iw < kW; ++iw) {
105             scalar_t val = ptr_input[ih*istrideH + iw*istrideW];
106             sum += static_cast<accscalar_t>(val);
107           }
108         }
109         ptr_input += istrideT; // next input frame
110       }
111       // Update output
112       const accscalar_t divide_factor = static_cast<accscalar_t>(kT * kH * kW);
113       *ptr_output = static_cast<scalar_t>(sum / divide_factor);
114     }
115   }
116 }
117 
118 template <typename scalar_t, typename accscalar_t>
adaptiveaveragepool_loop(const scalar_t * input_data,scalar_t * output_data,int64_t totalZ,int isizeT,int isizeH,int isizeW,int osizeT,int osizeH,int osizeW,int64_t istrideD,int64_t istrideT,int64_t istrideH,int64_t istrideW)119 void adaptiveaveragepool_loop(
120     const scalar_t *input_data, scalar_t *output_data,
121     int64_t totalZ,
122     int isizeT, int isizeH, int isizeW,
123     int osizeT, int osizeH, int osizeW,
124     int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) {
125   int64_t offsetZ = 0;
126   dim3 threads(32, 8);
127   // each H*W plane is processed by blocksH thread blocks
128   int blocksH = std::max((int)(16L / totalZ), 1);
129   while (totalZ > 0) {
130     dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
131     adaptiveaveragepool<scalar_t, accscalar_t>
132       <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
133         input_data, output_data,
134         isizeT, isizeH, isizeW,
135         osizeT, osizeH, osizeW,
136         istrideD,
137         istrideT, istrideH, istrideW,
138         offsetZ);
139     C10_CUDA_KERNEL_LAUNCH_CHECK();
140     totalZ -= 65535;
141     offsetZ += 65535;
142   }
143 }
144 
145 /*
146  * Description:
147  *    This function computes the gradInput from gradOutput.
148  *
149  *    gridDim.y blocks work together on a single 2D output plane specified by
150  *    (blockIdx.x + offsetZ).
151  */
152 template <typename scalar_t, typename accscalar_t>
adaptiveaveragegradinput(scalar_t * gradInput,const scalar_t * gradOutput,int isizeT,int isizeH,int isizeW,int osizeT,int osizeH,int osizeW,int64_t offsetZ)153 __global__ void adaptiveaveragegradinput(
154     scalar_t *gradInput, const scalar_t *gradOutput,
155     int isizeT, int isizeH, int isizeW,
156     int osizeT, int osizeH, int osizeW,
157     int64_t offsetZ)
158 {
159   // iterators on input pixels
160   int it, ih, iw;
161 
162   // compute offsets based on thread/block ID
163   int istartH = blockIdx.y * blockDim.y + threadIdx.y;
164   int iendH = isizeH;
165   int istepH = gridDim.y * blockDim.y;
166   int istartW = threadIdx.x;
167   int iendW = isizeW;
168   int istepW = blockDim.x;
169 
170   // select input plane
171   int64_t i_plane = blockIdx.x + offsetZ;
172   it = i_plane % isizeT; // output frame/time
173   int d = i_plane / isizeT; // slice/feature
174 
175   // output frame/time range is fixed.
176   int ostartT = start_index(it, isizeT, osizeT);
177   int oendT = end_index(it, isizeT, osizeT);
178 
179   // gradInput offset by slice/feature and frame/time.
180   scalar_t *gradInput_dt = gradInput + i_plane*isizeH*isizeW;
181   // gradOutput offset by slice/feature and earliest relevant frame/time
182   const scalar_t *gradOutput_dt = gradOutput + (d*osizeT + ostartT)*osizeH*osizeW;
183 
184   // For all input pixels...
185   for (ih = istartH; ih < iendH; ih += istepH) {
186     int ostartH = start_index(ih, isizeH, osizeH);
187     int oendH = end_index(ih, isizeH, osizeH);
188 
189     for (iw = istartW; iw < iendW; iw += istepW) {
190       int ostartW = start_index(iw, isizeW, osizeW);
191       int oendW = end_index(iw, isizeW, osizeW);
192 
193       // Compute the gradients from corresponding output pixels
194       scalar_t *ptr_gradInput = gradInput_dt + ih*isizeW + iw;
195       const scalar_t *ptr_gradOutput = gradOutput_dt;
196 
197       // for all relevant output pixels
198       int ot, oh, ow;
199       for (ot = ostartT; ot < oendT; ++ot) {
200         int kT = end_index(ot, osizeT, isizeT) - start_index(ot, osizeT, isizeT);
201         for (oh = ostartH; oh < oendH; ++oh) {
202           int kH = end_index(oh, osizeH, isizeH) - start_index(oh, osizeH, isizeH);
203           for (ow = ostartW; ow < oendW; ++ow) {
204             int kW = end_index(ow, osizeW, isizeW) - start_index(ow, osizeW, isizeW);
205             const accscalar_t divide_factor = kW * kH * kT;
206             accscalar_t grad_delta = static_cast<accscalar_t>(ptr_gradOutput[oh*osizeW + ow] / divide_factor);
207             *ptr_gradInput += static_cast<scalar_t>(grad_delta);
208           }
209         }
210         ptr_gradOutput += osizeH*osizeW; // next output frame
211       }
212     }
213   }
214 }
215 
216 template <typename scalar_t, typename accscalar_t>
adaptiveaveragegradinput_loop(scalar_t * gradInput_data,const scalar_t * gradOutput_data,int64_t totalZ,int isizeT,int isizeH,int isizeW,int osizeT,int osizeH,int osizeW)217 void adaptiveaveragegradinput_loop(
218     scalar_t *gradInput_data, const scalar_t *gradOutput_data,
219     int64_t totalZ,
220     int isizeT, int isizeH, int isizeW,
221     int osizeT, int osizeH, int osizeW) {
222   int64_t offsetZ = 0;
223   dim3 threads(32, 8);
224   // each H*W plane is processed by blocksH thread blocks
225   int blocksH = std::max((int)(16L / totalZ), 1);
226   while (totalZ > 0) {
227     dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
228     adaptiveaveragegradinput<scalar_t, accscalar_t>
229       <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
230         gradInput_data, gradOutput_data,
231         isizeT, isizeH, isizeW,
232         osizeT, osizeH, osizeW,
233         offsetZ);
234     C10_CUDA_KERNEL_LAUNCH_CHECK();
235     totalZ -= 65535;
236     offsetZ += 65535;
237   }
238 }
239 
240 /*
241  * Description:
242  *    This function computes the gradInput from gradOutput.
243  *
244  *    gridDim.y blocks work together on a single 2D output plane specified by
245  *    (blockIdx.x + offsetZ).
246  *
247  *    (uses atomic add)
248  *
249  */
250 template <typename scalar_t>
atomicadaptiveaveragegradinput(scalar_t * gradInput,const scalar_t * gradOutput,int isizeT,int isizeH,int isizeW,int osizeT,int osizeH,int osizeW,int64_t offsetZ)251 __global__ void atomicadaptiveaveragegradinput(
252     scalar_t *gradInput, const scalar_t *gradOutput,
253     int isizeT, int isizeH, int isizeW,
254     int osizeT, int osizeH, int osizeW,
255     int64_t offsetZ)
256 {
257   // iterators on output pixels
258   int ot, oh, ow;
259 
260   // compute offsets based on thread/block ID
261   int ostartH = blockIdx.y * blockDim.y + threadIdx.y;
262   int oendH = osizeH;
263   int ostepH = gridDim.y * blockDim.y;
264   int ostartW = threadIdx.x;
265   int oendW = osizeW;
266   int ostepW = blockDim.x;
267 
268   // select output plane
269   int64_t o_plane = blockIdx.x + offsetZ;
270   ot = o_plane % osizeT; // output frame/time
271   int d = o_plane / osizeT; // output slice/feature
272 
273   // input frame/time range is fixed.
274   int istartT = start_index(ot, osizeT, isizeT);
275   int iendT = end_index(ot, osizeT, isizeT);
276   int kT = iendT - istartT;
277 
278   // gradInput offset by slice/feature and earliest relevant frame/time
279   scalar_t *gradInput_nt = gradInput + (d*isizeT + istartT)*isizeH*isizeW;
280   // gradOutput offset by slice/feature and frame/time
281   const scalar_t *gradOutput_nt = gradOutput + o_plane*osizeH*osizeW;
282 
283   // For all output pixels...
284   for (oh = ostartH; oh < oendH; oh += ostepH) {
285     int istartH = start_index(oh, osizeH, isizeH);
286     int iendH = end_index(oh, osizeH, isizeH);
287     int kH = iendH - istartH;
288 
289     for (ow = ostartW; ow < oendW; ow += ostepW) {
290       int istartW = start_index(ow, osizeW, isizeW);
291       int iendW = end_index(ow, osizeW, isizeW);
292       int kW = iendW - istartW;
293 
294       // Compute the gradients from corresponding input pixels
295       scalar_t *ptr_gradInput = gradInput_nt + istartH*isizeW + istartW;
296       const scalar_t *ptr_gradOutput = gradOutput_nt + oh*osizeW + ow;
297       scalar_t grad_delta = *ptr_gradOutput / kT / kH / kW;
298 
299       int it, ih, iw;
300       for (it = 0; it < kT; ++it) {
301         for (ih = 0; ih < kH; ++ih) {
302           for (iw = 0; iw < kW; ++iw) {
303             gpuAtomicAddNoReturn(&(ptr_gradInput[ih*isizeW + iw]), grad_delta);
304           }
305         }
306         ptr_gradInput += isizeH*isizeW; // next input frame
307       }
308     }
309   }
310 }
311 
312 template <typename scalar_t>
atomicadaptiveaveragegradinput_loop(scalar_t * gradInput_data,const scalar_t * gradOutput_data,int64_t totalZ,int isizeT,int isizeH,int isizeW,int osizeT,int osizeH,int osizeW)313 void atomicadaptiveaveragegradinput_loop(
314     scalar_t* gradInput_data, const scalar_t* gradOutput_data,
315     int64_t totalZ,
316     int isizeT, int isizeH, int isizeW,
317     int osizeT, int osizeH, int osizeW) {
318   int64_t offsetZ = 0;
319   dim3 threads(32, 8);
320   int blocksH = std::max((int)(16L / totalZ), 1);
321   while (totalZ > 0) {
322     dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
323     atomicadaptiveaveragegradinput<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
324         gradInput_data, gradOutput_data,
325         isizeT, isizeH, isizeW,
326         osizeT, osizeH, osizeW,
327         offsetZ);
328     C10_CUDA_KERNEL_LAUNCH_CHECK();
329     totalZ -= 65535;
330     offsetZ += 65535;
331   }
332 }
333 
334 // 5D tensor B x D x T x H x w
335 
adaptive_avg_pool3d_out_cuda_template(Tensor & output,const Tensor & input_,IntArrayRef & output_size)336 void adaptive_avg_pool3d_out_cuda_template(
337     Tensor& output,
338     const Tensor& input_,
339     IntArrayRef& output_size) {
340   TensorArg output_arg{output, "output", 1};
341   TensorArg input_arg{input_, "input_", 2};
342 
343   checkAllSameGPU("adaptive_avg_pool3d_cuda", {output_arg, input_arg});
344 
345   for (int64_t i = 1; i < input_.ndimension(); i++) {
346     TORCH_CHECK(
347         input_.size(i) > 0,
348         "adaptive_avg_pool3d_cuda(): Expected input to have non-zero size for non-batch dimensions, "
349         "but input has sizes ", input_.sizes(),
350         " with dimension ", i, " being empty");
351   }
352 
353   TORCH_CHECK(
354       (input_.ndimension() == 4 || input_.ndimension() == 5),
355       "adaptive_avg_pool3d_cuda(): Expected 4D or 5D tensor, but got ", input_.sizes());
356 
357   // the jit sometimes passes output_size.size() == 1
358   TORCH_CHECK(
359       output_size.size() == 1 || output_size.size() == 3,
360       "adaptive_avg_pool3d: internal error: output_size.size() must be 1 or 3");
361 
362   int64_t osizeT = output_size[0];
363   int64_t osizeH = output_size[1];
364   int64_t osizeW = output_size[2];
365 
366   int64_t sizeD, isizeT, isizeH, isizeW;
367   int64_t istrideD, istrideT, istrideH, istrideW;
368   int64_t totalZ;
369 
370   const Tensor& input = input_.ndimension() == 4 ? input_ : input_.contiguous();
371 
372   if (input.ndimension() == 4) {
373     sizeD = input.size(0);
374     isizeT = input.size(1);
375     isizeH = input.size(2);
376     isizeW = input.size(3);
377 
378     istrideD = input.stride(0);
379     istrideT = input.stride(1);
380     istrideH = input.stride(2);
381     istrideW = input.stride(3);
382 
383     output.resize_({sizeD, osizeT, osizeH, osizeW});
384 
385     totalZ = sizeD * osizeT;
386   } else {
387     int64_t sizeB = input.size(0);
388     sizeD = input.size(1);
389     isizeT = input.size(2);
390     isizeH = input.size(3);
391     isizeW = input.size(4);
392 
393     istrideD = input.stride(1);
394     istrideT = input.stride(2);
395     istrideH = input.stride(3);
396     istrideW = input.stride(4);
397 
398     output.resize_({sizeB, sizeD, osizeT, osizeH, osizeW});
399 
400     totalZ = sizeB * sizeD * osizeT;
401   }
402 
403   if (output.numel() == 0) {
404     return;
405   }
406 
407   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
408       input.scalar_type(), "adaptive_avg_pool3d_cuda", [&] {
409         using accscalar_t = at::acc_type<scalar_t, true>;
410         const scalar_t* input_data = input.const_data_ptr<scalar_t>();
411         scalar_t* output_data = output.mutable_data_ptr<scalar_t>();
412 
413         adaptiveaveragepool_loop<scalar_t, accscalar_t>(
414             input_data, output_data,
415             totalZ,
416             isizeT, isizeH, isizeW,
417             osizeT, osizeH, osizeW,
418             istrideD, istrideT, istrideH, istrideW);
419       });
420 }
421 
adaptive_avg_pool3d_backward_out_cuda_template(Tensor & gradInput,const Tensor & gradOutput_,const Tensor & input)422 void adaptive_avg_pool3d_backward_out_cuda_template(
423     Tensor& gradInput,
424     const Tensor& gradOutput_,
425     const Tensor& input) {
426   TensorArg grad_input_arg{gradInput, "gradInput", 1};
427   TensorArg grad_output_arg{gradOutput_, "gradOutput_", 2};
428   TensorArg input_arg{input, "input", 3};
429 
430   adaptive_pool_empty_output_check(gradOutput_, "adaptive_avg_pool3d_backward");
431 
432   checkAllSameGPU(
433       "adaptive_avg_pool3d_out_cuda",
434       {grad_input_arg, grad_output_arg, input_arg});
435 
436   const Tensor gradOutput = gradOutput_.contiguous();
437 
438   gradInput.resize_as_(input);
439   if (gradInput.numel() == 0) {
440     return;
441   }
442 
443   gradInput.zero_();
444 
445   int64_t sizeD, isizeT, isizeH, isizeW;
446   int64_t osizeT, osizeH, osizeW;
447   int64_t totalZ;
448 
449   if (input.ndimension() == 4) {
450     sizeD = input.size(0);
451     isizeT = input.size(1);
452     isizeH = input.size(2);
453     isizeW = input.size(3);
454 
455     osizeT = gradOutput.size(1);
456     osizeH = gradOutput.size(2);
457     osizeW = gradOutput.size(3);
458   } else {
459     sizeD = input.size(1);
460     isizeT = input.size(2);
461     isizeH = input.size(3);
462     isizeW = input.size(4);
463 
464     osizeT = gradOutput.size(2);
465     osizeH = gradOutput.size(3);
466     osizeW = gradOutput.size(4);
467   }
468 
469   bool atomic = (isizeW%osizeW != 0) || (isizeH%osizeH != 0) || (isizeT%osizeT != 0);
470 
471   if (input.ndimension() == 4) {
472     totalZ = atomic ? sizeD * osizeT : sizeD * isizeT;
473   } else {
474     int sizeB = input.size(0);
475     totalZ = atomic ? sizeB * sizeD * osizeT : sizeB * sizeD * isizeT;
476   }
477 
478   if (atomic) {
479     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
480         input.scalar_type(), "adaptive_avg_pool3d_backward_cuda", [&] {
481           scalar_t* gradInput_data = gradInput.mutable_data_ptr<scalar_t>();
482           const scalar_t* gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
483 
484           atomicadaptiveaveragegradinput_loop(
485               gradInput_data, gradOutput_data,
486               totalZ,
487               isizeT, isizeH, isizeW,
488               osizeT, osizeH, osizeW);
489         });
490   } else {
491     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
492         input.scalar_type(), "adaptive_avg_pool3d_backward_cuda", [&] {
493           using accscalar_t = at::acc_type<scalar_t, true>;
494 
495           scalar_t* gradInput_data = gradInput.mutable_data_ptr<scalar_t>();
496           const scalar_t* gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
497 
498           adaptiveaveragegradinput_loop<scalar_t, accscalar_t>(
499               gradInput_data, gradOutput_data,
500               totalZ,
501               isizeT, isizeH, isizeW,
502               osizeT, osizeH, osizeW);
503         });
504   }
505 }
506 
507 } // namespace
508 
adaptive_avg_pool3d_out_cuda(const Tensor & input,IntArrayRef output_size,Tensor & output)509 Tensor& adaptive_avg_pool3d_out_cuda(const Tensor& input,
510     IntArrayRef output_size,
511     Tensor& output) {
512   adaptive_avg_pool3d_out_cuda_template(output, input, output_size);
513   return output;
514 }
515 
adaptive_avg_pool3d_cuda(const Tensor & input,IntArrayRef output_size)516 Tensor adaptive_avg_pool3d_cuda(
517     const Tensor& input,
518     IntArrayRef output_size) {
519   auto output = at::empty({0}, input.options());
520   adaptive_avg_pool3d_out_cuda_template(output, input, output_size);
521   return output;
522 }
523 
adaptive_avg_pool3d_backward_out_cuda(const Tensor & gradOutput_,const Tensor & input,Tensor & gradInput)524 Tensor& adaptive_avg_pool3d_backward_out_cuda(const Tensor& gradOutput_,
525     const Tensor& input,
526     Tensor& gradInput) {
527   // See Note [Writing Nondeterministic Operations]
528   // Nondeterministic because of atomicAdd usage
529   globalContext().alertNotDeterministic("adaptive_avg_pool3d_backward_out_cuda");
530   adaptive_avg_pool3d_backward_out_cuda_template(gradInput, gradOutput_, input);
531   return gradInput;
532 }
533 
adaptive_avg_pool3d_backward_cuda(const Tensor & gradOutput_,const Tensor & input)534 Tensor adaptive_avg_pool3d_backward_cuda(
535     const Tensor& gradOutput_,
536     const Tensor& input) {
537   // See Note [Writing Nondeterministic Operations]
538   // Nondeterministic because of atomicAdd usage
539   globalContext().alertNotDeterministic("adaptive_avg_pool3d_backward_cuda");
540   auto gradInput = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
541   adaptive_avg_pool3d_backward_out_cuda_template(gradInput, gradOutput_, input);
542   return gradInput;
543 }
544 
545 } // namespace at::native
546