1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_BILINEAR_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_BILINEAR_H_
17
18 #include <algorithm>
19 #include <cmath>
20 #include <cstdint>
21 #include <limits>
22
23 #include "tensorflow/lite/kernels/internal/cppmath.h"
24 #include "tensorflow/lite/kernels/internal/types.h"
25
26 namespace tflite {
27 namespace reference_ops {
28
ComputeInterpolationValues(const float value,const float scale,const bool half_pixel_centers,int32_t input_size,float * scaled_value,int32_t * lower_bound,int32_t * upper_bound)29 inline void ComputeInterpolationValues(const float value, const float scale,
30 const bool half_pixel_centers,
31 int32_t input_size, float* scaled_value,
32 int32_t* lower_bound,
33 int32_t* upper_bound) {
34 if (half_pixel_centers) {
35 *scaled_value = (value + 0.5f) * scale - 0.5f;
36 } else {
37 *scaled_value = value * scale;
38 }
39 float scaled_value_floor = std::floor(*scaled_value);
40 *lower_bound = std::max(static_cast<int32_t>(scaled_value_floor),
41 static_cast<int32_t>(0));
42 *upper_bound =
43 std::min(static_cast<int32_t>(std::ceil(*scaled_value)), input_size - 1);
44 }
45
46 template <typename T>
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_size_shape,const int32_t * output_size_data,const RuntimeShape & unextended_output_shape,T * output_data)47 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
48 const RuntimeShape& unextended_input_shape,
49 const T* input_data,
50 const RuntimeShape& unextended_output_size_shape,
51 const int32_t* output_size_data,
52 const RuntimeShape& unextended_output_shape,
53 T* output_data) {
54 // If half_pixel_centers is True, align_corners must be False.
55 TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
56 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
57 TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
58 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
59 const RuntimeShape input_shape =
60 RuntimeShape::ExtendedShape(4, unextended_input_shape);
61 const RuntimeShape output_size_shape =
62 RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
63 const RuntimeShape output_shape =
64 RuntimeShape::ExtendedShape(4, unextended_output_shape);
65
66 int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
67 int32_t input_height = input_shape.Dims(1);
68 int32_t input_width = input_shape.Dims(2);
69 int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
70
71 TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
72 TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
73 TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
74 TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
75 int32_t output_height =
76 output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
77 int32_t output_width =
78 output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
79
80 float height_scale = static_cast<float>(input_height) / output_height;
81 float width_scale = static_cast<float>(input_width) / output_width;
82 if (op_params.align_corners && output_height > 1) {
83 height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
84 }
85 if (op_params.align_corners && output_width > 1) {
86 width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
87 }
88 const float rounding_offset = std::numeric_limits<T>::is_integer ? .5f : .0f;
89
90 for (int b = 0; b < batches; ++b) {
91 for (int y = 0; y < output_height; ++y) {
92 float input_y;
93 int32_t y0, y1;
94 ComputeInterpolationValues(y, height_scale, op_params.half_pixel_centers,
95 input_height, &input_y, &y0, &y1);
96 for (int x = 0; x < output_width; ++x) {
97 float input_x;
98 int32_t x0, x1;
99 ComputeInterpolationValues(x, width_scale, op_params.half_pixel_centers,
100 input_width, &input_x, &x0, &x1);
101 for (int c = 0; c < depth; ++c) {
102 T interpolation =
103 static_cast<T>(input_data[Offset(input_shape, b, y0, x0, c)] *
104 (1 - (input_y - y0)) * (1 - (input_x - x0)) +
105 input_data[Offset(input_shape, b, y1, x0, c)] *
106 (input_y - y0) * (1 - (input_x - x0)) +
107 input_data[Offset(input_shape, b, y0, x1, c)] *
108 (1 - (input_y - y0)) * (input_x - x0) +
109 input_data[Offset(input_shape, b, y1, x1, c)] *
110 (input_y - y0) * (input_x - x0) +
111 rounding_offset);
112 output_data[Offset(output_shape, b, y, x, c)] = interpolation;
113 }
114 }
115 }
116 }
117 }
118
ComputeInterpolationValuesInteger(const int32_t value,const int32_t scale_10,const bool half_pixel_centers,int32_t input_size,int32_t * scaled_value,int32_t * lower_bound,int32_t * upper_bound)119 inline void ComputeInterpolationValuesInteger(
120 const int32_t value, const int32_t scale_10, const bool half_pixel_centers,
121 int32_t input_size, int32_t* scaled_value, int32_t* lower_bound,
122 int32_t* upper_bound) {
123 if (half_pixel_centers) {
124 *scaled_value = value * scale_10 + scale_10 / 2 - (1 << 9);
125 } else {
126 *scaled_value = value * scale_10;
127 }
128 constexpr int32_t zero = 0;
129 *lower_bound = std::max(*scaled_value / (1 << 10), zero);
130 *upper_bound =
131 std::min((*scaled_value + (1 << 10) - 1) / (1 << 10), input_size - 1);
132 }
133
134 // Same as above but doesn't use any floating-point for the resize
135 template <typename T>
ResizeBilinearInteger(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_size_shape,const int32_t * output_size_data,const RuntimeShape & unextended_output_shape,T * output_data)136 inline void ResizeBilinearInteger(
137 const tflite::ResizeBilinearParams& op_params,
138 const RuntimeShape& unextended_input_shape, const T* input_data,
139 const RuntimeShape& unextended_output_size_shape,
140 const int32_t* output_size_data,
141 const RuntimeShape& unextended_output_shape, T* output_data) {
142 // If half_pixel_centers is True, align_corners must be False.
143 TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
144 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
145 TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
146 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
147 const RuntimeShape input_shape =
148 RuntimeShape::ExtendedShape(4, unextended_input_shape);
149 const RuntimeShape output_size_shape =
150 RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
151 const RuntimeShape output_shape =
152 RuntimeShape::ExtendedShape(4, unextended_output_shape);
153
154 const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
155 const int32_t input_height = input_shape.Dims(1);
156 const int32_t input_width = input_shape.Dims(2);
157 const int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
158
159 TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
160 TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
161 TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
162 TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
163 const int32_t output_height =
164 output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
165 const int32_t output_width =
166 output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
167
168 int32_t height_scale_10 =
169 ((1 << 10) * input_height + output_height / 2) / output_height;
170 int32_t width_scale_10 =
171 ((1 << 10) * input_width + output_width / 2) / output_width;
172 if (op_params.align_corners && output_height > 1) {
173 height_scale_10 =
174 ((1 << 10) * (input_height - 1) + (output_height - 1) / 2) /
175 (output_height - 1);
176 }
177 if (op_params.align_corners && output_width > 1) {
178 width_scale_10 = ((1 << 10) * (input_width - 1) + (output_width - 1) / 2) /
179 (output_width - 1);
180 }
181
182 for (int b = 0; b < batches; ++b) {
183 for (int y = 0; y < output_height; ++y) {
184 int32_t input_y, y0, y1;
185 ComputeInterpolationValuesInteger(y, height_scale_10,
186 op_params.half_pixel_centers,
187 input_height, &input_y, &y0, &y1);
188 for (int x = 0; x < output_width; ++x) {
189 int32_t input_x, x0, x1;
190 ComputeInterpolationValuesInteger(x, width_scale_10,
191 op_params.half_pixel_centers,
192 input_width, &input_x, &x0, &x1);
193 for (int c = 0; c < depth; ++c) {
194 const int64_t output_20_ll =
195 static_cast<int64_t>(
196 input_data[Offset(input_shape, b, y0, x0, c)]) *
197 ((1 << 10) - (input_y - (1 << 10) * y0)) *
198 ((1 << 10) - (input_x - (1 << 10) * x0));
199 const int64_t output_20_lu =
200 static_cast<int64_t>(
201 input_data[Offset(input_shape, b, y1, x0, c)]) *
202 (input_y - (1 << 10) * y0) *
203 ((1 << 10) - (input_x - (1 << 10) * x0));
204 const int64_t output_20_rl =
205 static_cast<int64_t>(
206 input_data[Offset(input_shape, b, y0, x1, c)]) *
207 ((1 << 10) - (input_y - (1 << 10) * y0)) *
208 (input_x - (1 << 10) * x0);
209 const int64_t output_20_ru =
210 static_cast<int64_t>(
211 input_data[Offset(input_shape, b, y1, x1, c)]) *
212 (input_y - (1 << 10) * y0) * (input_x - (1 << 10) * x0);
213 const int64_t output_20 =
214 output_20_ll + output_20_lu + output_20_rl + output_20_ru;
215 const int64_t round = (output_20 > 0) ? (1 << 19) : -(1 << 19);
216 const T interpolation =
217 static_cast<T>((output_20 + round) / (1 << 20));
218 output_data[Offset(output_shape, b, y, x, c)] = interpolation;
219 }
220 }
221 }
222 }
223 }
224
225 } // namespace reference_ops
226 } // namespace tflite
227
228 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_BILINEAR_H_
229