1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_RESIZE_BILINEAR_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_RESIZE_BILINEAR_H_
17
18 #include <stdint.h>
19 #include <sys/types.h>
20
21 #include <algorithm>
22 #include <cmath>
23 #include <limits>
24 #include <memory>
25 #include <type_traits>
26
27 #include "ruy/profiler/instrumentation.h" // from @ruy
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/kernels/internal/common.h"
30 #include "tensorflow/lite/kernels/internal/compatibility.h"
31 #include "tensorflow/lite/kernels/internal/quantization_util.h"
32 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
33 #include "tensorflow/lite/kernels/internal/tensor.h"
34 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
35 #include "tensorflow/lite/kernels/internal/types.h"
36
37 namespace tflite {
38 namespace optimized_ops {
39 namespace resize_bilinear {
40
41 #ifdef USE_NEON
42 // These utility functions are split off not just for convenience. Most
43 // incoporate packing or unpacking of data.
44 //
45 // (a) Optimizations can be tried experimentally.
46 // (b) Optimizations can be specialized for architectures, eg Intel vs ARM.
47
Load8IntoLowerS16(const uint8 * data_ptr)48 inline int16x8_t Load8IntoLowerS16(const uint8* data_ptr) {
49 return vreinterpretq_s16_u16(vmovl_u8(vld1_u8(data_ptr)));
50 }
51
Move8IntoUpperU16(const uint8x8_t vec_val)52 inline uint16x8_t Move8IntoUpperU16(const uint8x8_t vec_val) {
53 // Alternatively one could zip with a zero vector.
54 return vshlq_n_u16(vmovl_u8(vec_val), 8);
55 }
56
Load8IntoUpperU16(const uint8 * data_ptr)57 inline uint16x8_t Load8IntoUpperU16(const uint8* data_ptr) {
58 return Move8IntoUpperU16(vld1_u8(data_ptr));
59 }
60
61 // Extract upper 8 bits from each 16-bit integer in vector registers. This is
62 // performed for a pair, because instructions often work on pairs.
PairExtractUpper(const uint16x8_t accum_0,const uint16x8_t accum_1,uint8x8_t * res_0,uint8x8_t * res_1)63 inline void PairExtractUpper(const uint16x8_t accum_0, const uint16x8_t accum_1,
64 uint8x8_t* res_0, uint8x8_t* res_1) {
65 uint8x16x2_t unzipped =
66 vuzpq_u8(vreinterpretq_u8_u16(accum_0), vreinterpretq_u8_u16(accum_1));
67 *res_0 = vget_low_u8(unzipped.val[1]);
68 *res_1 = vget_high_u8(unzipped.val[1]);
69 }
70
71 // This is an exceptional definition.
72 //
73 // Modify int16x8_t, adding operators.
74 //
75 // There are exceptional circumstances that make it reasonable to write code
76 // on vector types for quantized resize bilinear in *some cases*.
77 //
78 // (a) In exact quant resize bilinear, it should be possible to guarantee that
79 // arithmetic never overflows.
80 // (b) When the resize scaling is 2 or 4 or 8 it is possible to guarantee
81 // exact accumulation and exact incrementation.
82 // (c) In quant resize bilinear the choice of unsigned vs signed accumulation
83 // and saturated vs unsaturated arithmetic is often unimportant.
84 //
85 // This pattern simplifies the code considerably. This pattern should not be
86 // used more widely in code since it can hide important numerical detail.
87 //
88 // DO NOT add to this any "class-like" methods: only those that do no more than
89 // redirecting operators to specific intrinsics functions.
90 struct op_int16x8_t {
91 inline op_int16x8_t() = default;
op_int16x8_top_int16x8_t92 inline explicit op_int16x8_t(const int16x8_t& initial_val) {
93 val = initial_val;
94 }
95 inline op_int16x8_t& operator=(const int16x8_t& new_val) {
96 val = new_val;
97 return *this;
98 }
99 inline op_int16x8_t operator+=(const op_int16x8_t& add_val) {
100 val = vaddq_s16(val, add_val.val);
101 return *this;
102 }
103 inline op_int16x8_t operator-=(const op_int16x8_t& sub_val) {
104 val = vsubq_s16(val, sub_val.val);
105 return *this;
106 }
107 // This really selects vshlq_n_s16, but requires a longer implementation to
108 // convert the shift argument back to a constant. In some compiles are macros
109 // requiring constant args.
110 inline op_int16x8_t operator<<=(int32 left_shift) {
111 switch (left_shift) {
112 case 1:
113 val = vshlq_n_s16(val, 1);
114 break;
115 case 4:
116 val = vshlq_n_s16(val, 4);
117 break;
118 case 8:
119 val = vshlq_n_s16(val, 8);
120 break;
121 default:
122 TFLITE_CHECK(false);
123 break;
124 }
125 return *this;
126 }
127 // This really selects vshrq_n_u16, but requires a longer implementation to
128 // convert the shift argument back to a constant. In some compiles are macros
129 // requiring constant args.
130 inline op_int16x8_t operator>>=(int32 right_shift) {
131 switch (right_shift) {
132 case 1:
133 val = vshrq_n_s16(val, 1);
134 break;
135 case 4:
136 val = vshrq_n_s16(val, 4);
137 break;
138 case 8:
139 val = vshrq_n_s16(val, 8);
140 break;
141 default:
142 TFLITE_CHECK(false);
143 break;
144 }
145 return *this;
146 }
147 friend inline op_int16x8_t operator+(op_int16x8_t lhs,
148 const op_int16x8_t& rhs) {
149 lhs += rhs;
150 return lhs;
151 }
152 friend inline op_int16x8_t operator-(op_int16x8_t lhs,
153 const op_int16x8_t& rhs) {
154 lhs -= rhs;
155 return lhs;
156 }
157 friend inline op_int16x8_t operator<<(op_int16x8_t lhs, int32 left_shift) {
158 lhs <<= left_shift;
159 return lhs;
160 }
161 friend inline op_int16x8_t operator>>(op_int16x8_t lhs, int32 right_shift) {
162 lhs >>= right_shift;
163 return lhs;
164 }
165
166 int16x8_t val;
167 };
168
169 // This is an exceptional definition.
170 //
171 // Modify uint16x8_t, adding operators.
172 //
173 // Important: See above notes on op_int16x8_t.
174 struct op_uint16x8_t {
175 inline op_uint16x8_t() = default;
op_uint16x8_top_uint16x8_t176 inline explicit op_uint16x8_t(const uint16x8_t initial_val) {
177 val = initial_val;
178 }
179 inline op_uint16x8_t& operator=(const uint16x8_t& new_val) {
180 val = new_val;
181 return *this;
182 }
183 inline op_uint16x8_t operator+=(const op_int16x8_t& add_val) {
184 val = vaddq_u16(val, vreinterpretq_u16_s16(add_val.val));
185 return *this;
186 }
187 inline op_uint16x8_t operator-=(const op_int16x8_t& sub_val) {
188 val = vsubq_u16(val, vreinterpretq_u16_s16(sub_val.val));
189 return *this;
190 }
191 // This really selects vshlq_n_s16, but requires a longer implementation to
192 // convert the shift argument back to a constant. In some compiles are macros
193 // requiring constant args.
194 inline op_uint16x8_t operator<<=(int32 left_shift) {
195 switch (left_shift) {
196 case 1:
197 val = vshlq_n_u16(val, 1);
198 break;
199 case 4:
200 val = vshlq_n_u16(val, 4);
201 break;
202 case 8:
203 val = vshlq_n_u16(val, 8);
204 break;
205 default:
206 TFLITE_CHECK(false);
207 break;
208 }
209 return *this;
210 }
211 // This really selects vshrq_n_u16, but requires a longer implementation to
212 // convert the shift argument back to a constant. In some compiles are macros
213 // requiring constant args.
214 inline op_uint16x8_t operator>>=(int32 right_shift) {
215 switch (right_shift) {
216 case 1:
217 val = vshrq_n_u16(val, 1);
218 break;
219 case 4:
220 val = vshrq_n_u16(val, 4);
221 break;
222 case 8:
223 val = vshrq_n_u16(val, 8);
224 break;
225 default:
226 TFLITE_CHECK(false);
227 break;
228 }
229 return *this;
230 }
231 friend inline op_uint16x8_t operator+(op_uint16x8_t lhs,
232 const op_int16x8_t& rhs) {
233 lhs += rhs;
234 return lhs;
235 }
236 friend inline op_uint16x8_t operator-(op_uint16x8_t lhs,
237 const op_int16x8_t& rhs) {
238 lhs -= rhs;
239 return lhs;
240 }
241 friend inline op_uint16x8_t operator<<(op_uint16x8_t lhs, int32 left_shift) {
242 lhs <<= left_shift;
243 return lhs;
244 }
245 friend inline op_uint16x8_t operator>>(op_uint16x8_t lhs, int32 right_shift) {
246 lhs >>= right_shift;
247 return lhs;
248 }
249
250 uint16x8_t val;
251 };
252
VReinterpretQU16S16(const op_int16x8_t & other)253 inline op_uint16x8_t VReinterpretQU16S16(const op_int16x8_t& other) {
254 op_uint16x8_t ret_val(vreinterpretq_u16_s16(other.val));
255 return ret_val;
256 }
257 #endif // USE_NEON
258
259 // Optimized resize-bilinear for the special case where the scaling is x8 in
260 // width and height, and where we can operate on depth-8 blocks at a time. So
261 // the output blocks are 8x8x8 in width-height-depth.
262 //
263 // This optimization is for the half_pixel_centers == true version, for uint8.
264 // There are versions for NEON and non-NEON compilation.
ResizeBilinear888Uint8(int32 batches,int32 input_height,int32 input_width,int32 depth,const uint8 * input_data,uint8 * output_data)265 inline void ResizeBilinear888Uint8(int32 batches, int32 input_height,
266 int32 input_width, int32 depth,
267 const uint8* input_data,
268 uint8* output_data) {
269 TFLITE_DCHECK_GE(input_height, 1);
270 TFLITE_DCHECK_GE(input_width, 1);
271 TFLITE_DCHECK_EQ(depth % 8, 0);
272
273 const int32 input_row_stride = input_width * depth;
274 const int32 output_row_stride = input_row_stride * 8;
275 for (int b = 0; b < batches; ++b) {
276 const uint8* input_base_ptr =
277 input_data + b * input_row_stride * input_height;
278 uint8* output_base_ptr =
279 output_data + b * output_row_stride * input_height * 8;
280
281 #ifdef USE_NEON
282 for (int c_block = 0; c_block < depth; c_block += 8) {
283 op_uint16x8_t accum_c_v;
284 // Top-left margin corner.
285 {
286 uint8x8_t output_data = vld1_u8(&input_base_ptr[c_block]);
287 vst1_u8(&output_base_ptr[c_block], output_data);
288 vst1_u8(&output_base_ptr[c_block + depth], output_data);
289 vst1_u8(&output_base_ptr[c_block + depth * 2], output_data);
290 vst1_u8(&output_base_ptr[c_block + depth * 3], output_data);
291
292 // Accumulate in 8.8 representation, pre-adding 0.5 for later rounding.
293 accum_c_v = vaddq_u16(Move8IntoUpperU16(output_data), vdupq_n_u16(128));
294 }
295
296 // Top-centre margin.
297 op_int16x8_t wdelta_c_v;
298 op_int16x8_t wdelta_twice_c_v;
299 for (int j = 0; j < (input_width - 1); ++j) {
300 {
301 uint8x8_t output_data_alt;
302 uint8x8_t output_data;
303
304 const op_int16x8_t tl_val(
305 Load8IntoLowerS16(&input_base_ptr[c_block + depth * j]));
306 const op_int16x8_t tr_val(
307 Load8IntoLowerS16(&input_base_ptr[c_block + depth * (j + 1)]));
308 wdelta_c_v = (tr_val - tl_val) << 4;
309 wdelta_twice_c_v = wdelta_c_v << 1;
310
311 op_uint16x8_t accum_c_v_alt = accum_c_v + wdelta_c_v;
312 accum_c_v = accum_c_v_alt + wdelta_twice_c_v;
313 PairExtractUpper(accum_c_v_alt.val, accum_c_v.val, &output_data_alt,
314 &output_data);
315
316 vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth * 4],
317 output_data_alt);
318 vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth + depth * 4],
319 output_data);
320
321 for (int p = 2; p < 8; p += 2) {
322 accum_c_v_alt = accum_c_v + wdelta_twice_c_v;
323 accum_c_v = accum_c_v_alt + wdelta_twice_c_v;
324 PairExtractUpper(accum_c_v_alt.val, accum_c_v.val, &output_data_alt,
325 &output_data);
326
327 vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth * p +
328 depth * 4],
329 output_data_alt);
330 vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth * (p + 1) +
331 depth * 4],
332 output_data);
333 }
334 accum_c_v += wdelta_c_v;
335 }
336 }
337
338 // Top-right margin corner.
339 {
340 uint8x8_t output_data_discard;
341 uint8x8_t output_data;
342
343 // Accumulations have pre-added 0.5 for rounding, but that is just
344 // discarded and this just avoids re-loading.
345 PairExtractUpper(accum_c_v.val, accum_c_v.val, &output_data,
346 &output_data_discard);
347
348 vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
349 depth * 4],
350 output_data);
351 vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
352 depth * 4 + depth],
353 output_data);
354 vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
355 depth * 4 + depth * 2],
356 output_data);
357 vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
358 depth * 4 + depth * 3],
359 output_data);
360 }
361 }
362 // Fill out remainder of top margin.
363 std::memcpy(output_base_ptr + output_row_stride, output_base_ptr,
364 output_row_stride * sizeof(uint8));
365 std::memcpy(output_base_ptr + output_row_stride * 2, output_base_ptr,
366 output_row_stride * sizeof(uint8));
367 std::memcpy(output_base_ptr + output_row_stride * 3, output_base_ptr,
368 output_row_stride * sizeof(uint8));
369 output_base_ptr += output_row_stride * 4;
370
371 // Main rows.
372 for (int k = 0; k < (input_height - 1); ++k) {
373 for (int c_block = 0; c_block < depth; c_block += 8) {
374 uint8* output_base_ptr_0 = output_base_ptr;
375 uint8* output_base_ptr_1;
376 uint8* output_base_ptr_2;
377 uint8* output_base_ptr_3;
378 uint8* output_base_ptr_4;
379 uint8* output_base_ptr_5;
380 uint8* output_base_ptr_6;
381 uint8* output_base_ptr_7;
382
383 op_uint16x8_t accum_0_c_v;
384 op_uint16x8_t accum_1_c_v;
385 op_uint16x8_t accum_2_c_v;
386 op_uint16x8_t accum_3_c_v;
387 op_uint16x8_t accum_4_c_v;
388 op_uint16x8_t accum_5_c_v;
389 op_uint16x8_t accum_6_c_v;
390 op_uint16x8_t accum_7_c_v;
391
392 op_int16x8_t hdelta_c_v;
393 op_int16x8_t hdelta_twice_c_v;
394
395 // Left margin for 8 rows.
396 {
397 uint8x8_t output_data_0_c;
398 uint8x8_t output_data_1_c;
399 uint8x8_t output_data_2_c;
400 uint8x8_t output_data_3_c;
401 uint8x8_t output_data_4_c;
402 uint8x8_t output_data_5_c;
403 uint8x8_t output_data_6_c;
404 uint8x8_t output_data_7_c;
405
406 const op_int16x8_t tl_val(
407 Load8IntoLowerS16(&input_base_ptr[c_block]));
408 const op_int16x8_t bl_val(
409 Load8IntoLowerS16(&input_base_ptr[c_block + input_row_stride]));
410 hdelta_c_v = (bl_val - tl_val) << 4;
411
412 // Accumulate in 8.8 representation, pre-adding 0.5 for later
413 // rounding.
414 accum_0_c_v = VReinterpretQU16S16(tl_val << 8);
415 accum_0_c_v = vaddq_u16(accum_0_c_v.val, vdupq_n_u16(128));
416
417 hdelta_twice_c_v = hdelta_c_v << 1;
418
419 accum_0_c_v += hdelta_c_v;
420 accum_1_c_v = accum_0_c_v + hdelta_twice_c_v;
421 PairExtractUpper(accum_0_c_v.val, accum_1_c_v.val, &output_data_0_c,
422 &output_data_1_c);
423
424 vst1_u8(&output_base_ptr_0[c_block], output_data_0_c);
425 vst1_u8(&output_base_ptr_0[c_block + depth], output_data_0_c);
426 vst1_u8(&output_base_ptr_0[c_block + depth * 2], output_data_0_c);
427 vst1_u8(&output_base_ptr_0[c_block + depth * 3], output_data_0_c);
428
429 output_base_ptr_1 = output_base_ptr_0 + output_row_stride;
430 vst1_u8(&output_base_ptr_1[c_block], output_data_1_c);
431 vst1_u8(&output_base_ptr_1[c_block + depth], output_data_1_c);
432 vst1_u8(&output_base_ptr_1[c_block + depth * 2], output_data_1_c);
433 vst1_u8(&output_base_ptr_1[c_block + depth * 3], output_data_1_c);
434
435 //
436
437 output_base_ptr_2 = output_base_ptr_1 + output_row_stride;
438 accum_2_c_v = accum_1_c_v + hdelta_twice_c_v;
439 accum_3_c_v = accum_2_c_v + hdelta_twice_c_v;
440 PairExtractUpper(accum_2_c_v.val, accum_3_c_v.val, &output_data_2_c,
441 &output_data_3_c);
442
443 vst1_u8(&output_base_ptr_2[c_block], output_data_2_c);
444 vst1_u8(&output_base_ptr_2[c_block + depth], output_data_2_c);
445 vst1_u8(&output_base_ptr_2[c_block + depth * 2], output_data_2_c);
446 vst1_u8(&output_base_ptr_2[c_block + depth * 3], output_data_2_c);
447
448 output_base_ptr_3 = output_base_ptr_2 + output_row_stride;
449 vst1_u8(&output_base_ptr_3[c_block], output_data_3_c);
450 vst1_u8(&output_base_ptr_3[c_block + depth], output_data_3_c);
451 vst1_u8(&output_base_ptr_3[c_block + depth * 2], output_data_3_c);
452 vst1_u8(&output_base_ptr_3[c_block + depth * 3], output_data_3_c);
453
454 //
455
456 output_base_ptr_4 = output_base_ptr_3 + output_row_stride;
457 accum_4_c_v = accum_3_c_v + hdelta_twice_c_v;
458 accum_5_c_v = accum_4_c_v + hdelta_twice_c_v;
459 PairExtractUpper(accum_4_c_v.val, accum_5_c_v.val, &output_data_4_c,
460 &output_data_5_c);
461
462 vst1_u8(&output_base_ptr_4[c_block], output_data_4_c);
463 vst1_u8(&output_base_ptr_4[c_block + depth], output_data_4_c);
464 vst1_u8(&output_base_ptr_4[c_block + depth * 2], output_data_4_c);
465 vst1_u8(&output_base_ptr_4[c_block + depth * 3], output_data_4_c);
466
467 output_base_ptr_5 = output_base_ptr_4 + output_row_stride;
468 vst1_u8(&output_base_ptr_5[c_block], output_data_5_c);
469 vst1_u8(&output_base_ptr_5[c_block + depth], output_data_5_c);
470 vst1_u8(&output_base_ptr_5[c_block + depth * 2], output_data_5_c);
471 vst1_u8(&output_base_ptr_5[c_block + depth * 3], output_data_5_c);
472
473 //
474
475 output_base_ptr_6 = output_base_ptr_5 + output_row_stride;
476 accum_6_c_v = accum_5_c_v + hdelta_twice_c_v;
477 accum_7_c_v = accum_6_c_v + hdelta_twice_c_v;
478 PairExtractUpper(accum_6_c_v.val, accum_7_c_v.val, &output_data_6_c,
479 &output_data_7_c);
480
481 vst1_u8(&output_base_ptr_6[c_block], output_data_6_c);
482 vst1_u8(&output_base_ptr_6[c_block + depth], output_data_6_c);
483 vst1_u8(&output_base_ptr_6[c_block + depth * 2], output_data_6_c);
484 vst1_u8(&output_base_ptr_6[c_block + depth * 3], output_data_6_c);
485
486 output_base_ptr_7 = output_base_ptr_6 + output_row_stride;
487 vst1_u8(&output_base_ptr_7[c_block], output_data_7_c);
488 vst1_u8(&output_base_ptr_7[c_block + depth], output_data_7_c);
489 vst1_u8(&output_base_ptr_7[c_block + depth * 2], output_data_7_c);
490 vst1_u8(&output_base_ptr_7[c_block + depth * 3], output_data_7_c);
491 }
492
493 // Main central body.
494 op_int16x8_t wdelta_c;
495 op_int16x8_t wdelta_twice_c;
496 op_int16x8_t hwdelta_c;
497 op_int16x8_t hwdelta_twice_c;
498
499 op_int16x8_t incr_0_c;
500 op_int16x8_t incr_1_c;
501 op_int16x8_t incr_2_c;
502 op_int16x8_t incr_3_c;
503 op_int16x8_t incr_4_c;
504 op_int16x8_t incr_5_c;
505 op_int16x8_t incr_6_c;
506 op_int16x8_t incr_7_c;
507
508 uint8x8_t output_data_0_c;
509 uint8x8_t output_data_1_c;
510 uint8x8_t output_data_2_c;
511 uint8x8_t output_data_3_c;
512 uint8x8_t output_data_4_c;
513 uint8x8_t output_data_5_c;
514 uint8x8_t output_data_6_c;
515 uint8x8_t output_data_7_c;
516 for (int j = 0; j < (input_width - 1); ++j) {
517 // output_base_ptr_0 = output_base_ptr;
518 // output_base_ptr_1 = output_base_ptr_0 + output_row_stride; ETC
519 {
520 const op_int16x8_t tl_val(
521 Load8IntoLowerS16(&input_base_ptr[c_block + depth * j]));
522 const op_int16x8_t bl_val(Load8IntoLowerS16(
523 &input_base_ptr[c_block + depth * j + input_row_stride]));
524 const op_int16x8_t tr_val(
525 Load8IntoLowerS16(&input_base_ptr[c_block + depth * (j + 1)]));
526 const op_int16x8_t br_val(Load8IntoLowerS16(
527 &input_base_ptr[c_block + depth * (j + 1) + input_row_stride]));
528
529 const op_int16x8_t tmp_diff = tr_val - tl_val;
530 wdelta_c = tmp_diff << 4;
531 wdelta_twice_c = wdelta_c << 1;
532 hwdelta_c = (br_val - bl_val) - tmp_diff;
533 hwdelta_twice_c = hwdelta_c << 1;
534
535 op_int16x8_t incr_base = wdelta_c + hwdelta_c;
536 accum_0_c_v += incr_base;
537 incr_0_c = incr_base << 1;
538 incr_base += hwdelta_twice_c;
539 accum_1_c_v += incr_base;
540 incr_1_c = incr_base << 1;
541
542 PairExtractUpper(accum_0_c_v.val, accum_1_c_v.val, &output_data_0_c,
543 &output_data_1_c);
544 vst1_u8(&output_base_ptr_0[c_block + depth * j * 8 + depth * 4],
545 output_data_0_c);
546 vst1_u8(&output_base_ptr_1[c_block + depth * j * 8 + depth * 4],
547 output_data_1_c);
548
549 incr_base += hwdelta_twice_c;
550 accum_2_c_v += incr_base;
551 incr_2_c = incr_base << 1;
552 incr_base += hwdelta_twice_c;
553 accum_3_c_v += incr_base;
554 incr_3_c = incr_base << 1;
555
556 PairExtractUpper(accum_2_c_v.val, accum_3_c_v.val, &output_data_2_c,
557 &output_data_3_c);
558 vst1_u8(&output_base_ptr_2[c_block + depth * j * 8 + depth * 4],
559 output_data_2_c);
560 vst1_u8(&output_base_ptr_3[c_block + depth * j * 8 + depth * 4],
561 output_data_3_c);
562
563 incr_base += hwdelta_twice_c;
564 accum_4_c_v += incr_base;
565 incr_4_c = incr_base << 1;
566 incr_base += hwdelta_twice_c;
567 accum_5_c_v += incr_base;
568 incr_5_c = incr_base << 1;
569
570 PairExtractUpper(accum_4_c_v.val, accum_5_c_v.val, &output_data_4_c,
571 &output_data_5_c);
572 vst1_u8(&output_base_ptr_4[c_block + depth * j * 8 + depth * 4],
573 output_data_4_c);
574 vst1_u8(&output_base_ptr_5[c_block + depth * j * 8 + depth * 4],
575 output_data_5_c);
576
577 incr_base += hwdelta_twice_c;
578 accum_6_c_v += incr_base;
579 incr_6_c = incr_base << 1;
580 incr_base += hwdelta_twice_c;
581 accum_7_c_v += incr_base;
582 incr_7_c = incr_base << 1;
583
584 PairExtractUpper(accum_6_c_v.val, accum_7_c_v.val, &output_data_6_c,
585 &output_data_7_c);
586 vst1_u8(&output_base_ptr_6[c_block + depth * j * 8 + depth * 4],
587 output_data_6_c);
588 vst1_u8(&output_base_ptr_7[c_block + depth * j * 8 + depth * 4],
589 output_data_7_c);
590
591 for (int p = 1; p < 8; ++p) {
592 accum_0_c_v += incr_0_c;
593 accum_1_c_v += incr_1_c;
594 PairExtractUpper(accum_0_c_v.val, accum_1_c_v.val,
595 &output_data_0_c, &output_data_1_c);
596 vst1_u8(&output_base_ptr_0[c_block + depth * j * 8 + depth * p +
597 depth * 4],
598 output_data_0_c);
599 vst1_u8(&output_base_ptr_1[c_block + depth * j * 8 + depth * p +
600 depth * 4],
601 output_data_1_c);
602
603 accum_2_c_v += incr_2_c;
604 accum_3_c_v += incr_3_c;
605 PairExtractUpper(accum_2_c_v.val, accum_3_c_v.val,
606 &output_data_2_c, &output_data_3_c);
607 vst1_u8(&output_base_ptr_2[c_block + depth * j * 8 + depth * p +
608 depth * 4],
609 output_data_2_c);
610 vst1_u8(&output_base_ptr_3[c_block + depth * j * 8 + depth * p +
611 depth * 4],
612 output_data_3_c);
613
614 accum_4_c_v += incr_4_c;
615 accum_5_c_v += incr_5_c;
616 PairExtractUpper(accum_4_c_v.val, accum_5_c_v.val,
617 &output_data_4_c, &output_data_5_c);
618 vst1_u8(&output_base_ptr_4[c_block + depth * j * 8 + depth * p +
619 depth * 4],
620 output_data_4_c);
621 vst1_u8(&output_base_ptr_5[c_block + depth * j * 8 + depth * p +
622 depth * 4],
623 output_data_5_c);
624
625 accum_6_c_v += incr_6_c;
626 accum_7_c_v += incr_7_c;
627 PairExtractUpper(accum_6_c_v.val, accum_7_c_v.val,
628 &output_data_6_c, &output_data_7_c);
629 vst1_u8(&output_base_ptr_6[c_block + depth * j * 8 + depth * p +
630 depth * 4],
631 output_data_6_c);
632 vst1_u8(&output_base_ptr_7[c_block + depth * j * 8 + depth * p +
633 depth * 4],
634 output_data_7_c);
635 }
636
637 accum_0_c_v += (incr_0_c >> 1);
638 accum_1_c_v += (incr_1_c >> 1);
639 accum_2_c_v += (incr_2_c >> 1);
640 accum_3_c_v += (incr_3_c >> 1);
641 accum_4_c_v += (incr_4_c >> 1);
642 accum_5_c_v += (incr_5_c >> 1);
643 accum_6_c_v += (incr_6_c >> 1);
644 accum_7_c_v += (incr_7_c >> 1);
645 }
646 }
647
648 // Right margin.
649 {
650 // Accumulations have pre-added 0.5 for rounding, but that is just
651 // discarded and this just avoids re-loading.
652 PairExtractUpper(accum_0_c_v.val, accum_1_c_v.val, &output_data_0_c,
653 &output_data_1_c);
654 PairExtractUpper(accum_2_c_v.val, accum_3_c_v.val, &output_data_2_c,
655 &output_data_3_c);
656 PairExtractUpper(accum_4_c_v.val, accum_5_c_v.val, &output_data_4_c,
657 &output_data_5_c);
658 PairExtractUpper(accum_6_c_v.val, accum_7_c_v.val, &output_data_6_c,
659 &output_data_7_c);
660 for (int p = 0; p < 4; ++p) {
661 vst1_u8(&output_base_ptr_0[c_block + depth * (input_width - 1) * 8 +
662 depth * 4 + depth * p],
663 output_data_0_c);
664 vst1_u8(&output_base_ptr_1[c_block + depth * (input_width - 1) * 8 +
665 depth * 4 + depth * p],
666 output_data_1_c);
667 vst1_u8(&output_base_ptr_2[c_block + depth * (input_width - 1) * 8 +
668 depth * 4 + depth * p],
669 output_data_2_c);
670 vst1_u8(&output_base_ptr_3[c_block + depth * (input_width - 1) * 8 +
671 depth * 4 + depth * p],
672 output_data_3_c);
673 vst1_u8(&output_base_ptr_4[c_block + depth * (input_width - 1) * 8 +
674 depth * 4 + depth * p],
675 output_data_4_c);
676 vst1_u8(&output_base_ptr_5[c_block + depth * (input_width - 1) * 8 +
677 depth * 4 + depth * p],
678 output_data_5_c);
679 vst1_u8(&output_base_ptr_6[c_block + depth * (input_width - 1) * 8 +
680 depth * 4 + depth * p],
681 output_data_6_c);
682 vst1_u8(&output_base_ptr_7[c_block + depth * (input_width - 1) * 8 +
683 depth * 4 + depth * p],
684 output_data_7_c);
685 }
686 }
687 }
688
689 output_base_ptr += output_row_stride * 8;
690 input_base_ptr += input_row_stride;
691 }
692
693 //
694
695 for (int c_block = 0; c_block < depth; c_block += 8) {
696 op_uint16x8_t accum_c_v;
697 // Bottom-left margin corner.
698 {
699 uint8x8_t output_data = vld1_u8(&input_base_ptr[c_block]);
700 vst1_u8(&output_base_ptr[c_block], output_data);
701 vst1_u8(&output_base_ptr[c_block + depth], output_data);
702 vst1_u8(&output_base_ptr[c_block + depth * 2], output_data);
703 vst1_u8(&output_base_ptr[c_block + depth * 3], output_data);
704
705 // Accumulate in 8.8 representation, pre-adding 0.5 for later rounding.
706 accum_c_v = vaddq_u16(Move8IntoUpperU16(output_data), vdupq_n_u16(128));
707 }
708
709 // Bottom-centre margin.
710 op_int16x8_t wdelta_c_v;
711 op_int16x8_t wdelta_twice_c_v;
712 for (int j = 0; j < (input_width - 1); ++j) {
713 {
714 uint8x8_t output_data_alt;
715 uint8x8_t output_data;
716
717 const op_int16x8_t tl_val(
718 Load8IntoLowerS16(&input_base_ptr[c_block + depth * j]));
719 const op_int16x8_t tr_val(
720 Load8IntoLowerS16(&input_base_ptr[c_block + depth * (j + 1)]));
721 wdelta_c_v = (tr_val - tl_val) << 4;
722 wdelta_twice_c_v = wdelta_c_v << 1;
723
724 op_uint16x8_t accum_c_v_alt = accum_c_v + wdelta_c_v;
725 accum_c_v = accum_c_v_alt + wdelta_twice_c_v;
726 PairExtractUpper(accum_c_v_alt.val, accum_c_v.val, &output_data_alt,
727 &output_data);
728
729 vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth * 4],
730 output_data_alt);
731 vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth + depth * 4],
732 output_data);
733
734 for (int p = 2; p < 8; p += 2) {
735 accum_c_v_alt = accum_c_v + wdelta_twice_c_v;
736 accum_c_v = accum_c_v_alt + wdelta_twice_c_v;
737 PairExtractUpper(accum_c_v_alt.val, accum_c_v.val, &output_data_alt,
738 &output_data);
739
740 vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth * p +
741 depth * 4],
742 output_data_alt);
743 vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth * (p + 1) +
744 depth * 4],
745 output_data);
746 }
747 accum_c_v += wdelta_c_v;
748 }
749 }
750
751 // Bottom-right margin corner.
752 {
753 uint8x8_t output_data_discard;
754 uint8x8_t output_data;
755
756 // Accumulations have pre-added 0.5 for rounding, but that is just
757 // discarded and this just avoids re-loading.
758 PairExtractUpper(accum_c_v.val, accum_c_v.val, &output_data,
759 &output_data_discard);
760
761 vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
762 depth * 4],
763 output_data);
764 vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
765 depth * 4 + depth],
766 output_data);
767 vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
768 depth * 4 + depth * 2],
769 output_data);
770 vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
771 depth * 4 + depth * 3],
772 output_data);
773 }
774 }
775 // Fill out remainder of bottom margin.
776 std::memcpy(output_base_ptr + output_row_stride, output_base_ptr,
777 output_row_stride * sizeof(uint8));
778 std::memcpy(output_base_ptr + output_row_stride * 2, output_base_ptr,
779 output_row_stride * sizeof(uint8));
780 std::memcpy(output_base_ptr + output_row_stride * 3, output_base_ptr,
781 output_row_stride * sizeof(uint8));
782
783 #else // USE_NEON
784 for (int c_block = 0; c_block < depth; c_block += 8) {
785 uint8 output_data[8];
786 uint16 accum[8];
787 // Top-left margin corner.
788 for (int c = 0; c < 8; ++c) {
789 output_data[c] = input_base_ptr[c_block + c];
790 output_base_ptr[c_block + c] = output_data[c];
791 output_base_ptr[c_block + c + depth] = output_data[c];
792 output_base_ptr[c_block + c + depth * 2] = output_data[c];
793 output_base_ptr[c_block + c + depth * 3] = output_data[c];
794
795 // Accumulate in 8.8 representation, pre-adding 0.5 for later rounding.
796 accum[c] =
797 (output_data[c] << 8) + 128; // 128 = 0.5 in 8.8 representation.
798 }
799
800 // Top-centre margin.
801 uint16 wdelta[8];
802 uint16 wdelta_twice[8];
803 for (int j = 0; j < (input_width - 1); ++j) {
804 for (int c = 0; c < 8; ++c) {
805 wdelta[c] = static_cast<uint16>(
806 input_base_ptr[c_block + c + depth * (j + 1)] -
807 input_base_ptr[c_block + c + depth * j])
808 << 4;
809 wdelta_twice[c] = wdelta[c] << 1;
810
811 accum[c] += wdelta[c];
812 output_base_ptr[c_block + c + depth * j * 8 + depth * 4] =
813 accum[c] >> 8;
814 for (int p = 1; p < 8; ++p) {
815 accum[c] += wdelta_twice[c];
816 output_base_ptr[c_block + c + depth * j * 8 + depth * p +
817 depth * 4] = accum[c] >> 8;
818 }
819 accum[c] += wdelta[c];
820 }
821 }
822
823 // Top-right margin corner.
824 for (int c = 0; c < 8; ++c) {
825 // Accumulations have pre-added 0.5 for rounding, but that is just
826 // discarded and this just avoids re-loading.
827 output_data[c] = accum[c] >> 8;
828 TFLITE_DCHECK_EQ(
829 output_data[c],
830 input_base_ptr[c_block + c + depth * (input_width - 1)]);
831 output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
832 depth * 4] = output_data[c];
833 output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
834 depth * 4 + depth] = output_data[c];
835 output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
836 depth * 4 + depth * 2] = output_data[c];
837 output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
838 depth * 4 + depth * 3] = output_data[c];
839 }
840 }
841 // Fill out remainder of top margin.
842 std::memcpy(output_base_ptr + output_row_stride, output_base_ptr,
843 output_row_stride * sizeof(uint8));
844 std::memcpy(output_base_ptr + output_row_stride * 2, output_base_ptr,
845 output_row_stride * sizeof(uint8));
846 std::memcpy(output_base_ptr + output_row_stride * 3, output_base_ptr,
847 output_row_stride * sizeof(uint8));
848 output_base_ptr += output_row_stride * 4;
849
850 // Main rows.
851 for (int k = 0; k < (input_height - 1); ++k) {
852 for (int c_block = 0; c_block < depth; c_block += 8) {
853 uint8* output_base_ptr_0 = output_base_ptr;
854 uint8* output_base_ptr_1;
855 uint8* output_base_ptr_2;
856 uint8* output_base_ptr_3;
857 uint8* output_base_ptr_4;
858 uint8* output_base_ptr_5;
859 uint8* output_base_ptr_6;
860 uint8* output_base_ptr_7;
861 uint16 accum_0[8];
862 uint16 accum_1[8];
863 uint16 accum_2[8];
864 uint16 accum_3[8];
865 uint16 accum_4[8];
866 uint16 accum_5[8];
867 uint16 accum_6[8];
868 uint16 accum_7[8];
869
870 // We prefer accum_0[c], etc, in sense of packed-data array for
871 // register. However the compiler will not reliably optimize for an
872 // array, and so we do most of the work in pure scalar variables.
873 uint16 accum_0_c;
874 uint16 accum_1_c;
875 uint16 accum_2_c;
876 uint16 accum_3_c;
877 uint16 accum_4_c;
878 uint16 accum_5_c;
879 uint16 accum_6_c;
880 uint16 accum_7_c;
881
882 int16 hdelta_c;
883 int16 hdelta_twice_c;
884
885 // Left margin for 8 rows.
886 for (int c = 0; c < 8; ++c) {
887 hdelta_c = static_cast<uint16>(
888 input_base_ptr[c_block + c + input_row_stride] -
889 input_base_ptr[c_block + c])
890 << 4;
891
892 // Accumulate in 8.8 representation, pre-adding 0.5 for later
893 // rounding.
894 accum_0_c = (input_base_ptr[c_block + c] << 8) + 128;
895
896 accum_0_c += hdelta_c;
897 output_base_ptr_0[c_block + c] = accum_0_c >> 8;
898 output_base_ptr_0[c_block + c + depth] = accum_0_c >> 8;
899 output_base_ptr_0[c_block + c + depth * 2] = accum_0_c >> 8;
900 output_base_ptr_0[c_block + c + depth * 3] = accum_0_c >> 8;
901
902 hdelta_twice_c = hdelta_c << 1;
903
904 output_base_ptr_1 = output_base_ptr_0 + output_row_stride;
905 accum_1_c = accum_0_c + hdelta_twice_c;
906 output_base_ptr_1[c_block + c] = accum_1_c >> 8;
907 output_base_ptr_1[c_block + c + depth] = accum_1_c >> 8;
908 output_base_ptr_1[c_block + c + depth * 2] = accum_1_c >> 8;
909 output_base_ptr_1[c_block + c + depth * 3] = accum_1_c >> 8;
910
911 output_base_ptr_2 = output_base_ptr_1 + output_row_stride;
912 accum_2_c = accum_1_c + hdelta_twice_c;
913 output_base_ptr_2[c_block + c] = accum_2_c >> 8;
914 output_base_ptr_2[c_block + c + depth] = accum_2_c >> 8;
915 output_base_ptr_2[c_block + c + depth * 2] = accum_2_c >> 8;
916 output_base_ptr_2[c_block + c + depth * 3] = accum_2_c >> 8;
917
918 output_base_ptr_3 = output_base_ptr_2 + output_row_stride;
919 accum_3_c = accum_2_c + hdelta_twice_c;
920 output_base_ptr_3[c_block + c] = accum_3_c >> 8;
921 output_base_ptr_3[c_block + c + depth] = accum_3_c >> 8;
922 output_base_ptr_3[c_block + c + depth * 2] = accum_3_c >> 8;
923 output_base_ptr_3[c_block + c + depth * 3] = accum_3_c >> 8;
924
925 output_base_ptr_4 = output_base_ptr_3 + output_row_stride;
926 accum_4_c = accum_3_c + hdelta_twice_c;
927 output_base_ptr_4[c_block + c] = accum_4_c >> 8;
928 output_base_ptr_4[c_block + c + depth] = accum_4_c >> 8;
929 output_base_ptr_4[c_block + c + depth * 2] = accum_4_c >> 8;
930 output_base_ptr_4[c_block + c + depth * 3] = accum_4_c >> 8;
931
932 output_base_ptr_5 = output_base_ptr_4 + output_row_stride;
933 accum_5_c = accum_4_c + hdelta_twice_c;
934 output_base_ptr_5[c_block + c] = accum_5_c >> 8;
935 output_base_ptr_5[c_block + c + depth] = accum_5_c >> 8;
936 output_base_ptr_5[c_block + c + depth * 2] = accum_5_c >> 8;
937 output_base_ptr_5[c_block + c + depth * 3] = accum_5_c >> 8;
938
939 output_base_ptr_6 = output_base_ptr_5 + output_row_stride;
940 accum_6_c = accum_5_c + hdelta_twice_c;
941 output_base_ptr_6[c_block + c] = accum_6_c >> 8;
942 output_base_ptr_6[c_block + c + depth] = accum_6_c >> 8;
943 output_base_ptr_6[c_block + c + depth * 2] = accum_6_c >> 8;
944 output_base_ptr_6[c_block + c + depth * 3] = accum_6_c >> 8;
945
946 output_base_ptr_7 = output_base_ptr_6 + output_row_stride;
947 accum_7_c = accum_6_c + hdelta_twice_c;
948 output_base_ptr_7[c_block + c] = accum_7_c >> 8;
949 output_base_ptr_7[c_block + c + depth] = accum_7_c >> 8;
950 output_base_ptr_7[c_block + c + depth * 2] = accum_7_c >> 8;
951 output_base_ptr_7[c_block + c + depth * 3] = accum_7_c >> 8;
952
953 accum_0[c] = accum_0_c;
954 accum_1[c] = accum_1_c;
955 accum_2[c] = accum_2_c;
956 accum_3[c] = accum_3_c;
957 accum_4[c] = accum_4_c;
958 accum_5[c] = accum_5_c;
959 accum_6[c] = accum_6_c;
960 accum_7[c] = accum_7_c;
961 }
962
963 // Main central body.
964 int16 wdelta_c;
965 int16 wdelta_twice_c;
966 int16 hwdelta_c;
967 int16 hwdelta_twice_c;
968
969 int16 incr_0_c;
970 int16 incr_1_c;
971 int16 incr_2_c;
972 int16 incr_3_c;
973 int16 incr_4_c;
974 int16 incr_5_c;
975 int16 incr_6_c;
976 int16 incr_7_c;
977 for (int j = 0; j < (input_width - 1); ++j) {
978 for (int c = 0; c < 8; ++c) {
979 accum_0_c = accum_0[c];
980 accum_1_c = accum_1[c];
981 accum_2_c = accum_2[c];
982 accum_3_c = accum_3[c];
983 accum_4_c = accum_4[c];
984 accum_5_c = accum_5[c];
985 accum_6_c = accum_6[c];
986 accum_7_c = accum_7[c];
987
988 wdelta_c = static_cast<uint16>(
989 input_base_ptr[c_block + c + depth * (j + 1)] -
990 input_base_ptr[c_block + c + depth * j])
991 << 4;
992 wdelta_twice_c = wdelta_c << 1;
993 hwdelta_c = static_cast<uint16>(
994 input_base_ptr[c_block + c + depth * (j + 1) +
995 input_row_stride] -
996 input_base_ptr[c_block + c + depth * (j + 1)] -
997 input_base_ptr[c_block + c + depth * j + input_row_stride] +
998 input_base_ptr[c_block + c + depth * j]);
999 hwdelta_twice_c = hwdelta_c << 1;
1000
1001 uint16 incr_base = wdelta_c + hwdelta_c;
1002 accum_0_c += incr_base;
1003 output_base_ptr_0[c_block + c + depth * j * 8 + depth * 4] =
1004 accum_0_c >> 8;
1005 incr_0_c = incr_base << 1;
1006
1007 incr_base += hwdelta_twice_c;
1008 accum_1_c += incr_base;
1009 output_base_ptr_1[c_block + c + depth * j * 8 + depth * 4] =
1010 accum_1_c >> 8;
1011 incr_1_c = incr_base << 1;
1012
1013 incr_base += hwdelta_twice_c;
1014 accum_2_c += incr_base;
1015 output_base_ptr_2[c_block + c + depth * j * 8 + depth * 4] =
1016 accum_2_c >> 8;
1017 incr_2_c = incr_base << 1;
1018
1019 incr_base += hwdelta_twice_c;
1020 accum_3_c += incr_base;
1021 output_base_ptr_3[c_block + c + depth * j * 8 + depth * 4] =
1022 accum_3_c >> 8;
1023 incr_3_c = incr_base << 1;
1024
1025 incr_base += hwdelta_twice_c;
1026 accum_4_c += incr_base;
1027 output_base_ptr_4[c_block + c + depth * j * 8 + depth * 4] =
1028 accum_4_c >> 8;
1029 incr_4_c = incr_base << 1;
1030
1031 incr_base += hwdelta_twice_c;
1032 accum_5_c += incr_base;
1033 output_base_ptr_5[c_block + c + depth * j * 8 + depth * 4] =
1034 accum_5_c >> 8;
1035 incr_5_c = incr_base << 1;
1036
1037 incr_base += hwdelta_twice_c;
1038 accum_6_c += incr_base;
1039 output_base_ptr_6[c_block + c + depth * j * 8 + depth * 4] =
1040 accum_6_c >> 8;
1041 incr_6_c = incr_base << 1;
1042
1043 incr_base += hwdelta_twice_c;
1044 accum_7_c += incr_base;
1045 output_base_ptr_7[c_block + c + depth * j * 8 + depth * 4] =
1046 accum_7_c >> 8;
1047 incr_7_c = incr_base << 1;
1048
1049 for (int p = 1; p < 8; ++p) {
1050 accum_0_c += incr_0_c;
1051 output_base_ptr_0[c_block + c + depth * j * 8 + depth * p +
1052 depth * 4] = accum_0_c >> 8;
1053 accum_1_c += incr_1_c;
1054 output_base_ptr_1[c_block + c + depth * j * 8 + depth * p +
1055 depth * 4] = accum_1_c >> 8;
1056 accum_2_c += incr_2_c;
1057 output_base_ptr_2[c_block + c + depth * j * 8 + depth * p +
1058 depth * 4] = accum_2_c >> 8;
1059 accum_3_c += incr_3_c;
1060 output_base_ptr_3[c_block + c + depth * j * 8 + depth * p +
1061 depth * 4] = accum_3_c >> 8;
1062 accum_4_c += incr_4_c;
1063 output_base_ptr_4[c_block + c + depth * j * 8 + depth * p +
1064 depth * 4] = accum_4_c >> 8;
1065 accum_5_c += incr_5_c;
1066 output_base_ptr_5[c_block + c + depth * j * 8 + depth * p +
1067 depth * 4] = accum_5_c >> 8;
1068 accum_6_c += incr_6_c;
1069 output_base_ptr_6[c_block + c + depth * j * 8 + depth * p +
1070 depth * 4] = accum_6_c >> 8;
1071 accum_7_c += incr_7_c;
1072 output_base_ptr_7[c_block + c + depth * j * 8 + depth * p +
1073 depth * 4] = accum_7_c >> 8;
1074 }
1075 accum_0_c += incr_0_c / 2;
1076 accum_1_c += incr_1_c / 2;
1077 accum_2_c += incr_2_c / 2;
1078 accum_3_c += incr_3_c / 2;
1079 accum_4_c += incr_4_c / 2;
1080 accum_5_c += incr_5_c / 2;
1081 accum_6_c += incr_6_c / 2;
1082 accum_7_c += incr_7_c / 2;
1083
1084 accum_0[c] = accum_0_c;
1085 accum_1[c] = accum_1_c;
1086 accum_2[c] = accum_2_c;
1087 accum_3[c] = accum_3_c;
1088 accum_4[c] = accum_4_c;
1089 accum_5[c] = accum_5_c;
1090 accum_6[c] = accum_6_c;
1091 accum_7[c] = accum_7_c;
1092 }
1093 }
1094
1095 // Right margin.
1096 uint8 output_data_0_c;
1097 uint8 output_data_1_c;
1098 uint8 output_data_2_c;
1099 uint8 output_data_3_c;
1100 uint8 output_data_4_c;
1101 uint8 output_data_5_c;
1102 uint8 output_data_6_c;
1103 uint8 output_data_7_c;
1104 for (int c = 0; c < 8; ++c) {
1105 accum_0_c = accum_0[c];
1106 accum_1_c = accum_1[c];
1107 accum_2_c = accum_2[c];
1108 accum_3_c = accum_3[c];
1109 accum_4_c = accum_4[c];
1110 accum_5_c = accum_5[c];
1111 accum_6_c = accum_6[c];
1112 accum_7_c = accum_7[c];
1113
1114 // Accumulations have pre-added 0.5 for rounding, but that is just
1115 // discarded and this just avoids re-loading.
1116 output_data_0_c = accum_0_c >> 8;
1117 output_data_1_c = accum_1_c >> 8;
1118 output_data_2_c = accum_2_c >> 8;
1119 output_data_3_c = accum_3_c >> 8;
1120 output_data_4_c = accum_4_c >> 8;
1121 output_data_5_c = accum_5_c >> 8;
1122 output_data_6_c = accum_6_c >> 8;
1123 output_data_7_c = accum_7_c >> 8;
1124 for (int p = 0; p < 4; ++p) {
1125 output_base_ptr_0[c_block + c + depth * (input_width - 1) * 8 +
1126 depth * 4 + depth * p] = output_data_0_c;
1127 output_base_ptr_1[c_block + c + depth * (input_width - 1) * 8 +
1128 depth * 4 + depth * p] = output_data_1_c;
1129 output_base_ptr_2[c_block + c + depth * (input_width - 1) * 8 +
1130 depth * 4 + depth * p] = output_data_2_c;
1131 output_base_ptr_3[c_block + c + depth * (input_width - 1) * 8 +
1132 depth * 4 + depth * p] = output_data_3_c;
1133 output_base_ptr_4[c_block + c + depth * (input_width - 1) * 8 +
1134 depth * 4 + depth * p] = output_data_4_c;
1135 output_base_ptr_5[c_block + c + depth * (input_width - 1) * 8 +
1136 depth * 4 + depth * p] = output_data_5_c;
1137 output_base_ptr_6[c_block + c + depth * (input_width - 1) * 8 +
1138 depth * 4 + depth * p] = output_data_6_c;
1139 output_base_ptr_7[c_block + c + depth * (input_width - 1) * 8 +
1140 depth * 4 + depth * p] = output_data_7_c;
1141 }
1142
1143 accum_0[c] = accum_0_c;
1144 accum_1[c] = accum_1_c;
1145 accum_2[c] = accum_2_c;
1146 accum_3[c] = accum_3_c;
1147 accum_4[c] = accum_4_c;
1148 accum_5[c] = accum_5_c;
1149 accum_6[c] = accum_6_c;
1150 accum_7[c] = accum_7_c;
1151 }
1152 }
1153
1154 output_base_ptr += output_row_stride * 8;
1155 input_base_ptr += input_row_stride;
1156 }
1157
1158 for (int c_block = 0; c_block < depth; c_block += 8) {
1159 uint8 output_data[8];
1160 uint16 accum[8];
1161 // Bottom-left margin corner.
1162 for (int c = 0; c < 8; ++c) {
1163 output_data[c] = input_base_ptr[c_block + c];
1164 output_base_ptr[c_block + c] = output_data[c];
1165 output_base_ptr[c_block + c + depth] = output_data[c];
1166 output_base_ptr[c_block + c + depth * 2] = output_data[c];
1167 output_base_ptr[c_block + c + depth * 3] = output_data[c];
1168
1169 // Accumulate in 8.8 representation, pre-adding 0.5 for later rounding.
1170 accum[c] =
1171 (output_data[c] << 8) + 128; // 128 = 0.5 in 8.8 representation.
1172 }
1173
1174 // Bottom-centre margin.
1175 uint16 wdelta[8];
1176 uint16 wdelta_twice[8];
1177 for (int j = 0; j < (input_width - 1); ++j) {
1178 for (int c = 0; c < 8; ++c) {
1179 wdelta[c] = static_cast<uint16>(
1180 input_base_ptr[c_block + c + depth * (j + 1)] -
1181 input_base_ptr[c_block + c + depth * j])
1182 << 4;
1183 wdelta_twice[c] = wdelta[c] << 1;
1184
1185 accum[c] += wdelta[c];
1186 output_base_ptr[c_block + c + depth * j * 8 + depth * 4] =
1187 accum[c] >> 8;
1188 for (int p = 1; p < 8; ++p) {
1189 accum[c] += wdelta_twice[c];
1190 output_base_ptr[c_block + c + depth * j * 8 + depth * p +
1191 depth * 4] = accum[c] >> 8;
1192 }
1193 accum[c] += wdelta[c];
1194 }
1195 }
1196
1197 // Bottom-right margin corner.
1198 for (int c = 0; c < 8; ++c) {
1199 // Accumulations have pre-added 0.5 for rounding, but that is just
1200 // discarded and this just avoids re-loading.
1201 output_data[c] = accum[c] >> 8;
1202 TFLITE_DCHECK_EQ(
1203 output_data[c],
1204 input_base_ptr[c_block + c + depth * (input_width - 1)]);
1205 output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
1206 depth * 4] = output_data[c];
1207 output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
1208 depth * 4 + depth] = output_data[c];
1209 output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
1210 depth * 4 + depth * 2] = output_data[c];
1211 output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
1212 depth * 4 + depth * 3] = output_data[c];
1213 }
1214 }
1215 // Fill out remainder of bottom margin.
1216 std::memcpy(output_base_ptr + output_row_stride, output_base_ptr,
1217 output_row_stride * sizeof(uint8));
1218 std::memcpy(output_base_ptr + output_row_stride * 2, output_base_ptr,
1219 output_row_stride * sizeof(uint8));
1220 std::memcpy(output_base_ptr + output_row_stride * 3, output_base_ptr,
1221 output_row_stride * sizeof(uint8));
1222
1223 #endif // USE_NEON
1224 }
1225 } // NOLINT(readability/fn_size)
1226
1227 } // namespace resize_bilinear
1228
1229 #ifdef USE_NEON
ResizeBilinearKernel(const float * input_ptr,int32 depth,float scale,float * output_ptr)1230 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
1231 float scale, float* output_ptr) {
1232 int ic = 0;
1233 // Handle 32 input channels at a time.
1234 for (; ic <= depth - 32; ic += 32) {
1235 float32x4x2_t input[4];
1236 for (int i = 0; i < 4; i++) {
1237 input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
1238 input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
1239 }
1240 float32x4x2_t acc[4];
1241 for (int i = 0; i < 4; i++) {
1242 acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
1243 acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
1244 }
1245 for (int i = 0; i < 4; i++) {
1246 acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
1247 acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
1248 }
1249 for (int i = 0; i < 4; i++) {
1250 vst1q_f32(output_ptr, acc[i].val[0]);
1251 vst1q_f32(output_ptr + 4, acc[i].val[1]);
1252 output_ptr += 8;
1253 }
1254 input_ptr += 32;
1255 }
1256 // Handle 16 input channels at a time.
1257 for (; ic <= depth - 16; ic += 16) {
1258 float32x4x2_t input[2];
1259 for (int i = 0; i < 2; i++) {
1260 input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
1261 input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
1262 }
1263 float32x4x2_t acc[2];
1264 for (int i = 0; i < 2; i++) {
1265 acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
1266 acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
1267 }
1268 for (int i = 0; i < 2; i++) {
1269 acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
1270 acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
1271 }
1272 for (int i = 0; i < 2; i++) {
1273 vst1q_f32(output_ptr, acc[i].val[0]);
1274 vst1q_f32(output_ptr + 4, acc[i].val[1]);
1275 output_ptr += 8;
1276 }
1277 input_ptr += 16;
1278 }
1279 // Handle 8 input channels at a time.
1280 for (; ic <= depth - 8; ic += 8) {
1281 float32x4x2_t input;
1282 input.val[0] = vld1q_f32(input_ptr);
1283 input.val[1] = vld1q_f32(input_ptr + 4);
1284
1285 float32x4x2_t acc;
1286 acc.val[0] = vld1q_f32(output_ptr);
1287 acc.val[1] = vld1q_f32(output_ptr + 4);
1288 acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale);
1289 acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale);
1290
1291 vst1q_f32(output_ptr, acc.val[0]);
1292 vst1q_f32(output_ptr + 4, acc.val[1]);
1293
1294 input_ptr += 8;
1295 output_ptr += 8;
1296 }
1297 // Handle 4 input channels at a time.
1298 for (; ic <= depth - 4; ic += 4) {
1299 float32x4_t input = vld1q_f32(input_ptr);
1300 float32x4_t acc = vld1q_f32(output_ptr);
1301
1302 acc = vmlaq_n_f32(acc, input, scale);
1303 vst1q_f32(output_ptr, acc);
1304
1305 input_ptr += 4;
1306 output_ptr += 4;
1307 }
1308 // Handle 1 input channel at a time.
1309 for (; ic < depth; ic++) {
1310 *output_ptr += *input_ptr * scale;
1311 output_ptr++;
1312 input_ptr++;
1313 }
1314 }
1315 #else
ResizeBilinearKernel(const float * input_ptr,int32 depth,float scale,float * output_ptr)1316 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
1317 float scale, float* output_ptr) {
1318 for (int32 i = 0; i < depth; i++) {
1319 *output_ptr += *input_ptr * scale;
1320 output_ptr++;
1321 input_ptr++;
1322 }
1323 }
1324 #endif
1325
ResizeBilinearKernel2x2(int32 x0,int32 x1,int32 y0,int32 y1,int32 x,int32 y,int32 depth,int32 batch,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)1326 inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
1327 int32 x, int32 y, int32 depth, int32 batch,
1328 const RuntimeShape& input_shape,
1329 const float* input_data,
1330 const RuntimeShape& output_shape,
1331 float* output_data) {
1332 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1333 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1334 const int32 input_width = input_shape.Dims(2);
1335 const int32 output_width = output_shape.Dims(2);
1336
1337 const int32 input_x_offset = (x1 - x0) * depth;
1338 const int32 input_y_offset = (y1 - y0) * depth * input_width;
1339 const int32 output_x_offset = depth;
1340 const int32 output_y_offset = depth * output_width;
1341
1342 #ifdef USE_NEON
1343 TFLITE_DCHECK(x1 >= x0);
1344 TFLITE_DCHECK(y1 >= y0);
1345
1346 int ic = 0;
1347 // Handle 8 input channels at a time.
1348 for (; ic <= depth - 8; ic += 8) {
1349 const float* input_ptr = nullptr;
1350
1351 float32x4x2_t x0y0;
1352 input_ptr = &input_data[Offset(input_shape, batch, y0, x0, ic)];
1353 x0y0.val[0] = vld1q_f32(input_ptr);
1354 x0y0.val[1] = vld1q_f32(input_ptr + 4);
1355
1356 float32x4x2_t x1y0;
1357 input_ptr += input_x_offset;
1358 x1y0.val[0] = vld1q_f32(input_ptr);
1359 x1y0.val[1] = vld1q_f32(input_ptr + 4);
1360
1361 float32x4x2_t x0y1;
1362 input_ptr += -input_x_offset + input_y_offset;
1363 x0y1.val[0] = vld1q_f32(input_ptr);
1364 x0y1.val[1] = vld1q_f32(input_ptr + 4);
1365
1366 float32x4x2_t x1y1;
1367 input_ptr += input_x_offset;
1368 x1y1.val[0] = vld1q_f32(input_ptr);
1369 x1y1.val[1] = vld1q_f32(input_ptr + 4);
1370
1371 // Top left corner.
1372 float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
1373 vst1q_f32(output_ptr, x0y0.val[0]);
1374 vst1q_f32(output_ptr + 4, x0y0.val[1]);
1375
1376 // Top right corner.
1377 output_ptr += output_x_offset;
1378 float32x4x2_t tr;
1379 tr.val[0] = vaddq_f32(x0y0.val[0], x1y0.val[0]);
1380 tr.val[1] = vaddq_f32(x0y0.val[1], x1y0.val[1]);
1381 tr.val[0] = vmulq_n_f32(tr.val[0], 0.5f);
1382 tr.val[1] = vmulq_n_f32(tr.val[1], 0.5f);
1383
1384 vst1q_f32(output_ptr, tr.val[0]);
1385 vst1q_f32(output_ptr + 4, tr.val[1]);
1386
1387 // Bottom left corner.
1388 output_ptr += -output_x_offset + output_y_offset;
1389 float32x4x2_t bl;
1390 bl.val[0] = vaddq_f32(x0y0.val[0], x0y1.val[0]);
1391 bl.val[1] = vaddq_f32(x0y0.val[1], x0y1.val[1]);
1392 bl.val[0] = vmulq_n_f32(bl.val[0], 0.5f);
1393 bl.val[1] = vmulq_n_f32(bl.val[1], 0.5f);
1394 vst1q_f32(output_ptr, bl.val[0]);
1395 vst1q_f32(output_ptr + 4, bl.val[1]);
1396
1397 // Bottom right corner.
1398 output_ptr += output_x_offset;
1399 float32x4x2_t br;
1400 br.val[0] = vaddq_f32(x1y0.val[0], x1y1.val[0]);
1401 br.val[1] = vaddq_f32(x1y0.val[1], x1y1.val[1]);
1402 br.val[0] = vmlaq_n_f32(bl.val[0], br.val[0], 0.5f);
1403 br.val[1] = vmlaq_n_f32(bl.val[1], br.val[1], 0.5f);
1404 br.val[0] = vmulq_n_f32(br.val[0], 0.5f);
1405 br.val[1] = vmulq_n_f32(br.val[1], 0.5f);
1406 vst1q_f32(output_ptr, br.val[0]);
1407 vst1q_f32(output_ptr + 4, br.val[1]);
1408 }
1409 // Handle 4 input channels at a time.
1410 for (; ic <= depth - 4; ic += 4) {
1411 const float* input_ptr =
1412 &input_data[Offset(input_shape, batch, y0, x0, ic)];
1413 float32x4_t x0y0 = vld1q_f32(input_ptr);
1414 float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset);
1415 float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset);
1416 float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset);
1417
1418 // Top left corner.
1419 float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
1420 vst1q_f32(output_ptr, x0y0);
1421
1422 // Top right corner.
1423 output_ptr += output_x_offset;
1424 float32x4_t tr = vaddq_f32(x0y0, x1y0);
1425 tr = vmulq_n_f32(tr, 0.5f);
1426 vst1q_f32(output_ptr, tr);
1427
1428 // Bottom left corner.
1429 output_ptr += -output_x_offset + output_y_offset;
1430 float32x4_t bl = vaddq_f32(x0y0, x0y1);
1431 bl = vmulq_n_f32(bl, 0.5f);
1432 vst1q_f32(output_ptr, bl);
1433
1434 // Bottom right corner.
1435 output_ptr += output_x_offset;
1436 float32x4_t br = vaddq_f32(x1y0, x1y1);
1437 br = vmlaq_n_f32(bl, br, 0.5f);
1438 br = vmulq_n_f32(br, 0.5f);
1439 vst1q_f32(output_ptr, br);
1440 }
1441 // Handle one input channel at a time.
1442 for (; ic < depth; ic++) {
1443 const int32 input_offset = Offset(input_shape, batch, y0, x0, ic);
1444
1445 float x0y0 = input_data[input_offset];
1446 float x1y0 = input_data[input_offset + input_x_offset];
1447 float x0y1 = input_data[input_offset + input_y_offset];
1448 float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
1449
1450 // Top left corner.
1451 const int32 output_offset = Offset(output_shape, batch, y, x, ic);
1452 output_data[output_offset] = x0y0;
1453
1454 // Top right corner.
1455 output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
1456
1457 // Bottom left corner.
1458 float output = (x0y0 + x0y1) / 2;
1459 output_data[output_offset + output_y_offset] = output;
1460
1461 // Bottom right corner.
1462 output_data[output_offset + output_x_offset + output_y_offset] =
1463 (output + ((x1y0 + x1y1) / 2)) / 2;
1464 }
1465 #else
1466 for (int ch = 0; ch < depth; ch++) {
1467 const int32 input_offset = Offset(input_shape, batch, y0, x0, ch);
1468
1469 float x0y0 = input_data[input_offset];
1470 float x1y0 = input_data[input_offset + input_x_offset];
1471 float x0y1 = input_data[input_offset + input_y_offset];
1472 float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
1473
1474 // Top left corner.
1475 const int32 output_offset = Offset(output_shape, batch, y, x, ch);
1476 output_data[output_offset] = x0y0;
1477
1478 // Top right corner.
1479 output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
1480
1481 // Bottom left corner.
1482 float output = (x0y0 + x0y1) / 2;
1483 output_data[output_offset + output_y_offset] = output;
1484
1485 // Bottom right corner.
1486 output_data[output_offset + output_x_offset + output_y_offset] =
1487 (output + ((x1y0 + x1y1) / 2)) / 2;
1488 }
1489 #endif
1490 }
1491
ResizeBilinear2x2(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)1492 inline void ResizeBilinear2x2(int32 batches, int32 input_height,
1493 int32 input_width, int32 depth,
1494 int32 output_height, int32 output_width,
1495 const RuntimeShape& input_shape,
1496 const float* input_data,
1497 const RuntimeShape& output_shape,
1498 float* output_data) {
1499 for (int b = 0; b < batches; b++) {
1500 for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) {
1501 for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) {
1502 int32 x1 = std::min(x0 + 1, input_width - 1);
1503 int32 y1 = std::min(y0 + 1, input_height - 1);
1504 ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_shape,
1505 input_data, output_shape, output_data);
1506 }
1507 }
1508 }
1509 }
1510
ResizeBilinearGeneric(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,float height_scale,float width_scale,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data,const bool half_pixel_centers)1511 inline void ResizeBilinearGeneric(
1512 int32 batches, int32 input_height, int32 input_width, int32 depth,
1513 int32 output_height, int32 output_width, float height_scale,
1514 float width_scale, const RuntimeShape& input_shape, const float* input_data,
1515 const RuntimeShape& output_shape, float* output_data,
1516 const bool half_pixel_centers) {
1517 memset(output_data, 0,
1518 batches * output_height * output_width * depth * sizeof(float));
1519
1520 int32 output_offset = 0;
1521 for (int b = 0; b < batches; ++b) {
1522 for (int y = 0; y < output_height; ++y) {
1523 float input_y;
1524 int32 y0, y1;
1525 reference_ops::ComputeInterpolationValues(
1526 y, height_scale, half_pixel_centers, input_height, &input_y, &y0,
1527 &y1);
1528 for (int x = 0; x < output_width; ++x) {
1529 float input_x;
1530 int32 x0, x1;
1531 reference_ops::ComputeInterpolationValues(
1532 x, width_scale, half_pixel_centers, input_width, &input_x, &x0,
1533 &x1);
1534 float* output_ptr = &output_data[output_offset];
1535
1536 // Run kernel on the 4 corners of the bilinear resize algorithm.
1537 int32 input_offset = Offset(input_shape, b, y0, x0, 0);
1538 float scale = (1 - (input_y - y0)) * (1 - (input_x - x0));
1539 const float* input_ptr = &input_data[input_offset];
1540 ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
1541
1542 input_offset = Offset(input_shape, b, y0, x1, 0);
1543 scale = (1 - (input_y - y0)) * (input_x - x0);
1544 input_ptr = &input_data[input_offset];
1545 ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
1546
1547 input_offset = Offset(input_shape, b, y1, x0, 0);
1548 scale = (input_y - y0) * (1 - (input_x - x0));
1549 input_ptr = &input_data[input_offset];
1550 ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
1551
1552 input_offset = Offset(input_shape, b, y1, x1, 0);
1553 scale = (input_y - y0) * (input_x - x0);
1554 input_ptr = &input_data[input_offset];
1555 ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
1556
1557 output_offset += depth;
1558 }
1559 }
1560 }
1561 }
1562
1563 template <typename T>
ResizeBilinearGenericSmallChannel(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,float height_scale,float width_scale,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data,const bool half_pixel_centers)1564 inline void ResizeBilinearGenericSmallChannel(
1565 int32 batches, int32 input_height, int32 input_width, int32 depth,
1566 int32 output_height, int32 output_width, float height_scale,
1567 float width_scale, const RuntimeShape& input_shape, const T* input_data,
1568 const RuntimeShape& output_shape, T* output_data,
1569 const bool half_pixel_centers) {
1570 T* output_ptr = &output_data[0];
1571 const float rounding_offset = std::numeric_limits<T>::is_integer ? .5f : .0f;
1572
1573 for (int b = 0; b < batches; ++b) {
1574 for (int y = 0; y < output_height; ++y) {
1575 float input_y;
1576 int32 y0, y1;
1577 reference_ops::ComputeInterpolationValues(
1578 y, height_scale, half_pixel_centers, input_height, &input_y, &y0,
1579 &y1);
1580 for (int x = 0; x < output_width; ++x) {
1581 float input_x;
1582 int32 x0, x1;
1583 reference_ops::ComputeInterpolationValues(
1584 x, width_scale, half_pixel_centers, input_width, &input_x, &x0,
1585 &x1);
1586
1587 int32 input_offset[4] = {Offset(input_shape, b, y0, x0, 0),
1588 Offset(input_shape, b, y0, x1, 0),
1589 Offset(input_shape, b, y1, x0, 0),
1590 Offset(input_shape, b, y1, x1, 0)};
1591 float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)),
1592 (1 - (input_y - y0)) * (input_x - x0),
1593 (input_y - y0) * (1 - (input_x - x0)),
1594 (input_y - y0) * (input_x - x0)};
1595
1596 for (int d = 0; d < depth; d++) {
1597 const T* input_ptr = &input_data[d];
1598 *output_ptr++ = static_cast<T>(input_ptr[input_offset[0]] * scale[0] +
1599 input_ptr[input_offset[1]] * scale[1] +
1600 input_ptr[input_offset[2]] * scale[2] +
1601 input_ptr[input_offset[3]] * scale[3] +
1602 rounding_offset);
1603 }
1604 }
1605 }
1606 }
1607 }
1608
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,float * output_data)1609 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
1610 const RuntimeShape& unextended_input_shape,
1611 const float* input_data,
1612 const RuntimeShape& output_size_shape,
1613 const int32* output_size_data,
1614 const RuntimeShape& unextended_output_shape,
1615 float* output_data) {
1616 ruy::profiler::ScopeLabel label("ResizeBilinear");
1617 // If half_pixel_centers is True, align_corners must be False.
1618 TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
1619 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1620 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1621 const RuntimeShape input_shape =
1622 RuntimeShape::ExtendedShape(4, unextended_input_shape);
1623 const RuntimeShape output_shape =
1624 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1625
1626 int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
1627 int32 input_height = input_shape.Dims(1);
1628 int32 input_width = input_shape.Dims(2);
1629 int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
1630
1631 TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
1632 int32 output_height = output_size_data[0];
1633 int32 output_width = output_size_data[1];
1634
1635 // Specialize for 2x2 upsample.
1636 if (!op_params.align_corners && !op_params.half_pixel_centers &&
1637 output_height == 2 * input_height && output_width == 2 * input_width) {
1638 ResizeBilinear2x2(batches, input_height, input_width, depth, output_height,
1639 output_width, input_shape, input_data, output_shape,
1640 output_data);
1641 } else {
1642 float height_scale = static_cast<float>(input_height) / output_height;
1643 float width_scale = static_cast<float>(input_width) / output_width;
1644 if (op_params.align_corners && output_height > 1) {
1645 height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
1646 }
1647 if (op_params.align_corners && output_width > 1) {
1648 width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
1649 }
1650
1651 ResizeBilinearGeneric(batches, input_height, input_width, depth,
1652 output_height, output_width, height_scale,
1653 width_scale, input_shape, input_data, output_shape,
1654 output_data, op_params.half_pixel_centers);
1655 }
1656 }
1657
1658 // Note: This is not a universal quantized bilinear. It does not use int8
1659 // or int16 arithmetic.
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)1660 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
1661 const RuntimeShape& unextended_input_shape,
1662 const uint8* input_data,
1663 const RuntimeShape& output_size_shape,
1664 const int32* output_size_data,
1665 const RuntimeShape& unextended_output_shape,
1666 uint8* output_data) {
1667 ruy::profiler::ScopeLabel label("ResizeBilinearUint8");
1668 // If half_pixel_centers is True, align_corners must be False.
1669 TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
1670 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1671 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1672 const RuntimeShape input_shape =
1673 RuntimeShape::ExtendedShape(4, unextended_input_shape);
1674 const RuntimeShape output_shape =
1675 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1676
1677 int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
1678 int32 input_height = input_shape.Dims(1);
1679 int32 input_width = input_shape.Dims(2);
1680 int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
1681
1682 TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
1683 int32 output_height = output_size_data[0];
1684 int32 output_width = output_size_data[1];
1685
1686 if (!op_params.align_corners && op_params.half_pixel_centers &&
1687 ((depth % 8) == 0)) {
1688 const int32 scale = output_height / input_height;
1689 // Restricting the minimum output dimensions may not be necessary, but
1690 // ensures that kernels can use unrolling with minimal code size.
1691 if ((output_height >= 8) && (output_width >= 8) &&
1692 ((input_height * scale) == output_height) &&
1693 ((input_width * scale) == output_width)) {
1694 if (scale == 8) {
1695 resize_bilinear::ResizeBilinear888Uint8(
1696 batches, input_height, input_width, depth, input_data, output_data);
1697 return;
1698 }
1699 }
1700 }
1701
1702 float height_scale =
1703 (op_params.align_corners && output_height > 1)
1704 ? (static_cast<float>(input_height - 1) / (output_height - 1))
1705 : (static_cast<float>(input_height) / output_height);
1706
1707 float width_scale =
1708 (op_params.align_corners && output_width > 1)
1709 ? (static_cast<float>(input_width - 1) / (output_width - 1))
1710 : (static_cast<float>(input_width) / output_width);
1711
1712 ResizeBilinearGenericSmallChannel<uint8>(
1713 batches, input_height, input_width, depth, output_height, output_width,
1714 height_scale, width_scale, input_shape, input_data, output_shape,
1715 output_data, op_params.half_pixel_centers);
1716 }
1717
1718 // TODO(b/180609127) Create optimized int8 version from uint8. Call from here.
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const int8 * input_data,const RuntimeShape & unextended_output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,int8 * output_data)1719 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
1720 const RuntimeShape& unextended_input_shape,
1721 const int8* input_data,
1722 const RuntimeShape& unextended_output_size_shape,
1723 const int32* output_size_data,
1724 const RuntimeShape& unextended_output_shape,
1725 int8* output_data) {
1726 reference_ops::ResizeBilinearInteger(op_params, unextended_input_shape,
1727 input_data, unextended_output_size_shape,
1728 output_size_data,
1729 unextended_output_shape, output_data);
1730 }
1731
1732 } // namespace optimized_ops
1733 } // namespace tflite
1734
1735 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_RESIZE_BILINEAR_H_
1736