xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/Utils.h>
8 #include <ATen/cuda/CUDAContext.h>
9 #include <ATen/native/cuda/UpSample.cuh>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/upsample_bicubic2d_native.h>
16 #include <ATen/ops/upsample_bicubic2d_backward_native.h>
17 #endif
18 
19 namespace at::native {
20 namespace {
21 
22 template <typename scalar_t, typename accscalar_t>
23 C10_LAUNCH_BOUNDS_1(1024)
upsample_bicubic2d_out_frame(const int num_elements,const accscalar_t height_scale,const accscalar_t width_scale,const bool align_corners,const PackedTensorAccessor64<const scalar_t,4> idata,PackedTensorAccessor64<scalar_t,4> odata)24 __global__ void upsample_bicubic2d_out_frame(
25     const int num_elements,
26     const accscalar_t height_scale,
27     const accscalar_t width_scale,
28     const bool align_corners,
29     const PackedTensorAccessor64<const scalar_t, 4> idata,
30     PackedTensorAccessor64<scalar_t, 4> odata) {
31   int index = threadIdx.x + blockIdx.x * blockDim.x;
32 
33   const int batchsize = idata.size(0);
34   const int channels = idata.size(1);
35   const int input_height = idata.size(2);
36   const int input_width = idata.size(3);
37   const int output_height = odata.size(2);
38   const int output_width = odata.size(3);
39 
40   if (index >= num_elements) {
41     return;
42   }
43 
44   // Special case: input and output are the same size, just copy
45   const int output_x = index % output_width;
46   const int output_y = index / output_width;
47 
48   if (input_height == output_height && input_width == output_width) {
49     for (int n = 0; n < batchsize; n++) {
50       for (int c = 0; c < channels; c++) {
51         const scalar_t val = idata[n][c][output_y][output_x];
52         odata[n][c][output_y][output_x] = val;
53       }
54     }
55     return;
56   }
57 
58   // Interpolation kernel
59   accscalar_t real_x = area_pixel_compute_source_index(
60       width_scale, output_x, align_corners, /*cubic=*/true);
61   int in_x = floorf(real_x);
62   accscalar_t t_x = real_x - in_x;
63 
64   accscalar_t real_y = area_pixel_compute_source_index(
65       height_scale, output_y, align_corners, /*cubic=*/true);
66   int in_y = floorf(real_y);
67   accscalar_t t_y = real_y - in_y;
68 
69   for (int n = 0; n < batchsize; n++) {
70     for (int c = 0; c < channels; c++) {
71       accscalar_t coefficients[4];
72 
73       for (int k = 0; k < 4; k++) {
74         coefficients[k] = cubic_interp1d(
75             upsample_get_value_bounded<scalar_t>(
76                 idata, n, c, input_height, input_width, in_y - 1 + k, in_x - 1),
77             upsample_get_value_bounded<scalar_t>(
78                 idata, n, c, input_height, input_width, in_y - 1 + k, in_x + 0),
79             upsample_get_value_bounded<scalar_t>(
80                 idata, n, c, input_height, input_width, in_y - 1 + k, in_x + 1),
81             upsample_get_value_bounded<scalar_t>(
82                 idata, n, c, input_height, input_width, in_y - 1 + k, in_x + 2),
83             t_x);
84       }
85 
86       odata[n][c][output_y][output_x] = static_cast<scalar_t>(cubic_interp1d(
87           coefficients[0],
88           coefficients[1],
89           coefficients[2],
90           coefficients[3],
91           t_y));
92     }
93   }
94 }
95 
96 // Backward (adjoint) operation 1 <- 2 (accumulates)
97 template <typename scalar_t, typename accscalar_t>
98 C10_LAUNCH_BOUNDS_1(1024)
upsample_bicubic2d_backward_out_frame(const int num_elements,const accscalar_t height_scale,const accscalar_t width_scale,const bool align_corners,PackedTensorAccessor64<scalar_t,4> idata,const PackedTensorAccessor64<const scalar_t,4> odata)99 __global__ void upsample_bicubic2d_backward_out_frame(
100     const int num_elements,
101     const accscalar_t height_scale,
102     const accscalar_t width_scale,
103     const bool align_corners,
104     PackedTensorAccessor64<scalar_t, 4> idata,
105     const PackedTensorAccessor64<const scalar_t, 4> odata) {
106   int index = threadIdx.x + blockIdx.x * blockDim.x;
107 
108   const int batchsize = idata.size(0);
109   const int channels = idata.size(1);
110   const int input_height = idata.size(2);
111   const int input_width = idata.size(3);
112   const int output_height = odata.size(2);
113   const int output_width = odata.size(3);
114 
115   if (index >= num_elements) {
116     return;
117   }
118 
119   const int output_x = index % output_width;
120   const int output_y = index / output_width;
121   // special case: output_xust copy
122   if (input_height == output_height && input_width == output_width) {
123     for (int n = 0; n < batchsize; n++) {
124       for (int c = 0; c < channels; ++c) {
125         const scalar_t val = odata[n][c][output_y][output_x];
126         idata[n][c][output_y][output_x] = val;
127       }
128     }
129     return;
130   }
131 
132   accscalar_t real_x = area_pixel_compute_source_index(
133       width_scale, output_x, align_corners, /*cubic=*/true);
134   int input_x = floorf(real_x);
135   accscalar_t t_x = real_x - input_x;
136 
137   accscalar_t real_y = area_pixel_compute_source_index(
138       height_scale, output_y, align_corners, /*cubic=*/true);
139   int input_y = floorf(real_y);
140   accscalar_t t_y = real_y - input_y;
141 
142   accscalar_t x_coeffs[4];
143   accscalar_t y_coeffs[4];
144 
145   get_cubic_upsampling_coefficients(x_coeffs, t_x);
146   get_cubic_upsampling_coefficients(y_coeffs, t_y);
147 
148   for (int n = 0; n < batchsize; n++) {
149     for (int c = 0; c < channels; ++c) {
150       scalar_t out_value = odata[n][c][output_y][output_x];
151       for (int i = 0; i < 4; i++) {
152         for (int j = 0; j < 4; j++) {
153           upsample_increment_value_bounded<scalar_t, accscalar_t>(
154               idata,
155               n,
156               c,
157               input_height,
158               input_width,
159               input_y - 1 + i,
160               input_x - 1 + j,
161               out_value * y_coeffs[i] * x_coeffs[j]);
162         }
163       }
164     }
165   }
166 }
167 
upsample_bicubic2d_out_cuda_template(const Tensor & output,const Tensor & input,IntArrayRef output_size,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)168 static void upsample_bicubic2d_out_cuda_template(
169     const Tensor& output,
170     const Tensor& input,
171     IntArrayRef output_size,
172     bool align_corners,
173     std::optional<double> scales_h,
174     std::optional<double> scales_w) {
175   TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2};
176   checkAllSameGPU(__func__, {input_arg, output_arg});
177 
178   int output_height = output_size[0];
179   int output_width = output_size[1];
180 
181   int input_height = input.size(2);
182   int input_width = input.size(3);
183 
184   output.zero_();
185 
186   const int num_output_elements = output_height * output_width;
187   const int max_threads = std::min(
188       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
189 
190   // Launch kernel
191   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
192 
193   AT_DISPATCH_FLOATING_TYPES_AND2(
194       at::ScalarType::Half, at::ScalarType::BFloat16,
195       input.scalar_type(), "upsample_bicubic2d_out_frame", [&] {
196         using accscalar_t = at::acc_type<scalar_t, true>;
197 
198         auto idata = input.packed_accessor64<const scalar_t, 4>();
199         auto odata = output.packed_accessor64<scalar_t, 4>();
200 
201         // Get scaling factors
202         const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
203             input_height, output_height, align_corners, scales_h);
204         const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
205             input_width, output_width, align_corners, scales_w);
206 
207         upsample_bicubic2d_out_frame<scalar_t, accscalar_t>
208             <<<ceil_div(num_output_elements, max_threads),
209                max_threads,
210                0,
211                stream>>>(
212                 num_output_elements,
213                 rheight,
214                 rwidth,
215                 align_corners,
216                 idata,
217                 odata);
218         C10_CUDA_KERNEL_LAUNCH_CHECK();
219       });
220 }
221 
upsample_bicubic2d_backward_out_cuda_template(const Tensor & grad_input,const Tensor & grad_output_,IntArrayRef output_size,IntArrayRef input_size,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)222 static void upsample_bicubic2d_backward_out_cuda_template(
223     const Tensor& grad_input,
224     const Tensor& grad_output_,
225     IntArrayRef output_size,
226     IntArrayRef input_size,
227     bool align_corners,
228     std::optional<double> scales_h,
229     std::optional<double> scales_w) {
230   TensorArg grad_input_arg{grad_input, "grad_input", 1},
231       grad_output_arg{grad_output_, "grad_output_", 2};
232   checkAllSameGPU(__func__, {grad_output_arg, grad_input_arg});
233 
234   int output_height = output_size[0];
235   int output_width = output_size[1];
236 
237   int input_height = input_size[2];
238   int input_width = input_size[3];
239 
240   Tensor grad_output = grad_output_.contiguous();
241 
242   grad_input.zero_();
243 
244   const int num_kernels = output_height * output_width;
245   const int num_threads = std::min(
246       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
247   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
248 
249   AT_DISPATCH_FLOATING_TYPES_AND2(
250       at::ScalarType::Half, at::ScalarType::BFloat16,
251       grad_output.scalar_type(), "upsample_bicubic2d_backward_out_frame", [&] {
252         using accscalar_t = at::acc_type<scalar_t, true>;
253 
254         auto idata = grad_input.packed_accessor64<scalar_t, 4>();
255         auto odata = grad_output.packed_accessor64<const scalar_t, 4>();
256 
257         const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
258             input_height, output_height, align_corners, scales_h);
259         const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
260             input_width, output_width, align_corners, scales_w);
261 
262         upsample_bicubic2d_backward_out_frame<scalar_t, accscalar_t>
263             <<<ceil_div(num_kernels, num_threads),
264                num_threads,
265                0,
266                stream>>>(
267                 num_kernels, rheight, rwidth, align_corners, idata, odata);
268         C10_CUDA_KERNEL_LAUNCH_CHECK();
269       });
270 }
271 
272 } // namespace
273 
TORCH_IMPL_FUNC(upsample_bicubic2d_out_cuda)274 TORCH_IMPL_FUNC(upsample_bicubic2d_out_cuda) (
275     const Tensor& input,
276     IntArrayRef output_size,
277     bool align_corners,
278     std::optional<double> scales_h,
279     std::optional<double> scales_w,
280     const Tensor& output) {
281   upsample_bicubic2d_out_cuda_template(output, input, output_size, align_corners, scales_h, scales_w);
282 }
283 
TORCH_IMPL_FUNC(upsample_bicubic2d_backward_out_cuda)284 TORCH_IMPL_FUNC(upsample_bicubic2d_backward_out_cuda) (
285     const Tensor& grad_output,
286     IntArrayRef output_size,
287     IntArrayRef input_size,
288     bool align_corners,
289     std::optional<double> scales_h,
290     std::optional<double> scales_w,
291     const Tensor& grad_input) {
292   // See Note [Writing Nondeterministic Operations]
293   // Nondeterministic because of atomicAdd usage
294   globalContext().alertNotDeterministic("upsample_bicubic2d_backward_out_cuda");
295   upsample_bicubic2d_backward_out_cuda_template(
296       grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w);
297 }
298 
299 } // namespace at::native
300