xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UpSampleNearest2d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/Utils.h>
8 #include <ATen/cuda/CUDAContext.h>
9 #include <ATen/native/cuda/LaunchUtils.h>
10 #include <ATen/native/cuda/UpSample.cuh>
11 #include <ATen/native/cuda/KernelUtils.cuh>
12 #include <ATen/cuda/detail/KernelUtils.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/_upsample_nearest_exact2d_backward_native.h>
19 #include <ATen/ops/_upsample_nearest_exact2d_native.h>
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/upsample_nearest2d_backward_native.h>
22 #include <ATen/ops/upsample_nearest2d_native.h>
23 #endif
24 
25 namespace at::native {
26 namespace {
27 
28 #define MAX_THREADS 512
29 
30 // Define a typedef to dispatch to nearest_neighbor_compute_source_index or
31 // nearest_neighbor_exact_compute_source_index
32 typedef int (*nn_compute_source_index_fn_t)(const float, int, int);
33 
34 // Define a typedef to dispatch to nearest_neighbor_bw_compute_source_index or
35 // nearest_neighbor_exact_bw_compute_source_index
36 typedef int (*nn_bw_compute_source_index_fn_t)(const float, int, int);
37 
38 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
39 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
40 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest2d_out_frame(const scalar_t * idata,scalar_t * odata,const size_t nc,const size_t height1,const size_t width1,const size_t height2,const size_t width2,float height_scale,float width_scale)41 __global__ void upsample_nearest2d_out_frame(
42     const scalar_t* idata,
43     scalar_t* odata,
44     const size_t nc,
45     const size_t height1,
46     const size_t width1,
47     const size_t height2,
48     const size_t width2,
49     float height_scale,
50     float width_scale) {
51   size_t nc_iter = threadIdx.z + blockIdx.z * blockDim.z;
52   int w2 = threadIdx.x + blockIdx.x * blockDim.x;
53   int h2 = threadIdx.y + blockIdx.y * blockDim.y;
54 
55   if (w2 >= width2 || h2 >= height2) {
56     return;
57   }
58 
59   int nc_stride = blockDim.z * gridDim.z;
60 
61   const size_t h1 = height1 == height2
62       ? h2
63       : nn_compute_source_index_fn(height_scale, h2, height1);
64   const size_t w1 = width1 == width2
65       ? w2
66       : nn_compute_source_index_fn(width_scale, w2, width1);
67 
68   size_t src_index = (nc_iter * height1 + h1) * width1 + w1;
69   size_t src_index_stride = nc_stride * width1 * height1;
70   size_t dst_index = (nc_iter * height2 + h2) * width2 + w2;
71   size_t dst_index_stride = nc_stride * width2 * height2;
72 
73   // iterating over
74   while (nc_iter < nc) {
75     odata[dst_index] = idata[src_index];
76     dst_index += dst_index_stride;
77     src_index += src_index_stride;
78     nc_iter += nc_stride;
79   }
80 }
81 
82 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
83 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest2d_nhwc_out_frame(const scalar_t * idata,scalar_t * odata,const size_t channels,const size_t height1,const size_t width1,const size_t height2,const size_t width2,float height_scale,float width_scale,const size_t out_numel)84 __global__ void upsample_nearest2d_nhwc_out_frame(
85     const scalar_t* idata,
86     scalar_t* odata,
87     const size_t channels,
88     const size_t height1,
89     const size_t width1,
90     const size_t height2,
91     const size_t width2,
92     float height_scale,
93     float width_scale,
94     const size_t out_numel) {
95 
96   const int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
97 
98   if (index < out_numel) {
99     const auto c = index % channels;
100     const auto w2 = (index / channels) % width2;
101     const auto h2 = (index / channels / width2) % height2;
102     const auto n = index / channels / width2 / height2;
103 
104     const size_t h1 = height1 == height2 ? h2 : nn_compute_source_index_fn(height_scale, h2, height1);
105     const size_t w1 = width1 == width2 ? w2 : nn_compute_source_index_fn(width_scale, w2, width1);
106 
107     odata[index] = idata[idx_cl(n, h1, w1, c, height1, width1, channels)];
108   }
109 }
110 
111 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
112 template <typename scalar_t, typename accscalar_t, nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
113 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest2d_backward_out_frame(const scalar_t * grad_o,size_t dim_b,size_t dim_c,size_t src_dim_h,size_t src_dim_w,size_t dst_dim_h,size_t dst_dim_w,scalar_t * grad_i,float height_scale,float width_scale)114 __global__ void upsample_nearest2d_backward_out_frame(
115     const scalar_t* grad_o,
116     size_t dim_b,
117     size_t dim_c,
118     size_t src_dim_h,
119     size_t src_dim_w,
120     size_t dst_dim_h,
121     size_t dst_dim_w,
122     scalar_t* grad_i,
123     float height_scale,
124     float width_scale) {
125   int64_t dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
126   if (dst_idx >= dim_c * dst_dim_h * dst_dim_w)
127     return;
128 
129   int dst_c_stride = dst_dim_h * dst_dim_w;
130   int src_c_stride = src_dim_h * src_dim_w;
131 
132   int c = (dst_idx / (dst_c_stride)) % dim_c;
133 
134   int dst_y = (dst_idx / dst_dim_w) % dst_dim_h;
135   // note that we do not want to clamp src_y to src_dim_y, since we might
136   // intentionally want to skip in case of scale_factor < 1.0
137   int src_y =
138       nn_bw_compute_source_index_fn(height_scale, dst_y, src_dim_h);
139   int src_y_up = nn_bw_compute_source_index_fn(
140       height_scale, dst_y + 1, src_dim_h);
141 
142   int dst_x = dst_idx % dst_dim_w;
143   // note that we do not want to clamp src_x to src_dim_w, since we might
144   // intentionally want to skip in case of scale_factor < 1.0
145   int src_x =
146       nn_bw_compute_source_index_fn(width_scale, dst_x, src_dim_w);
147   int src_x_up = nn_bw_compute_source_index_fn(
148       width_scale, dst_x + 1, src_dim_w);
149 
150   for (int b = 0; b < dim_b; b++) {
151     accscalar_t grad = 0;
152     for (int y = src_y; y < src_y_up; y++) {
153       for (int x = src_x; x < src_x_up; x++) {
154         int64_t src_idx =
155             b * dim_c * src_c_stride + c * src_c_stride + y * src_dim_w + x;
156         grad += grad_o[src_idx];
157       }
158     }
159     grad_i[dst_idx] = grad;
160     dst_idx += dim_c * dst_c_stride;
161   }
162 }
163 
164 template <typename scalar_t, typename accscalar_t, nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
165 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest2d_backward_nhwc_out_frame(const scalar_t * go,scalar_t * gi,const size_t height1,const size_t width1,const size_t height2,const size_t width2,const size_t channels,const float height_scale,const float width_scale,const size_t gi_numel)166 __global__ void upsample_nearest2d_backward_nhwc_out_frame(
167     const scalar_t* go,
168     scalar_t* gi,
169     const size_t height1,
170     const size_t width1,
171     const size_t height2,
172     const size_t width2,
173     const size_t channels,
174     const float height_scale,
175     const float width_scale,
176     const size_t gi_numel) {
177 
178   // 1 is for grad_output (src)
179   // 2 is for grad_input (dst)
180 
181   const int index = blockIdx.x * blockDim.x + threadIdx.x;
182 
183   if (index < gi_numel) {
184     const int c = index % channels;
185     const int w2 = (index / channels) % width2;
186     const int h2 = (index / channels / width2) % height2;
187     const int n = index / channels / width2 / height2;
188 
189     int h1 = nn_bw_compute_source_index_fn(height_scale, h2, height1);
190     int h1_up = nn_bw_compute_source_index_fn(height_scale, h2 + 1, height1);
191 
192     int w1 = nn_bw_compute_source_index_fn(width_scale, w2, width1);
193     int w1_up = nn_bw_compute_source_index_fn(width_scale, w2 + 1, width1);
194 
195     accscalar_t grad = 0;
196     for (int ih = h1; ih < h1_up; ih++) {
197       for (int iw = w1; iw < w1_up; iw++) {
198         grad += go[idx_cl(n, ih, iw, c, height1, width1, channels)];
199       }
200     }
201     gi[index] = static_cast<scalar_t>(grad);
202   }
203 }
204 
205 template<nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest2d_out_cuda_template(const Tensor & output,const Tensor & input_,IntArrayRef output_size,std::optional<double> scales_h,std::optional<double> scales_w)206 static void upsample_nearest2d_out_cuda_template(
207     const Tensor& output,
208     const Tensor& input_,
209     IntArrayRef output_size,
210     std::optional<double> scales_h,
211     std::optional<double> scales_w) {
212   TensorArg input_arg{input_, "input_", 1}, output_arg{output, "output", 2};
213   checkAllSameGPU(__func__, {input_arg, output_arg});
214 
215   if (input_.numel() == 0) {
216     return;
217   }
218 
219   int output_height = output_size[0];
220   int output_width = output_size[1];
221 
222   int nbatch = input_.size(0);
223   int channels = input_.size(1);
224   int input_height = input_.size(2);
225   int input_width = input_.size(3);
226 
227   const float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
228   const float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
229 
230   const auto memory_format = input_.suggest_memory_format();
231 
232   if (input_.sizes() == output.sizes()) {
233     output.copy_(input_);
234     return;
235   }
236 
237   // heuristic: only use channels_last path when it's faster than the contiguous path
238   if (memory_format == at::MemoryFormat::ChannelsLast && channels >= 4 && \
239         output.is_contiguous(memory_format)) {
240     at::Tensor input = input_.contiguous(at::MemoryFormat::ChannelsLast);
241 
242     TORCH_CHECK(input.numel() < std::numeric_limits<int64_t>::max(),
243       "upsample_nearest_nhwc only supports input tensors with less than 2^63 - 1 elements, but got ", input.sizes());
244     TORCH_CHECK(output.numel() < std::numeric_limits<int64_t>::max(),
245       "upsample_nearest_nhwc only supports output tensors with less than 2^63 - 1 elements, but got ", output.sizes());
246 
247     const int64_t num_kernels = output.numel();
248     const int64_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
249 
250     AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_nhwc_out_frame", [&] {
251       const scalar_t* idata = input.const_data_ptr<scalar_t>();
252       scalar_t* odata = output.mutable_data_ptr<scalar_t>();
253 
254       upsample_nearest2d_nhwc_out_frame<scalar_t, nn_compute_source_index_fn>
255         <<<ceil_div(num_kernels, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
256           idata,
257           odata,
258           channels,
259           input_height,
260           input_width,
261           output_height,
262           output_width,
263           height_scale,
264           width_scale,
265           output.numel()
266       );
267       C10_CUDA_KERNEL_LAUNCH_CHECK();
268     });
269   }
270   else {
271     // This is needed for non-contiguous tensors.
272     Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
273     Tensor input = input_.contiguous();
274 
275     int nc = nbatch * channels;
276 
277     const int max_threads = std::min<int>(
278         at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS);
279 
280     int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
281     int* maxGridSize = at::cuda::getCurrentDeviceProperties()->maxGridSize;
282 
283     // upsample_nearest2d meta call makes sure input/output tensor is not empty;
284     int block_x = std::min<int>(
285         maxThreadsDim[0], std::min<int>(lastPow2(output_width), max_threads));
286     int block_y = std::min<int>(
287         maxThreadsDim[1],
288         std::min<int>(lastPow2(output_height), max_threads / block_x));
289     int block_z = std::min<int>(
290         maxThreadsDim[2], std::min<int>(nc, max_threads / block_x / block_y));
291     const dim3 block(block_x, block_y, block_z);
292 
293     int grid_x = ceil_div(output_width, block_x);
294     int grid_y = ceil_div(output_height, block_y);
295     int grid_z = std::min<int>(
296         maxGridSize[2], ceil_div(nc, block_z * 4));
297     const dim3 grid(grid_x, grid_y, grid_z);
298     // Error out on cases where grid_x & grid_y exceeds limit of launch config, as
299     // the current kernel implementation doesn't loop over the two dimensions.
300     // This is unlikely to happen.
301     // TODO: kernel implementation could stride on spatial dimension. We probably
302     //       need to overhaul the kernel.
303     TORCH_CHECK(
304         grid_x <= maxGridSize[0] && grid_y <= maxGridSize[1],
305         "input tensor has spatial dimension larger than the kernel capacity");
306 
307     cudaStream_t stream = at::cuda::getCurrentCUDAStream();
308     AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_out_frame", [&] {
309           using accscalar_t = at::acc_type<scalar_t, true>;
310 
311           auto idata = input.const_data_ptr<scalar_t>();
312           auto odata = output_c.mutable_data_ptr<scalar_t>();
313 
314           upsample_nearest2d_out_frame<scalar_t, nn_compute_source_index_fn>
315               <<<grid, block, 0, stream>>>(
316                   idata,
317                   odata,
318                   nc,
319                   input_height,
320                   input_width,
321                   output_height,
322                   output_width,
323                   height_scale,
324                   width_scale);
325           C10_CUDA_KERNEL_LAUNCH_CHECK();
326         });
327 
328     if (!output.is_contiguous()) {
329         output.copy_(output_c);
330     }
331   }
332 }
333 
334 template<nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
upsample_nearest2d_backward_out_cuda_template(const Tensor & grad_input,const Tensor & grad_output_,IntArrayRef output_size,IntArrayRef input_size,std::optional<double> scales_h,std::optional<double> scales_w)335 static void upsample_nearest2d_backward_out_cuda_template(
336     const Tensor& grad_input,
337     const Tensor& grad_output_,
338     IntArrayRef output_size,
339     IntArrayRef input_size,
340     std::optional<double> scales_h,
341     std::optional<double> scales_w) {
342   TensorArg grad_input_arg{grad_input, "grad_input", 1},
343       grad_output_arg{grad_output_, "grad_output_", 2};
344   checkAllSameGPU(__func__, {grad_output_arg, grad_input_arg});
345 
346   if (grad_input.numel() == 0) {
347     return;
348   }
349 
350   int output_height = output_size[0];
351   int output_width = output_size[1];
352 
353   int nbatch = input_size[0];
354   int channels = input_size[1];
355   int input_height = input_size[2];
356   int input_width = input_size[3];
357 
358   const float height_scale = compute_scales_value_backwards<float>(scales_h, output_height, input_height);
359   const float width_scale = compute_scales_value_backwards<float>(scales_w, output_width, input_width);
360 
361   auto memory_format = grad_output_.suggest_memory_format();
362 
363   if (grad_output_.sizes() == grad_input.sizes()) {
364     grad_input.copy_(grad_output_);
365     return;
366   }
367 
368   if (memory_format == at::MemoryFormat::ChannelsLast && channels >= 4 && \
369         grad_input.is_contiguous(memory_format)) {
370     Tensor grad_output = grad_output_.contiguous(at::MemoryFormat::ChannelsLast);
371 
372     TORCH_CHECK(grad_input.numel() < std::numeric_limits<int>::max(),
373       "upsample_nearest_nhwc only supports grad_input tensors with less than INT_MAX elements, but got ", grad_input.sizes());
374     TORCH_CHECK(grad_output.numel() < std::numeric_limits<int>::max(),
375       "upsample_nearest_nhwc only supports grad_output tensors with less than INT_MAX elements, but got ", grad_output.sizes());
376 
377     const int num_kernels = grad_input.numel();
378     const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
379 
380     AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest2d_backward_nhwc_out_frame", [&] {
381       using accscalar_t = at::acc_type<scalar_t, true>;
382 
383       const scalar_t* go = grad_output.const_data_ptr<scalar_t>();
384       scalar_t* gi = grad_input.mutable_data_ptr<scalar_t>();
385 
386       upsample_nearest2d_backward_nhwc_out_frame<scalar_t, accscalar_t, nn_bw_compute_source_index_fn>
387         <<<ceil_div(num_kernels, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
388           go,
389           gi,
390           output_height,
391           output_width,
392           input_height,
393           input_width,
394           channels,
395           height_scale,
396           width_scale,
397           grad_input.numel()
398       );
399       C10_CUDA_KERNEL_LAUNCH_CHECK();
400     });
401   } else {
402     // This is needed for non-contiguous tensors.
403     Tensor grad_input_c = grad_input.is_contiguous() ? grad_input : at::empty(grad_input.sizes(), grad_input.options());
404     Tensor grad_output = grad_output_.contiguous();
405 
406     // upsample_nearest2d meta call makes sure `nbatch != 0`
407     unsigned int n = grad_input.numel() / nbatch;
408     dim3 bdim{std::min<unsigned int>(
409         at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
410     dim3 gdim{ceil_div(n, bdim.x)};
411     // safe check for int64 indexing; implicitly restrict launch config for kernel
412     TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int64_t>::max(), "upsample2d grad_input.numel() <= std::numeric_limits<int64_t>::max(), but got ", grad_input.sizes());
413     TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int64_t>::max(), "upsample2d grad_output.numel() <= std::numeric_limits<int64_t>::max(), but got ", grad_output.sizes());
414 
415     cudaStream_t stream = at::cuda::getCurrentCUDAStream();
416     AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest2d_backward_out_frame", [&] {
417       using accscalar_t = at::acc_type<scalar_t, true>;
418 
419       auto idata = grad_input_c.mutable_data_ptr<scalar_t>();
420       auto odata = grad_output.const_data_ptr<scalar_t>();
421 
422 
423       upsample_nearest2d_backward_out_frame<scalar_t, accscalar_t, nn_bw_compute_source_index_fn>
424           <<<gdim, bdim, 0, stream>>>(
425               odata,
426               nbatch,
427               channels,
428               output_height,
429               output_width,
430               input_height,
431               input_width,
432               idata,
433               height_scale,
434               width_scale);
435       C10_CUDA_KERNEL_LAUNCH_CHECK();
436     });
437 
438     if (!grad_input.is_contiguous()) {
439         grad_input.copy_(grad_input_c);
440     }
441   }
442 }
443 
444 } // namespace
445 
TORCH_IMPL_FUNC(upsample_nearest2d_out_cuda)446 TORCH_IMPL_FUNC(upsample_nearest2d_out_cuda) (
447     const Tensor& input,
448     IntArrayRef output_size,
449     std::optional<double> scales_h,
450     std::optional<double> scales_w,
451     const Tensor& output) {
452   upsample_nearest2d_out_cuda_template<nearest_neighbor_compute_source_index>(
453       output, input, output_size, scales_h, scales_w);
454 }
455 
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_cuda)456 TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_cuda) (
457     const Tensor& input,
458     IntArrayRef output_size,
459     std::optional<double> scales_h,
460     std::optional<double> scales_w,
461     const Tensor& output) {
462   upsample_nearest2d_out_cuda_template<nearest_neighbor_exact_compute_source_index>(
463       output, input, output_size, scales_h, scales_w);
464 }
465 
TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_cuda)466 TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_cuda) (
467     const Tensor& grad_output,
468     IntArrayRef output_size,
469     IntArrayRef input_size,
470     std::optional<double> scales_h,
471     std::optional<double> scales_w,
472     const Tensor& grad_input) {
473   upsample_nearest2d_backward_out_cuda_template<nearest_neighbor_bw_compute_source_index>(
474       grad_input, grad_output, output_size, input_size, scales_h, scales_w);
475 }
476 
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_cuda)477 TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_cuda) (
478     const Tensor& grad_output,
479     IntArrayRef output_size,
480     IntArrayRef input_size,
481     std::optional<double> scales_h,
482     std::optional<double> scales_w,
483     const Tensor& grad_input) {
484   upsample_nearest2d_backward_out_cuda_template<nearest_neighbor_exact_bw_compute_source_index>(
485       grad_input, grad_output, output_size, input_size, scales_h, scales_w);
486 }
487 
488 } // namespace at::native
489