1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ceil_div.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/cuda/Atomic.cuh>
6 #include <ATen/cuda/detail/IndexUtils.cuh>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/Utils.h>
10 #include <ATen/native/Padding.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/zeros_like.h>
18 #include <ATen/ops/reflection_pad1d_native.h>
19 #include <ATen/ops/reflection_pad2d_native.h>
20 #include <ATen/ops/reflection_pad3d_native.h>
21 #include <ATen/ops/reflection_pad1d_backward_native.h>
22 #include <ATen/ops/reflection_pad2d_backward_native.h>
23 #include <ATen/ops/reflection_pad3d_backward_native.h>
24 #endif
25
26 #include <thrust/pair.h>
27
28 namespace at::native {
29 namespace {
30
31 using at::cuda::detail::canUse32BitIndexMath;
32
33 __device__
get_index_mapping1d(int64_t input_w,int64_t output_w,int64_t output_x,int64_t pad_l)34 inline thrust::pair<int64_t, int64_t> get_index_mapping1d(
35 int64_t input_w, int64_t output_w,
36 int64_t output_x,
37 int64_t pad_l) {
38 // 3D grid of 1D blocks
39 auto input_offset =
40 (blockIdx.y + blockIdx.z * gridDim.y) * input_w;
41 auto output_offset =
42 (blockIdx.y + blockIdx.z * gridDim.y) * output_w;
43
44 auto i_start_x = ::max(int64_t(0), -pad_l);
45 auto o_start_x = ::max(int64_t(0), pad_l);
46
47 int64_t input_x = ::abs(output_x - pad_l)
48 - ::abs(output_x - (input_w + pad_l - 1))
49 - output_x
50 + 2 * pad_l + input_w - 1
51 - o_start_x + i_start_x;
52
53 return thrust::make_pair<int64_t, int64_t>(
54 input_offset + input_x, output_offset + output_x);
55 }
56
57
58 __device__
get_index_mapping2d(int64_t input_dim_x,int64_t input_dim_y,int64_t output_dim_x,int64_t output_dim_y,int64_t pad_l,int64_t pad_t,int64_t output_xy,int y_shift,int z_shift,int nplane)59 inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
60 int64_t input_dim_x, int64_t input_dim_y,
61 int64_t output_dim_x, int64_t output_dim_y,
62 int64_t pad_l, int64_t pad_t,
63 int64_t output_xy, int y_shift, int z_shift, int nplane) {
64 // 3D grid of 1D blocks
65 auto input_offset =
66 ((blockIdx.y + y_shift) + (blockIdx.z + z_shift) * nplane) * input_dim_x * input_dim_y;
67 auto output_offset =
68 ((blockIdx.y + y_shift) + (blockIdx.z + z_shift) * nplane) * output_dim_x * output_dim_y;
69
70 auto output_x = output_xy % output_dim_x;
71 auto output_y = output_xy / output_dim_x;
72
73 auto i_start_x = ::max(int64_t(0), -pad_l);
74 auto i_start_y = ::max(int64_t(0), -pad_t);
75 auto o_start_x = ::max(int64_t(0), pad_l);
76 auto o_start_y = ::max(int64_t(0), pad_t);
77
78 auto input_x = ::abs(output_x - pad_l)
79 - ::abs(output_x - (input_dim_x + pad_l - 1))
80 - output_x
81 + 2 * pad_l + input_dim_x - 1
82 - o_start_x + i_start_x;
83
84 auto input_y = ::abs(output_y - pad_t)
85 - ::abs(output_y - (input_dim_y + pad_t - 1))
86 - output_y
87 + 2 * pad_t + input_dim_y - 1
88 - o_start_y + i_start_y;
89
90 return thrust::make_pair<int64_t, int64_t>(
91 input_offset + input_y * input_dim_x + input_x,
92 output_offset + output_y * output_dim_x + output_x);
93 }
94
95 template<typename scalar_t>
reflection_pad1d_out_kernel(const scalar_t * input,scalar_t * output,int64_t input_w,int64_t pad_l,int64_t pad_r)96 __global__ void reflection_pad1d_out_kernel(
97 const scalar_t * input, scalar_t * output,
98 int64_t input_w,
99 int64_t pad_l, int64_t pad_r) {
100 auto output_x = threadIdx.x + blockIdx.x * blockDim.x;
101 auto output_w = input_w + pad_l + pad_r;
102
103 if (output_x < output_w) {
104 auto index_pair = get_index_mapping1d(input_w, output_w, output_x, pad_l);
105 output[index_pair.second] = input[index_pair.first];
106 }
107 }
108
109 template <typename scalar_t>
reflection_pad1d_backward_out_kernel(scalar_t * grad_input,const scalar_t * grad_output,int64_t input_w,int64_t pad_l,int64_t pad_r)110 __global__ void reflection_pad1d_backward_out_kernel(
111 scalar_t * grad_input, const scalar_t * grad_output,
112 int64_t input_w,
113 int64_t pad_l, int64_t pad_r) {
114 auto output_x = threadIdx.x + blockIdx.x * blockDim.x;
115 auto output_w = input_w + pad_l + pad_r;
116
117 if (output_x < output_w) {
118 auto index_pair = get_index_mapping1d(input_w, output_w, output_x, pad_l);
119 gpuAtomicAddNoReturn(
120 &grad_input[index_pair.first], grad_output[index_pair.second]);
121 }
122 }
123
124 template<typename scalar_t>
reflection_pad2d_out_kernel(const scalar_t * input,scalar_t * output,int64_t input_dim_x,int64_t input_dim_y,int pad_t,int pad_b,int pad_l,int pad_r,int y_shift,int z_shift,int nplane)125 __global__ void reflection_pad2d_out_kernel(
126 const scalar_t * input, scalar_t * output,
127 int64_t input_dim_x, int64_t input_dim_y,
128 int pad_t, int pad_b, int pad_l, int pad_r, int y_shift, int z_shift, int nplane) {
129 auto output_xy = threadIdx.x + blockIdx.x * blockDim.x;
130 auto output_dim_x = input_dim_x + pad_l + pad_r;
131 auto output_dim_y = input_dim_y + pad_t + pad_b;
132
133 if (output_xy < output_dim_x * output_dim_y) {
134 auto index_pair = get_index_mapping2d(
135 input_dim_x, input_dim_y,
136 output_dim_x, output_dim_y,
137 pad_l, pad_t,
138 output_xy, y_shift, z_shift, nplane);
139
140 output[index_pair.second] = input[index_pair.first];
141 }
142 }
143
144 template <typename scalar_t>
reflection_pad2d_backward_out_kernel(scalar_t * grad_input,const scalar_t * grad_output,int64_t input_dim_x,int64_t input_dim_y,int pad_t,int pad_b,int pad_l,int pad_r,int y_shift,int z_shift,int nplane)145 __global__ void reflection_pad2d_backward_out_kernel(
146 scalar_t * grad_input, const scalar_t * grad_output,
147 int64_t input_dim_x, int64_t input_dim_y,
148 int pad_t, int pad_b, int pad_l, int pad_r, int y_shift, int z_shift, int nplane) {
149 auto output_xy = threadIdx.x + blockIdx.x * blockDim.x;
150 auto output_dim_x = input_dim_x + pad_l + pad_r;
151 auto output_dim_y = input_dim_y + pad_t + pad_b;
152
153 if (output_xy < output_dim_x * output_dim_y) {
154 auto index_pair = get_index_mapping2d(
155 input_dim_x, input_dim_y,
156 output_dim_x, output_dim_y,
157 pad_l, pad_t,
158 output_xy, y_shift, z_shift, nplane);
159
160 gpuAtomicAddNoReturn(&grad_input[index_pair.first], grad_output[index_pair.second]);
161 }
162 }
163 template <typename input_scalar_t, typename output_scalar_t, typename F>
parallel_reflection_pad3d(PackedTensorAccessor64<input_scalar_t,5> input,PackedTensorAccessor64<output_scalar_t,5> output,int64_t pad_left,int64_t pad_top,int64_t pad_front,int64_t y_shift,int64_t z_shift,const F & f)164 __device__ inline void parallel_reflection_pad3d(
165 PackedTensorAccessor64<input_scalar_t, 5> input,
166 PackedTensorAccessor64<output_scalar_t, 5> output,
167 int64_t pad_left,
168 int64_t pad_top,
169 int64_t pad_front,
170 int64_t y_shift,
171 int64_t z_shift,
172 const F& f) {
173 int64_t output_id = threadIdx.x + blockIdx.x * blockDim.x;
174
175 if (output_id >= (output.size(2) * output.size(3) * output.size(4))) {
176 return;
177 }
178
179 int64_t output_x = output_id % output.size(4);
180 int64_t output_y = (output_id / output.size(4)) % output.size(3);
181 int64_t output_z = output_id / (output.size(3) * output.size(4));
182
183 int64_t i_start_x = ::max(int64_t(0), -pad_left);
184 int64_t o_start_x = ::max(int64_t(0), pad_left);
185 int64_t i_start_y = ::max(int64_t(0), -pad_top);
186 int64_t o_start_y = ::max(int64_t(0), pad_top);
187 int64_t i_start_z = ::max(int64_t(0), -pad_front);
188 int64_t o_start_z = ::max(int64_t(0), pad_front);
189
190 int64_t input_x = ::abs(output_x - pad_left)
191 - ::abs(output_x - (input.size(4) + pad_left - 1))
192 - output_x
193 + 2 * pad_left + input.size(4) - 1
194 - o_start_x + i_start_x;
195 int64_t input_y = ::abs(output_y - pad_top)
196 - ::abs(output_y - (input.size(3) + pad_top - 1))
197 - output_y
198 + 2 * pad_top + input.size(3) - 1
199 - o_start_y + i_start_y;
200
201 int64_t input_z = ::abs(output_z - pad_front)
202 - ::abs(output_z - (input.size(2) + pad_front - 1))
203 - output_z
204 + 2 * pad_front + input.size(2) - 1
205 - o_start_z + i_start_z;
206
207 int64_t plane = blockIdx.y + y_shift;
208 int64_t batch = blockIdx.z + z_shift;
209 f(plane, batch, output_z, output_y, output_x, input_z, input_y, input_x);
210 }
211
212 template<typename scalar_t>
reflection_pad3d_out_kernel(PackedTensorAccessor64<const scalar_t,5> input,PackedTensorAccessor64<scalar_t,5> output,int64_t pad_left,int64_t pad_top,int64_t pad_front,int64_t y_shift,int64_t z_shift)213 __global__ void reflection_pad3d_out_kernel(
214 PackedTensorAccessor64<const scalar_t, 5> input,
215 PackedTensorAccessor64<scalar_t, 5> output,
216 int64_t pad_left, int64_t pad_top, int64_t pad_front,
217 int64_t y_shift, int64_t z_shift
218 ){
219 parallel_reflection_pad3d(
220 input,
221 output,
222 pad_left,
223 pad_top,
224 pad_front,
225 y_shift,
226 z_shift,
227 [&] __device__(
228 int64_t plane,
229 int64_t batch,
230 int64_t output_z,
231 int64_t output_y,
232 int64_t output_x,
233 int64_t input_z,
234 int64_t input_y,
235 int64_t input_x) {
236 auto value_to_copy = input[batch][plane][input_z][input_y][input_x];
237 output[batch][plane][output_z][output_y][output_x] = value_to_copy;
238 });
239 }
240
241 template <typename scalar_t>
reflection_pad3d_backward_out_kernel(PackedTensorAccessor64<scalar_t,5> grad_input,PackedTensorAccessor64<const scalar_t,5> grad_output,int64_t pad_left,int64_t pad_top,int64_t pad_front,int64_t y_shift,int64_t z_shift)242 __global__ void reflection_pad3d_backward_out_kernel(
243 PackedTensorAccessor64<scalar_t, 5> grad_input,
244 PackedTensorAccessor64<const scalar_t, 5> grad_output,
245 int64_t pad_left, int64_t pad_top, int64_t pad_front,
246 int64_t y_shift, int64_t z_shift
247 ) {
248 parallel_reflection_pad3d(
249 grad_input,
250 grad_output,
251 pad_left,
252 pad_top,
253 pad_front,
254 y_shift,
255 z_shift,
256 [&] __device__(
257 int64_t plane,
258 int64_t batch,
259 int64_t output_z,
260 int64_t output_y,
261 int64_t output_x,
262 int64_t input_z,
263 int64_t input_y,
264 int64_t input_x) {
265 auto value_to_add = grad_output[batch][plane][output_z][output_y][output_x];
266 auto target = &grad_input[batch][plane][input_z][input_y][input_x];
267 gpuAtomicAddNoReturn(target, value_to_add);
268 });
269 }
270
reflection_pad2d_out_template(Tensor & output,const Tensor & input_,IntArrayRef padding)271 void reflection_pad2d_out_template(
272 Tensor &output, const Tensor &input_, IntArrayRef padding) {
273
274 TORCH_CHECK(canUse32BitIndexMath(input_),
275 "input tensor must fit into 32-bit index math");
276
277 int plane_dim = 0;
278 int dim_h = 1;
279 int dim_w = 2;
280 int nbatch = 1;
281
282 at::native::padding::check_valid_input<2>(input_, padding);
283
284 if (input_.ndimension() == 4) {
285 nbatch = input_.size(0);
286 plane_dim++;
287 dim_h++;
288 dim_w++;
289 }
290
291 int64_t pad_l = padding[0];
292 int64_t pad_r = padding[1];
293 int64_t pad_t = padding[2];
294 int64_t pad_b = padding[3];
295
296 int nplane = input_.size(plane_dim);
297 int input_h = input_.size(dim_h);
298 int input_w = input_.size(dim_w);
299
300 TORCH_CHECK(pad_l < input_w && pad_r < input_w,
301 "Padding size should be less than the corresponding input dimension, but "
302 "got: padding (", pad_l, ", ", pad_r, ") at dimension ", dim_w,
303 " of input ", input_.sizes());
304
305 TORCH_CHECK(pad_t < input_h && pad_b < input_h,
306 "Padding size should be less than the corresponding input dimension, but "
307 "got: padding (", pad_t, ", ", pad_b, ") at dimension ", dim_h,
308 " of input ", input_.sizes());
309
310 int output_h = input_h + pad_t + pad_b;
311 int output_w = input_w + pad_l + pad_r;
312
313 TORCH_CHECK(output_w >= 1 || output_h >= 1,
314 "input (H: ", input_h, ", W: ", input_w, ") is too small. Calculated "
315 "output H: ", output_h, " W: ", output_w);
316
317 if (input_.ndimension() == 3) {
318 output.resize_({nplane, output_h, output_w});
319 } else {
320 output.resize_({nbatch, nplane, output_h, output_w});
321 }
322 if (output.numel() == 0) {
323 return;
324 }
325
326 Tensor input = input_.contiguous();
327
328 int64_t output_plane_size = output_h * output_w;
329 dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
330
331 int64_t size_y = nplane;
332 int64_t size_z = nbatch;
333
334 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
335 input.scalar_type(), "reflection_pad2d_out_template", [&] {
336
337 for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
338 int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
339 for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
340 int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));
341
342 dim3 grid_size(at::ceil_div(output_plane_size, static_cast<int64_t>(256)), block_y_size, block_z_size);
343
344 reflection_pad2d_out_kernel<<<
345 grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
346 input.const_data_ptr<scalar_t>(), output.mutable_data_ptr<scalar_t>(),
347 input_w, input_h,
348 pad_t, pad_b, pad_l, pad_r, block_y, block_z, nplane);
349 C10_CUDA_KERNEL_LAUNCH_CHECK();
350 }
351 }
352 }
353 );
354 }
355
reflection_pad2d_backward_out_template(Tensor & grad_input,const Tensor & grad_output_,const Tensor & input,IntArrayRef padding)356 void reflection_pad2d_backward_out_template(
357 Tensor &grad_input, const Tensor &grad_output_,
358 const Tensor &input, IntArrayRef padding) {
359
360 if (grad_input.numel() == 0) {
361 return;
362 }
363
364 TORCH_CHECK(canUse32BitIndexMath(input),
365 "input tensor must fit into 32-bit index math");
366 TORCH_CHECK(canUse32BitIndexMath(grad_output_),
367 "output gradient tensor must fit into 32-bit index math");
368
369 int plane_dim = 0;
370 int dim_h = 1;
371 int dim_w = 2;
372 int nbatch = 1;
373
374 if (input.ndimension() == 4) {
375 nbatch = input.size(0);
376 plane_dim++;
377 dim_h++;
378 dim_w++;
379 }
380
381 int64_t pad_l = padding[0];
382 int64_t pad_r = padding[1];
383 int64_t pad_t = padding[2];
384 int64_t pad_b = padding[3];
385
386 int nplane = input.size(plane_dim);
387 int input_h = input.size(dim_h);
388 int input_w = input.size(dim_w);
389
390 int output_h = input_h + pad_t + pad_b;
391 int output_w = input_w + pad_l + pad_r;
392
393 TORCH_CHECK(output_w == grad_output_.size(dim_w), "grad_output width "
394 "unexpected. Expected: ", output_w, ", Got: ", grad_output_.size(dim_w));
395 TORCH_CHECK(output_h == grad_output_.size(dim_h), "grad_output height "
396 "unexpected. Expected: ", output_h, ", Got: ", grad_output_.size(dim_h));
397
398 Tensor grad_output = grad_output_.contiguous();
399
400 int64_t output_plane_size = output_h * output_w;
401 dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
402
403 int64_t size_y = nplane;
404 int64_t size_z = nbatch;
405
406 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
407 input.scalar_type(), "reflection_pad2d_backward_out_template", [&] {
408
409 for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
410 int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
411 for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
412 int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));
413
414 dim3 grid_size(at::ceil_div(output_plane_size, static_cast<int64_t>(256)), block_y_size, block_z_size);
415
416 reflection_pad2d_backward_out_kernel<<<
417 grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
418 grad_input.mutable_data_ptr<scalar_t>(), grad_output.const_data_ptr<scalar_t>(),
419 input_w, input_h,
420 pad_t, pad_b, pad_l, pad_r, block_y, block_z, nplane);
421 C10_CUDA_KERNEL_LAUNCH_CHECK();
422 }
423 }
424 }
425 );
426 }
427
428 } // namespace
429
TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)430 TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
431 (const Tensor& input_, IntArrayRef padding, const Tensor& output) {
432 TORCH_CHECK(
433 canUse32BitIndexMath(input_),
434 "input tensor must fit into 32-bit index math");
435
436 if (output.numel() == 0) {
437 return;
438 }
439
440 int64_t dim_plane = 0;
441 int64_t dim_w = 1;
442 int64_t nbatch = 1;
443
444 if (input_.ndimension() == 3) {
445 nbatch = input_.size(0);
446 dim_plane++;
447 dim_w++;
448 }
449
450 int64_t pad_l = padding[0];
451 int64_t pad_r = padding[1];
452
453 int64_t nplane = input_.size(dim_plane);
454 int64_t input_w = input_.size(dim_w);
455 int64_t output_w = input_w + pad_l + pad_r;
456
457 dim3 block_size(output_w > 256 ? 256 : output_w);
458 dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
459
460 Tensor input = input_.contiguous();
461
462 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
463 kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
464 reflection_pad1d_out_kernel<<<
465 grid_size,
466 block_size,
467 0,
468 at::cuda::getCurrentCUDAStream()>>>(
469 input.const_data_ptr<scalar_t>(),
470 output.mutable_data_ptr<scalar_t>(),
471 input_w,
472 pad_l,
473 pad_r);
474 C10_CUDA_KERNEL_LAUNCH_CHECK();
475 });
476 }
477
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)478 TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,
479 const Tensor& input,
480 IntArrayRef padding,
481 const Tensor& grad_input) {
482 // See Note [Writing Nondeterministic Operations]
483 // Nondeterministic because of atomicAdd usage
484 globalContext().alertNotDeterministic("reflection_pad1d_backward_out_cuda");
485 grad_input.zero_();
486
487 if (grad_input.numel() == 0) {
488 return;
489 }
490
491 TORCH_CHECK(canUse32BitIndexMath(input),
492 "input tensor must fit into 32-bit index math");
493
494 TORCH_CHECK(canUse32BitIndexMath(grad_output_),
495 "input tensor must fit into 32-bit index math");
496
497 int64_t dim_plane = 0;
498 int64_t dim_w = 1;
499 int64_t nbatch = 1;
500
501 if (input.ndimension() == 3) {
502 nbatch = input.size(0);
503 dim_plane++;
504 dim_w++;
505 }
506
507 int64_t pad_l = padding[0];
508 int64_t pad_r = padding[1];
509
510 int64_t nplane = input.size(dim_plane);
511 int64_t input_w = input.size(dim_w);
512 int64_t output_w = input_w + pad_l + pad_r;
513
514 Tensor grad_output = grad_output_.contiguous();
515
516 dim3 block_size(output_w > 256 ? 256 : output_w);
517 dim3 grid_size((int) ::ceil(output_w / 256.0), nplane, nbatch);
518
519 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
520 grad_input.scalar_type(), "reflection_pad1d_backward_out_cuda", [&] {
521 reflection_pad1d_backward_out_kernel<<<
522 grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
523 grad_input.mutable_data_ptr<scalar_t>(), grad_output.const_data_ptr<scalar_t>(),
524 input_w, pad_l, pad_r);
525 C10_CUDA_KERNEL_LAUNCH_CHECK();
526 }
527 );
528 }
529
reflection_pad2d_out_cuda(const Tensor & input,IntArrayRef padding,Tensor & output)530 Tensor& reflection_pad2d_out_cuda(const Tensor& input, IntArrayRef padding,
531 Tensor& output) {
532 reflection_pad2d_out_template(output, input, padding);
533 return output;
534 }
535
reflection_pad2d_cuda(const Tensor & input,IntArrayRef padding)536 Tensor reflection_pad2d_cuda(const Tensor& input, IntArrayRef padding) {
537 auto output = at::empty({0}, input.options());
538 reflection_pad2d_out_template(output, input, padding);
539 return output;
540 }
541
reflection_pad2d_backward_out_cuda(const Tensor & grad_output,const Tensor & input,IntArrayRef padding,Tensor & grad_input)542 Tensor& reflection_pad2d_backward_out_cuda(const Tensor& grad_output,
543 const Tensor& input,
544 IntArrayRef padding,
545 Tensor& grad_input) {
546 // See Note [Writing Nondeterministic Operations]
547 // Nondeterministic because of atomicAdd usage
548 globalContext().alertNotDeterministic("reflection_pad2d_backward_out_cuda");
549 grad_input.resize_as_(input);
550 grad_input.zero_();
551 reflection_pad2d_backward_out_template(
552 grad_input, grad_output, input, padding);
553 return grad_input;
554 }
555
reflection_pad2d_backward_cuda(const Tensor & grad_output,const Tensor & input,IntArrayRef padding)556 Tensor reflection_pad2d_backward_cuda(
557 const Tensor& grad_output,
558 const Tensor& input,
559 IntArrayRef padding) {
560 // See Note [Writing Nondeterministic Operations]
561 // Nondeterministic because of atomicAdd usage
562 globalContext().alertNotDeterministic("reflection_pad2d_backward_cuda");
563 auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
564 reflection_pad2d_backward_out_template(
565 grad_input, grad_output, input, padding);
566 return grad_input;
567 }
568
569
TORCH_IMPL_FUNC(reflection_pad3d_out_cuda)570 TORCH_IMPL_FUNC(reflection_pad3d_out_cuda) (
571 const Tensor& input_, IntArrayRef padding, const Tensor& output
572 ) {
573 TORCH_CHECK(
574 canUse32BitIndexMath(input_),
575 "input tensor must fit into 32-bit index math");
576
577 if (output.numel() == 0) {
578 return;
579 }
580
581 int64_t pad_left = padding[0];
582 int64_t pad_top = padding[2];
583 int64_t pad_front = padding[4];
584
585 auto input = input_.contiguous();
586 bool batch_mode = (input.dim() == 5);
587
588 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
589 input.scalar_type(), "reflection_pad3d_out_cuda", [&] {
590 auto input_inner = input;
591 auto output_inner = output;
592 if (!batch_mode) {
593 // non-batch mode
594 input_inner = input.unsqueeze(0);
595 output_inner = output.unsqueeze(0);
596 }
597
598 auto input_packed = input_inner.packed_accessor64<const scalar_t, 5>();
599 auto output_packed = output_inner.packed_accessor64<scalar_t, 5>();
600
601 int64_t output_plane_size = output_packed.size(2) * output_packed.size(3) * output_packed.size(4);
602 int64_t size_y = input_packed.size(1);
603 int64_t size_z = input_packed.size(0);
604 dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
605
606 for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
607 int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
608 for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
609 int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));
610
611 dim3 grid_size(at::ceil_div(output_plane_size, static_cast<int64_t>(256)), \
612 block_y_size, block_z_size);
613
614 reflection_pad3d_out_kernel<<<
615 grid_size, block_size,0, at::cuda::getCurrentCUDAStream()>>>(
616 input_packed, output_packed, pad_left, pad_top, pad_front,
617 block_y, block_z);
618 C10_CUDA_KERNEL_LAUNCH_CHECK();
619 }
620 }
621 });
622 }
623
TORCH_IMPL_FUNC(reflection_pad3d_backward_out_cuda)624 TORCH_IMPL_FUNC(reflection_pad3d_backward_out_cuda) (
625 const Tensor& grad_output, const Tensor& input, IntArrayRef padding,
626 const Tensor& grad_input) {
627 globalContext().alertNotDeterministic("reflection_pad3d_backward_out_cuda");
628 TORCH_CHECK(canUse32BitIndexMath(input), "input tensor must fit into 32-bit index math");
629 TORCH_CHECK(canUse32BitIndexMath(grad_output), "input tensor must fit into 32-bit index math");
630
631 if (grad_input.numel() == 0) {
632 return;
633 }
634 grad_input.zero_();
635
636 int64_t pad_left = padding[0];
637 int64_t pad_top = padding[2];
638 int64_t pad_front = padding[4];
639
640 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
641 input.scalar_type(), "reflection_pad3d_backward_out_cuda", [&] {
642 auto grad_input_ = grad_input;
643 auto grad_output_ = grad_output;
644 if (input.dim() == 4) {
645 // non-batch mode
646 grad_input_ = grad_input.unsqueeze(0);
647 grad_output_ = grad_output.unsqueeze(0);
648 }
649
650 auto grad_input_packed = grad_input_.packed_accessor64<scalar_t, 5>();
651 auto grad_output_packed = grad_output_.packed_accessor64<const scalar_t, 5>();
652
653 int64_t output_plane_size = grad_output_packed.size(2) *
654 grad_output_packed.size(3) * grad_output_packed.size(4);
655 int64_t size_y = grad_input_packed.size(1);
656 int64_t size_z = grad_input_packed.size(0);
657 dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
658
659 for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
660 int64_t block_y_size = std::min(size_y - block_y, static_cast<int64_t>(65535));
661 for (int64_t block_z = 0; block_z < size_z; block_z += 65535) {
662 int64_t block_z_size = std::min(size_z - block_z, static_cast<int64_t>(65535));
663
664 dim3 grid_size(at::ceil_div(output_plane_size, static_cast<int64_t>(256)), \
665 block_y_size, block_z_size);
666
667 reflection_pad3d_backward_out_kernel<<<
668 grid_size, block_size,0, at::cuda::getCurrentCUDAStream()>>>(
669 grad_input_packed, grad_output_packed, pad_left, pad_top, pad_front,
670 block_y, block_z);
671 C10_CUDA_KERNEL_LAUNCH_CHECK();
672 }
673 }
674 });
675 }
676
677 } // namespace at::native
678