1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/Utils.h>
8 #include <ATen/cuda/CUDAContext.h>
9 #include <ATen/native/cuda/LaunchUtils.h>
10 #include <ATen/native/cuda/UpSample.cuh>
11 #include <ATen/native/cuda/KernelUtils.cuh>
12 #include <ATen/cuda/detail/KernelUtils.h>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/_upsample_nearest_exact2d_backward_native.h>
19 #include <ATen/ops/_upsample_nearest_exact2d_native.h>
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/upsample_nearest2d_backward_native.h>
22 #include <ATen/ops/upsample_nearest2d_native.h>
23 #endif
24
25 namespace at::native {
26 namespace {
27
28 #define MAX_THREADS 512
29
30 // Define a typedef to dispatch to nearest_neighbor_compute_source_index or
31 // nearest_neighbor_exact_compute_source_index
32 typedef int (*nn_compute_source_index_fn_t)(const float, int, int);
33
34 // Define a typedef to dispatch to nearest_neighbor_bw_compute_source_index or
35 // nearest_neighbor_exact_bw_compute_source_index
36 typedef int (*nn_bw_compute_source_index_fn_t)(const float, int, int);
37
38 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
39 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
40 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest2d_out_frame(const scalar_t * idata,scalar_t * odata,const size_t nc,const size_t height1,const size_t width1,const size_t height2,const size_t width2,float height_scale,float width_scale)41 __global__ void upsample_nearest2d_out_frame(
42 const scalar_t* idata,
43 scalar_t* odata,
44 const size_t nc,
45 const size_t height1,
46 const size_t width1,
47 const size_t height2,
48 const size_t width2,
49 float height_scale,
50 float width_scale) {
51 size_t nc_iter = threadIdx.z + blockIdx.z * blockDim.z;
52 int w2 = threadIdx.x + blockIdx.x * blockDim.x;
53 int h2 = threadIdx.y + blockIdx.y * blockDim.y;
54
55 if (w2 >= width2 || h2 >= height2) {
56 return;
57 }
58
59 int nc_stride = blockDim.z * gridDim.z;
60
61 const size_t h1 = height1 == height2
62 ? h2
63 : nn_compute_source_index_fn(height_scale, h2, height1);
64 const size_t w1 = width1 == width2
65 ? w2
66 : nn_compute_source_index_fn(width_scale, w2, width1);
67
68 size_t src_index = (nc_iter * height1 + h1) * width1 + w1;
69 size_t src_index_stride = nc_stride * width1 * height1;
70 size_t dst_index = (nc_iter * height2 + h2) * width2 + w2;
71 size_t dst_index_stride = nc_stride * width2 * height2;
72
73 // iterating over
74 while (nc_iter < nc) {
75 odata[dst_index] = idata[src_index];
76 dst_index += dst_index_stride;
77 src_index += src_index_stride;
78 nc_iter += nc_stride;
79 }
80 }
81
82 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
83 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest2d_nhwc_out_frame(const scalar_t * idata,scalar_t * odata,const size_t channels,const size_t height1,const size_t width1,const size_t height2,const size_t width2,float height_scale,float width_scale,const size_t out_numel)84 __global__ void upsample_nearest2d_nhwc_out_frame(
85 const scalar_t* idata,
86 scalar_t* odata,
87 const size_t channels,
88 const size_t height1,
89 const size_t width1,
90 const size_t height2,
91 const size_t width2,
92 float height_scale,
93 float width_scale,
94 const size_t out_numel) {
95
96 const int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
97
98 if (index < out_numel) {
99 const auto c = index % channels;
100 const auto w2 = (index / channels) % width2;
101 const auto h2 = (index / channels / width2) % height2;
102 const auto n = index / channels / width2 / height2;
103
104 const size_t h1 = height1 == height2 ? h2 : nn_compute_source_index_fn(height_scale, h2, height1);
105 const size_t w1 = width1 == width2 ? w2 : nn_compute_source_index_fn(width_scale, w2, width1);
106
107 odata[index] = idata[idx_cl(n, h1, w1, c, height1, width1, channels)];
108 }
109 }
110
111 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
112 template <typename scalar_t, typename accscalar_t, nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
113 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest2d_backward_out_frame(const scalar_t * grad_o,size_t dim_b,size_t dim_c,size_t src_dim_h,size_t src_dim_w,size_t dst_dim_h,size_t dst_dim_w,scalar_t * grad_i,float height_scale,float width_scale)114 __global__ void upsample_nearest2d_backward_out_frame(
115 const scalar_t* grad_o,
116 size_t dim_b,
117 size_t dim_c,
118 size_t src_dim_h,
119 size_t src_dim_w,
120 size_t dst_dim_h,
121 size_t dst_dim_w,
122 scalar_t* grad_i,
123 float height_scale,
124 float width_scale) {
125 int64_t dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
126 if (dst_idx >= dim_c * dst_dim_h * dst_dim_w)
127 return;
128
129 int dst_c_stride = dst_dim_h * dst_dim_w;
130 int src_c_stride = src_dim_h * src_dim_w;
131
132 int c = (dst_idx / (dst_c_stride)) % dim_c;
133
134 int dst_y = (dst_idx / dst_dim_w) % dst_dim_h;
135 // note that we do not want to clamp src_y to src_dim_y, since we might
136 // intentionally want to skip in case of scale_factor < 1.0
137 int src_y =
138 nn_bw_compute_source_index_fn(height_scale, dst_y, src_dim_h);
139 int src_y_up = nn_bw_compute_source_index_fn(
140 height_scale, dst_y + 1, src_dim_h);
141
142 int dst_x = dst_idx % dst_dim_w;
143 // note that we do not want to clamp src_x to src_dim_w, since we might
144 // intentionally want to skip in case of scale_factor < 1.0
145 int src_x =
146 nn_bw_compute_source_index_fn(width_scale, dst_x, src_dim_w);
147 int src_x_up = nn_bw_compute_source_index_fn(
148 width_scale, dst_x + 1, src_dim_w);
149
150 for (int b = 0; b < dim_b; b++) {
151 accscalar_t grad = 0;
152 for (int y = src_y; y < src_y_up; y++) {
153 for (int x = src_x; x < src_x_up; x++) {
154 int64_t src_idx =
155 b * dim_c * src_c_stride + c * src_c_stride + y * src_dim_w + x;
156 grad += grad_o[src_idx];
157 }
158 }
159 grad_i[dst_idx] = grad;
160 dst_idx += dim_c * dst_c_stride;
161 }
162 }
163
164 template <typename scalar_t, typename accscalar_t, nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
165 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest2d_backward_nhwc_out_frame(const scalar_t * go,scalar_t * gi,const size_t height1,const size_t width1,const size_t height2,const size_t width2,const size_t channels,const float height_scale,const float width_scale,const size_t gi_numel)166 __global__ void upsample_nearest2d_backward_nhwc_out_frame(
167 const scalar_t* go,
168 scalar_t* gi,
169 const size_t height1,
170 const size_t width1,
171 const size_t height2,
172 const size_t width2,
173 const size_t channels,
174 const float height_scale,
175 const float width_scale,
176 const size_t gi_numel) {
177
178 // 1 is for grad_output (src)
179 // 2 is for grad_input (dst)
180
181 const int index = blockIdx.x * blockDim.x + threadIdx.x;
182
183 if (index < gi_numel) {
184 const int c = index % channels;
185 const int w2 = (index / channels) % width2;
186 const int h2 = (index / channels / width2) % height2;
187 const int n = index / channels / width2 / height2;
188
189 int h1 = nn_bw_compute_source_index_fn(height_scale, h2, height1);
190 int h1_up = nn_bw_compute_source_index_fn(height_scale, h2 + 1, height1);
191
192 int w1 = nn_bw_compute_source_index_fn(width_scale, w2, width1);
193 int w1_up = nn_bw_compute_source_index_fn(width_scale, w2 + 1, width1);
194
195 accscalar_t grad = 0;
196 for (int ih = h1; ih < h1_up; ih++) {
197 for (int iw = w1; iw < w1_up; iw++) {
198 grad += go[idx_cl(n, ih, iw, c, height1, width1, channels)];
199 }
200 }
201 gi[index] = static_cast<scalar_t>(grad);
202 }
203 }
204
205 template<nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest2d_out_cuda_template(const Tensor & output,const Tensor & input_,IntArrayRef output_size,std::optional<double> scales_h,std::optional<double> scales_w)206 static void upsample_nearest2d_out_cuda_template(
207 const Tensor& output,
208 const Tensor& input_,
209 IntArrayRef output_size,
210 std::optional<double> scales_h,
211 std::optional<double> scales_w) {
212 TensorArg input_arg{input_, "input_", 1}, output_arg{output, "output", 2};
213 checkAllSameGPU(__func__, {input_arg, output_arg});
214
215 if (input_.numel() == 0) {
216 return;
217 }
218
219 int output_height = output_size[0];
220 int output_width = output_size[1];
221
222 int nbatch = input_.size(0);
223 int channels = input_.size(1);
224 int input_height = input_.size(2);
225 int input_width = input_.size(3);
226
227 const float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
228 const float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
229
230 const auto memory_format = input_.suggest_memory_format();
231
232 if (input_.sizes() == output.sizes()) {
233 output.copy_(input_);
234 return;
235 }
236
237 // heuristic: only use channels_last path when it's faster than the contiguous path
238 if (memory_format == at::MemoryFormat::ChannelsLast && channels >= 4 && \
239 output.is_contiguous(memory_format)) {
240 at::Tensor input = input_.contiguous(at::MemoryFormat::ChannelsLast);
241
242 TORCH_CHECK(input.numel() < std::numeric_limits<int64_t>::max(),
243 "upsample_nearest_nhwc only supports input tensors with less than 2^63 - 1 elements, but got ", input.sizes());
244 TORCH_CHECK(output.numel() < std::numeric_limits<int64_t>::max(),
245 "upsample_nearest_nhwc only supports output tensors with less than 2^63 - 1 elements, but got ", output.sizes());
246
247 const int64_t num_kernels = output.numel();
248 const int64_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
249
250 AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_nhwc_out_frame", [&] {
251 const scalar_t* idata = input.const_data_ptr<scalar_t>();
252 scalar_t* odata = output.mutable_data_ptr<scalar_t>();
253
254 upsample_nearest2d_nhwc_out_frame<scalar_t, nn_compute_source_index_fn>
255 <<<ceil_div(num_kernels, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
256 idata,
257 odata,
258 channels,
259 input_height,
260 input_width,
261 output_height,
262 output_width,
263 height_scale,
264 width_scale,
265 output.numel()
266 );
267 C10_CUDA_KERNEL_LAUNCH_CHECK();
268 });
269 }
270 else {
271 // This is needed for non-contiguous tensors.
272 Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
273 Tensor input = input_.contiguous();
274
275 int nc = nbatch * channels;
276
277 const int max_threads = std::min<int>(
278 at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS);
279
280 int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
281 int* maxGridSize = at::cuda::getCurrentDeviceProperties()->maxGridSize;
282
283 // upsample_nearest2d meta call makes sure input/output tensor is not empty;
284 int block_x = std::min<int>(
285 maxThreadsDim[0], std::min<int>(lastPow2(output_width), max_threads));
286 int block_y = std::min<int>(
287 maxThreadsDim[1],
288 std::min<int>(lastPow2(output_height), max_threads / block_x));
289 int block_z = std::min<int>(
290 maxThreadsDim[2], std::min<int>(nc, max_threads / block_x / block_y));
291 const dim3 block(block_x, block_y, block_z);
292
293 int grid_x = ceil_div(output_width, block_x);
294 int grid_y = ceil_div(output_height, block_y);
295 int grid_z = std::min<int>(
296 maxGridSize[2], ceil_div(nc, block_z * 4));
297 const dim3 grid(grid_x, grid_y, grid_z);
298 // Error out on cases where grid_x & grid_y exceeds limit of launch config, as
299 // the current kernel implementation doesn't loop over the two dimensions.
300 // This is unlikely to happen.
301 // TODO: kernel implementation could stride on spatial dimension. We probably
302 // need to overhaul the kernel.
303 TORCH_CHECK(
304 grid_x <= maxGridSize[0] && grid_y <= maxGridSize[1],
305 "input tensor has spatial dimension larger than the kernel capacity");
306
307 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
308 AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_out_frame", [&] {
309 using accscalar_t = at::acc_type<scalar_t, true>;
310
311 auto idata = input.const_data_ptr<scalar_t>();
312 auto odata = output_c.mutable_data_ptr<scalar_t>();
313
314 upsample_nearest2d_out_frame<scalar_t, nn_compute_source_index_fn>
315 <<<grid, block, 0, stream>>>(
316 idata,
317 odata,
318 nc,
319 input_height,
320 input_width,
321 output_height,
322 output_width,
323 height_scale,
324 width_scale);
325 C10_CUDA_KERNEL_LAUNCH_CHECK();
326 });
327
328 if (!output.is_contiguous()) {
329 output.copy_(output_c);
330 }
331 }
332 }
333
334 template<nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
upsample_nearest2d_backward_out_cuda_template(const Tensor & grad_input,const Tensor & grad_output_,IntArrayRef output_size,IntArrayRef input_size,std::optional<double> scales_h,std::optional<double> scales_w)335 static void upsample_nearest2d_backward_out_cuda_template(
336 const Tensor& grad_input,
337 const Tensor& grad_output_,
338 IntArrayRef output_size,
339 IntArrayRef input_size,
340 std::optional<double> scales_h,
341 std::optional<double> scales_w) {
342 TensorArg grad_input_arg{grad_input, "grad_input", 1},
343 grad_output_arg{grad_output_, "grad_output_", 2};
344 checkAllSameGPU(__func__, {grad_output_arg, grad_input_arg});
345
346 if (grad_input.numel() == 0) {
347 return;
348 }
349
350 int output_height = output_size[0];
351 int output_width = output_size[1];
352
353 int nbatch = input_size[0];
354 int channels = input_size[1];
355 int input_height = input_size[2];
356 int input_width = input_size[3];
357
358 const float height_scale = compute_scales_value_backwards<float>(scales_h, output_height, input_height);
359 const float width_scale = compute_scales_value_backwards<float>(scales_w, output_width, input_width);
360
361 auto memory_format = grad_output_.suggest_memory_format();
362
363 if (grad_output_.sizes() == grad_input.sizes()) {
364 grad_input.copy_(grad_output_);
365 return;
366 }
367
368 if (memory_format == at::MemoryFormat::ChannelsLast && channels >= 4 && \
369 grad_input.is_contiguous(memory_format)) {
370 Tensor grad_output = grad_output_.contiguous(at::MemoryFormat::ChannelsLast);
371
372 TORCH_CHECK(grad_input.numel() < std::numeric_limits<int>::max(),
373 "upsample_nearest_nhwc only supports grad_input tensors with less than INT_MAX elements, but got ", grad_input.sizes());
374 TORCH_CHECK(grad_output.numel() < std::numeric_limits<int>::max(),
375 "upsample_nearest_nhwc only supports grad_output tensors with less than INT_MAX elements, but got ", grad_output.sizes());
376
377 const int num_kernels = grad_input.numel();
378 const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
379
380 AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest2d_backward_nhwc_out_frame", [&] {
381 using accscalar_t = at::acc_type<scalar_t, true>;
382
383 const scalar_t* go = grad_output.const_data_ptr<scalar_t>();
384 scalar_t* gi = grad_input.mutable_data_ptr<scalar_t>();
385
386 upsample_nearest2d_backward_nhwc_out_frame<scalar_t, accscalar_t, nn_bw_compute_source_index_fn>
387 <<<ceil_div(num_kernels, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
388 go,
389 gi,
390 output_height,
391 output_width,
392 input_height,
393 input_width,
394 channels,
395 height_scale,
396 width_scale,
397 grad_input.numel()
398 );
399 C10_CUDA_KERNEL_LAUNCH_CHECK();
400 });
401 } else {
402 // This is needed for non-contiguous tensors.
403 Tensor grad_input_c = grad_input.is_contiguous() ? grad_input : at::empty(grad_input.sizes(), grad_input.options());
404 Tensor grad_output = grad_output_.contiguous();
405
406 // upsample_nearest2d meta call makes sure `nbatch != 0`
407 unsigned int n = grad_input.numel() / nbatch;
408 dim3 bdim{std::min<unsigned int>(
409 at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
410 dim3 gdim{ceil_div(n, bdim.x)};
411 // safe check for int64 indexing; implicitly restrict launch config for kernel
412 TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int64_t>::max(), "upsample2d grad_input.numel() <= std::numeric_limits<int64_t>::max(), but got ", grad_input.sizes());
413 TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int64_t>::max(), "upsample2d grad_output.numel() <= std::numeric_limits<int64_t>::max(), but got ", grad_output.sizes());
414
415 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
416 AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest2d_backward_out_frame", [&] {
417 using accscalar_t = at::acc_type<scalar_t, true>;
418
419 auto idata = grad_input_c.mutable_data_ptr<scalar_t>();
420 auto odata = grad_output.const_data_ptr<scalar_t>();
421
422
423 upsample_nearest2d_backward_out_frame<scalar_t, accscalar_t, nn_bw_compute_source_index_fn>
424 <<<gdim, bdim, 0, stream>>>(
425 odata,
426 nbatch,
427 channels,
428 output_height,
429 output_width,
430 input_height,
431 input_width,
432 idata,
433 height_scale,
434 width_scale);
435 C10_CUDA_KERNEL_LAUNCH_CHECK();
436 });
437
438 if (!grad_input.is_contiguous()) {
439 grad_input.copy_(grad_input_c);
440 }
441 }
442 }
443
444 } // namespace
445
TORCH_IMPL_FUNC(upsample_nearest2d_out_cuda)446 TORCH_IMPL_FUNC(upsample_nearest2d_out_cuda) (
447 const Tensor& input,
448 IntArrayRef output_size,
449 std::optional<double> scales_h,
450 std::optional<double> scales_w,
451 const Tensor& output) {
452 upsample_nearest2d_out_cuda_template<nearest_neighbor_compute_source_index>(
453 output, input, output_size, scales_h, scales_w);
454 }
455
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_cuda)456 TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_cuda) (
457 const Tensor& input,
458 IntArrayRef output_size,
459 std::optional<double> scales_h,
460 std::optional<double> scales_w,
461 const Tensor& output) {
462 upsample_nearest2d_out_cuda_template<nearest_neighbor_exact_compute_source_index>(
463 output, input, output_size, scales_h, scales_w);
464 }
465
TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_cuda)466 TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_cuda) (
467 const Tensor& grad_output,
468 IntArrayRef output_size,
469 IntArrayRef input_size,
470 std::optional<double> scales_h,
471 std::optional<double> scales_w,
472 const Tensor& grad_input) {
473 upsample_nearest2d_backward_out_cuda_template<nearest_neighbor_bw_compute_source_index>(
474 grad_input, grad_output, output_size, input_size, scales_h, scales_w);
475 }
476
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_cuda)477 TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_cuda) (
478 const Tensor& grad_output,
479 IntArrayRef output_size,
480 IntArrayRef input_size,
481 std::optional<double> scales_h,
482 std::optional<double> scales_w,
483 const Tensor& grad_input) {
484 upsample_nearest2d_backward_out_cuda_template<nearest_neighbor_exact_bw_compute_source_index>(
485 grad_input, grad_output, output_size, input_size, scales_h, scales_w);
486 }
487
488 } // namespace at::native
489