xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/PaddingKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/cpu/vec/vec.h>
7 #include <ATen/native/Padding.h>
8 #include <ATen/native/cpu/utils.h>
9 #include <c10/util/irange.h>
10 
11 namespace at::native {
12 
13 namespace {
14 
15 struct PaddingParams {
16   int ndim;
17   int64_t nbatch;
18   int64_t channels;
19 
20   // use vectorized logic on width when output index is in [pad, input_width + pad),
21   // applies only to Channels First format when pad_l and pad_r are both positive.
22   bool is_padding_positive_width;
23 
24   c10::SmallVector<int64_t, 3u> ishape;
25   c10::SmallVector<int64_t, 3u> oshape;
26   c10::SmallVector<int64_t, 3u> pads;
27   c10::SmallVector<int64_t, 3u> offsets;
28 
PaddingParamsat::native::__anon2f59a4fa0111::PaddingParams29   PaddingParams(const Tensor& input, const Tensor& output, IntArrayRef padding) {
30     ndim = padding.size() / 2;
31 
32     bool is_batch = input.dim() == ndim + 2;
33     nbatch = is_batch ? input.size(0) : 1;
34     channels = is_batch ? input.size(1) : input.size(0);
35 
36     is_padding_positive_width = padding[0] >= 0 && padding[1] >=0;
37 
38     // handle sizes with batch-mode and non-batch-mode
39     int ind = is_batch ? 2 : 1;
40     for (const auto d : c10::irange(ndim)) {
41       ishape.emplace_back(input.size(ind + d));
42       oshape.emplace_back(output.size(ind + d));
43     }
44 
45     // padding is organized in order of:
46     //   { left, right, top, bottom, front, back }
47     //
48     // re-organize into order of:
49     //   { depth, height, width}
50     //
51     if (ndim == 1) {
52       pads.emplace_back(padding[0]);
53     } else if (ndim == 2) {
54       pads.emplace_back(padding[2]);
55       pads.emplace_back(padding[0]);
56     } else {
57       pads.emplace_back(padding[4]);
58       pads.emplace_back(padding[2]);
59       pads.emplace_back(padding[0]);
60     }
61     for (const auto d : c10::irange(ndim)) {
62       int64_t pad = pads[d];
63       auto i_start = std::max(int64_t(0), -pad);
64       auto o_start = std::max(int64_t(0), pad);
65       offsets.emplace_back(i_start - o_start);
66     }
67   };
68 };
69 
70 struct ReflectionPad {
indexat::native::__anon2f59a4fa0111::ReflectionPad71   static int64_t index(int64_t j, int64_t size, int64_t pad, int64_t offset) {
72     int64_t i;
73     if (j < pad) {
74       i = pad * 2 - j;
75     } else if (j >= pad && j < size + pad) {
76       i = j;
77     } else {
78       i = (size + pad - 1) * 2 - j;
79     }
80     return i + offset;
81   }
82 };
83 
84 struct ReplicationPad {
indexat::native::__anon2f59a4fa0111::ReplicationPad85   static int64_t index(int64_t j, int64_t size, int64_t pad, int64_t offset) {
86     int64_t i;
87     if (j < pad) {
88       i = pad;
89     } else if (j >= pad && j < size + pad) {
90       i = j;
91     } else {
92       i = size + pad - 1;
93     }
94     return i + offset;
95   }
96 };
97 
98 template <typename scalar_t>
copy_stub(scalar_t * out,const scalar_t * in,int64_t size)99 static inline void copy_stub(scalar_t* out, const scalar_t* in, int64_t size) {
100   using Vec = Vectorized<scalar_t>;
101   int64_t d = 0;
102   for (; d < size - (size % Vec::size()); d += Vec::size()) {
103     Vec in_vec = Vec::loadu(in + d);
104     in_vec.store(out + d);
105   }
106   #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
107   # pragma unroll
108   #endif
109   for (; d < size; d++) {
110     out[d] = in[d];
111   }
112 }
113 
114 template <typename scalar_t>
add_stub(scalar_t * grad_in,const scalar_t * grad_out,int64_t size)115 static inline void add_stub(scalar_t* grad_in, const scalar_t* grad_out, int64_t size) {
116   using Vec = Vectorized<scalar_t>;
117   int64_t d = 0;
118   for (; d < size - (size % Vec::size()); d += Vec::size()) {
119     Vec grad_vec = Vec::loadu(grad_in + d) + Vec::loadu(grad_out + d);
120     grad_vec.store(grad_in + d);
121   }
122   #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
123   # pragma unroll
124   #endif
125   for (; d < size; d++) {
126     grad_in[d] += grad_out[d];
127   }
128 }
129 
130 template <typename scalar_t, typename PaddingType>
cpu_padding(const Tensor & output_,const Tensor & input_,PaddingParams & p)131 void cpu_padding(
132     const Tensor& output_,
133     const Tensor& input_,
134     PaddingParams& p) {
135 
136   auto input = input_.contiguous();
137   auto output = output_.contiguous();
138 
139   auto input_data = input.const_data_ptr<scalar_t>();
140   auto output_data = output.data_ptr<scalar_t>();
141 
142   // fold nbatch and channels into single dimension for channels first.
143   int64_t channels = p.nbatch * p.channels;
144 
145   int ndim = p.ndim;
146   int64_t input_depth = ndim == 3 ? p.ishape[ndim - 3] : 1;
147   int64_t input_height = ndim >=2 ? p.ishape[ndim - 2] : 1;
148   int64_t input_width = p.ishape[ndim - 1];
149   int64_t output_depth = ndim == 3 ? p.oshape[ndim - 3] : 1;
150   int64_t output_height = ndim >= 2 ? p.oshape[ndim - 2] : 1;
151   int64_t output_width = p.oshape[ndim - 1];
152   int64_t pad_d = ndim == 3 ? p.pads[ndim - 3] : 0;
153   int64_t pad_h = ndim >= 2 ? p.pads[ndim - 2] : 0;
154   int64_t pad_w = p.pads[ndim - 1];
155   int64_t offset_d = ndim == 3 ? p.offsets[ndim - 3] : 0;
156   int64_t offset_h = ndim >= 2 ? p.offsets[ndim - 2] : 0;
157   int64_t offset_w = p.offsets[ndim - 1];
158 
159   // do vectorized copy whe output is overlapped with input on W,
160   // only applies to positive padding
161   auto loop = [=](scalar_t* out, const scalar_t* in, bool positive_padding) {
162     if (positive_padding) {
163       for (const auto ow : c10::irange(pad_w)) {
164         int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
165         out[ow] = in[iw];
166       }
167       copy_stub(out + pad_w, in, input_width);
168       for (const auto ow : c10::irange(input_width + pad_w, output_width)) {
169         int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
170         out[ow] = in[iw];
171       }
172     } else {
173       for (const auto ow : c10::irange(output_width)) {
174         int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
175         out[ow] = in[iw];
176       }
177     }
178   };
179 
180   if (ndim == 1) {
181     // parallel on N,C,W
182     at::parallel_for(0, channels * output_width, 1, [&](int64_t begin, int64_t end) {
183       int64_t c{0}, ow{0};
184       data_index_init(begin, c, channels, ow, output_width);
185 
186       for (const auto i : c10::irange(begin, end)) {
187         int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
188         output_data[i] = input_data[c * input_width + iw];
189         data_index_step(c, channels, ow, output_width);
190       }
191     });
192   } else if (ndim == 2) {
193     // parallel on N,C,H, vectorize on W
194     at::parallel_for(0, channels * output_height, 1, [&](int64_t begin, int64_t end) {
195       int64_t c{0}, oh{0};
196       data_index_init(begin, c, channels, oh, output_height);
197 
198       for (const auto i : c10::irange(begin, end)) {
199         int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
200         scalar_t* output_ptr = output_data + i * output_width;
201         const scalar_t* input_ptr = input_data + c * input_height * input_width + ih * input_width;
202 
203         loop(output_ptr, input_ptr, p.is_padding_positive_width);
204         data_index_step(c, channels, oh, output_height);
205       }
206     });
207   } else if (ndim == 3) {
208     // parallel on N,C,D,H, vectorize on W
209     at::parallel_for(0, channels * output_depth * output_height, 1, [&](int64_t begin, int64_t end) {
210       int64_t c{0}, od{0}, oh{0};
211       data_index_init(begin, c, channels, od, output_depth, oh, output_height);
212 
213       for (const auto i : c10::irange(begin, end)) {
214         int64_t id = PaddingType::index(od, input_depth, pad_d, offset_d);
215         int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
216         scalar_t* output_ptr = output_data + i * output_width;
217         const scalar_t* input_ptr = input_data + c * input_depth * input_height * input_width +
218             id * input_height * input_width + ih * input_width;
219 
220         loop(output_ptr, input_ptr, p.is_padding_positive_width);
221         data_index_step(c, channels, od, output_depth, oh, output_height);
222       }
223     });
224   } else {
225     TORCH_INTERNAL_ASSERT(false, "expect input dim to be 1d, 2d or 3d.");
226   }
227 
228   if (!output_.is_contiguous()) {
229     output_.copy_(output);
230   }
231 }
232 
233 template <typename scalar_t, typename PaddingType>
cpu_padding_channels_last(const Tensor & output_,const Tensor & input_,PaddingParams & p)234 void cpu_padding_channels_last(
235     const Tensor& output_,
236     const Tensor& input_,
237     PaddingParams& p) {
238 
239   auto memory_format = p.ndim == 2
240       ? at::MemoryFormat::ChannelsLast
241       : at::MemoryFormat::ChannelsLast3d;
242 
243   auto input = input_.contiguous(memory_format);
244   auto output = output_.contiguous(memory_format);
245 
246   auto input_data = input.const_data_ptr<scalar_t>();
247   auto output_data = output.data_ptr<scalar_t>();
248 
249   int64_t nbatch = p.nbatch;
250   int64_t channels = p.channels;
251 
252   int ndim = p.ndim;
253   int64_t input_depth = ndim == 3 ? p.ishape[ndim - 3] : 1;
254   int64_t input_height = ndim >=2 ? p.ishape[ndim - 2] : 1;
255   int64_t input_width = p.ishape[ndim - 1];
256   int64_t output_depth = ndim == 3 ? p.oshape[ndim - 3] : 1;
257   int64_t output_height = ndim >= 2 ? p.oshape[ndim - 2] : 1;
258   int64_t output_width = p.oshape[ndim - 1];
259   int64_t pad_d = ndim == 3 ? p.pads[ndim - 3] : 0;
260   int64_t pad_h = ndim >= 2 ? p.pads[ndim - 2] : 0;
261   int64_t pad_w = p.pads[ndim - 1];
262   int64_t offset_d = ndim == 3 ? p.offsets[ndim - 3] : 0;
263   int64_t offset_h = ndim >= 2 ? p.offsets[ndim - 2] : 0;
264   int64_t offset_w = p.offsets[ndim - 1];
265 
266   if (ndim == 2) {
267     // parallel on N,H,W, vectorize on C
268     at::parallel_for(0, nbatch * output_height * output_width, 1, [&](int64_t begin, int64_t end) {
269       int64_t n{0}, oh{0}, ow{0};
270       data_index_init(begin, n, nbatch, oh, output_height, ow, output_width);
271 
272       for (const auto i : c10::irange(begin, end)) {
273         int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
274         int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
275 
276         scalar_t* output_ptr = output_data + i * channels;
277         const scalar_t* input_ptr = input_data + (n * input_height * input_width + ih * input_width + iw) * channels;
278         copy_stub(output_ptr, input_ptr, channels);
279 
280         data_index_step(n, nbatch, oh, output_height, ow, output_width);
281       }
282     });
283   } else if (ndim == 3) {
284     // parallel on N,D,H,W, vectorize on C
285     at::parallel_for(0, nbatch * output_depth * output_height * output_width, 1, [&](int64_t begin, int64_t end) {
286       int64_t n{0}, od{0}, oh{0}, ow{0};
287       data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
288 
289       for (const auto i : c10::irange(begin, end)) {
290         int64_t id = PaddingType::index(od, input_depth, pad_d, offset_d);
291         int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
292         int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
293 
294         scalar_t* output_ptr = output_data + i * channels;
295         const scalar_t* input_ptr = input_data + (n * input_depth * input_height * input_width +
296             id * input_height * input_width + ih * input_width + iw) * channels;
297         copy_stub(output_ptr, input_ptr, channels);
298 
299         data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
300       }
301     });
302   } else {
303     TORCH_INTERNAL_ASSERT(false, "expect input dim to be 2d or 3d.");
304   }
305 
306   if (!output_.is_contiguous(memory_format)) {
307     output_.copy_(output);
308   }
309 }
310 
311 template <typename scalar_t, typename PaddingType>
cpu_padding_backward(const Tensor & grad_input_,const Tensor & grad_output_,PaddingParams & p)312 void cpu_padding_backward(
313     const Tensor& grad_input_,
314     const Tensor& grad_output_,
315     PaddingParams& p) {
316 
317   auto grad_output = grad_output_.contiguous();
318   auto grad_input = grad_input_.contiguous();
319 
320   auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
321   auto grad_input_data = grad_input.data_ptr<scalar_t>();
322 
323   // fold nbatch and channels into single dimension for channels first.
324   int64_t channels = p.nbatch * p.channels;
325 
326   int ndim = p.ndim;
327   int64_t input_depth = ndim == 3 ? p.ishape[ndim - 3] : 1;
328   int64_t input_height = ndim >=2 ? p.ishape[ndim - 2] : 1;
329   int64_t input_width = p.ishape[ndim - 1];
330   int64_t output_depth = ndim == 3 ? p.oshape[ndim - 3] : 1;
331   int64_t output_height = ndim >= 2 ? p.oshape[ndim - 2] : 1;
332   int64_t output_width = p.oshape[ndim - 1];
333   int64_t pad_d = ndim == 3 ? p.pads[ndim - 3] : 0;
334   int64_t pad_h = ndim >= 2 ? p.pads[ndim - 2] : 0;
335   int64_t pad_w = p.pads[ndim - 1];
336   int64_t offset_d = ndim == 3 ? p.offsets[ndim - 3] : 0;
337   int64_t offset_h = ndim >= 2 ? p.offsets[ndim - 2] : 0;
338   int64_t offset_w = p.offsets[ndim - 1];
339 
340   if (ndim == 1) {
341     // parallel on N,C, sequential on W
342     at::parallel_for(0, channels, 1, [&](int64_t begin, int64_t end) {
343       for (const auto c : c10::irange(begin, end)) {
344         for (const auto ow : c10::irange(output_width)) {
345           int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
346           grad_input_data[c * input_width + iw] += grad_output_data[c * output_width + ow];
347         }
348       }
349     });
350   } else if (ndim == 2) {
351     // parallel on N,C, sequential on H,W
352     at::parallel_for(0, channels, 1, [&](int64_t begin, int64_t end) {
353       for (const auto c : c10::irange(begin, end)) {
354         const scalar_t* grad_output_ptr = grad_output_data + c * output_height * output_width;
355         scalar_t* grad_input_ptr = grad_input_data + c * input_height * input_width;
356 
357         for (const auto oh : c10::irange(output_height)) {
358           int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
359           for (const auto ow : c10::irange(output_width)) {
360             int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
361             grad_input_ptr[ih * input_width + iw] += grad_output_ptr[oh * output_width + ow];
362           }
363         }
364       }
365     });
366   } else if (p.ndim == 3) {
367     // parallel on N,C, sequential on D,H,W
368     at::parallel_for(0, channels, 1, [&](int64_t begin, int64_t end) {
369       for (const auto c : c10::irange(begin, end)) {
370         const scalar_t* grad_output_ptr = grad_output_data + c * output_depth *output_height * output_width;
371         scalar_t* grad_input_ptr = grad_input_data + c * input_depth * input_height * input_width;
372 
373         for (const auto od : c10::irange(output_depth)) {
374           int64_t id = PaddingType::index(od, input_depth, pad_d, offset_d);
375           for (const auto oh : c10::irange(output_height)) {
376             int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
377             for (const auto ow : c10::irange(output_width)) {
378               int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
379               grad_input_ptr[id * input_height * input_width + ih * input_width + iw] +=
380                   grad_output_ptr[od * output_height * output_width + oh * output_width + ow];
381             }
382           }
383         }
384       }
385     });
386   } else {
387     TORCH_INTERNAL_ASSERT(false, "expect input dim to be 1d, 2d, or 3d.");
388   }
389 
390   if (!grad_input_.is_contiguous()) {
391     grad_input_.copy_(grad_input);
392   }
393 }
394 
395 template <typename scalar_t, typename PaddingType>
cpu_padding_backward_channels_last(const Tensor & grad_input_,const Tensor & grad_output_,PaddingParams & p)396 void cpu_padding_backward_channels_last(
397     const Tensor& grad_input_,
398     const Tensor& grad_output_,
399     PaddingParams& p) {
400 
401   auto memory_format = p.ndim == 2
402       ? at::MemoryFormat::ChannelsLast
403       : at::MemoryFormat::ChannelsLast3d;
404 
405   auto grad_input = grad_input_.contiguous(memory_format);
406   auto grad_output = grad_output_.contiguous(memory_format);
407 
408   auto grad_input_data = grad_input.data_ptr<scalar_t>();
409   auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
410 
411   int64_t nbatch = p.nbatch;
412   int64_t channels = p.channels;
413 
414   int ndim = p.ndim;
415   int64_t input_depth = ndim == 3 ? p.ishape[ndim - 3] : 1;
416   int64_t input_height = ndim >=2 ? p.ishape[ndim - 2] : 1;
417   int64_t input_width = p.ishape[ndim - 1];
418   int64_t output_depth = ndim == 3 ? p.oshape[ndim - 3] : 1;
419   int64_t output_height = ndim >= 2 ? p.oshape[ndim - 2] : 1;
420   int64_t output_width = p.oshape[ndim - 1];
421   int64_t pad_d = ndim == 3 ? p.pads[ndim - 3] : 0;
422   int64_t pad_h = ndim >= 2 ? p.pads[ndim - 2] : 0;
423   int64_t pad_w = p.pads[ndim - 1];
424   int64_t offset_d = ndim == 3 ? p.offsets[ndim - 3] : 0;
425   int64_t offset_h = ndim >= 2 ? p.offsets[ndim - 2] : 0;
426   int64_t offset_w = p.offsets[ndim - 1];
427 
428   if (ndim == 2) {
429     // parallel on N, sequential on H,W, vectorize on C
430     at::parallel_for(0, nbatch, 1, [&](int64_t begin, int64_t end) {
431       for (const auto n : c10::irange(begin, end)) {
432         for (const auto oh : c10::irange(output_height)) {
433           int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
434           for (const auto ow : c10::irange(output_width)) {
435             int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
436             scalar_t* grad_input_ptr = grad_input_data +
437                 (n * input_height * input_width + ih * input_width + iw) * channels;
438             const scalar_t* grad_output_ptr = grad_output_data +
439                 (n * output_height * output_width + oh * output_width + ow) * channels;
440             add_stub(grad_input_ptr, grad_output_ptr, channels);
441           }
442         }
443       }
444     });
445   } else if (ndim == 3) {
446     // parallel on N, sequential on D,H,W, vectorize on C
447     at::parallel_for(0, nbatch, 1, [&](int64_t begin, int64_t end) {
448       for (const auto n : c10::irange(begin, end)) {
449         for (const auto od : c10::irange(output_depth)) {
450           int64_t id = PaddingType::index(od, input_depth, pad_d, offset_d);
451           for (const auto oh : c10::irange(output_height)) {
452             int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
453             for (const auto ow : c10::irange(output_width)) {
454               int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
455               scalar_t* grad_input_ptr = grad_input_data +
456                   (n * input_depth * input_height * input_width + id * input_height * input_width +
457                    ih * input_width + iw) * channels;
458               const scalar_t* grad_output_ptr = grad_output_data +
459                   (n * output_depth * output_height * output_width + od * output_height * output_width +
460                    oh * output_width + ow) * channels;
461               add_stub(grad_input_ptr, grad_output_ptr, channels);
462             }
463           }
464         }
465       }
466     });
467   } else {
468     TORCH_INTERNAL_ASSERT(false, "expect input dim to be 2d or 3d.");
469   }
470 
471   if (!grad_input_.is_contiguous(memory_format)) {
472     grad_input_.copy_(grad_input);
473   }
474 }
475 
476 // non-batch mode 4d input will be considered as Contiguous in format of CDHW
padding_memory_format_3d(const Tensor & input)477 at::MemoryFormat padding_memory_format_3d(const Tensor& input) {
478   return input.dim() == 4 ? at::MemoryFormat::Contiguous : input.suggest_memory_format();
479 }
480 
481 // reflection padding
reflection_pad1d_kernel_impl(const Tensor & output,const Tensor & input,IntArrayRef padding)482 void reflection_pad1d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) {
483   PaddingParams param{input, output, padding};
484   if (input.is_quantized()) {
485     AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qreflection_pad1d", [&] {
486       cpu_padding<scalar_t, ReflectionPad>(output, input, param);
487     });
488   } else {
489     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
490         "reflection_pad1d", [&] {
491       cpu_padding<scalar_t, ReflectionPad>(output, input, param);
492     });
493   }
494 }
495 
reflection_pad1d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,IntArrayRef padding)496 void reflection_pad1d_backward_kernel_impl(
497     const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) {
498   PaddingParams param{grad_input, grad_output, padding};
499   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
500       "reflection_pad1d_backward", [&] {
501     cpu_padding_backward<scalar_t, ReflectionPad>(grad_input, grad_output, param);
502   });
503 }
504 
reflection_pad2d_kernel_impl(const Tensor & output,const Tensor & input,IntArrayRef padding)505 void reflection_pad2d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) {
506   PaddingParams param{input, output, padding};
507   if (input.is_quantized()) {
508     // original quantized impl doesn't have channels last support,
509     // if this is intended, make a switch here.
510     AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qreflection_pad2d", [&] {
511       cpu_padding<scalar_t, ReflectionPad>(output, input, param);
512     });
513   } else {
514     switch (input.suggest_memory_format()) {
515       case at::MemoryFormat::Contiguous: {
516         AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
517             "reflection_pad2d", [&] {
518           cpu_padding<scalar_t, ReflectionPad>(output, input, param);
519         });
520         break;
521       }
522       case at::MemoryFormat::ChannelsLast: {
523         AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
524             "reflection_pad2d_channels_last", [&]{
525           cpu_padding_channels_last<scalar_t, ReflectionPad>(output, input, param);
526         });
527         break;
528       }
529       default:
530         TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
531     }
532   }
533 }
534 
reflection_pad2d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,IntArrayRef padding)535 void reflection_pad2d_backward_kernel_impl(
536     const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) {
537   PaddingParams param{grad_input, grad_output, padding};
538   switch (grad_output.suggest_memory_format()) {
539     case at::MemoryFormat::Contiguous: {
540       AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
541           "reflection_pad2d_backward", [&] {
542         cpu_padding_backward<scalar_t, ReflectionPad>(grad_input, grad_output, param);
543       });
544       break;
545     }
546     case at::MemoryFormat::ChannelsLast: {
547       AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
548           "reflection_pad2d_backward_channels_last", [&]{
549         cpu_padding_backward_channels_last<scalar_t, ReflectionPad>(grad_input, grad_output, param);
550       });
551       break;
552     }
553     default:
554       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
555   }
556 }
557 
reflection_pad3d_kernel_impl(const Tensor & output,const Tensor & input,IntArrayRef padding)558 void reflection_pad3d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) {
559   PaddingParams param{input, output, padding};
560   switch (padding_memory_format_3d(input)) {
561     case at::MemoryFormat::Contiguous: {
562       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(),
563           "reflection_pad3d", [&] {
564         cpu_padding<scalar_t, ReflectionPad>(output, input, param);
565       });
566       break;
567     }
568     case at::MemoryFormat::ChannelsLast3d: {
569       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(),
570           "reflection_pad3d_channels_last", [&]{
571         cpu_padding_channels_last<scalar_t, ReflectionPad>(output, input, param);
572       });
573       break;
574     }
575     default:
576       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous");
577   }
578 }
579 
reflection_pad3d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,IntArrayRef padding)580 void reflection_pad3d_backward_kernel_impl(
581     const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) {
582   PaddingParams param{grad_input, grad_output, padding};
583   switch (padding_memory_format_3d(grad_output)) {
584     case at::MemoryFormat::Contiguous: {
585       AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(),
586           "reflection_pad3d_backward", [&] {
587         cpu_padding_backward<scalar_t, ReflectionPad>(grad_input, grad_output, param);
588       });
589       break;
590     }
591     case at::MemoryFormat::ChannelsLast3d: {
592       AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(),
593           "reflection_pad3d_backward_channels_last", [&]{
594         cpu_padding_backward_channels_last<scalar_t, ReflectionPad>(grad_input, grad_output, param);
595       });
596       break;
597     }
598     default:
599       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous");
600   }
601 }
602 
603 // replication padding
replication_pad1d_kernel_impl(const Tensor & output,const Tensor & input,IntArrayRef padding)604 void replication_pad1d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) {
605   PaddingParams param{input, output, padding};
606   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
607       "replication_pad1d", [&] {
608     cpu_padding<scalar_t, ReplicationPad>(output, input, param);
609   });
610 }
611 
replication_pad1d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,IntArrayRef padding)612 void replication_pad1d_backward_kernel_impl(
613     const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) {
614   PaddingParams param{grad_input, grad_output, padding};
615   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
616       "replication_pad1d_backward", [&] {
617     cpu_padding_backward<scalar_t, ReplicationPad>(grad_input, grad_output, param);
618   });
619 }
620 
replication_pad2d_kernel_impl(const Tensor & output,const Tensor & input,IntArrayRef padding)621 void replication_pad2d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) {
622   PaddingParams param{input, output, padding};
623   switch (input.suggest_memory_format()) {
624     case at::MemoryFormat::Contiguous: {
625       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
626           "replication_pad2d", [&] {
627         cpu_padding<scalar_t, ReplicationPad>(output, input, param);
628       });
629       break;
630     }
631     case at::MemoryFormat::ChannelsLast: {
632       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
633           "replication_pad2d_channels_last", [&]{
634         cpu_padding_channels_last<scalar_t, ReplicationPad>(output, input, param);
635       });
636       break;
637     }
638     default:
639       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
640   }
641 }
642 
replication_pad2d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,IntArrayRef padding)643 void replication_pad2d_backward_kernel_impl(
644     const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) {
645   PaddingParams param{grad_input, grad_output, padding};
646   switch (grad_output.suggest_memory_format()) {
647     case at::MemoryFormat::Contiguous: {
648       AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
649           "replication_pad2d_backward", [&] {
650         cpu_padding_backward<scalar_t, ReplicationPad>(grad_input, grad_output, param);
651       });
652       break;
653     }
654     case at::MemoryFormat::ChannelsLast: {
655       AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
656           "replication_pad2d_backward_channels_last", [&]{
657         cpu_padding_backward_channels_last<scalar_t, ReplicationPad>(grad_input, grad_output, param);
658       });
659       break;
660     }
661     default:
662       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
663   }
664 }
665 
replication_pad3d_kernel_impl(const Tensor & output,const Tensor & input,IntArrayRef padding)666 void replication_pad3d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) {
667   PaddingParams param{input, output, padding};
668   switch (padding_memory_format_3d(input)) {
669     case at::MemoryFormat::Contiguous: {
670       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
671           "replication_pad3d", [&] {
672         cpu_padding<scalar_t, ReplicationPad>(output, input, param);
673       });
674       break;
675     }
676     case at::MemoryFormat::ChannelsLast3d: {
677       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(),
678           "replication_pad3d_channels_last", [&]{
679         cpu_padding_channels_last<scalar_t, ReplicationPad>(output, input, param);
680       });
681       break;
682     }
683     default:
684       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous");
685   }
686 }
687 
replication_pad3d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,IntArrayRef padding)688 void replication_pad3d_backward_kernel_impl(
689     const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) {
690   PaddingParams param{grad_input, grad_output, padding};
691   switch (padding_memory_format_3d(grad_output)) {
692     case at::MemoryFormat::Contiguous: {
693       AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
694           "replication_pad3d_backward", [&] {
695         cpu_padding_backward<scalar_t, ReplicationPad>(grad_input, grad_output, param);
696       });
697       break;
698     }
699     case at::MemoryFormat::ChannelsLast3d: {
700       AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(),
701           "replication_pad3d_backward_channels_last", [&]{
702         cpu_padding_backward_channels_last<scalar_t, ReplicationPad>(grad_input, grad_output, param);
703       });
704       break;
705     }
706     default:
707       TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous");
708   }
709 }
710 
711 } // anonymous namespace
712 
713 // reflection padding
714 REGISTER_DISPATCH(reflection_pad1d_kernel, &reflection_pad1d_kernel_impl);
715 REGISTER_DISPATCH(reflection_pad1d_backward_kernel, &reflection_pad1d_backward_kernel_impl);
716 REGISTER_DISPATCH(reflection_pad2d_kernel, &reflection_pad2d_kernel_impl);
717 REGISTER_DISPATCH(reflection_pad2d_backward_kernel, &reflection_pad2d_backward_kernel_impl);
718 REGISTER_DISPATCH(reflection_pad3d_kernel, &reflection_pad3d_kernel_impl);
719 REGISTER_DISPATCH(reflection_pad3d_backward_kernel, &reflection_pad3d_backward_kernel_impl);
720 
721 // replication padding
722 REGISTER_DISPATCH(replication_pad1d_kernel, &replication_pad1d_kernel_impl);
723 REGISTER_DISPATCH(replication_pad1d_backward_kernel, &replication_pad1d_backward_kernel_impl);
724 REGISTER_DISPATCH(replication_pad2d_kernel, &replication_pad2d_kernel_impl);
725 REGISTER_DISPATCH(replication_pad2d_backward_kernel, &replication_pad2d_backward_kernel_impl);
726 REGISTER_DISPATCH(replication_pad3d_kernel, &replication_pad3d_kernel_impl);
727 REGISTER_DISPATCH(replication_pad3d_backward_kernel, &replication_pad3d_backward_kernel_impl);
728 
729 } // at::native
730