1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/Utils.h>
8 #include <ATen/cuda/CUDAContext.h>
9 #include <ATen/native/cuda/UpSample.cuh>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/upsample_bicubic2d_native.h>
16 #include <ATen/ops/upsample_bicubic2d_backward_native.h>
17 #endif
18
19 namespace at::native {
20 namespace {
21
22 template <typename scalar_t, typename accscalar_t>
23 C10_LAUNCH_BOUNDS_1(1024)
upsample_bicubic2d_out_frame(const int num_elements,const accscalar_t height_scale,const accscalar_t width_scale,const bool align_corners,const PackedTensorAccessor64<const scalar_t,4> idata,PackedTensorAccessor64<scalar_t,4> odata)24 __global__ void upsample_bicubic2d_out_frame(
25 const int num_elements,
26 const accscalar_t height_scale,
27 const accscalar_t width_scale,
28 const bool align_corners,
29 const PackedTensorAccessor64<const scalar_t, 4> idata,
30 PackedTensorAccessor64<scalar_t, 4> odata) {
31 int index = threadIdx.x + blockIdx.x * blockDim.x;
32
33 const int batchsize = idata.size(0);
34 const int channels = idata.size(1);
35 const int input_height = idata.size(2);
36 const int input_width = idata.size(3);
37 const int output_height = odata.size(2);
38 const int output_width = odata.size(3);
39
40 if (index >= num_elements) {
41 return;
42 }
43
44 // Special case: input and output are the same size, just copy
45 const int output_x = index % output_width;
46 const int output_y = index / output_width;
47
48 if (input_height == output_height && input_width == output_width) {
49 for (int n = 0; n < batchsize; n++) {
50 for (int c = 0; c < channels; c++) {
51 const scalar_t val = idata[n][c][output_y][output_x];
52 odata[n][c][output_y][output_x] = val;
53 }
54 }
55 return;
56 }
57
58 // Interpolation kernel
59 accscalar_t real_x = area_pixel_compute_source_index(
60 width_scale, output_x, align_corners, /*cubic=*/true);
61 int in_x = floorf(real_x);
62 accscalar_t t_x = real_x - in_x;
63
64 accscalar_t real_y = area_pixel_compute_source_index(
65 height_scale, output_y, align_corners, /*cubic=*/true);
66 int in_y = floorf(real_y);
67 accscalar_t t_y = real_y - in_y;
68
69 for (int n = 0; n < batchsize; n++) {
70 for (int c = 0; c < channels; c++) {
71 accscalar_t coefficients[4];
72
73 for (int k = 0; k < 4; k++) {
74 coefficients[k] = cubic_interp1d(
75 upsample_get_value_bounded<scalar_t>(
76 idata, n, c, input_height, input_width, in_y - 1 + k, in_x - 1),
77 upsample_get_value_bounded<scalar_t>(
78 idata, n, c, input_height, input_width, in_y - 1 + k, in_x + 0),
79 upsample_get_value_bounded<scalar_t>(
80 idata, n, c, input_height, input_width, in_y - 1 + k, in_x + 1),
81 upsample_get_value_bounded<scalar_t>(
82 idata, n, c, input_height, input_width, in_y - 1 + k, in_x + 2),
83 t_x);
84 }
85
86 odata[n][c][output_y][output_x] = static_cast<scalar_t>(cubic_interp1d(
87 coefficients[0],
88 coefficients[1],
89 coefficients[2],
90 coefficients[3],
91 t_y));
92 }
93 }
94 }
95
96 // Backward (adjoint) operation 1 <- 2 (accumulates)
97 template <typename scalar_t, typename accscalar_t>
98 C10_LAUNCH_BOUNDS_1(1024)
upsample_bicubic2d_backward_out_frame(const int num_elements,const accscalar_t height_scale,const accscalar_t width_scale,const bool align_corners,PackedTensorAccessor64<scalar_t,4> idata,const PackedTensorAccessor64<const scalar_t,4> odata)99 __global__ void upsample_bicubic2d_backward_out_frame(
100 const int num_elements,
101 const accscalar_t height_scale,
102 const accscalar_t width_scale,
103 const bool align_corners,
104 PackedTensorAccessor64<scalar_t, 4> idata,
105 const PackedTensorAccessor64<const scalar_t, 4> odata) {
106 int index = threadIdx.x + blockIdx.x * blockDim.x;
107
108 const int batchsize = idata.size(0);
109 const int channels = idata.size(1);
110 const int input_height = idata.size(2);
111 const int input_width = idata.size(3);
112 const int output_height = odata.size(2);
113 const int output_width = odata.size(3);
114
115 if (index >= num_elements) {
116 return;
117 }
118
119 const int output_x = index % output_width;
120 const int output_y = index / output_width;
121 // special case: output_xust copy
122 if (input_height == output_height && input_width == output_width) {
123 for (int n = 0; n < batchsize; n++) {
124 for (int c = 0; c < channels; ++c) {
125 const scalar_t val = odata[n][c][output_y][output_x];
126 idata[n][c][output_y][output_x] = val;
127 }
128 }
129 return;
130 }
131
132 accscalar_t real_x = area_pixel_compute_source_index(
133 width_scale, output_x, align_corners, /*cubic=*/true);
134 int input_x = floorf(real_x);
135 accscalar_t t_x = real_x - input_x;
136
137 accscalar_t real_y = area_pixel_compute_source_index(
138 height_scale, output_y, align_corners, /*cubic=*/true);
139 int input_y = floorf(real_y);
140 accscalar_t t_y = real_y - input_y;
141
142 accscalar_t x_coeffs[4];
143 accscalar_t y_coeffs[4];
144
145 get_cubic_upsampling_coefficients(x_coeffs, t_x);
146 get_cubic_upsampling_coefficients(y_coeffs, t_y);
147
148 for (int n = 0; n < batchsize; n++) {
149 for (int c = 0; c < channels; ++c) {
150 scalar_t out_value = odata[n][c][output_y][output_x];
151 for (int i = 0; i < 4; i++) {
152 for (int j = 0; j < 4; j++) {
153 upsample_increment_value_bounded<scalar_t, accscalar_t>(
154 idata,
155 n,
156 c,
157 input_height,
158 input_width,
159 input_y - 1 + i,
160 input_x - 1 + j,
161 out_value * y_coeffs[i] * x_coeffs[j]);
162 }
163 }
164 }
165 }
166 }
167
upsample_bicubic2d_out_cuda_template(const Tensor & output,const Tensor & input,IntArrayRef output_size,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)168 static void upsample_bicubic2d_out_cuda_template(
169 const Tensor& output,
170 const Tensor& input,
171 IntArrayRef output_size,
172 bool align_corners,
173 std::optional<double> scales_h,
174 std::optional<double> scales_w) {
175 TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2};
176 checkAllSameGPU(__func__, {input_arg, output_arg});
177
178 int output_height = output_size[0];
179 int output_width = output_size[1];
180
181 int input_height = input.size(2);
182 int input_width = input.size(3);
183
184 output.zero_();
185
186 const int num_output_elements = output_height * output_width;
187 const int max_threads = std::min(
188 at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
189
190 // Launch kernel
191 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
192
193 AT_DISPATCH_FLOATING_TYPES_AND2(
194 at::ScalarType::Half, at::ScalarType::BFloat16,
195 input.scalar_type(), "upsample_bicubic2d_out_frame", [&] {
196 using accscalar_t = at::acc_type<scalar_t, true>;
197
198 auto idata = input.packed_accessor64<const scalar_t, 4>();
199 auto odata = output.packed_accessor64<scalar_t, 4>();
200
201 // Get scaling factors
202 const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
203 input_height, output_height, align_corners, scales_h);
204 const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
205 input_width, output_width, align_corners, scales_w);
206
207 upsample_bicubic2d_out_frame<scalar_t, accscalar_t>
208 <<<ceil_div(num_output_elements, max_threads),
209 max_threads,
210 0,
211 stream>>>(
212 num_output_elements,
213 rheight,
214 rwidth,
215 align_corners,
216 idata,
217 odata);
218 C10_CUDA_KERNEL_LAUNCH_CHECK();
219 });
220 }
221
upsample_bicubic2d_backward_out_cuda_template(const Tensor & grad_input,const Tensor & grad_output_,IntArrayRef output_size,IntArrayRef input_size,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)222 static void upsample_bicubic2d_backward_out_cuda_template(
223 const Tensor& grad_input,
224 const Tensor& grad_output_,
225 IntArrayRef output_size,
226 IntArrayRef input_size,
227 bool align_corners,
228 std::optional<double> scales_h,
229 std::optional<double> scales_w) {
230 TensorArg grad_input_arg{grad_input, "grad_input", 1},
231 grad_output_arg{grad_output_, "grad_output_", 2};
232 checkAllSameGPU(__func__, {grad_output_arg, grad_input_arg});
233
234 int output_height = output_size[0];
235 int output_width = output_size[1];
236
237 int input_height = input_size[2];
238 int input_width = input_size[3];
239
240 Tensor grad_output = grad_output_.contiguous();
241
242 grad_input.zero_();
243
244 const int num_kernels = output_height * output_width;
245 const int num_threads = std::min(
246 at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
247 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
248
249 AT_DISPATCH_FLOATING_TYPES_AND2(
250 at::ScalarType::Half, at::ScalarType::BFloat16,
251 grad_output.scalar_type(), "upsample_bicubic2d_backward_out_frame", [&] {
252 using accscalar_t = at::acc_type<scalar_t, true>;
253
254 auto idata = grad_input.packed_accessor64<scalar_t, 4>();
255 auto odata = grad_output.packed_accessor64<const scalar_t, 4>();
256
257 const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
258 input_height, output_height, align_corners, scales_h);
259 const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
260 input_width, output_width, align_corners, scales_w);
261
262 upsample_bicubic2d_backward_out_frame<scalar_t, accscalar_t>
263 <<<ceil_div(num_kernels, num_threads),
264 num_threads,
265 0,
266 stream>>>(
267 num_kernels, rheight, rwidth, align_corners, idata, odata);
268 C10_CUDA_KERNEL_LAUNCH_CHECK();
269 });
270 }
271
272 } // namespace
273
TORCH_IMPL_FUNC(upsample_bicubic2d_out_cuda)274 TORCH_IMPL_FUNC(upsample_bicubic2d_out_cuda) (
275 const Tensor& input,
276 IntArrayRef output_size,
277 bool align_corners,
278 std::optional<double> scales_h,
279 std::optional<double> scales_w,
280 const Tensor& output) {
281 upsample_bicubic2d_out_cuda_template(output, input, output_size, align_corners, scales_h, scales_w);
282 }
283
TORCH_IMPL_FUNC(upsample_bicubic2d_backward_out_cuda)284 TORCH_IMPL_FUNC(upsample_bicubic2d_backward_out_cuda) (
285 const Tensor& grad_output,
286 IntArrayRef output_size,
287 IntArrayRef input_size,
288 bool align_corners,
289 std::optional<double> scales_h,
290 std::optional<double> scales_w,
291 const Tensor& grad_input) {
292 // See Note [Writing Nondeterministic Operations]
293 // Nondeterministic because of atomicAdd usage
294 globalContext().alertNotDeterministic("upsample_bicubic2d_backward_out_cuda");
295 upsample_bicubic2d_backward_out_cuda_template(
296 grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w);
297 }
298
299 } // namespace at::native
300