xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UpSampleNearest3d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/UpSample.cuh>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/ceil_div.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/Utils.h>
10 #include <ATen/cuda/CUDAContext.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/upsample_nearest3d.h>
18 #include <ATen/ops/upsample_nearest3d_native.h>
19 #include <ATen/ops/upsample_nearest3d_backward.h>
20 #include <ATen/ops/upsample_nearest3d_backward_native.h>
21 #include <ATen/ops/_upsample_nearest_exact3d.h>
22 #include <ATen/ops/_upsample_nearest_exact3d_native.h>
23 #include <ATen/ops/_upsample_nearest_exact3d_backward.h>
24 #include <ATen/ops/_upsample_nearest_exact3d_backward_native.h>
25 #endif
26 
27 namespace at::native {
28 namespace {
29 
30 #define MAX_THREADS 512
31 
32 // Define a typedef to dispatch to nearest_neighbor_compute_source_index or
33 // nearest_neighbor_exact_compute_source_index
34 typedef int (*nn_compute_source_index_fn_t)(const float, int, int);
35 
36 // Define a typedef to dispatch to nearest_neighbor_bw_compute_source_index or
37 // nearest_neighbor_exact_bw_compute_source_index
38 typedef int (*nn_bw_compute_source_index_fn_t)(const float, int, int);
39 
40 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
41 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
42 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest3d_out_frame(const scalar_t * input,size_t dim_b,size_t dim_c,size_t src_dim_d,size_t src_dim_h,size_t src_dim_w,size_t dst_dim_d,size_t dst_dim_h,size_t dst_dim_w,scalar_t * output,float depth_scale,float height_scale,float width_scale)43 __global__ void upsample_nearest3d_out_frame(
44     const scalar_t* input,
45     size_t dim_b,
46     size_t dim_c,
47     size_t src_dim_d,
48     size_t src_dim_h,
49     size_t src_dim_w,
50     size_t dst_dim_d,
51     size_t dst_dim_h,
52     size_t dst_dim_w,
53     scalar_t* output,
54     float depth_scale,
55     float height_scale,
56     float width_scale) {
57 
58   int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
59   if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
60     return;
61 
62   int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
63   int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
64 
65   int c = (dst_idx / (dst_c_stride)) % dim_c;
66 
67   int dst_z = (dst_idx / dst_dim_h / dst_dim_w) % dst_dim_d;
68   int src_z = nn_compute_source_index_fn(depth_scale, dst_z, src_dim_d);
69   int dst_y = (dst_idx / dst_dim_w) % dst_dim_h;
70   int src_y = nn_compute_source_index_fn(height_scale, dst_y, src_dim_h);
71 
72   int dst_x = dst_idx % dst_dim_w;
73   int src_x = nn_compute_source_index_fn(width_scale, dst_x, src_dim_w);
74 
75   int src_idx = c * src_c_stride + src_z * src_dim_h * src_dim_w +
76       src_y * src_dim_w + src_x;
77   for (int b = 0; b < dim_b; b++) {
78     output[dst_idx] = input[src_idx];
79     src_idx += dim_c * src_c_stride;
80     dst_idx += dim_c * dst_c_stride;
81   }
82 }
83 
84 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
85 // Backward operation
86 template <typename scalar_t, typename accscalar_t, nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
87 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest3d_backward_out_frame(const scalar_t * grad_o,size_t dim_b,size_t dim_c,size_t src_dim_d,size_t src_dim_h,size_t src_dim_w,size_t dst_dim_d,size_t dst_dim_h,size_t dst_dim_w,scalar_t * grad_i,float depth_scale,float height_scale,float width_scale)88 __global__ void upsample_nearest3d_backward_out_frame(
89     const scalar_t* grad_o,
90     size_t dim_b,
91     size_t dim_c,
92     size_t src_dim_d,
93     size_t src_dim_h,
94     size_t src_dim_w,
95     size_t dst_dim_d,
96     size_t dst_dim_h,
97     size_t dst_dim_w,
98     scalar_t* grad_i,
99     float depth_scale,
100     float height_scale,
101     float width_scale) {
102 
103   int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
104   if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
105     return;
106 
107   int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
108   int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
109 
110   int c = (dst_idx / (dst_c_stride)) % dim_c;
111 
112   int dst_z = (dst_idx / dst_dim_h / dst_dim_w) % dst_dim_d;
113   // note that we do not want to clamp src_z to src_dim_z, since we might
114   // intentionally want to skip in case of scale_factor < 1.0
115   int src_z = nn_bw_compute_source_index_fn(depth_scale, dst_z, src_dim_d);
116   int src_z_up = nn_bw_compute_source_index_fn(depth_scale, dst_z+1, src_dim_d);
117 
118   int dst_y = (dst_idx / dst_dim_w) % dst_dim_h;
119   // note that we do not want to clamp src_y to src_dim_y, since we might
120   // intentionally want to skip in case of scale_factor < 1.0
121   int src_y = nn_bw_compute_source_index_fn(height_scale, dst_y, src_dim_h);
122   int src_y_up = nn_bw_compute_source_index_fn(height_scale, dst_y+1, src_dim_h);
123 
124   int dst_x = dst_idx % dst_dim_w;
125   // note that we do not want to clamp src_x to src_dim_w, since we might
126   // intentionally want to skip in case of scale_factor < 1.0
127   int src_x = nn_bw_compute_source_index_fn(width_scale, dst_x, src_dim_w);
128   int src_x_up = nn_bw_compute_source_index_fn(width_scale, dst_x+1, src_dim_w);
129 
130   for (int b = 0; b < dim_b; b++) {
131     accscalar_t grad = 0;
132     for (int z = src_z; z < src_z_up; z++) {
133       for (int y = src_y; y < src_y_up; y++) {
134         for (int x = src_x; x < src_x_up; x++) {
135           int src_idx = b * dim_c * src_c_stride + c * src_c_stride +
136               z * src_dim_h * src_dim_w + y * src_dim_w + x;
137           grad += grad_o[src_idx];
138         }
139       }
140     }
141     grad_i[dst_idx] = grad;
142     dst_idx += dim_c * dst_c_stride;
143   }
144 }
145 
146 template<nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest3d_out_cuda_template(const Tensor & output,const Tensor & input_,IntArrayRef output_size,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)147 static void upsample_nearest3d_out_cuda_template(
148     const Tensor& output,
149     const Tensor& input_,
150     IntArrayRef output_size,
151     std::optional<double> scales_d,
152     std::optional<double> scales_h,
153     std::optional<double> scales_w) {
154   TensorArg input_arg{input_, "input_", 1}, output_arg{output, "output", 2};
155   checkAllSameGPU(__func__, {input_arg, output_arg});
156 
157   // TODO: remove this when the cuda kernel is updated to support the channels_last memory format.
158   // This is a temporary hack to prevent a silence correctness issue when calling this kernel
159   // with tensors in channels_last format.
160   auto output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
161 
162   int output_depth = output_size[0];
163   int output_height = output_size[1];
164   int output_width = output_size[2];
165 
166   int nbatch = input_.size(0);
167   int channels = input_.size(1);
168   int input_depth = input_.size(2);
169   int input_height = input_.size(3);
170   int input_width = input_.size(4);
171 
172   Tensor input = input_.contiguous();
173 
174   if (input.numel() == 0) {
175     return;
176   }
177 
178   // upsample_nearest3d meta call makes sure `nbatch != 0`
179   unsigned int n = output.numel() / nbatch;
180   dim3 bdim{std::min<unsigned int>(
181       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
182   dim3 gdim{ceil_div(n, bdim.x)};
183   // safe check for int32 indexing; implicitly restrict launch config for kernel
184   TORCH_CHECK(output.numel() <= std::numeric_limits<int32_t>::max(),
185         "upsample_nearest3d only supports output tensors with less than INT_MAX elements, but got ", output.sizes());
186 
187   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
188   AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte,input.scalar_type(), "upsample_nearest3d_out_frame", [&] {
189         using accscalar_t = at::acc_type<scalar_t, true>;
190 
191         auto idata = input.const_data_ptr<scalar_t>();
192         auto odata = output_c.mutable_data_ptr<scalar_t>();
193 
194         const float depth_scale = compute_scales_value<float>(scales_d, input_depth, output_depth);
195         const float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
196         const float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
197 
198         upsample_nearest3d_out_frame<scalar_t, nn_compute_source_index_fn>
199           <<<gdim, bdim, 0, stream>>>(
200             idata,
201             nbatch,
202             channels,
203             input_depth,
204             input_height,
205             input_width,
206             output_depth,
207             output_height,
208             output_width,
209             odata,
210             depth_scale,
211             height_scale,
212             width_scale);
213         C10_CUDA_KERNEL_LAUNCH_CHECK();
214       });
215 
216   if (!output.is_contiguous()) {
217       output.copy_(output_c);
218   }
219 }
220 
221 template<nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
upsample_nearest3d_backward_out_cuda_template(const Tensor & grad_input,const Tensor & grad_output_,IntArrayRef output_size,IntArrayRef input_size,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)222 static void upsample_nearest3d_backward_out_cuda_template(
223     const Tensor& grad_input,
224     const Tensor& grad_output_,
225     IntArrayRef output_size,
226     IntArrayRef input_size,
227     std::optional<double> scales_d,
228     std::optional<double> scales_h,
229     std::optional<double> scales_w) {
230   TensorArg grad_input_arg{grad_input, "grad_input", 1},
231       grad_output_arg{grad_output_, "grad_output_", 2};
232   checkAllSameGPU(
233       __func__,
234       {grad_output_arg, grad_input_arg});
235 
236   int output_depth = output_size[0];
237   int output_height = output_size[1];
238   int output_width = output_size[2];
239 
240   int nbatch = input_size[0];
241   int channels = input_size[1];
242   int input_depth = input_size[2];
243   int input_height = input_size[3];
244   int input_width = input_size[4];
245 
246   Tensor grad_output = grad_output_.contiguous();
247 
248   if (grad_input.numel() == 0) {
249     return;
250   }
251 
252   // upsample_nearest3d meta call makes sure `nbatch != 0`
253   unsigned int n = grad_input.numel() / nbatch;
254   dim3 bdim{std::min<unsigned int>(
255       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
256   dim3 gdim{ceil_div(n, bdim.x)};
257   // safe check for int32 indexing; implicitly restrict launch config for kernel
258   TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max(),
259     "upsample_nearest3d_backward only supports input tensors with less than INT_MAX elements, but got ", grad_input.sizes());
260   TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max(),
261     "upsample_nearest3d_backward only supports output tensors with less than INT_MAX elements, but got ", grad_output.sizes());
262 
263   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
264   AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest3d_backward_out_frame", [&] {
265         using accscalar_t = at::acc_type<scalar_t, true>;
266 
267         auto idata = grad_input.mutable_data_ptr<scalar_t>();
268         auto odata = grad_output.const_data_ptr<scalar_t>();
269 
270         float depth_scale = compute_scales_value_backwards<float>(scales_d, output_depth, input_depth);
271         float height_scale = compute_scales_value_backwards<float>(scales_h, output_height, input_height);
272         float width_scale = compute_scales_value_backwards<float>(scales_w, output_width, input_width);
273 
274         upsample_nearest3d_backward_out_frame<scalar_t, accscalar_t, nn_bw_compute_source_index_fn>
275             <<<gdim, bdim, 0, stream>>>(
276                 odata,
277                 nbatch,
278                 channels,
279                 output_depth,
280                 output_height,
281                 output_width,
282                 input_depth,
283                 input_height,
284                 input_width,
285                 idata,
286                 depth_scale,
287                 height_scale,
288                 width_scale);
289         C10_CUDA_KERNEL_LAUNCH_CHECK();
290       });
291 }
292 
293 } // namespace
294 
TORCH_IMPL_FUNC(upsample_nearest3d_out_cuda)295 TORCH_IMPL_FUNC(upsample_nearest3d_out_cuda) (
296     const Tensor& input,
297     IntArrayRef output_size,
298     std::optional<double> scales_d,
299     std::optional<double> scales_h,
300     std::optional<double> scales_w,
301     const Tensor& output) {
302   upsample_nearest3d_out_cuda_template<nearest_neighbor_compute_source_index>(
303       output, input, output_size, scales_d, scales_h, scales_w);
304 }
305 
TORCH_IMPL_FUNC(_upsample_nearest_exact3d_out_cuda)306 TORCH_IMPL_FUNC(_upsample_nearest_exact3d_out_cuda) (
307     const Tensor& input,
308     IntArrayRef output_size,
309     std::optional<double> scales_d,
310     std::optional<double> scales_h,
311     std::optional<double> scales_w,
312     const Tensor& output) {
313   upsample_nearest3d_out_cuda_template<nearest_neighbor_exact_compute_source_index>(output, input, output_size, scales_d, scales_h, scales_w);
314 }
315 
TORCH_IMPL_FUNC(upsample_nearest3d_backward_out_cuda)316 TORCH_IMPL_FUNC(upsample_nearest3d_backward_out_cuda) (
317     const Tensor& grad_output,
318     IntArrayRef output_size,
319     IntArrayRef input_size,
320     std::optional<double> scales_d,
321     std::optional<double> scales_h,
322     std::optional<double> scales_w,
323     const Tensor& grad_input) {
324   upsample_nearest3d_backward_out_cuda_template<nearest_neighbor_bw_compute_source_index>(
325       grad_input, grad_output, output_size, input_size, scales_d, scales_h, scales_w);
326 }
327 
TORCH_IMPL_FUNC(_upsample_nearest_exact3d_backward_out_cuda)328 TORCH_IMPL_FUNC(_upsample_nearest_exact3d_backward_out_cuda) (
329     const Tensor& grad_output,
330     IntArrayRef output_size,
331     IntArrayRef input_size,
332     std::optional<double> scales_d,
333     std::optional<double> scales_h,
334     std::optional<double> scales_w,
335     const Tensor& grad_input) {
336   upsample_nearest3d_backward_out_cuda_template<nearest_neighbor_exact_bw_compute_source_index>(
337       grad_input, grad_output, output_size, input_size, scales_d, scales_h, scales_w);
338 }
339 
340 using at::native::upsample::compute_output_size;
341 using at::native::upsample_cuda::get_scale_value;
342 
343 } // namespace at::native
344