xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ReflectionPad.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ceil_div.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/cuda/Atomic.cuh>
6 #include <ATen/cuda/detail/IndexUtils.cuh>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/Utils.h>
10 #include <ATen/native/Padding.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/zeros_like.h>
18 #include <ATen/ops/reflection_pad1d_native.h>
19 #include <ATen/ops/reflection_pad2d_native.h>
20 #include <ATen/ops/reflection_pad3d_native.h>
21 #include <ATen/ops/reflection_pad1d_backward_native.h>
22 #include <ATen/ops/reflection_pad2d_backward_native.h>
23 #include <ATen/ops/reflection_pad3d_backward_native.h>
24 #endif
25 
26 #include <thrust/pair.h>
27 
28 namespace at::native {
29 namespace {
30 
31 using at::cuda::detail::canUse32BitIndexMath;
32 
33 __device__
get_index_mapping1d(int64_t input_w,int64_t output_w,int64_t output_x,int64_t pad_l)34 inline thrust::pair<int64_t, int64_t> get_index_mapping1d(
35     int64_t input_w, int64_t output_w,
36     int64_t output_x,
37     int64_t pad_l) {
38   // 3D grid of 1D blocks
39   auto input_offset =
40     (blockIdx.y + blockIdx.z * gridDim.y) * input_w;
41   auto output_offset =
42     (blockIdx.y + blockIdx.z * gridDim.y) * output_w;
43 
44   auto i_start_x = ::max(int64_t(0), -pad_l);
45   auto o_start_x = ::max(int64_t(0), pad_l);
46 
47   int64_t input_x = ::abs(output_x - pad_l)
48                     - ::abs(output_x - (input_w + pad_l - 1))
49                     - output_x
50                     + 2 * pad_l + input_w - 1
51                     - o_start_x + i_start_x;
52 
53   return thrust::make_pair<int64_t, int64_t>(
54     input_offset + input_x, output_offset + output_x);
55 }
56 
57 
58 __device__
get_index_mapping2d(int64_t input_dim_x,int64_t input_dim_y,int64_t output_dim_x,int64_t output_dim_y,int64_t pad_l,int64_t pad_t,int64_t output_xy,int y_shift,int z_shift,int nplane)59 inline thrust::pair<int64_t, int64_t>  get_index_mapping2d(
60     int64_t input_dim_x, int64_t input_dim_y,
61     int64_t output_dim_x, int64_t output_dim_y,
62     int64_t pad_l, int64_t pad_t,
63     int64_t output_xy, int y_shift, int z_shift, int nplane) {
64   // 3D grid of 1D blocks
65   auto input_offset =
66     ((blockIdx.y + y_shift) + (blockIdx.z + z_shift) * nplane) * input_dim_x * input_dim_y;
67   auto output_offset =
68     ((blockIdx.y + y_shift) + (blockIdx.z + z_shift) * nplane) * output_dim_x * output_dim_y;
69 
70   auto output_x = output_xy % output_dim_x;
71   auto output_y = output_xy / output_dim_x;
72 
73   auto i_start_x = ::max(int64_t(0), -pad_l);
74   auto i_start_y = ::max(int64_t(0), -pad_t);
75   auto o_start_x = ::max(int64_t(0), pad_l);
76   auto o_start_y = ::max(int64_t(0), pad_t);
77 
78   auto input_x = ::abs(output_x - pad_l)
79                  - ::abs(output_x - (input_dim_x + pad_l - 1))
80                  - output_x
81                  + 2 * pad_l + input_dim_x - 1
82                  - o_start_x + i_start_x;
83 
84   auto input_y = ::abs(output_y - pad_t)
85                  - ::abs(output_y - (input_dim_y + pad_t - 1))
86                  - output_y
87                  + 2 * pad_t + input_dim_y - 1
88                  - o_start_y + i_start_y;
89 
90   return thrust::make_pair<int64_t, int64_t>(
91     input_offset + input_y * input_dim_x + input_x,
92     output_offset + output_y * output_dim_x + output_x);
93 }
94 
95 template<typename scalar_t>
reflection_pad1d_out_kernel(const scalar_t * input,scalar_t * output,int64_t input_w,int64_t pad_l,int64_t pad_r)96 __global__ void reflection_pad1d_out_kernel(
97     const scalar_t * input, scalar_t * output,
98     int64_t input_w,
99     int64_t pad_l, int64_t pad_r) {
100   auto output_x = threadIdx.x + blockIdx.x * blockDim.x;
101   auto output_w = input_w + pad_l + pad_r;
102 
103   if (output_x < output_w) {
104     auto index_pair = get_index_mapping1d(input_w, output_w, output_x, pad_l);
105     output[index_pair.second] = input[index_pair.first];
106   }
107 }
108 
109 template <typename scalar_t>
reflection_pad1d_backward_out_kernel(scalar_t * grad_input,const scalar_t * grad_output,int64_t input_w,int64_t pad_l,int64_t pad_r)110 __global__ void reflection_pad1d_backward_out_kernel(
111     scalar_t * grad_input, const scalar_t * grad_output,
112     int64_t input_w,
113     int64_t pad_l, int64_t pad_r) {
114   auto output_x = threadIdx.x + blockIdx.x * blockDim.x;
115   auto output_w = input_w + pad_l + pad_r;
116 
117   if (output_x < output_w) {
118     auto index_pair = get_index_mapping1d(input_w, output_w, output_x, pad_l);
119     gpuAtomicAddNoReturn(
120       &grad_input[index_pair.first], grad_output[index_pair.second]);
121   }
122 }
123 
124 template<typename scalar_t>
reflection_pad2d_out_kernel(const scalar_t * input,scalar_t * output,int64_t input_dim_x,int64_t input_dim_y,int pad_t,int pad_b,int pad_l,int pad_r,int y_shift,int z_shift,int nplane)125 __global__ void reflection_pad2d_out_kernel(
126     const scalar_t * input, scalar_t * output,
127     int64_t input_dim_x, int64_t input_dim_y,
128     int pad_t, int pad_b, int pad_l, int pad_r, int y_shift, int z_shift, int nplane) {
129   auto output_xy = threadIdx.x + blockIdx.x * blockDim.x;
130   auto output_dim_x = input_dim_x + pad_l + pad_r;
131   auto output_dim_y = input_dim_y + pad_t + pad_b;
132 
133   if (output_xy < output_dim_x * output_dim_y) {
134     auto index_pair = get_index_mapping2d(
135       input_dim_x, input_dim_y,
136       output_dim_x, output_dim_y,
137       pad_l, pad_t,
138       output_xy, y_shift, z_shift, nplane);
139 
140     output[index_pair.second] = input[index_pair.first];
141   }
142 }
143 
144 template <typename scalar_t>
reflection_pad2d_backward_out_kernel(scalar_t * grad_input,const scalar_t * grad_output,int64_t input_dim_x,int64_t input_dim_y,int pad_t,int pad_b,int pad_l,int pad_r,int y_shift,int z_shift,int nplane)145 __global__ void reflection_pad2d_backward_out_kernel(
146     scalar_t * grad_input, const scalar_t * grad_output,
147     int64_t input_dim_x, int64_t input_dim_y,
148     int pad_t, int pad_b, int pad_l, int pad_r, int y_shift, int z_shift, int nplane) {
149   auto output_xy = threadIdx.x + blockIdx.x * blockDim.x;
150   auto output_dim_x = input_dim_x + pad_l + pad_r;
151   auto output_dim_y = input_dim_y + pad_t + pad_b;
152 
153   if (output_xy < output_dim_x * output_dim_y) {
154     auto index_pair = get_index_mapping2d(
155       input_dim_x, input_dim_y,
156       output_dim_x, output_dim_y,
157       pad_l, pad_t,
158       output_xy, y_shift, z_shift, nplane);
159 
160     gpuAtomicAddNoReturn(&grad_input[index_pair.first], grad_output[index_pair.second]);
161   }
162 }
163 template <typename input_scalar_t, typename output_scalar_t, typename F>
parallel_reflection_pad3d(PackedTensorAccessor64<input_scalar_t,5> input,PackedTensorAccessor64<output_scalar_t,5> output,int64_t pad_left,int64_t pad_top,int64_t pad_front,int64_t y_shift,int64_t z_shift,const F & f)164 __device__ inline void parallel_reflection_pad3d(
165     PackedTensorAccessor64<input_scalar_t, 5> input,
166     PackedTensorAccessor64<output_scalar_t, 5> output,
167     int64_t pad_left,
168     int64_t pad_top,
169     int64_t pad_front,
170     int64_t y_shift,
171     int64_t z_shift,
172     const F& f) {
173   int64_t output_id = threadIdx.x + blockIdx.x * blockDim.x;
174 
175   if (output_id >= (output.size(2) * output.size(3) * output.size(4))) {
176     return;
177   }
178 
179   int64_t output_x = output_id % output.size(4);
180   int64_t output_y = (output_id / output.size(4)) % output.size(3);
181   int64_t output_z = output_id / (output.size(3) * output.size(4));
182 
183   int64_t i_start_x = ::max(int64_t(0), -pad_left);
184   int64_t o_start_x = ::max(int64_t(0), pad_left);
185   int64_t i_start_y = ::max(int64_t(0), -pad_top);
186   int64_t o_start_y = ::max(int64_t(0), pad_top);
187   int64_t i_start_z = ::max(int64_t(0), -pad_front);
188   int64_t o_start_z = ::max(int64_t(0), pad_front);
189 
190   int64_t input_x = ::abs(output_x - pad_left)
191                  - ::abs(output_x - (input.size(4) + pad_left - 1))
192                  - output_x
193                  + 2 * pad_left + input.size(4) - 1
194                  - o_start_x + i_start_x;
195   int64_t input_y = ::abs(output_y - pad_top)
196                  - ::abs(output_y - (input.size(3) + pad_top - 1))
197                  - output_y
198                  + 2 * pad_top + input.size(3) - 1
199                  - o_start_y + i_start_y;
200 
201   int64_t input_z = ::abs(output_z - pad_front)
202                  - ::abs(output_z - (input.size(2) + pad_front - 1))
203                  - output_z
204                  + 2 * pad_front + input.size(2) - 1
205                  - o_start_z + i_start_z;
206 
207   int64_t plane = blockIdx.y + y_shift;
208   int64_t batch = blockIdx.z + z_shift;
209   f(plane, batch, output_z, output_y, output_x, input_z, input_y, input_x);
210 }
211 
212 template<typename scalar_t>
reflection_pad3d_out_kernel(PackedTensorAccessor64<const scalar_t,5> input,PackedTensorAccessor64<scalar_t,5> output,int64_t pad_left,int64_t pad_top,int64_t pad_front,int64_t y_shift,int64_t z_shift)213 __global__ void reflection_pad3d_out_kernel(
214     PackedTensorAccessor64<const scalar_t, 5> input,
215     PackedTensorAccessor64<scalar_t, 5> output,
216     int64_t pad_left,  int64_t pad_top, int64_t pad_front,
217     int64_t y_shift, int64_t z_shift
218 ){
219   parallel_reflection_pad3d(
220       input,
221       output,
222       pad_left,
223       pad_top,
224       pad_front,
225       y_shift,
226       z_shift,
227       [&] __device__(
228           int64_t plane,
229           int64_t batch,
230           int64_t output_z,
231           int64_t output_y,
232           int64_t output_x,
233           int64_t input_z,
234           int64_t input_y,
235           int64_t input_x) {
236         auto value_to_copy = input[batch][plane][input_z][input_y][input_x];
237         output[batch][plane][output_z][output_y][output_x] = value_to_copy;
238       });
239 }
240 
241 template <typename scalar_t>
reflection_pad3d_backward_out_kernel(PackedTensorAccessor64<scalar_t,5> grad_input,PackedTensorAccessor64<const scalar_t,5> grad_output,int64_t pad_left,int64_t pad_top,int64_t pad_front,int64_t y_shift,int64_t z_shift)242 __global__ void reflection_pad3d_backward_out_kernel(
243     PackedTensorAccessor64<scalar_t, 5> grad_input,
244     PackedTensorAccessor64<const scalar_t, 5> grad_output,
245     int64_t pad_left,  int64_t pad_top, int64_t pad_front,
246     int64_t y_shift, int64_t z_shift
247 ) {
248   parallel_reflection_pad3d(
249       grad_input,
250       grad_output,
251       pad_left,
252       pad_top,
253       pad_front,
254       y_shift,
255       z_shift,
256       [&] __device__(
257           int64_t plane,
258           int64_t batch,
259           int64_t output_z,
260           int64_t output_y,
261           int64_t output_x,
262           int64_t input_z,
263           int64_t input_y,
264           int64_t input_x) {
265         auto value_to_add = grad_output[batch][plane][output_z][output_y][output_x];
266         auto target = &grad_input[batch][plane][input_z][input_y][input_x];
267         gpuAtomicAddNoReturn(target, value_to_add);
268       });
269 }
270 
reflection_pad2d_out_template(Tensor & output,const Tensor & input_,IntArrayRef padding)271 void reflection_pad2d_out_template(
272     Tensor &output, const Tensor &input_, IntArrayRef padding) {
273 
274   TORCH_CHECK(canUse32BitIndexMath(input_),
275     "input tensor must fit into 32-bit index math");
276 
277   int plane_dim = 0;
278   int dim_h = 1;
279   int dim_w = 2;
280   int nbatch = 1;
281 
282   at::native::padding::check_valid_input<2>(input_, padding);
283 
284   if (input_.ndimension() == 4) {
285     nbatch = input_.size(0);
286     plane_dim++;
287     dim_h++;
288     dim_w++;
289   }
290 
291   int64_t pad_l = padding[0];
292   int64_t pad_r = padding[1];
293   int64_t pad_t = padding[2];
294   int64_t pad_b = padding[3];
295 
296   int nplane = input_.size(plane_dim);
297   int input_h = input_.size(dim_h);
298   int input_w = input_.size(dim_w);
299 
300   TORCH_CHECK(pad_l < input_w && pad_r < input_w,
301     "Padding size should be less than the corresponding input dimension, but "
302     "got: padding (", pad_l, ", ", pad_r, ") at dimension ", dim_w,
303     " of input ", input_.sizes());
304 
305   TORCH_CHECK(pad_t < input_h && pad_b < input_h,
306     "Padding size should be less than the corresponding input dimension, but "
307     "got: padding (", pad_t, ", ", pad_b, ") at dimension ", dim_h,
308     " of input ", input_.sizes());
309 
310   int output_h = input_h + pad_t + pad_b;
311   int output_w  = input_w + pad_l + pad_r;
312 
313   TORCH_CHECK(output_w >= 1 || output_h >= 1,
314     "input (H: ", input_h, ", W: ", input_w, ") is too small.  Calculated "
315     "output H: ", output_h, " W: ", output_w);
316 
317   if (input_.ndimension() == 3) {
318     output.resize_({nplane, output_h, output_w});
319   } else {
320     output.resize_({nbatch, nplane, output_h, output_w});
321   }
322   if (output.numel() == 0) {
323     return;
324   }
325 
326   Tensor input = input_.contiguous();
327 
328   int64_t output_plane_size = output_h * output_w;
329   dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
330 
331   int64_t size_y = nplane;
332   int64_t size_z = nbatch;
333 
334   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
335     input.scalar_type(), "reflection_pad2d_out_template", [&] {
336 
337       for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
338         int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
339         for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
340           int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));
341 
342           dim3 grid_size(at::ceil_div(output_plane_size, static_cast<int64_t>(256)), block_y_size, block_z_size);
343 
344           reflection_pad2d_out_kernel<<<
345             grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
346               input.const_data_ptr<scalar_t>(), output.mutable_data_ptr<scalar_t>(),
347               input_w, input_h,
348               pad_t, pad_b, pad_l, pad_r, block_y, block_z, nplane);
349           C10_CUDA_KERNEL_LAUNCH_CHECK();
350         }
351       }
352     }
353   );
354 }
355 
reflection_pad2d_backward_out_template(Tensor & grad_input,const Tensor & grad_output_,const Tensor & input,IntArrayRef padding)356 void reflection_pad2d_backward_out_template(
357     Tensor &grad_input, const Tensor &grad_output_,
358     const Tensor &input, IntArrayRef padding) {
359 
360   if (grad_input.numel() == 0) {
361     return;
362   }
363 
364   TORCH_CHECK(canUse32BitIndexMath(input),
365     "input tensor must fit into 32-bit index math");
366   TORCH_CHECK(canUse32BitIndexMath(grad_output_),
367     "output gradient tensor must fit into 32-bit index math");
368 
369   int plane_dim = 0;
370   int dim_h = 1;
371   int dim_w = 2;
372   int nbatch = 1;
373 
374   if (input.ndimension() == 4) {
375     nbatch = input.size(0);
376     plane_dim++;
377     dim_h++;
378     dim_w++;
379   }
380 
381   int64_t pad_l = padding[0];
382   int64_t pad_r = padding[1];
383   int64_t pad_t = padding[2];
384   int64_t pad_b = padding[3];
385 
386   int nplane = input.size(plane_dim);
387   int input_h = input.size(dim_h);
388   int input_w = input.size(dim_w);
389 
390   int output_h = input_h + pad_t + pad_b;
391   int output_w  = input_w + pad_l + pad_r;
392 
393   TORCH_CHECK(output_w == grad_output_.size(dim_w), "grad_output width "
394     "unexpected. Expected: ", output_w, ", Got: ", grad_output_.size(dim_w));
395   TORCH_CHECK(output_h == grad_output_.size(dim_h), "grad_output height "
396     "unexpected. Expected: ", output_h, ", Got: ", grad_output_.size(dim_h));
397 
398   Tensor grad_output = grad_output_.contiguous();
399 
400   int64_t output_plane_size = output_h * output_w;
401   dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
402 
403   int64_t size_y = nplane;
404   int64_t size_z = nbatch;
405 
406   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
407     input.scalar_type(), "reflection_pad2d_backward_out_template", [&] {
408 
409       for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
410         int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
411         for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
412           int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));
413 
414           dim3 grid_size(at::ceil_div(output_plane_size, static_cast<int64_t>(256)), block_y_size, block_z_size);
415 
416           reflection_pad2d_backward_out_kernel<<<
417             grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
418               grad_input.mutable_data_ptr<scalar_t>(), grad_output.const_data_ptr<scalar_t>(),
419               input_w, input_h,
420               pad_t, pad_b, pad_l, pad_r, block_y, block_z, nplane);
421           C10_CUDA_KERNEL_LAUNCH_CHECK();
422         }
423       }
424     }
425   );
426 }
427 
428 } // namespace
429 
TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)430 TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
431 (const Tensor& input_, IntArrayRef padding, const Tensor& output) {
432   TORCH_CHECK(
433       canUse32BitIndexMath(input_),
434       "input tensor must fit into 32-bit index math");
435 
436   if (output.numel() == 0) {
437     return;
438   }
439 
440   int64_t dim_plane = 0;
441   int64_t dim_w = 1;
442   int64_t nbatch = 1;
443 
444   if (input_.ndimension() == 3) {
445     nbatch = input_.size(0);
446     dim_plane++;
447     dim_w++;
448   }
449 
450   int64_t pad_l = padding[0];
451   int64_t pad_r = padding[1];
452 
453   int64_t nplane = input_.size(dim_plane);
454   int64_t input_w = input_.size(dim_w);
455   int64_t output_w = input_w + pad_l + pad_r;
456 
457   dim3 block_size(output_w > 256 ? 256 : output_w);
458   dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
459 
460   Tensor input = input_.contiguous();
461 
462   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
463       kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
464         reflection_pad1d_out_kernel<<<
465             grid_size,
466             block_size,
467             0,
468             at::cuda::getCurrentCUDAStream()>>>(
469             input.const_data_ptr<scalar_t>(),
470             output.mutable_data_ptr<scalar_t>(),
471             input_w,
472             pad_l,
473             pad_r);
474         C10_CUDA_KERNEL_LAUNCH_CHECK();
475       });
476 }
477 
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)478 TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,
479     const Tensor& input,
480     IntArrayRef padding,
481     const Tensor& grad_input) {
482   // See Note [Writing Nondeterministic Operations]
483   // Nondeterministic because of atomicAdd usage
484   globalContext().alertNotDeterministic("reflection_pad1d_backward_out_cuda");
485   grad_input.zero_();
486 
487   if (grad_input.numel() == 0) {
488     return;
489   }
490 
491   TORCH_CHECK(canUse32BitIndexMath(input),
492     "input tensor must fit into 32-bit index math");
493 
494   TORCH_CHECK(canUse32BitIndexMath(grad_output_),
495     "input tensor must fit into 32-bit index math");
496 
497   int64_t dim_plane = 0;
498   int64_t dim_w = 1;
499   int64_t nbatch = 1;
500 
501   if (input.ndimension() == 3) {
502     nbatch = input.size(0);
503     dim_plane++;
504     dim_w++;
505   }
506 
507   int64_t pad_l = padding[0];
508   int64_t pad_r = padding[1];
509 
510   int64_t nplane = input.size(dim_plane);
511   int64_t input_w = input.size(dim_w);
512   int64_t output_w  = input_w + pad_l + pad_r;
513 
514   Tensor grad_output = grad_output_.contiguous();
515 
516   dim3 block_size(output_w > 256 ? 256 : output_w);
517   dim3 grid_size((int) ::ceil(output_w / 256.0), nplane, nbatch);
518 
519   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
520     grad_input.scalar_type(), "reflection_pad1d_backward_out_cuda", [&] {
521       reflection_pad1d_backward_out_kernel<<<
522         grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
523           grad_input.mutable_data_ptr<scalar_t>(), grad_output.const_data_ptr<scalar_t>(),
524           input_w, pad_l, pad_r);
525       C10_CUDA_KERNEL_LAUNCH_CHECK();
526     }
527   );
528 }
529 
reflection_pad2d_out_cuda(const Tensor & input,IntArrayRef padding,Tensor & output)530 Tensor& reflection_pad2d_out_cuda(const Tensor& input, IntArrayRef padding,
531     Tensor& output) {
532   reflection_pad2d_out_template(output, input, padding);
533   return output;
534 }
535 
reflection_pad2d_cuda(const Tensor & input,IntArrayRef padding)536 Tensor reflection_pad2d_cuda(const Tensor& input, IntArrayRef padding) {
537   auto output = at::empty({0}, input.options());
538   reflection_pad2d_out_template(output, input, padding);
539   return output;
540 }
541 
reflection_pad2d_backward_out_cuda(const Tensor & grad_output,const Tensor & input,IntArrayRef padding,Tensor & grad_input)542 Tensor& reflection_pad2d_backward_out_cuda(const Tensor& grad_output,
543     const Tensor& input,
544     IntArrayRef padding,
545     Tensor& grad_input) {
546   // See Note [Writing Nondeterministic Operations]
547   // Nondeterministic because of atomicAdd usage
548   globalContext().alertNotDeterministic("reflection_pad2d_backward_out_cuda");
549   grad_input.resize_as_(input);
550   grad_input.zero_();
551   reflection_pad2d_backward_out_template(
552     grad_input, grad_output, input, padding);
553   return grad_input;
554 }
555 
reflection_pad2d_backward_cuda(const Tensor & grad_output,const Tensor & input,IntArrayRef padding)556 Tensor reflection_pad2d_backward_cuda(
557     const Tensor& grad_output,
558     const Tensor& input,
559     IntArrayRef padding) {
560   // See Note [Writing Nondeterministic Operations]
561   // Nondeterministic because of atomicAdd usage
562   globalContext().alertNotDeterministic("reflection_pad2d_backward_cuda");
563   auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
564   reflection_pad2d_backward_out_template(
565     grad_input, grad_output, input, padding);
566   return grad_input;
567 }
568 
569 
TORCH_IMPL_FUNC(reflection_pad3d_out_cuda)570 TORCH_IMPL_FUNC(reflection_pad3d_out_cuda) (
571   const Tensor& input_, IntArrayRef padding, const Tensor& output
572   ) {
573   TORCH_CHECK(
574       canUse32BitIndexMath(input_),
575       "input tensor must fit into 32-bit index math");
576 
577   if (output.numel() == 0) {
578     return;
579   }
580 
581   int64_t pad_left = padding[0];
582   int64_t pad_top = padding[2];
583   int64_t pad_front = padding[4];
584 
585   auto input = input_.contiguous();
586   bool batch_mode = (input.dim() == 5);
587 
588   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
589       input.scalar_type(), "reflection_pad3d_out_cuda", [&] {
590         auto input_inner = input;
591         auto output_inner = output;
592         if (!batch_mode) {
593           // non-batch mode
594           input_inner = input.unsqueeze(0);
595           output_inner = output.unsqueeze(0);
596         }
597 
598         auto input_packed = input_inner.packed_accessor64<const scalar_t, 5>();
599         auto output_packed = output_inner.packed_accessor64<scalar_t, 5>();
600 
601         int64_t output_plane_size = output_packed.size(2) * output_packed.size(3) * output_packed.size(4);
602         int64_t size_y = input_packed.size(1);
603         int64_t size_z = input_packed.size(0);
604         dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
605 
606         for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
607           int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
608           for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
609             int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));
610 
611             dim3 grid_size(at::ceil_div(output_plane_size, static_cast<int64_t>(256)), \
612                            block_y_size, block_z_size);
613 
614             reflection_pad3d_out_kernel<<<
615                 grid_size, block_size,0, at::cuda::getCurrentCUDAStream()>>>(
616                 input_packed, output_packed, pad_left, pad_top, pad_front,
617                 block_y, block_z);
618             C10_CUDA_KERNEL_LAUNCH_CHECK();
619           }
620         }
621       });
622 }
623 
TORCH_IMPL_FUNC(reflection_pad3d_backward_out_cuda)624 TORCH_IMPL_FUNC(reflection_pad3d_backward_out_cuda) (
625   const Tensor& grad_output, const Tensor& input, IntArrayRef padding,
626   const Tensor& grad_input) {
627   globalContext().alertNotDeterministic("reflection_pad3d_backward_out_cuda");
628   TORCH_CHECK(canUse32BitIndexMath(input), "input tensor must fit into 32-bit index math");
629   TORCH_CHECK(canUse32BitIndexMath(grad_output), "input tensor must fit into 32-bit index math");
630 
631   if (grad_input.numel() == 0) {
632     return;
633   }
634   grad_input.zero_();
635 
636   int64_t pad_left = padding[0];
637   int64_t pad_top = padding[2];
638   int64_t pad_front = padding[4];
639 
640   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
641       input.scalar_type(), "reflection_pad3d_backward_out_cuda", [&] {
642         auto grad_input_ = grad_input;
643         auto grad_output_ = grad_output;
644         if (input.dim() == 4) {
645           // non-batch mode
646           grad_input_ = grad_input.unsqueeze(0);
647           grad_output_ = grad_output.unsqueeze(0);
648         }
649 
650         auto grad_input_packed = grad_input_.packed_accessor64<scalar_t, 5>();
651         auto grad_output_packed = grad_output_.packed_accessor64<const scalar_t, 5>();
652 
653         int64_t output_plane_size = grad_output_packed.size(2) *
654             grad_output_packed.size(3) * grad_output_packed.size(4);
655         int64_t size_y = grad_input_packed.size(1);
656         int64_t size_z = grad_input_packed.size(0);
657         dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
658 
659         for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
660           int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
661           for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
662             int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));
663 
664             dim3 grid_size(at::ceil_div(output_plane_size, static_cast<int64_t>(256)), \
665                            block_y_size, block_z_size);
666 
667             reflection_pad3d_backward_out_kernel<<<
668                 grid_size, block_size,0, at::cuda::getCurrentCUDAStream()>>>(
669                 grad_input_packed, grad_output_packed, pad_left, pad_top, pad_front,
670                 block_y, block_z);
671             C10_CUDA_KERNEL_LAUNCH_CHECK();
672           }
673         }
674       });
675 }
676 
677 } // namespace at::native
678