1 #pragma once
2
3 #include <math.h>
4
5 #include <ATen/OpMathType.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/OpMathType.h>
8 #include <ATen/core/Tensor.h>
9 #include <ATen/cpu/vec/functional.h>
10 #include <ATen/cpu/vec/vec.h>
11 #include <ATen/native/DispatchStub.h>
12 #include <ATen/native/cpu/utils.h>
13
14 /**
15 * Note [compute_scales_value]
16 * Note [area_pixel_compute_scale]
17 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
18 * Interpolate with scale_factor can have different behaviors
19 * depending on the value of recompute_scale_factor:
20 *
21 * - With recompute_scale_factor = True (current default behavior):
22 * the scale_factor, when provided by the user, are used to calculate
23 * the output size. The input size and the computed output_size
24 * are then used to infer new values for the scales which are
25 * used in the interpolation. Because floating-point math is not exact,
26 * this may be a different value from the user-supplied scales.
27 *
28 * - With recompute_scale_factor = False (which will be the default
29 * behavior starting 1.5.0):
30 * the behavior follows opencv logic, and the scales provided by
31 * the user are the ones used in the interpolation calculations.
32 *
33 * If the scales are not provided or if they are provided but
34 * recompute_scale_factor is set to True (default behavior), the scales
35 * are computed from the input and the output size;
36 *
37 *
38 * When the scales are inferred from the input and output sizes,
39 * we view each pixel as an area, idx + 0.5 as its center index.
40 * Here is an example formula in 1D case.
41 * if align_corners: center of two corner pixel areas are preserved,
42 * (0.5, 0.5) -> (0.5, 0.5),
43 * (input_size - 0.5, 0.5) -> (output_size - 0.5)
44 * scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5)
45 * src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5)
46 * if not align_corners: the whole range is scaled accordingly
47 * scale = input_size / output_size
48 * src_idx + 0.5 = scale * (dst_index + 0.5)
49 */
50
51 namespace at::native {
52
53 namespace upsample {
54
55 TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
56 c10::IntArrayRef input_size, // Full input tensor size.
57 at::OptionalIntArrayRef output_size,
58 std::optional<c10::ArrayRef<double>> scale_factors);
59
get_scale_value(std::optional<c10::ArrayRef<double>> scales,int idx)60 inline std::optional<double> get_scale_value(std::optional<c10::ArrayRef<double>> scales, int idx) {
61 if (!scales) {
62 return std::nullopt;
63 }
64 return scales->at(idx);
65 }
66
67 } // namespace upsample
68
69 using scale_t = std::optional<double>;
70 using upsampling_nearest1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
71 using _upsampling_nearest_exact1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
72 using upsampling_nearest2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
73 using _upsampling_nearest_exact2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
74 using upsampling_nearest3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
75 using _upsampling_nearest_exact3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
76 using upsampling_linear1d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_w);
77 using upsampling_bilinear2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
78 using _upsampling_bilinear2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
79 using upsampling_trilinear3d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_d, scale_t scales_h, scale_t scales_w);
80 using upsampling_bicubic2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
81 using _upsampling_bicubic2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
82 DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_kernel);
83 DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_kernel);
84 DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_kernel);
85 DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_kernel);
86 DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_kernel);
87 DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_kernel);
88 DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_backward_kernel);
89 DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_backward_kernel);
90 DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_backward_kernel);
91 DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_backward_kernel);
92 DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_backward_kernel);
93 DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_backward_kernel);
94 DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_kernel);
95 DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_kernel);
96 DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_kernel);
97 DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_kernel);
98 DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_backward_kernel);
99 DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_backward_kernel);
100 DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_backward_kernel);
101 DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_backward_kernel);
102 DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel);
103 DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel);
104 DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel);
105
upsample_1d_common_check(IntArrayRef input_size,IntArrayRef output_size)106 inline C10_UNUSED std::array<int64_t, 3> upsample_1d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
107 TORCH_CHECK(
108 output_size.size() == 1,
109 "It is expected output_size equals to 1, but got size ",
110 output_size.size());
111
112 TORCH_CHECK(
113 input_size.size() == 3,
114 "It is expected input_size equals to 3, but got size ",
115 input_size.size());
116
117 int64_t output_width = output_size[0];
118
119 int64_t nbatch = input_size[0];
120 int64_t channels = input_size[1];
121 int64_t input_width = input_size[2];
122
123 TORCH_CHECK(
124 input_width > 0 && output_width > 0,
125 "Input and output sizes should be greater than 0, but got input (W: ",
126 input_width,
127 ") and output (W: ",
128 output_width,
129 ")");
130
131 return {nbatch, channels, output_width};
132 }
133
upsample_2d_common_check(IntArrayRef input_size,IntArrayRef output_size)134 inline C10_UNUSED std::array<int64_t, 4> upsample_2d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
135 TORCH_CHECK(
136 output_size.size() == 2,
137 "It is expected output_size equals to 2, but got size ",
138 output_size.size());
139
140 TORCH_CHECK(
141 input_size.size() == 4,
142 "It is expected input_size equals to 4, but got size ",
143 input_size.size());
144
145 int64_t output_height = output_size[0];
146 int64_t output_width = output_size[1];
147
148 int64_t nbatch = input_size[0];
149 int64_t channels = input_size[1];
150 int64_t input_height = input_size[2];
151 int64_t input_width = input_size[3];
152
153 TORCH_CHECK(
154 input_height > 0 && input_width > 0 && output_height > 0 &&
155 output_width > 0,
156 "Input and output sizes should be greater than 0,"
157 " but got input (H: ",
158 input_height,
159 ", W: ",
160 input_width,
161 ") output (H: ",
162 output_height,
163 ", W: ",
164 output_width,
165 ")");
166
167 return {nbatch, channels, output_height, output_width};
168 }
169
170 inline C10_UNUSED
upsample_3d_common_check(IntArrayRef input_size,IntArrayRef output_size)171 std::array<int64_t, 5> upsample_3d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
172 TORCH_CHECK(
173 output_size.size() == 3,
174 "It is expected output_size equals to 3, but got size ",
175 output_size.size());
176
177 TORCH_CHECK(
178 input_size.size() == 5,
179 "It is expected input_size equals to 5, but got size ",
180 input_size.size());
181
182 int64_t output_depth = output_size[0];
183 int64_t output_height = output_size[1];
184 int64_t output_width = output_size[2];
185
186 int64_t nbatch = input_size[0];
187 int64_t channels = input_size[1];
188 int64_t input_depth = input_size[2];
189 int64_t input_height = input_size[3];
190 int64_t input_width = input_size[4];
191
192 TORCH_CHECK(
193 input_depth > 0 && input_height > 0 && input_width > 0 &&
194 output_depth > 0 && output_height > 0 && output_width > 0,
195 "Input and output sizes should be greater than 0, but got input (D: ",
196 input_depth,
197 ", H: ",
198 input_height,
199 ", W: ",
200 input_width,
201 ") output (D: ",
202 output_depth,
203 ", H: ",
204 output_height,
205 ", W: ",
206 output_width,
207 ")");
208
209
210 return {nbatch, channels, output_depth, output_height, output_width};
211 }
212
upsample_2d_shape_check(const Tensor & input,const Tensor & grad_output,int64_t nbatch,int64_t nchannels,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width)213 inline void upsample_2d_shape_check(
214 const Tensor& input,
215 const Tensor& grad_output,
216 int64_t nbatch,
217 int64_t nchannels,
218 int64_t input_height,
219 int64_t input_width,
220 int64_t output_height,
221 int64_t output_width) {
222 TORCH_CHECK(
223 input_height > 0 && input_width > 0 && output_height > 0 &&
224 output_width > 0,
225 "Input and output sizes should be greater than 0,"
226 " but got input (H: ",
227 input_height,
228 ", W: ",
229 input_width,
230 ") output (H: ",
231 output_height,
232 ", W: ",
233 output_width,
234 ")");
235
236 if (input.defined()) {
237 // Allow for empty batch size but not other dimensions
238 TORCH_CHECK(
239 (input.numel() != 0 ||
240 (input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0)
241 ) &&
242 input.dim() == 4,
243 "Non-empty 4D data tensor expected but got a tensor with sizes ",
244 input.sizes());
245 } else if (grad_output.defined()) {
246 check_dim_size(grad_output, 4, 0, nbatch);
247 check_dim_size(grad_output, 4, 1, nchannels);
248 check_dim_size(grad_output, 4, 2, output_height);
249 check_dim_size(grad_output, 4, 3, output_width);
250 }
251 }
252
253 template <typename scalar_t>
compute_scales_value(const std::optional<double> scale,int64_t input_size,int64_t output_size)254 inline scalar_t compute_scales_value(
255 const std::optional<double> scale,
256 int64_t input_size,
257 int64_t output_size) {
258 // see Note [compute_scales_value]
259 // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
260 return (scale.has_value() && scale.value() > 0.)
261 ? static_cast<scalar_t>(1.0 / scale.value())
262 : (static_cast<scalar_t>(input_size) / output_size);
263 }
264
265 template <typename scalar_t>
area_pixel_compute_scale(int64_t input_size,int64_t output_size,bool align_corners,const std::optional<double> scale)266 inline scalar_t area_pixel_compute_scale(
267 int64_t input_size,
268 int64_t output_size,
269 bool align_corners,
270 const std::optional<double> scale) {
271 // see Note [area_pixel_compute_scale]
272 if(align_corners) {
273 if(output_size > 1) {
274 return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
275 } else {
276 return static_cast<scalar_t>(0);
277 }
278 } else {
279 return compute_scales_value<scalar_t>(scale, input_size, output_size);
280 }
281 }
282
283 template <typename scalar_t>
area_pixel_compute_source_index(scalar_t scale,int64_t dst_index,bool align_corners,bool cubic)284 inline scalar_t area_pixel_compute_source_index(
285 scalar_t scale,
286 int64_t dst_index,
287 bool align_corners,
288 bool cubic) {
289 if (align_corners) {
290 return scale * dst_index;
291 } else {
292 scalar_t src_idx = scale * (dst_index + static_cast<scalar_t>(0.5)) -
293 static_cast<scalar_t>(0.5);
294 // [Note] Follow Opencv resize logic:
295 // We allow negative src_idx here and later will use
296 // dx = src_idx - floorf(src_idx)
297 // to compute the "distance"(which affects weights).
298 // For linear modes, weight distribution doesn't matter
299 // for negative indices as they use 2 pixels to interpolate.
300 // For example, [-1, 0], they both use pixel 0 value so it
301 // doesn't affect if we bound the src_idx to 0 or not.
302 // TODO: Our current linear mode impls use unbound indices
303 // where we should and then remove this cubic flag.
304 // This matters in cubic mode, as we might need [-1, 0, 1, 2]
305 // to interpolate and the weights can be affected.
306 return (!cubic && src_idx < static_cast<scalar_t>(0)) ? scalar_t(0)
307 : src_idx;
308 }
309 }
310
nearest_neighbor_compute_source_index(const float scale,int64_t dst_index,int64_t input_size)311 inline int64_t nearest_neighbor_compute_source_index(
312 const float scale,
313 int64_t dst_index,
314 int64_t input_size) {
315 // Index computation matching OpenCV INTER_NEAREST
316 // which is buggy and kept for BC
317 const int64_t src_index =
318 std::min(static_cast<int64_t>(floorf(dst_index * scale)), input_size - 1);
319 return src_index;
320 }
321
nearest_neighbor_exact_compute_source_index(const float scale,int64_t dst_index,int64_t input_size)322 inline int64_t nearest_neighbor_exact_compute_source_index(
323 const float scale,
324 int64_t dst_index,
325 int64_t input_size) {
326 // index_f32 = (output_index + 0.5) * scale - 0.5
327 // input_index = round(index_f32)
328 // Same as Pillow and Scikit-Image/Scipy ndi.zoom
329 const int64_t src_index =
330 std::min(static_cast<int64_t>(floorf((dst_index + 0.5) * scale)), input_size - 1);
331 return src_index;
332 }
333
nearest_idx(int64_t output_index,int64_t input_size,int64_t output_size,std::optional<double> scales)334 inline int64_t nearest_idx(
335 int64_t output_index,
336 int64_t input_size,
337 int64_t output_size,
338 std::optional<double> scales) {
339 // This method specificly treats cases: output_size == input_size or
340 // output_size == 2 * input_size, that we would like to get rid of
341 // We keep this method for BC and consider as deprecated.
342 // See nearest_exact_idx as replacement
343 if (output_size == input_size) {
344 // scale_factor = 1, simply copy
345 return output_index;
346 } else if (output_size == 2 * input_size) {
347 // scale_factor = 2, shift input index
348 return output_index >> 1;
349 } else {
350 float scale = compute_scales_value<float>(scales, input_size, output_size);
351 return nearest_neighbor_compute_source_index(scale, output_index, input_size);
352 }
353 }
354
nearest_exact_idx(int64_t output_index,int64_t input_size,int64_t output_size,std::optional<double> scales)355 inline int64_t nearest_exact_idx(
356 int64_t output_index,
357 int64_t input_size,
358 int64_t output_size,
359 std::optional<double> scales) {
360 float scale = compute_scales_value<float>(scales, input_size, output_size);
361 return nearest_neighbor_exact_compute_source_index(scale, output_index, input_size);
362 }
363
364 // Define a typedef to dispatch to nearest_idx or nearest_exact_idx
365 typedef int64_t (*nearest_idx_fn_t)(int64_t, int64_t, int64_t, std::optional<double>);
366
367 template <typename scalar_t>
upsample_get_value_bounded(scalar_t * data,int64_t width,int64_t height,int64_t x,int64_t y)368 scalar_t upsample_get_value_bounded(
369 scalar_t* data,
370 int64_t width,
371 int64_t height,
372 int64_t x,
373 int64_t y) {
374 int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
375 int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
376 return data[access_y * width + access_x];
377 }
378
379 template <typename scalar_t>
upsample_increment_value_bounded(scalar_t * data,int64_t width,int64_t height,int64_t x,int64_t y,scalar_t value)380 void upsample_increment_value_bounded(
381 scalar_t* data,
382 int64_t width,
383 int64_t height,
384 int64_t x,
385 int64_t y,
386 scalar_t value) {
387 int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
388 int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
389 data[access_y * width + access_x] += value;
390 }
391
392 // Based on
393 // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
394 template <typename scalar_t>
cubic_convolution1(scalar_t x,scalar_t A)395 scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
396 return ((A + 2) * x - (A + 3)) * x * x + 1;
397 }
398
399 template <typename scalar_t>
cubic_convolution2(scalar_t x,scalar_t A)400 scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
401 return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
402 }
403
404 template <typename scalar_t>
get_cubic_upsample_coefficients(scalar_t coeffs[4],scalar_t t)405 void get_cubic_upsample_coefficients(
406 scalar_t coeffs[4],
407 scalar_t t) {
408 scalar_t A = -0.75;
409
410 scalar_t x1 = t;
411 coeffs[0] = cubic_convolution2<scalar_t>(x1 + 1.0, A);
412 coeffs[1] = cubic_convolution1<scalar_t>(x1, A);
413
414 // opposite coefficients
415 scalar_t x2 = 1.0 - t;
416 coeffs[2] = cubic_convolution1<scalar_t>(x2, A);
417 coeffs[3] = cubic_convolution2<scalar_t>(x2 + 1.0, A);
418 }
419
420 template <typename scalar_t>
cubic_interp1d(scalar_t x0,scalar_t x1,scalar_t x2,scalar_t x3,scalar_t t)421 inline scalar_t cubic_interp1d(
422 scalar_t x0,
423 scalar_t x1,
424 scalar_t x2,
425 scalar_t x3,
426 scalar_t t) {
427 scalar_t coeffs[4];
428 get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
429
430 return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
431 }
432
433 // when `real_input_index` becomes larger than the range the floating point
434 // type can accurately represent, the type casting to `int64_t` might exceed
435 // `input_size`, causing overflow. So we guard it with `std::min` below.
436 template<typename scalar_t, typename opmath_t>
guard_index_and_lambda(const opmath_t & real_input_index,const int64_t & input_size,int64_t & input_index,scalar_t & lambda)437 inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
438 input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
439 lambda = std::min(
440 std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
441 static_cast<opmath_t>(1)
442 );
443 }
444
445 template<typename scalar_t, typename opmath_t>
compute_source_index_and_lambda(int64_t & input_index0,int64_t & input_index1,scalar_t & lambda0,scalar_t & lambda1,opmath_t ratio,int64_t output_index,int64_t input_size,int64_t output_size,bool align_corners)446 inline void compute_source_index_and_lambda(
447 int64_t& input_index0,
448 int64_t& input_index1,
449 scalar_t& lambda0,
450 scalar_t& lambda1,
451 opmath_t ratio,
452 int64_t output_index,
453 int64_t input_size,
454 int64_t output_size,
455 bool align_corners) {
456 if (output_size == input_size) {
457 // scale_factor = 1, simply copy
458 input_index0 = output_index;
459 input_index1 = output_index;
460 lambda0 = static_cast<scalar_t>(1);
461 lambda1 = static_cast<scalar_t>(0);
462 } else {
463 const auto real_input_index =
464 area_pixel_compute_source_index<opmath_t>(
465 ratio, output_index, align_corners, /*cubic=*/false);
466 guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
467 int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
468 input_index1 = input_index0 + offset;
469 lambda0 = static_cast<scalar_t>(1.) - lambda1;
470 }
471 }
472
473 // It will not be used by data types other than BFloat16 and Half.
474 template <typename scalar_in, typename scalar_out,
475 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_out> || !std::is_same<scalar_in, float>::value, int> = 0>
apply_grad_input(scalar_in * buffer_ptr,scalar_out * gin,int64_t size)476 void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
477 TORCH_CHECK((is_reduced_floating_point_v<scalar_out>),
478 "Upsample backward only support BFloat16 and Half in the lower precision data types on CPU.")
479 TORCH_CHECK((std::is_same<scalar_in, float>::value),
480 "Upsample backward should use float as acc buffer for BFloat16 and Half grad input on CPU.")
481 return;
482 }
483
484 template <typename scalar_in, typename scalar_out,
485 typename std::enable_if_t<is_reduced_floating_point_v<scalar_out> && std::is_same<scalar_in, float>::value, int> = 0>
apply_grad_input(scalar_in * buffer_ptr,scalar_out * gin,int64_t size)486 void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
487 using bVec = Vectorized<scalar_out>;
488 using fVec = Vectorized<float>;
489 int64_t d = 0;
490 for (; d < size - (size % bVec::size()); d += bVec::size()) {
491 bVec gin_bvec = bVec::loadu(gin + d);
492 auto [gin_fvec0, gin_fvec1] = convert_to_float<scalar_out>(gin_bvec);
493 gin_fvec0 += fVec::loadu(buffer_ptr + d);
494 gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size());
495 fVec(0).store(buffer_ptr + d);
496 fVec(0).store(buffer_ptr + d + fVec::size());
497 convert_from_float<scalar_out>(gin_fvec0, gin_fvec1).store(gin + d);
498 }
499 for (; d < size; d++) {
500 gin[d] += buffer_ptr[d];
501 buffer_ptr[d] = 0;
502 }
503 }
504
505 } // namespace at::native
506