1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ceil_div.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/cuda/Atomic.cuh>
6 #include <ATen/cuda/detail/IndexUtils.cuh>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/Utils.h>
10 #include <c10/util/Exception.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/empty_like.h>
17 #include <ATen/ops/replication_pad1d_native.h>
18 #include <ATen/ops/replication_pad1d_backward_native.h>
19 #include <ATen/ops/replication_pad2d_native.h>
20 #include <ATen/ops/replication_pad2d_backward_native.h>
21 #include <ATen/ops/replication_pad3d_native.h>
22 #include <ATen/ops/replication_pad3d_backward_native.h>
23 #endif
24
25 #include <algorithm>
26 #include <cfloat>
27 #include <cmath>
28
29
30 namespace at::native {
imin(int a,int b)31 __host__ __device__ __forceinline__ int imin(int a, int b) {
32 return a > b ? b : a;
33 }
34
imax(int a,int b)35 __host__ __device__ __forceinline__ int imax(int a, int b) {
36 return a > b ? a : b;
37 }
38
39 namespace {
40 template <typename scalar_t>
replication_pad_forward_kernel1d(PackedTensorAccessor64<const scalar_t,3> input,PackedTensorAccessor64<scalar_t,3> output,const int padL,const int y_shift,const int z_shift)41 __global__ void replication_pad_forward_kernel1d(
42 PackedTensorAccessor64<const scalar_t, 3> input,
43 PackedTensorAccessor64<scalar_t, 3> output,
44 const int padL,
45 const int y_shift,
46 const int z_shift) {
47 const int64_t outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
48 const int64_t plane = blockIdx.y + y_shift;
49 const int64_t batch = blockIdx.z + z_shift;
50 if (outputPointId >= output.size(2)) {
51 return;
52 }
53 const auto outputPointX = outputPointId % output.size(2);
54
55 const int iStartX = imax(0, -padL);
56 const int oStartX = imax(0, padL);
57
58 const auto inputPointX = imin(imax(padL, outputPointX), input.size(2) + padL - 1) - oStartX + iStartX;
59
60 scalar_t valueToCopy = input[batch][plane][inputPointX];
61 output[batch][plane][outputPointX] = valueToCopy;
62 }
63
64 template <typename scalar_t>
replication_pad_backward_kernel(PackedTensorAccessor64<scalar_t,3> gradInput,PackedTensorAccessor64<const scalar_t,3> gradOutput,const int padL,const int y_shift,const int z_shift)65 __global__ void replication_pad_backward_kernel(
66 PackedTensorAccessor64<scalar_t, 3> gradInput,
67 PackedTensorAccessor64<const scalar_t, 3> gradOutput,
68 const int padL,
69 const int y_shift,
70 const int z_shift) {
71 const int64_t outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
72 const int64_t plane = blockIdx.y + y_shift;
73 const int64_t batch = blockIdx.z + z_shift;
74 if (outputPointId >= gradOutput.size(2)) {
75 return;
76 }
77 const auto outputPointX = outputPointId % gradOutput.size(2);
78
79 const int iStartX = imax(0, -padL);
80 const int oStartX = imax(0, padL);
81
82 const auto inputPointX = imin(imax(padL, outputPointX), gradInput.size(2) + padL - 1) - oStartX + iStartX;
83
84 scalar_t valueToCopy = gradOutput[batch][plane][outputPointX];
85 gpuAtomicAddNoReturn(&gradInput[batch][plane][inputPointX], valueToCopy);
86 }
87
88 template <typename scalar_t>
replication_pad_forward_kernel2d(PackedTensorAccessor64<const scalar_t,4> input,PackedTensorAccessor64<scalar_t,4> output,const int padT,const int padL,const int y_shift,const int z_shift)89 __global__ void replication_pad_forward_kernel2d(
90 PackedTensorAccessor64<const scalar_t, 4> input,
91 PackedTensorAccessor64<scalar_t, 4> output,
92 const int padT,
93 const int padL,
94 const int y_shift,
95 const int z_shift) {
96 const int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
97 const int plane = blockIdx.y + y_shift;
98 const int batch = blockIdx.z + z_shift;
99 if (outputPointId >= output.size(2) * output.size(3)) {
100 return;
101 }
102 const int outputPointX = outputPointId % output.size(3);
103 const int outputPointY = outputPointId / output.size(3);
104
105 const int iStartX = imax(0, -padL);
106 const int iStartY = imax(0, -padT);
107 const int oStartX = imax(0, padL);
108 const int oStartY = imax(0, padT);
109
110 const int inputPointX = imin(imax(padL, outputPointX), input.size(3) + padL - 1) - oStartX + iStartX;
111 const int inputPointY = imin(imax(padT, outputPointY), input.size(2) + padT - 1) - oStartY + iStartY;
112
113 scalar_t valueToCopy = input[batch][plane][inputPointY][inputPointX];
114 output[batch][plane][outputPointY][outputPointX] = valueToCopy;
115 }
116
117 template <typename scalar_t>
replication_pad_backward_kernel(PackedTensorAccessor64<scalar_t,4> gradInput,PackedTensorAccessor64<const scalar_t,4> gradOutput,const int padT,const int padL,const int y_shift,const int z_shift)118 __global__ void replication_pad_backward_kernel(
119 PackedTensorAccessor64<scalar_t, 4> gradInput,
120 PackedTensorAccessor64<const scalar_t, 4> gradOutput,
121 const int padT,
122 const int padL,
123 const int y_shift,
124 const int z_shift) {
125 const int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
126 const int plane = blockIdx.y + y_shift;
127 const int batch = blockIdx.z + z_shift;
128 if (outputPointId >= gradOutput.size(2) * gradOutput.size(3)) {
129 return;
130 }
131 const int outputPointX = outputPointId % gradOutput.size(3);
132 const int outputPointY = outputPointId / gradOutput.size(3);
133
134 const int iStartX = imax(0, -padL);
135 const int iStartY = imax(0, -padT);
136 const int oStartX = imax(0, padL);
137 const int oStartY = imax(0, padT);
138
139 const int inputPointX = imin(imax(padL, outputPointX), gradInput.size(3) + padL - 1) - oStartX + iStartX;
140 const int inputPointY = imin(imax(padT, outputPointY), gradInput.size(2) + padT - 1) - oStartY + iStartY;
141
142 scalar_t valueToCopy = gradOutput[batch][plane][outputPointY][outputPointX];
143 gpuAtomicAddNoReturn(&gradInput[batch][plane][inputPointY][inputPointX], valueToCopy);
144 }
145
146 template <typename scalar_t>
replication_pad_forward_kernel3d(PackedTensorAccessor64<const scalar_t,5> input,PackedTensorAccessor64<scalar_t,5> output,const int pfront,const int ptop,const int pleft,const int y_shift,const int z_shift)147 __global__ void replication_pad_forward_kernel3d(
148 PackedTensorAccessor64<const scalar_t, 5> input,
149 PackedTensorAccessor64<scalar_t, 5> output,
150 const int pfront,
151 const int ptop,
152 const int pleft,
153 const int y_shift,
154 const int z_shift) {
155 const int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
156 const int plane = blockIdx.y + y_shift;
157 const int batch = blockIdx.z + z_shift;
158 if (outputPointId >= (output.size(2) * output.size(3) *
159 output.size(4))) {
160 return;
161 }
162 const int outputPointX = outputPointId % output.size(4);
163 const int outputPointY = (outputPointId / output.size(4)) % output.size(3);
164 const int outputPointZ = outputPointId / (output.size(3) * output.size(4));
165
166 const int iStartX = imax(0, -pleft);
167 const int iStartY = imax(0, -ptop);
168 const int iStartZ = imax(0, -pfront);
169 const int oStartX = imax(0, pleft);
170 const int oStartY = imax(0, ptop);
171 const int oStartZ = imax(0, pfront);
172
173 const int inputPointX = imin(imax(pleft, outputPointX),
174 input.size(4) + pleft - 1) - oStartX + iStartX;
175 const int inputPointY = imin(imax(ptop, outputPointY),
176 input.size(3) + ptop - 1) - oStartY + iStartY;
177 const int inputPointZ = imin(imax(pfront, outputPointZ),
178 input.size(2) + pfront - 1) - oStartZ + iStartZ;
179
180 scalar_t valueToCopy =
181 input[batch][plane][inputPointZ][inputPointY][inputPointX];
182 output[batch][plane][outputPointZ][outputPointY][outputPointX] = valueToCopy;
183 }
184
185 template <typename scalar_t>
replication_pad_backward_kernel(PackedTensorAccessor64<scalar_t,5> gradInput,PackedTensorAccessor64<const scalar_t,5> gradOutput,const int pfront,const int ptop,const int pleft,const int y_shift,const int z_shift)186 __global__ void replication_pad_backward_kernel(
187 PackedTensorAccessor64<scalar_t, 5> gradInput,
188 PackedTensorAccessor64<const scalar_t, 5> gradOutput,
189 const int pfront,
190 const int ptop,
191 const int pleft,
192 const int y_shift,
193 const int z_shift) {
194 const int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
195 const int plane = blockIdx.y + y_shift;
196 const int batch = blockIdx.z + z_shift;
197
198 if (outputPointId >= (gradOutput.size(2) * gradOutput.size(3) *
199 gradOutput.size(4))) {
200 return;
201 }
202 const int outputPointX = outputPointId % gradOutput.size(4);
203 const int outputPointY = (outputPointId / gradOutput.size(4)) %
204 gradOutput.size(3);
205 const int outputPointZ = outputPointId / (gradOutput.size(3) *
206 gradOutput.size(4));
207
208 const int iStartX = imax(0, -pleft);
209 const int iStartY = imax(0, -ptop);
210 const int iStartZ = imax(0, -pfront);
211 const int oStartX = imax(0, pleft);
212 const int oStartY = imax(0, ptop);
213 const int oStartZ = imax(0, pfront);
214
215 const int inputPointX = imin(imax(pleft, outputPointX),
216 gradInput.size(4) + pleft - 1) - oStartX + iStartX;
217 const int inputPointY = imin(imax(ptop, outputPointY),
218 gradInput.size(3) + ptop - 1) - oStartY + iStartY;
219 const int inputPointZ = imin(imax(pfront, outputPointZ),
220 gradInput.size(2) + pfront - 1) - oStartZ + iStartZ;
221
222 scalar_t valueToCopy =
223 gradOutput[batch][plane][outputPointZ][outputPointY][outputPointX];
224 gpuAtomicAddNoReturn(&gradInput[batch][plane][inputPointZ][inputPointY][inputPointX],
225 valueToCopy);
226 }
227
replication_pad2d_backward_out_cuda_template(Tensor & gradInput,const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize)228 void replication_pad2d_backward_out_cuda_template(
229 Tensor& gradInput,
230 const Tensor& gradOutput,
231 const Tensor& input,
232 IntArrayRef paddingSize)
233 {
234
235 TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
236 "input tensor must fit into 32-bit index math");
237 TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(gradOutput),
238 "output gradient tensor must fit into 32-bit index math");
239 TORCH_CHECK(paddingSize.size() == 4, "padding Size is expected to be 4");
240
241 const auto padL = paddingSize[0];
242 const auto padR = paddingSize[1];
243 const auto padT = paddingSize[2];
244 const auto padB = paddingSize[3];
245 int dimh = 1;
246 int dimw = 2;
247
248 int numInputDims = input.dim();
249 if (numInputDims == 4) {
250 dimh++;
251 dimw++;
252 }
253 const auto iheight = input.size(dimh);
254 const auto iwidth = input.size(dimw);
255 const auto oheight = iheight + padT + padB;
256 const auto owidth = iwidth + padL + padR;
257
258 TORCH_CHECK(owidth == gradOutput.size(dimw),
259 "gradOutput width unexpected. Expected: ", owidth, ", Got: ",
260 gradOutput.size(dimw));
261 TORCH_CHECK(oheight == gradOutput.size(dimh),
262 "gradOutput height unexpected. Expected: ", oheight, ", Got: ",
263 gradOutput.size(dimh));
264
265 gradInput.resize_as_(input);
266 if (gradInput.numel() == 0) {
267 return;
268 }
269 gradInput.zero_();
270
271 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
272 input.scalar_type(), "replication_pad2d_backward_cuda", [&] {
273
274 auto gradInput_ = gradInput;
275 auto gradOutput_ = gradOutput;
276 if (numInputDims == 3) {
277 gradInput_ = gradInput.unsqueeze(0);
278 gradOutput_ = gradOutput.unsqueeze(0);
279 }
280 auto devGradInput = gradInput_.packed_accessor64<scalar_t, 4>();
281 auto devGradOutput = gradOutput_.packed_accessor64<const scalar_t, 4>();
282
283 int64_t outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3);
284 int64_t size1 = devGradOutput.size(1);
285 int64_t size0 = devGradOutput.size(0);
286
287 for (int64_t block_y = 0; block_y < size1; block_y += 65535) {
288 int64_t block_y_size = std::min(size1 - block_y, static_cast<int64_t>(65535));
289 for (int64_t block_z = 0; block_z < size0; block_z += 65535) {
290 int64_t block_z_size = std::min(size0 - block_z, static_cast<int64_t>(65535));
291
292 dim3 gridSize(ceil_div(outputPlaneSize, static_cast<int64_t>(256)), block_y_size, block_z_size);
293 dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
294
295 replication_pad_backward_kernel <<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
296 devGradInput, devGradOutput, padT, padL, block_y, block_z);
297 C10_CUDA_KERNEL_LAUNCH_CHECK();
298 }
299 }
300 }
301 );
302 }
303
shapeAndGradOutputCheck3d(const Tensor & input,const Tensor & gradOutput,int pleft,int pright,int ptop,int pbottom,int pfront,int pback)304 static inline void shapeAndGradOutputCheck3d(
305 const Tensor& input,
306 const Tensor& gradOutput,
307 int pleft, int pright,
308 int ptop, int pbottom,
309 int pfront, int pback) {
310 TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
311 "input tensor must fit into 32-bit index math");
312 int numInputDims = input.dim();
313
314 bool valid_dims = input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0;
315 TORCH_CHECK(
316 (numInputDims == 4 && valid_dims) ||
317 (numInputDims == 5 && valid_dims && input.size(4) != 0),
318 "Expected 4D or 5D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
319 input.sizes());
320
321 int planeDim = 0;
322 int dimd = 1;
323 int dimh = 2;
324 int dimw = 3;
325 if (numInputDims == 5) {
326 planeDim++;
327 dimd++;
328 dimh++;
329 dimw++;
330 }
331
332 int numPlanes = input.size(planeDim);
333 int idepth = input.size(dimd);
334 int iheight = input.size(dimh);
335 int iwidth = input.size(dimw);
336 int odepth = idepth + pfront + pback;
337 int oheight = iheight + ptop + pbottom;
338 int owidth = iwidth + pleft + pright;
339 TORCH_CHECK(owidth >= 1 || oheight >= 1 || odepth >= 1,
340 "input (D: ", idepth, " H: ", iheight, ", W: ", iwidth,
341 ") is too small."
342 " Calculated output D: ", odepth, " H: ", oheight, " W: ", owidth);
343
344 TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(gradOutput),
345 "output gradient tensor must fit into 32-bit index math");
346
347 TORCH_CHECK(numPlanes == gradOutput.size(planeDim),
348 "gradOutput width unexpected. Expected: ", numPlanes, ", Got: ",
349 gradOutput.size(planeDim));
350 TORCH_CHECK(owidth == gradOutput.size(dimw),
351 "gradOutput width unexpected. Expected: ", owidth, ", Got: ",
352 gradOutput.size(dimw));
353 TORCH_CHECK(oheight == gradOutput.size(dimh),
354 "gradOutput height unexpected. Expected: ", oheight, ", Got: ",
355 gradOutput.size(dimh));
356 TORCH_CHECK(odepth == gradOutput.size(dimd),
357 "gradOutput depth unexpected. Expected: ", odepth, ", Got: ",
358 gradOutput.size(dimd));
359 }
360
replication_pad3d_backward_out_cuda_template(Tensor & gradInput,const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize)361 void replication_pad3d_backward_out_cuda_template(
362 Tensor& gradInput,
363 const Tensor& gradOutput,
364 const Tensor& input,
365 IntArrayRef paddingSize)
366 {
367 TORCH_CHECK(paddingSize.size() == 6, "padding Size is expected to be 6");
368 const auto pleft = paddingSize[0];
369 const auto pright = paddingSize[1];
370 const auto ptop = paddingSize[2];
371 const auto pbottom = paddingSize[3];
372 const auto pfront = paddingSize[4];
373 const auto pback = paddingSize[5];
374 shapeAndGradOutputCheck3d(input, gradOutput, pleft, pright, ptop,
375 pbottom, pfront, pback);
376
377
378 int numInputDims = input.dim();
379
380 gradInput.resize_as_(input);
381 if (gradInput.numel() == 0) {
382 return;
383 }
384 gradInput.zero_();
385
386 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
387 input.scalar_type(), "replication_pad3d_backward_cuda", [&] {
388 auto gradInput_ = gradInput;
389 auto gradOutput_ = gradOutput;
390 if (numInputDims == 4) {
391 gradInput_ = gradInput.unsqueeze(0);
392 gradOutput_ = gradOutput.unsqueeze(0);
393 }
394 auto devGradInput = gradInput_.packed_accessor64<scalar_t, 5>();
395 auto devGradOutput = gradOutput_.packed_accessor64<const scalar_t, 5>();
396
397 const int64_t outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3) * devGradOutput.size(4);
398 const int64_t size1 = devGradOutput.size(1);
399 const int64_t size0 = devGradOutput.size(0);
400
401 for (int64_t block_y = 0; block_y < size1; block_y += 65535) {
402 int64_t block_y_size = std::min(size1 - block_y, static_cast<int64_t>(65535));
403 for (int64_t block_z = 0; block_z < size0; block_z += 65535) {
404 int64_t block_z_size = std::min(size0 - block_z, static_cast<int64_t>(65535));
405
406 dim3 gridSize(ceil_div(outputPlaneSize, static_cast<int64_t>(256)), block_y_size, block_z_size);
407 dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
408
409 replication_pad_backward_kernel <<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
410 devGradInput, devGradOutput, pfront, ptop, pleft, block_y, block_z);
411 C10_CUDA_KERNEL_LAUNCH_CHECK();
412 }
413 }
414 }
415 );
416 }
417 } // namespace
418
TORCH_IMPL_FUNC(replication_pad1d_out_cuda)419 TORCH_IMPL_FUNC(replication_pad1d_out_cuda) (
420 const Tensor& input, IntArrayRef paddingSize, const Tensor& output
421 ) {
422 TORCH_CHECK(input.numel() < std::numeric_limits<int64_t>::max(),
423 "replication_pad1d only supports input tensors with less than 2^63 - 1 elements");
424
425 int64_t padL = paddingSize[0];
426 int64_t padR = paddingSize[1];
427 constexpr int64_t planeDim = -2;
428 constexpr int64_t dimw = -1;
429
430 int numInputDims = input.ndimension();
431
432 int64_t numPlanes = input.size(planeDim);
433 int64_t inputW = input.size(dimw);
434 int64_t outputW = output.size(dimw);
435
436 if (input.numel() == 0) {
437 return;
438 }
439
440 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
441 input.scalar_type(), "replication_pad1d_cuda", [&] {
442 at::Tensor input_ = input;
443 at::Tensor output_ = output;
444 if (numInputDims == 2) {
445 input_ = input.unsqueeze(0);
446 output_ = output.unsqueeze(0);
447 }
448
449 auto devInput = input_.packed_accessor64<const scalar_t, 3>();
450 auto devOutput = output_.packed_accessor64<scalar_t, 3>();
451
452 int64_t outputPlaneSize = devOutput.size(2);
453 int64_t size1 = devOutput.size(1);
454 int64_t size0 = devOutput.size(0);
455
456 for (int64_t block_y = 0; block_y < size1; block_y += 65535) {
457 int64_t block_y_size = std::min(size1 - block_y, static_cast<int64_t>(65535));
458 for (int64_t block_z = 0; block_z < size0; block_z += 65535) {
459 int64_t block_z_size = std::min(size0 - block_z, static_cast<int64_t>(65535));
460
461 dim3 gridSize(ceil_div(outputPlaneSize, static_cast<int64_t>(256)), block_y_size, block_z_size);
462 dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
463
464 replication_pad_forward_kernel1d <<<gridSize, blockSize, 0,
465 at::cuda::getCurrentCUDAStream()>>>(devInput, devOutput, padL, block_y, block_z);
466 C10_CUDA_KERNEL_LAUNCH_CHECK();
467 }
468 }
469 }
470 );
471 }
472
TORCH_IMPL_FUNC(replication_pad1d_backward_out_cuda)473 TORCH_IMPL_FUNC(replication_pad1d_backward_out_cuda) (
474 const Tensor& gradOutput,
475 const Tensor& input,
476 IntArrayRef paddingSize,
477 const Tensor& gradInput
478 ) {
479 // See Note [Writing Nondeterministic Operations]
480 // Nondeterministic because of atomicAdd usage
481 globalContext().alertNotDeterministic("replication_pad1d_backward_cuda");
482
483 TORCH_CHECK(input.numel() < std::numeric_limits<int64_t>::max(),
484 "replication_pad1d only supports input tensors with less than 2^63 - 1 elements");
485 TORCH_CHECK(gradOutput.numel() < std::numeric_limits<int64_t>::max(),
486 "replication_pad1d only supports output tensors with less than 2^63 - 1 elements");
487
488 const int64_t padL = paddingSize[0];
489 int64_t dimw = 1;
490
491 int64_t numInputDims = input.ndimension();
492 if (numInputDims == 3) {
493 dimw++;
494 }
495 int64_t iwidth = input.size(dimw);
496
497 if (gradInput.numel() == 0) {
498 return;
499 }
500 gradInput.zero_();
501
502 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
503 input.scalar_type(), "replication_pad1d_backward_cuda", [&] {
504
505 auto gradInput_ = gradInput;
506 auto gradOutput_ = gradOutput;
507 if (numInputDims == 2) {
508 gradInput_ = gradInput.unsqueeze(0);
509 gradOutput_ = gradOutput.unsqueeze(0);
510 }
511 auto devGradInput = gradInput_.packed_accessor64<scalar_t, 3>();
512 auto devGradOutput = gradOutput_.packed_accessor64<const scalar_t, 3>();
513
514 int64_t outputPlaneSize = devGradOutput.size(2);
515 int64_t size1 = devGradOutput.size(1);
516 int64_t size0 = devGradOutput.size(0);
517
518 for (int64_t block_y = 0; block_y < size1; block_y += 65535) {
519 int64_t block_y_size = std::min(size1 - block_y, static_cast<int64_t>(65535));
520 for (int64_t block_z = 0; block_z < size0; block_z += 65535) {
521 int64_t block_z_size = std::min(size0 - block_z, static_cast<int64_t>(65535));
522
523 dim3 gridSize(ceil_div(outputPlaneSize, static_cast<int64_t>(256)), block_y_size, block_z_size);
524 dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
525
526 replication_pad_backward_kernel <<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
527 devGradInput, devGradOutput, padL, block_y, block_z);
528 C10_CUDA_KERNEL_LAUNCH_CHECK();
529 }
530 }
531 });
532 }
533
TORCH_IMPL_FUNC(replication_pad2d_out_cuda)534 TORCH_IMPL_FUNC(replication_pad2d_out_cuda) (
535 const Tensor& input, IntArrayRef paddingSize, const Tensor& output
536 ) {
537 TORCH_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
538 "input tensor must fit into 32-bit index math");
539 if (input.numel() == 0) {
540 return;
541 }
542 const auto padL = paddingSize[0];
543 // const auto padR = paddingSize[1]; // This padding is ignored here
544 const auto padT = paddingSize[2];
545 // const auto padB = paddingSize[3]; // This padding is ignored here
546 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
547 input.scalar_type(), "replication_pad2d_cuda", [&] {
548 at::Tensor input_ = input;
549 at::Tensor output_ = output;
550 if (input.dim() == 3) {
551 input_ = input.unsqueeze(0);
552 output_ = output.unsqueeze(0);
553 }
554 auto devInput = input_.packed_accessor64<const scalar_t, 4>();
555 auto devOutput = output_.packed_accessor64<scalar_t, 4>();
556 int64_t outputPlaneSize = devOutput.size(2) * devOutput.size(3);
557 int64_t size1 = devOutput.size(1);
558 int64_t size0 = devOutput.size(0);
559 for (int64_t block_y = 0; block_y < size1; block_y += 65535) {
560 int64_t block_y_size = std::min(size1 - block_y, static_cast<int64_t>(65535));
561 for (int64_t block_z = 0; block_z < size0; block_z += 65535) {
562 int64_t block_z_size = std::min(size0 - block_z, static_cast<int64_t>(65535));
563 dim3 gridSize(ceil_div(outputPlaneSize, static_cast<int64_t>(256)), block_y_size, block_z_size);
564 dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
565 replication_pad_forward_kernel2d <<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
566 devInput, devOutput, padT, padL, block_y, block_z);
567 C10_CUDA_KERNEL_LAUNCH_CHECK();
568 }
569 }
570 }
571 );
572 }
573
replication_pad2d_backward_out_cuda(const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize,Tensor & gradInput)574 Tensor& replication_pad2d_backward_out_cuda(const Tensor& gradOutput,
575 const Tensor& input,
576 IntArrayRef paddingSize,
577 Tensor& gradInput)
578 {
579 // See Note [Writing Nondeterministic Operations]
580 // Nondeterministic because of atomicAdd usage
581 globalContext().alertNotDeterministic("replication_pad2d_backward_out_cuda");
582 replication_pad2d_backward_out_cuda_template(
583 gradInput, gradOutput, input, paddingSize);
584 return gradInput;
585 }
586
replication_pad2d_backward_cuda(const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize)587 Tensor replication_pad2d_backward_cuda(
588 const Tensor& gradOutput,
589 const Tensor& input,
590 IntArrayRef paddingSize)
591 {
592 // See Note [Writing Nondeterministic Operations]
593 // Nondeterministic because of atomicAdd usage
594 globalContext().alertNotDeterministic("replication_pad2d_backward_cuda");
595 auto gradInput = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
596 replication_pad2d_backward_out_cuda_template(
597 gradInput, gradOutput, input, paddingSize);
598 return gradInput;
599 }
600
601
TORCH_IMPL_FUNC(replication_pad3d_out_cuda)602 TORCH_IMPL_FUNC(replication_pad3d_out_cuda) (
603 const Tensor& input, IntArrayRef paddingSize, const Tensor& output
604 ) {
605 const auto pleft = paddingSize[0];
606 // const auto pright = paddingSize[1]; // Ignored here
607 const auto ptop = paddingSize[2];
608 // const auto pbottom = paddingSize[3]; // Ignored here
609 const auto pfront = paddingSize[4];
610 // const auto pback = paddingSize[5]; // Ignored here
611
612 int planeDim = 0;
613 int dimd = 1;
614 int dimh = 2;
615 int dimw = 3;
616
617 int numInputDims = input.dim();
618
619 if (numInputDims == 5) {
620 planeDim++;
621 dimd++;
622 dimh++;
623 dimw++;
624 }
625
626 const auto numPlanes = input.size(planeDim);
627 const auto inputD = input.size(dimd);
628 const auto inputH = input.size(dimh);
629 const auto inputW = input.size(dimw);
630 const auto outputD = output.size(dimd);
631 const auto outputH = output.size(dimh);
632 const auto outputW = output.size(dimw);
633
634 if (input.numel() == 0) {
635 return;
636 }
637
638 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
639 input.scalar_type(), "replication_pad3d_cuda", [&] {
640 at::Tensor input_ = input;
641 at::Tensor output_ = output;
642 if (numInputDims == 4) {
643 input_ = input.unsqueeze(0);
644 output_ = output.unsqueeze(0);
645 }
646
647 auto devInput = input_.packed_accessor64<const scalar_t, 5>();
648 auto devOutput = output_.packed_accessor64<scalar_t, 5>();
649
650 const int64_t outputPlaneSize = devOutput.size(2) * devOutput.size(3) * devOutput.size(4);
651 const int64_t size1 = devOutput.size(1);
652 const int64_t size0 = devOutput.size(0);
653
654 for (int64_t block_y = 0; block_y < size1; block_y += 65535) {
655 int64_t block_y_size = std::min(size1 - block_y, static_cast<int64_t>(65535));
656 for (int64_t block_z = 0; block_z < size0; block_z += 65535) {
657 int64_t block_z_size = std::min(size0 - block_z, static_cast<int64_t>(65535));
658
659 dim3 gridSize(ceil_div(outputPlaneSize, static_cast<int64_t>(256)), block_y_size, block_z_size);
660 dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
661
662 replication_pad_forward_kernel3d <<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
663 devInput, devOutput, pfront, ptop, pleft, block_y, block_z);
664 C10_CUDA_KERNEL_LAUNCH_CHECK();
665 }
666 }
667 }
668 );
669 }
670
replication_pad3d_backward_out_cuda(const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize,Tensor & gradInput)671 Tensor& replication_pad3d_backward_out_cuda(const Tensor& gradOutput,
672 const Tensor& input,
673 IntArrayRef paddingSize,
674 Tensor& gradInput)
675 {
676 // See Note [Writing Nondeterministic Operations]
677 // Nondeterministic because of atomicAdd usage
678 globalContext().alertNotDeterministic("replication_pad3d_backward_out_cuda");
679 replication_pad3d_backward_out_cuda_template(
680 gradInput, gradOutput, input, paddingSize);
681 return gradInput;
682 }
683
replication_pad3d_backward_cuda(const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize)684 Tensor replication_pad3d_backward_cuda(
685 const Tensor& gradOutput,
686 const Tensor& input,
687 IntArrayRef paddingSize)
688 {
689 // See Note [Writing Nondeterministic Operations]
690 // Nondeterministic because of atomicAdd usage
691 globalContext().alertNotDeterministic("replication_pad3d_backward_cuda");
692 auto gradInput = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
693 replication_pad3d_backward_out_cuda_template(
694 gradInput, gradOutput, input, paddingSize);
695 return gradInput;
696 }
697
698 } // at::native
699