xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UpSampleNearest1d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/Utils.h>
8 #include <ATen/cuda/CUDAContext.h>
9 #include <ATen/native/cuda/UpSample.cuh>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/upsample_nearest1d_native.h>
16 #include <ATen/ops/upsample_nearest1d_backward_native.h>
17 #include <ATen/ops/_upsample_nearest_exact1d_native.h>
18 #include <ATen/ops/_upsample_nearest_exact1d_backward_native.h>
19 #endif
20 
21 namespace at::native {
22 namespace {
23 
24 #define MAX_THREADS 512
25 
26 // Define a typedef to dispatch to nearest_neighbor_compute_source_index or
27 // nearest_neighbor_exact_compute_source_index
28 typedef int (*nn_compute_source_index_fn_t)(const float, int, int);
29 
30 // Define a typedef to dispatch to nearest_neighbor_bw_compute_source_index or
31 // nearest_neighbor_exact_bw_compute_source_index
32 typedef int (*nn_bw_compute_source_index_fn_t)(const float, int, int);
33 
34 
35 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
36 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
37 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest1d_out_frame(const scalar_t * input,size_t dim_b,size_t dim_c,size_t src_dim_w,size_t dst_dim_w,scalar_t * output,float scale_factor)38 __global__ void upsample_nearest1d_out_frame(
39     const scalar_t* input,
40     size_t dim_b,
41     size_t dim_c,
42     size_t src_dim_w,
43     size_t dst_dim_w,
44     scalar_t* output,
45     float scale_factor) {
46   int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
47   if (dst_idx >= dim_c * dst_dim_w)
48     return;
49 
50   int c = (dst_idx / dst_dim_w) % dim_c;
51 
52   int dst_x = dst_idx % dst_dim_w;
53   int src_x = nn_compute_source_index_fn(scale_factor, dst_x, src_dim_w);
54 
55   int src_idx = c * src_dim_w + src_x;
56   int src_stride = dim_c * src_dim_w;
57   int dst_stride = dim_c * dst_dim_w;
58 
59   for (int b = 0; b < dim_b; b++) {
60     output[dst_idx] = input[src_idx];
61     src_idx += src_stride;
62     dst_idx += dst_stride;
63   }
64 }
65 
66 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
67 // Backward operation
68 template <typename scalar_t, typename accscalar_t, nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
69 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest1d_backward_out_frame(const scalar_t * grad_o,size_t dim_b,size_t dim_c,size_t src_dim_w,size_t dst_dim_w,scalar_t * grad_i,float scale_factor)70 __global__ void upsample_nearest1d_backward_out_frame(
71     const scalar_t* grad_o,
72     size_t dim_b,
73     size_t dim_c,
74     size_t src_dim_w,
75     size_t dst_dim_w,
76     scalar_t* grad_i,
77     float scale_factor) {
78 
79   int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
80   if (dst_idx >= dim_c * dst_dim_w)
81     return;
82 
83   int c = (dst_idx / (dst_dim_w)) % dim_c;
84 
85   int dst_x = dst_idx % dst_dim_w;
86   // note that we do not want to clamp src_x to src_dim_w, since we might
87   // intentionally want to skip in case of scale_factor < 1.0
88   int src_x = nn_bw_compute_source_index_fn(scale_factor, dst_x, src_dim_w);
89   int src_x_up = nn_bw_compute_source_index_fn(scale_factor, dst_x+1, src_dim_w);
90 
91   for (int b = 0; b < dim_b; b++) {
92     accscalar_t grad = 0;
93     int src_idx = b * dim_c * src_dim_w + c * src_dim_w + src_x;
94     for (int x = src_x; x < src_x_up; x++) {
95       grad += grad_o[src_idx++];
96     }
97     grad_i[dst_idx] = grad;
98     dst_idx += dim_c * dst_dim_w;
99   }
100 }
101 
102 template<nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest1d_out_cuda_template(const Tensor & output,const Tensor & input_,IntArrayRef output_size,std::optional<double> scales)103 static void upsample_nearest1d_out_cuda_template(
104     const Tensor& output,
105     const Tensor& input_,
106     IntArrayRef output_size,
107     std::optional<double> scales) {
108   TensorArg input_arg{input_, "input_", 1}, output_arg{output, "output", 2};
109   checkAllSameGPU("upsample_nearest1d_out_cuda", {input_arg, output_arg});
110 
111   int output_width = output_size[0];
112 
113   int nbatch = input_.size(0);
114   int channels = input_.size(1);
115   int input_width = input_.size(2);
116 
117   Tensor input = input_.contiguous();
118 
119   if (input.numel() == 0) {
120     return;
121   }
122 
123   // upsample_nearest1d meta call makes sure `nbatch != 0`
124   unsigned int n = output.numel() / nbatch;
125   dim3 bdim{std::min<unsigned int>(
126       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
127   dim3 gdim{ceil_div(n, bdim.x)};
128   // safe check for int32 indexing; implicitly restrict launch config for kernel
129   TORCH_CHECK(output.numel() <= std::numeric_limits<int32_t>::max());
130 
131   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
132   AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest1d_out_frame", [&] {
133         using accscalar_t = at::acc_type<scalar_t, true>;
134 
135         auto idata = input.const_data_ptr<scalar_t>();
136         auto odata = output.mutable_data_ptr<scalar_t>();
137 
138         const float scale_factor = compute_scales_value<float>(scales, input_width, output_width);
139 
140         upsample_nearest1d_out_frame<scalar_t, nn_compute_source_index_fn><<<gdim, bdim, 0, stream>>>(
141             idata, nbatch, channels, input_width, output_width, odata, scale_factor);
142         C10_CUDA_KERNEL_LAUNCH_CHECK();
143       });
144 }
145 
146 template<nn_compute_source_index_fn_t nn_bw_compute_source_index_fn>
upsample_nearest1d_backward_out_cuda_template(const Tensor & grad_input,const Tensor & grad_output_,IntArrayRef output_size,IntArrayRef input_size,std::optional<double> scales)147 static void upsample_nearest1d_backward_out_cuda_template(
148     const Tensor& grad_input,
149     const Tensor& grad_output_,
150     IntArrayRef output_size,
151     IntArrayRef input_size,
152     std::optional<double> scales) {
153   TensorArg grad_input_arg{grad_input, "grad_input", 1},
154       grad_output_arg{grad_output_, "grad_output_", 2};
155   checkAllSameGPU(
156       "upsample_nearest1d_backward_out_cuda_template",
157       {grad_output_arg, grad_input_arg});
158 
159   int output_width = output_size[0];
160 
161   int nbatch = input_size[0];
162   int channels = input_size[1];
163   int input_width = input_size[2];
164 
165   Tensor grad_output = grad_output_.contiguous();
166 
167   if (grad_input.numel() == 0) {
168     return;
169   }
170 
171   // upsample_nearest1d meta call makes sure `nbatch != 0`
172   unsigned int n = grad_input.numel() / nbatch;
173   dim3 bdim{std::min<unsigned int>(
174       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
175   dim3 gdim{ceil_div(n, bdim.x)};
176   // safe check for int32 indexing; implicitly restrict launch config for kernel
177   TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max(),
178     "upsample_nearest1d_backward only supports input tensors with less than INT_MAX elements, but got ", grad_input.sizes());
179   TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max(),
180         "upsample_nearest1d_backward only supports output tensors with less than INT_MAX elements, but got ", grad_output.sizes());
181 
182   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
183   AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest1d_backward_out_frame", [&] {
184         using accscalar_t = at::acc_type<scalar_t, true>;
185 
186         auto idata = grad_input.mutable_data_ptr<scalar_t>();
187         auto odata = grad_output.const_data_ptr<scalar_t>();
188 
189         const float scale_factor = compute_scales_value_backwards<float>(scales, output_width, input_width);
190 
191         upsample_nearest1d_backward_out_frame<scalar_t, accscalar_t, nn_bw_compute_source_index_fn>
192             <<<gdim, bdim, 0, stream>>>(
193                 odata, nbatch, channels, output_width, input_width, idata, scale_factor);
194         C10_CUDA_KERNEL_LAUNCH_CHECK();
195       });
196 }
197 
198 } // namespace
199 
TORCH_IMPL_FUNC(upsample_nearest1d_out_cuda)200 TORCH_IMPL_FUNC(upsample_nearest1d_out_cuda) (
201     const Tensor& input,
202     IntArrayRef output_size,
203     std::optional<double> scales,
204     const Tensor& output
205 ) {
206   upsample_nearest1d_out_cuda_template<nearest_neighbor_compute_source_index>(
207       output, input, output_size, scales);
208 }
209 
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_cuda)210 TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_cuda) (
211     const Tensor& input,
212     IntArrayRef output_size,
213     std::optional<double> scales,
214     const Tensor& output
215 ) {
216   upsample_nearest1d_out_cuda_template<nearest_neighbor_exact_compute_source_index>(output, input, output_size, scales);
217 }
218 
TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_cuda)219 TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_cuda) (
220     const Tensor& grad_output,
221     IntArrayRef output_size,
222     IntArrayRef input_size,
223     std::optional<double> scales,
224     const Tensor& grad_input
225 ) {
226   upsample_nearest1d_backward_out_cuda_template<nearest_neighbor_bw_compute_source_index>(
227       grad_input, grad_output, output_size, input_size, scales);
228 }
229 
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_cuda)230 TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_cuda) (
231     const Tensor& grad_output,
232     IntArrayRef output_size,
233     IntArrayRef input_size,
234     std::optional<double> scales,
235     const Tensor& grad_input
236 ) {
237   upsample_nearest1d_backward_out_cuda_template<nearest_neighbor_exact_bw_compute_source_index>(
238       grad_input, grad_output, output_size, input_size, scales);
239 }
240 
241 } // namespace at::native
242