1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/cpu/vec/vec.h>
7 #include <ATen/native/Padding.h>
8 #include <ATen/native/cpu/utils.h>
9 #include <c10/util/irange.h>
10
11 namespace at::native {
12
13 namespace {
14
15 struct PaddingParams {
16 int ndim;
17 int64_t nbatch;
18 int64_t channels;
19
20 // use vectorized logic on width when output index is in [pad, input_width + pad),
21 // applies only to Channels First format when pad_l and pad_r are both positive.
22 bool is_padding_positive_width;
23
24 c10::SmallVector<int64_t, 3u> ishape;
25 c10::SmallVector<int64_t, 3u> oshape;
26 c10::SmallVector<int64_t, 3u> pads;
27 c10::SmallVector<int64_t, 3u> offsets;
28
PaddingParamsat::native::__anon2f59a4fa0111::PaddingParams29 PaddingParams(const Tensor& input, const Tensor& output, IntArrayRef padding) {
30 ndim = padding.size() / 2;
31
32 bool is_batch = input.dim() == ndim + 2;
33 nbatch = is_batch ? input.size(0) : 1;
34 channels = is_batch ? input.size(1) : input.size(0);
35
36 is_padding_positive_width = padding[0] >= 0 && padding[1] >=0;
37
38 // handle sizes with batch-mode and non-batch-mode
39 int ind = is_batch ? 2 : 1;
40 for (const auto d : c10::irange(ndim)) {
41 ishape.emplace_back(input.size(ind + d));
42 oshape.emplace_back(output.size(ind + d));
43 }
44
45 // padding is organized in order of:
46 // { left, right, top, bottom, front, back }
47 //
48 // re-organize into order of:
49 // { depth, height, width}
50 //
51 if (ndim == 1) {
52 pads.emplace_back(padding[0]);
53 } else if (ndim == 2) {
54 pads.emplace_back(padding[2]);
55 pads.emplace_back(padding[0]);
56 } else {
57 pads.emplace_back(padding[4]);
58 pads.emplace_back(padding[2]);
59 pads.emplace_back(padding[0]);
60 }
61 for (const auto d : c10::irange(ndim)) {
62 int64_t pad = pads[d];
63 auto i_start = std::max(int64_t(0), -pad);
64 auto o_start = std::max(int64_t(0), pad);
65 offsets.emplace_back(i_start - o_start);
66 }
67 };
68 };
69
70 struct ReflectionPad {
indexat::native::__anon2f59a4fa0111::ReflectionPad71 static int64_t index(int64_t j, int64_t size, int64_t pad, int64_t offset) {
72 int64_t i;
73 if (j < pad) {
74 i = pad * 2 - j;
75 } else if (j >= pad && j < size + pad) {
76 i = j;
77 } else {
78 i = (size + pad - 1) * 2 - j;
79 }
80 return i + offset;
81 }
82 };
83
84 struct ReplicationPad {
indexat::native::__anon2f59a4fa0111::ReplicationPad85 static int64_t index(int64_t j, int64_t size, int64_t pad, int64_t offset) {
86 int64_t i;
87 if (j < pad) {
88 i = pad;
89 } else if (j >= pad && j < size + pad) {
90 i = j;
91 } else {
92 i = size + pad - 1;
93 }
94 return i + offset;
95 }
96 };
97
98 template <typename scalar_t>
copy_stub(scalar_t * out,const scalar_t * in,int64_t size)99 static inline void copy_stub(scalar_t* out, const scalar_t* in, int64_t size) {
100 using Vec = Vectorized<scalar_t>;
101 int64_t d = 0;
102 for (; d < size - (size % Vec::size()); d += Vec::size()) {
103 Vec in_vec = Vec::loadu(in + d);
104 in_vec.store(out + d);
105 }
106 #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
107 # pragma unroll
108 #endif
109 for (; d < size; d++) {
110 out[d] = in[d];
111 }
112 }
113
114 template <typename scalar_t>
add_stub(scalar_t * grad_in,const scalar_t * grad_out,int64_t size)115 static inline void add_stub(scalar_t* grad_in, const scalar_t* grad_out, int64_t size) {
116 using Vec = Vectorized<scalar_t>;
117 int64_t d = 0;
118 for (; d < size - (size % Vec::size()); d += Vec::size()) {
119 Vec grad_vec = Vec::loadu(grad_in + d) + Vec::loadu(grad_out + d);
120 grad_vec.store(grad_in + d);
121 }
122 #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
123 # pragma unroll
124 #endif
125 for (; d < size; d++) {
126 grad_in[d] += grad_out[d];
127 }
128 }
129
130 template <typename scalar_t, typename PaddingType>
cpu_padding(const Tensor & output_,const Tensor & input_,PaddingParams & p)131 void cpu_padding(
132 const Tensor& output_,
133 const Tensor& input_,
134 PaddingParams& p) {
135
136 auto input = input_.contiguous();
137 auto output = output_.contiguous();
138
139 auto input_data = input.const_data_ptr<scalar_t>();
140 auto output_data = output.data_ptr<scalar_t>();
141
142 // fold nbatch and channels into single dimension for channels first.
143 int64_t channels = p.nbatch * p.channels;
144
145 int ndim = p.ndim;
146 int64_t input_depth = ndim == 3 ? p.ishape[ndim - 3] : 1;
147 int64_t input_height = ndim >=2 ? p.ishape[ndim - 2] : 1;
148 int64_t input_width = p.ishape[ndim - 1];
149 int64_t output_depth = ndim == 3 ? p.oshape[ndim - 3] : 1;
150 int64_t output_height = ndim >= 2 ? p.oshape[ndim - 2] : 1;
151 int64_t output_width = p.oshape[ndim - 1];
152 int64_t pad_d = ndim == 3 ? p.pads[ndim - 3] : 0;
153 int64_t pad_h = ndim >= 2 ? p.pads[ndim - 2] : 0;
154 int64_t pad_w = p.pads[ndim - 1];
155 int64_t offset_d = ndim == 3 ? p.offsets[ndim - 3] : 0;
156 int64_t offset_h = ndim >= 2 ? p.offsets[ndim - 2] : 0;
157 int64_t offset_w = p.offsets[ndim - 1];
158
159 // do vectorized copy whe output is overlapped with input on W,
160 // only applies to positive padding
161 auto loop = [=](scalar_t* out, const scalar_t* in, bool positive_padding) {
162 if (positive_padding) {
163 for (const auto ow : c10::irange(pad_w)) {
164 int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
165 out[ow] = in[iw];
166 }
167 copy_stub(out + pad_w, in, input_width);
168 for (const auto ow : c10::irange(input_width + pad_w, output_width)) {
169 int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
170 out[ow] = in[iw];
171 }
172 } else {
173 for (const auto ow : c10::irange(output_width)) {
174 int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
175 out[ow] = in[iw];
176 }
177 }
178 };
179
180 if (ndim == 1) {
181 // parallel on N,C,W
182 at::parallel_for(0, channels * output_width, 1, [&](int64_t begin, int64_t end) {
183 int64_t c{0}, ow{0};
184 data_index_init(begin, c, channels, ow, output_width);
185
186 for (const auto i : c10::irange(begin, end)) {
187 int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
188 output_data[i] = input_data[c * input_width + iw];
189 data_index_step(c, channels, ow, output_width);
190 }
191 });
192 } else if (ndim == 2) {
193 // parallel on N,C,H, vectorize on W
194 at::parallel_for(0, channels * output_height, 1, [&](int64_t begin, int64_t end) {
195 int64_t c{0}, oh{0};
196 data_index_init(begin, c, channels, oh, output_height);
197
198 for (const auto i : c10::irange(begin, end)) {
199 int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
200 scalar_t* output_ptr = output_data + i * output_width;
201 const scalar_t* input_ptr = input_data + c * input_height * input_width + ih * input_width;
202
203 loop(output_ptr, input_ptr, p.is_padding_positive_width);
204 data_index_step(c, channels, oh, output_height);
205 }
206 });
207 } else if (ndim == 3) {
208 // parallel on N,C,D,H, vectorize on W
209 at::parallel_for(0, channels * output_depth * output_height, 1, [&](int64_t begin, int64_t end) {
210 int64_t c{0}, od{0}, oh{0};
211 data_index_init(begin, c, channels, od, output_depth, oh, output_height);
212
213 for (const auto i : c10::irange(begin, end)) {
214 int64_t id = PaddingType::index(od, input_depth, pad_d, offset_d);
215 int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
216 scalar_t* output_ptr = output_data + i * output_width;
217 const scalar_t* input_ptr = input_data + c * input_depth * input_height * input_width +
218 id * input_height * input_width + ih * input_width;
219
220 loop(output_ptr, input_ptr, p.is_padding_positive_width);
221 data_index_step(c, channels, od, output_depth, oh, output_height);
222 }
223 });
224 } else {
225 TORCH_INTERNAL_ASSERT(false, "expect input dim to be 1d, 2d or 3d.");
226 }
227
228 if (!output_.is_contiguous()) {
229 output_.copy_(output);
230 }
231 }
232
233 template <typename scalar_t, typename PaddingType>
cpu_padding_channels_last(const Tensor & output_,const Tensor & input_,PaddingParams & p)234 void cpu_padding_channels_last(
235 const Tensor& output_,
236 const Tensor& input_,
237 PaddingParams& p) {
238
239 auto memory_format = p.ndim == 2
240 ? at::MemoryFormat::ChannelsLast
241 : at::MemoryFormat::ChannelsLast3d;
242
243 auto input = input_.contiguous(memory_format);
244 auto output = output_.contiguous(memory_format);
245
246 auto input_data = input.const_data_ptr<scalar_t>();
247 auto output_data = output.data_ptr<scalar_t>();
248
249 int64_t nbatch = p.nbatch;
250 int64_t channels = p.channels;
251
252 int ndim = p.ndim;
253 int64_t input_depth = ndim == 3 ? p.ishape[ndim - 3] : 1;
254 int64_t input_height = ndim >=2 ? p.ishape[ndim - 2] : 1;
255 int64_t input_width = p.ishape[ndim - 1];
256 int64_t output_depth = ndim == 3 ? p.oshape[ndim - 3] : 1;
257 int64_t output_height = ndim >= 2 ? p.oshape[ndim - 2] : 1;
258 int64_t output_width = p.oshape[ndim - 1];
259 int64_t pad_d = ndim == 3 ? p.pads[ndim - 3] : 0;
260 int64_t pad_h = ndim >= 2 ? p.pads[ndim - 2] : 0;
261 int64_t pad_w = p.pads[ndim - 1];
262 int64_t offset_d = ndim == 3 ? p.offsets[ndim - 3] : 0;
263 int64_t offset_h = ndim >= 2 ? p.offsets[ndim - 2] : 0;
264 int64_t offset_w = p.offsets[ndim - 1];
265
266 if (ndim == 2) {
267 // parallel on N,H,W, vectorize on C
268 at::parallel_for(0, nbatch * output_height * output_width, 1, [&](int64_t begin, int64_t end) {
269 int64_t n{0}, oh{0}, ow{0};
270 data_index_init(begin, n, nbatch, oh, output_height, ow, output_width);
271
272 for (const auto i : c10::irange(begin, end)) {
273 int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
274 int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
275
276 scalar_t* output_ptr = output_data + i * channels;
277 const scalar_t* input_ptr = input_data + (n * input_height * input_width + ih * input_width + iw) * channels;
278 copy_stub(output_ptr, input_ptr, channels);
279
280 data_index_step(n, nbatch, oh, output_height, ow, output_width);
281 }
282 });
283 } else if (ndim == 3) {
284 // parallel on N,D,H,W, vectorize on C
285 at::parallel_for(0, nbatch * output_depth * output_height * output_width, 1, [&](int64_t begin, int64_t end) {
286 int64_t n{0}, od{0}, oh{0}, ow{0};
287 data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
288
289 for (const auto i : c10::irange(begin, end)) {
290 int64_t id = PaddingType::index(od, input_depth, pad_d, offset_d);
291 int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
292 int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
293
294 scalar_t* output_ptr = output_data + i * channels;
295 const scalar_t* input_ptr = input_data + (n * input_depth * input_height * input_width +
296 id * input_height * input_width + ih * input_width + iw) * channels;
297 copy_stub(output_ptr, input_ptr, channels);
298
299 data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
300 }
301 });
302 } else {
303 TORCH_INTERNAL_ASSERT(false, "expect input dim to be 2d or 3d.");
304 }
305
306 if (!output_.is_contiguous(memory_format)) {
307 output_.copy_(output);
308 }
309 }
310
311 template <typename scalar_t, typename PaddingType>
cpu_padding_backward(const Tensor & grad_input_,const Tensor & grad_output_,PaddingParams & p)312 void cpu_padding_backward(
313 const Tensor& grad_input_,
314 const Tensor& grad_output_,
315 PaddingParams& p) {
316
317 auto grad_output = grad_output_.contiguous();
318 auto grad_input = grad_input_.contiguous();
319
320 auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
321 auto grad_input_data = grad_input.data_ptr<scalar_t>();
322
323 // fold nbatch and channels into single dimension for channels first.
324 int64_t channels = p.nbatch * p.channels;
325
326 int ndim = p.ndim;
327 int64_t input_depth = ndim == 3 ? p.ishape[ndim - 3] : 1;
328 int64_t input_height = ndim >=2 ? p.ishape[ndim - 2] : 1;
329 int64_t input_width = p.ishape[ndim - 1];
330 int64_t output_depth = ndim == 3 ? p.oshape[ndim - 3] : 1;
331 int64_t output_height = ndim >= 2 ? p.oshape[ndim - 2] : 1;
332 int64_t output_width = p.oshape[ndim - 1];
333 int64_t pad_d = ndim == 3 ? p.pads[ndim - 3] : 0;
334 int64_t pad_h = ndim >= 2 ? p.pads[ndim - 2] : 0;
335 int64_t pad_w = p.pads[ndim - 1];
336 int64_t offset_d = ndim == 3 ? p.offsets[ndim - 3] : 0;
337 int64_t offset_h = ndim >= 2 ? p.offsets[ndim - 2] : 0;
338 int64_t offset_w = p.offsets[ndim - 1];
339
340 if (ndim == 1) {
341 // parallel on N,C, sequential on W
342 at::parallel_for(0, channels, 1, [&](int64_t begin, int64_t end) {
343 for (const auto c : c10::irange(begin, end)) {
344 for (const auto ow : c10::irange(output_width)) {
345 int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
346 grad_input_data[c * input_width + iw] += grad_output_data[c * output_width + ow];
347 }
348 }
349 });
350 } else if (ndim == 2) {
351 // parallel on N,C, sequential on H,W
352 at::parallel_for(0, channels, 1, [&](int64_t begin, int64_t end) {
353 for (const auto c : c10::irange(begin, end)) {
354 const scalar_t* grad_output_ptr = grad_output_data + c * output_height * output_width;
355 scalar_t* grad_input_ptr = grad_input_data + c * input_height * input_width;
356
357 for (const auto oh : c10::irange(output_height)) {
358 int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
359 for (const auto ow : c10::irange(output_width)) {
360 int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
361 grad_input_ptr[ih * input_width + iw] += grad_output_ptr[oh * output_width + ow];
362 }
363 }
364 }
365 });
366 } else if (p.ndim == 3) {
367 // parallel on N,C, sequential on D,H,W
368 at::parallel_for(0, channels, 1, [&](int64_t begin, int64_t end) {
369 for (const auto c : c10::irange(begin, end)) {
370 const scalar_t* grad_output_ptr = grad_output_data + c * output_depth *output_height * output_width;
371 scalar_t* grad_input_ptr = grad_input_data + c * input_depth * input_height * input_width;
372
373 for (const auto od : c10::irange(output_depth)) {
374 int64_t id = PaddingType::index(od, input_depth, pad_d, offset_d);
375 for (const auto oh : c10::irange(output_height)) {
376 int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
377 for (const auto ow : c10::irange(output_width)) {
378 int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
379 grad_input_ptr[id * input_height * input_width + ih * input_width + iw] +=
380 grad_output_ptr[od * output_height * output_width + oh * output_width + ow];
381 }
382 }
383 }
384 }
385 });
386 } else {
387 TORCH_INTERNAL_ASSERT(false, "expect input dim to be 1d, 2d, or 3d.");
388 }
389
390 if (!grad_input_.is_contiguous()) {
391 grad_input_.copy_(grad_input);
392 }
393 }
394
395 template <typename scalar_t, typename PaddingType>
cpu_padding_backward_channels_last(const Tensor & grad_input_,const Tensor & grad_output_,PaddingParams & p)396 void cpu_padding_backward_channels_last(
397 const Tensor& grad_input_,
398 const Tensor& grad_output_,
399 PaddingParams& p) {
400
401 auto memory_format = p.ndim == 2
402 ? at::MemoryFormat::ChannelsLast
403 : at::MemoryFormat::ChannelsLast3d;
404
405 auto grad_input = grad_input_.contiguous(memory_format);
406 auto grad_output = grad_output_.contiguous(memory_format);
407
408 auto grad_input_data = grad_input.data_ptr<scalar_t>();
409 auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
410
411 int64_t nbatch = p.nbatch;
412 int64_t channels = p.channels;
413
414 int ndim = p.ndim;
415 int64_t input_depth = ndim == 3 ? p.ishape[ndim - 3] : 1;
416 int64_t input_height = ndim >=2 ? p.ishape[ndim - 2] : 1;
417 int64_t input_width = p.ishape[ndim - 1];
418 int64_t output_depth = ndim == 3 ? p.oshape[ndim - 3] : 1;
419 int64_t output_height = ndim >= 2 ? p.oshape[ndim - 2] : 1;
420 int64_t output_width = p.oshape[ndim - 1];
421 int64_t pad_d = ndim == 3 ? p.pads[ndim - 3] : 0;
422 int64_t pad_h = ndim >= 2 ? p.pads[ndim - 2] : 0;
423 int64_t pad_w = p.pads[ndim - 1];
424 int64_t offset_d = ndim == 3 ? p.offsets[ndim - 3] : 0;
425 int64_t offset_h = ndim >= 2 ? p.offsets[ndim - 2] : 0;
426 int64_t offset_w = p.offsets[ndim - 1];
427
428 if (ndim == 2) {
429 // parallel on N, sequential on H,W, vectorize on C
430 at::parallel_for(0, nbatch, 1, [&](int64_t begin, int64_t end) {
431 for (const auto n : c10::irange(begin, end)) {
432 for (const auto oh : c10::irange(output_height)) {
433 int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
434 for (const auto ow : c10::irange(output_width)) {
435 int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
436 scalar_t* grad_input_ptr = grad_input_data +
437 (n * input_height * input_width + ih * input_width + iw) * channels;
438 const scalar_t* grad_output_ptr = grad_output_data +
439 (n * output_height * output_width + oh * output_width + ow) * channels;
440 add_stub(grad_input_ptr, grad_output_ptr, channels);
441 }
442 }
443 }
444 });
445 } else if (ndim == 3) {
446 // parallel on N, sequential on D,H,W, vectorize on C
447 at::parallel_for(0, nbatch, 1, [&](int64_t begin, int64_t end) {
448 for (const auto n : c10::irange(begin, end)) {
449 for (const auto od : c10::irange(output_depth)) {
450 int64_t id = PaddingType::index(od, input_depth, pad_d, offset_d);
451 for (const auto oh : c10::irange(output_height)) {
452 int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
453 for (const auto ow : c10::irange(output_width)) {
454 int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
455 scalar_t* grad_input_ptr = grad_input_data +
456 (n * input_depth * input_height * input_width + id * input_height * input_width +
457 ih * input_width + iw) * channels;
458 const scalar_t* grad_output_ptr = grad_output_data +
459 (n * output_depth * output_height * output_width + od * output_height * output_width +
460 oh * output_width + ow) * channels;
461 add_stub(grad_input_ptr, grad_output_ptr, channels);
462 }
463 }
464 }
465 }
466 });
467 } else {
468 TORCH_INTERNAL_ASSERT(false, "expect input dim to be 2d or 3d.");
469 }
470
471 if (!grad_input_.is_contiguous(memory_format)) {
472 grad_input_.copy_(grad_input);
473 }
474 }
475
476 // non-batch mode 4d input will be considered as Contiguous in format of CDHW
padding_memory_format_3d(const Tensor & input)477 at::MemoryFormat padding_memory_format_3d(const Tensor& input) {
478 return input.dim() == 4 ? at::MemoryFormat::Contiguous : input.suggest_memory_format();
479 }
480
481 // reflection padding
reflection_pad1d_kernel_impl(const Tensor & output,const Tensor & input,IntArrayRef padding)482 void reflection_pad1d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) {
483 PaddingParams param{input, output, padding};
484 if (input.is_quantized()) {
485 AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qreflection_pad1d", [&] {
486 cpu_padding<scalar_t, ReflectionPad>(output, input, param);
487 });
488 } else {
489 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
490 "reflection_pad1d", [&] {
491 cpu_padding<scalar_t, ReflectionPad>(output, input, param);
492 });
493 }
494 }
495
reflection_pad1d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,IntArrayRef padding)496 void reflection_pad1d_backward_kernel_impl(
497 const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) {
498 PaddingParams param{grad_input, grad_output, padding};
499 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
500 "reflection_pad1d_backward", [&] {
501 cpu_padding_backward<scalar_t, ReflectionPad>(grad_input, grad_output, param);
502 });
503 }
504
reflection_pad2d_kernel_impl(const Tensor & output,const Tensor & input,IntArrayRef padding)505 void reflection_pad2d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) {
506 PaddingParams param{input, output, padding};
507 if (input.is_quantized()) {
508 // original quantized impl doesn't have channels last support,
509 // if this is intended, make a switch here.
510 AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qreflection_pad2d", [&] {
511 cpu_padding<scalar_t, ReflectionPad>(output, input, param);
512 });
513 } else {
514 switch (input.suggest_memory_format()) {
515 case at::MemoryFormat::Contiguous: {
516 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
517 "reflection_pad2d", [&] {
518 cpu_padding<scalar_t, ReflectionPad>(output, input, param);
519 });
520 break;
521 }
522 case at::MemoryFormat::ChannelsLast: {
523 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
524 "reflection_pad2d_channels_last", [&]{
525 cpu_padding_channels_last<scalar_t, ReflectionPad>(output, input, param);
526 });
527 break;
528 }
529 default:
530 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
531 }
532 }
533 }
534
reflection_pad2d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,IntArrayRef padding)535 void reflection_pad2d_backward_kernel_impl(
536 const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) {
537 PaddingParams param{grad_input, grad_output, padding};
538 switch (grad_output.suggest_memory_format()) {
539 case at::MemoryFormat::Contiguous: {
540 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
541 "reflection_pad2d_backward", [&] {
542 cpu_padding_backward<scalar_t, ReflectionPad>(grad_input, grad_output, param);
543 });
544 break;
545 }
546 case at::MemoryFormat::ChannelsLast: {
547 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
548 "reflection_pad2d_backward_channels_last", [&]{
549 cpu_padding_backward_channels_last<scalar_t, ReflectionPad>(grad_input, grad_output, param);
550 });
551 break;
552 }
553 default:
554 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
555 }
556 }
557
reflection_pad3d_kernel_impl(const Tensor & output,const Tensor & input,IntArrayRef padding)558 void reflection_pad3d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) {
559 PaddingParams param{input, output, padding};
560 switch (padding_memory_format_3d(input)) {
561 case at::MemoryFormat::Contiguous: {
562 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(),
563 "reflection_pad3d", [&] {
564 cpu_padding<scalar_t, ReflectionPad>(output, input, param);
565 });
566 break;
567 }
568 case at::MemoryFormat::ChannelsLast3d: {
569 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(),
570 "reflection_pad3d_channels_last", [&]{
571 cpu_padding_channels_last<scalar_t, ReflectionPad>(output, input, param);
572 });
573 break;
574 }
575 default:
576 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous");
577 }
578 }
579
reflection_pad3d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,IntArrayRef padding)580 void reflection_pad3d_backward_kernel_impl(
581 const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) {
582 PaddingParams param{grad_input, grad_output, padding};
583 switch (padding_memory_format_3d(grad_output)) {
584 case at::MemoryFormat::Contiguous: {
585 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(),
586 "reflection_pad3d_backward", [&] {
587 cpu_padding_backward<scalar_t, ReflectionPad>(grad_input, grad_output, param);
588 });
589 break;
590 }
591 case at::MemoryFormat::ChannelsLast3d: {
592 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(),
593 "reflection_pad3d_backward_channels_last", [&]{
594 cpu_padding_backward_channels_last<scalar_t, ReflectionPad>(grad_input, grad_output, param);
595 });
596 break;
597 }
598 default:
599 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous");
600 }
601 }
602
603 // replication padding
replication_pad1d_kernel_impl(const Tensor & output,const Tensor & input,IntArrayRef padding)604 void replication_pad1d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) {
605 PaddingParams param{input, output, padding};
606 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
607 "replication_pad1d", [&] {
608 cpu_padding<scalar_t, ReplicationPad>(output, input, param);
609 });
610 }
611
replication_pad1d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,IntArrayRef padding)612 void replication_pad1d_backward_kernel_impl(
613 const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) {
614 PaddingParams param{grad_input, grad_output, padding};
615 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
616 "replication_pad1d_backward", [&] {
617 cpu_padding_backward<scalar_t, ReplicationPad>(grad_input, grad_output, param);
618 });
619 }
620
replication_pad2d_kernel_impl(const Tensor & output,const Tensor & input,IntArrayRef padding)621 void replication_pad2d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) {
622 PaddingParams param{input, output, padding};
623 switch (input.suggest_memory_format()) {
624 case at::MemoryFormat::Contiguous: {
625 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
626 "replication_pad2d", [&] {
627 cpu_padding<scalar_t, ReplicationPad>(output, input, param);
628 });
629 break;
630 }
631 case at::MemoryFormat::ChannelsLast: {
632 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
633 "replication_pad2d_channels_last", [&]{
634 cpu_padding_channels_last<scalar_t, ReplicationPad>(output, input, param);
635 });
636 break;
637 }
638 default:
639 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
640 }
641 }
642
replication_pad2d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,IntArrayRef padding)643 void replication_pad2d_backward_kernel_impl(
644 const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) {
645 PaddingParams param{grad_input, grad_output, padding};
646 switch (grad_output.suggest_memory_format()) {
647 case at::MemoryFormat::Contiguous: {
648 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
649 "replication_pad2d_backward", [&] {
650 cpu_padding_backward<scalar_t, ReplicationPad>(grad_input, grad_output, param);
651 });
652 break;
653 }
654 case at::MemoryFormat::ChannelsLast: {
655 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
656 "replication_pad2d_backward_channels_last", [&]{
657 cpu_padding_backward_channels_last<scalar_t, ReplicationPad>(grad_input, grad_output, param);
658 });
659 break;
660 }
661 default:
662 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
663 }
664 }
665
replication_pad3d_kernel_impl(const Tensor & output,const Tensor & input,IntArrayRef padding)666 void replication_pad3d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) {
667 PaddingParams param{input, output, padding};
668 switch (padding_memory_format_3d(input)) {
669 case at::MemoryFormat::Contiguous: {
670 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
671 "replication_pad3d", [&] {
672 cpu_padding<scalar_t, ReplicationPad>(output, input, param);
673 });
674 break;
675 }
676 case at::MemoryFormat::ChannelsLast3d: {
677 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
678 "replication_pad3d_channels_last", [&]{
679 cpu_padding_channels_last<scalar_t, ReplicationPad>(output, input, param);
680 });
681 break;
682 }
683 default:
684 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous");
685 }
686 }
687
replication_pad3d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,IntArrayRef padding)688 void replication_pad3d_backward_kernel_impl(
689 const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) {
690 PaddingParams param{grad_input, grad_output, padding};
691 switch (padding_memory_format_3d(grad_output)) {
692 case at::MemoryFormat::Contiguous: {
693 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
694 "replication_pad3d_backward", [&] {
695 cpu_padding_backward<scalar_t, ReplicationPad>(grad_input, grad_output, param);
696 });
697 break;
698 }
699 case at::MemoryFormat::ChannelsLast3d: {
700 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
701 "replication_pad3d_backward_channels_last", [&]{
702 cpu_padding_backward_channels_last<scalar_t, ReplicationPad>(grad_input, grad_output, param);
703 });
704 break;
705 }
706 default:
707 TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous");
708 }
709 }
710
711 } // anonymous namespace
712
713 // reflection padding
714 REGISTER_DISPATCH(reflection_pad1d_kernel, &reflection_pad1d_kernel_impl);
715 REGISTER_DISPATCH(reflection_pad1d_backward_kernel, &reflection_pad1d_backward_kernel_impl);
716 REGISTER_DISPATCH(reflection_pad2d_kernel, &reflection_pad2d_kernel_impl);
717 REGISTER_DISPATCH(reflection_pad2d_backward_kernel, &reflection_pad2d_backward_kernel_impl);
718 REGISTER_DISPATCH(reflection_pad3d_kernel, &reflection_pad3d_kernel_impl);
719 REGISTER_DISPATCH(reflection_pad3d_backward_kernel, &reflection_pad3d_backward_kernel_impl);
720
721 // replication padding
722 REGISTER_DISPATCH(replication_pad1d_kernel, &replication_pad1d_kernel_impl);
723 REGISTER_DISPATCH(replication_pad1d_backward_kernel, &replication_pad1d_backward_kernel_impl);
724 REGISTER_DISPATCH(replication_pad2d_kernel, &replication_pad2d_kernel_impl);
725 REGISTER_DISPATCH(replication_pad2d_backward_kernel, &replication_pad2d_backward_kernel_impl);
726 REGISTER_DISPATCH(replication_pad3d_kernel, &replication_pad3d_kernel_impl);
727 REGISTER_DISPATCH(replication_pad3d_backward_kernel, &replication_pad3d_backward_kernel_impl);
728
729 } // at::native
730