xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/reference/resize_bilinear.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_BILINEAR_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_BILINEAR_H_
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <cstdint>
21 #include <limits>
22 
23 #include "tensorflow/lite/kernels/internal/cppmath.h"
24 #include "tensorflow/lite/kernels/internal/types.h"
25 
26 namespace tflite {
27 namespace reference_ops {
28 
ComputeInterpolationValues(const float value,const float scale,const bool half_pixel_centers,int32_t input_size,float * scaled_value,int32_t * lower_bound,int32_t * upper_bound)29 inline void ComputeInterpolationValues(const float value, const float scale,
30                                        const bool half_pixel_centers,
31                                        int32_t input_size, float* scaled_value,
32                                        int32_t* lower_bound,
33                                        int32_t* upper_bound) {
34   if (half_pixel_centers) {
35     *scaled_value = (value + 0.5f) * scale - 0.5f;
36   } else {
37     *scaled_value = value * scale;
38   }
39   float scaled_value_floor = std::floor(*scaled_value);
40   *lower_bound = std::max(static_cast<int32_t>(scaled_value_floor),
41                           static_cast<int32_t>(0));
42   *upper_bound =
43       std::min(static_cast<int32_t>(std::ceil(*scaled_value)), input_size - 1);
44 }
45 
46 template <typename T>
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_size_shape,const int32_t * output_size_data,const RuntimeShape & unextended_output_shape,T * output_data)47 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
48                            const RuntimeShape& unextended_input_shape,
49                            const T* input_data,
50                            const RuntimeShape& unextended_output_size_shape,
51                            const int32_t* output_size_data,
52                            const RuntimeShape& unextended_output_shape,
53                            T* output_data) {
54   // If half_pixel_centers is True, align_corners must be False.
55   TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
56   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
57   TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
58   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
59   const RuntimeShape input_shape =
60       RuntimeShape::ExtendedShape(4, unextended_input_shape);
61   const RuntimeShape output_size_shape =
62       RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
63   const RuntimeShape output_shape =
64       RuntimeShape::ExtendedShape(4, unextended_output_shape);
65 
66   int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
67   int32_t input_height = input_shape.Dims(1);
68   int32_t input_width = input_shape.Dims(2);
69   int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
70 
71   TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
72   TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
73   TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
74   TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
75   int32_t output_height =
76       output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
77   int32_t output_width =
78       output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
79 
80   float height_scale = static_cast<float>(input_height) / output_height;
81   float width_scale = static_cast<float>(input_width) / output_width;
82   if (op_params.align_corners && output_height > 1) {
83     height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
84   }
85   if (op_params.align_corners && output_width > 1) {
86     width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
87   }
88   const float rounding_offset = std::numeric_limits<T>::is_integer ? .5f : .0f;
89 
90   for (int b = 0; b < batches; ++b) {
91     for (int y = 0; y < output_height; ++y) {
92       float input_y;
93       int32_t y0, y1;
94       ComputeInterpolationValues(y, height_scale, op_params.half_pixel_centers,
95                                  input_height, &input_y, &y0, &y1);
96       for (int x = 0; x < output_width; ++x) {
97         float input_x;
98         int32_t x0, x1;
99         ComputeInterpolationValues(x, width_scale, op_params.half_pixel_centers,
100                                    input_width, &input_x, &x0, &x1);
101         for (int c = 0; c < depth; ++c) {
102           T interpolation =
103               static_cast<T>(input_data[Offset(input_shape, b, y0, x0, c)] *
104                                  (1 - (input_y - y0)) * (1 - (input_x - x0)) +
105                              input_data[Offset(input_shape, b, y1, x0, c)] *
106                                  (input_y - y0) * (1 - (input_x - x0)) +
107                              input_data[Offset(input_shape, b, y0, x1, c)] *
108                                  (1 - (input_y - y0)) * (input_x - x0) +
109                              input_data[Offset(input_shape, b, y1, x1, c)] *
110                                  (input_y - y0) * (input_x - x0) +
111                              rounding_offset);
112           output_data[Offset(output_shape, b, y, x, c)] = interpolation;
113         }
114       }
115     }
116   }
117 }
118 
ComputeInterpolationValuesInteger(const int32_t value,const int32_t scale_10,const bool half_pixel_centers,int32_t input_size,int32_t * scaled_value,int32_t * lower_bound,int32_t * upper_bound)119 inline void ComputeInterpolationValuesInteger(
120     const int32_t value, const int32_t scale_10, const bool half_pixel_centers,
121     int32_t input_size, int32_t* scaled_value, int32_t* lower_bound,
122     int32_t* upper_bound) {
123   if (half_pixel_centers) {
124     *scaled_value = value * scale_10 + scale_10 / 2 - (1 << 9);
125   } else {
126     *scaled_value = value * scale_10;
127   }
128   constexpr int32_t zero = 0;
129   *lower_bound = std::max(*scaled_value / (1 << 10), zero);
130   *upper_bound =
131       std::min((*scaled_value + (1 << 10) - 1) / (1 << 10), input_size - 1);
132 }
133 
134 // Same as above but doesn't use any floating-point for the resize
135 template <typename T>
ResizeBilinearInteger(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_size_shape,const int32_t * output_size_data,const RuntimeShape & unextended_output_shape,T * output_data)136 inline void ResizeBilinearInteger(
137     const tflite::ResizeBilinearParams& op_params,
138     const RuntimeShape& unextended_input_shape, const T* input_data,
139     const RuntimeShape& unextended_output_size_shape,
140     const int32_t* output_size_data,
141     const RuntimeShape& unextended_output_shape, T* output_data) {
142   // If half_pixel_centers is True, align_corners must be False.
143   TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
144   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
145   TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
146   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
147   const RuntimeShape input_shape =
148       RuntimeShape::ExtendedShape(4, unextended_input_shape);
149   const RuntimeShape output_size_shape =
150       RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
151   const RuntimeShape output_shape =
152       RuntimeShape::ExtendedShape(4, unextended_output_shape);
153 
154   const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
155   const int32_t input_height = input_shape.Dims(1);
156   const int32_t input_width = input_shape.Dims(2);
157   const int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
158 
159   TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
160   TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
161   TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
162   TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
163   const int32_t output_height =
164       output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
165   const int32_t output_width =
166       output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
167 
168   int32_t height_scale_10 =
169       ((1 << 10) * input_height + output_height / 2) / output_height;
170   int32_t width_scale_10 =
171       ((1 << 10) * input_width + output_width / 2) / output_width;
172   if (op_params.align_corners && output_height > 1) {
173     height_scale_10 =
174         ((1 << 10) * (input_height - 1) + (output_height - 1) / 2) /
175         (output_height - 1);
176   }
177   if (op_params.align_corners && output_width > 1) {
178     width_scale_10 = ((1 << 10) * (input_width - 1) + (output_width - 1) / 2) /
179                      (output_width - 1);
180   }
181 
182   for (int b = 0; b < batches; ++b) {
183     for (int y = 0; y < output_height; ++y) {
184       int32_t input_y, y0, y1;
185       ComputeInterpolationValuesInteger(y, height_scale_10,
186                                         op_params.half_pixel_centers,
187                                         input_height, &input_y, &y0, &y1);
188       for (int x = 0; x < output_width; ++x) {
189         int32_t input_x, x0, x1;
190         ComputeInterpolationValuesInteger(x, width_scale_10,
191                                           op_params.half_pixel_centers,
192                                           input_width, &input_x, &x0, &x1);
193         for (int c = 0; c < depth; ++c) {
194           const int64_t output_20_ll =
195               static_cast<int64_t>(
196                   input_data[Offset(input_shape, b, y0, x0, c)]) *
197               ((1 << 10) - (input_y - (1 << 10) * y0)) *
198               ((1 << 10) - (input_x - (1 << 10) * x0));
199           const int64_t output_20_lu =
200               static_cast<int64_t>(
201                   input_data[Offset(input_shape, b, y1, x0, c)]) *
202               (input_y - (1 << 10) * y0) *
203               ((1 << 10) - (input_x - (1 << 10) * x0));
204           const int64_t output_20_rl =
205               static_cast<int64_t>(
206                   input_data[Offset(input_shape, b, y0, x1, c)]) *
207               ((1 << 10) - (input_y - (1 << 10) * y0)) *
208               (input_x - (1 << 10) * x0);
209           const int64_t output_20_ru =
210               static_cast<int64_t>(
211                   input_data[Offset(input_shape, b, y1, x1, c)]) *
212               (input_y - (1 << 10) * y0) * (input_x - (1 << 10) * x0);
213           const int64_t output_20 =
214               output_20_ll + output_20_lu + output_20_rl + output_20_ru;
215           const int64_t round = (output_20 > 0) ? (1 << 19) : -(1 << 19);
216           const T interpolation =
217               static_cast<T>((output_20 + round) / (1 << 20));
218           output_data[Offset(output_shape, b, y, x, c)] = interpolation;
219         }
220       }
221     }
222   }
223 }
224 
225 }  // namespace reference_ops
226 }  // namespace tflite
227 
228 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_BILINEAR_H_
229