xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/resize_bilinear.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_RESIZE_BILINEAR_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_RESIZE_BILINEAR_H_
17 
18 #include <stdint.h>
19 #include <sys/types.h>
20 
21 #include <algorithm>
22 #include <cmath>
23 #include <limits>
24 #include <memory>
25 #include <type_traits>
26 
27 #include "ruy/profiler/instrumentation.h"  // from @ruy
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/kernels/internal/common.h"
30 #include "tensorflow/lite/kernels/internal/compatibility.h"
31 #include "tensorflow/lite/kernels/internal/quantization_util.h"
32 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
33 #include "tensorflow/lite/kernels/internal/tensor.h"
34 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
35 #include "tensorflow/lite/kernels/internal/types.h"
36 
37 namespace tflite {
38 namespace optimized_ops {
39 namespace resize_bilinear {
40 
41 #ifdef USE_NEON
42 // These utility functions are split off not just for convenience. Most
43 // incoporate packing or unpacking of data.
44 //
45 // (a) Optimizations can be tried experimentally.
46 // (b) Optimizations can be specialized for architectures, eg Intel vs ARM.
47 
Load8IntoLowerS16(const uint8 * data_ptr)48 inline int16x8_t Load8IntoLowerS16(const uint8* data_ptr) {
49   return vreinterpretq_s16_u16(vmovl_u8(vld1_u8(data_ptr)));
50 }
51 
Move8IntoUpperU16(const uint8x8_t vec_val)52 inline uint16x8_t Move8IntoUpperU16(const uint8x8_t vec_val) {
53   // Alternatively one could zip with a zero vector.
54   return vshlq_n_u16(vmovl_u8(vec_val), 8);
55 }
56 
Load8IntoUpperU16(const uint8 * data_ptr)57 inline uint16x8_t Load8IntoUpperU16(const uint8* data_ptr) {
58   return Move8IntoUpperU16(vld1_u8(data_ptr));
59 }
60 
61 // Extract upper 8 bits from each 16-bit integer in vector registers. This is
62 // performed for a pair, because instructions often work on pairs.
PairExtractUpper(const uint16x8_t accum_0,const uint16x8_t accum_1,uint8x8_t * res_0,uint8x8_t * res_1)63 inline void PairExtractUpper(const uint16x8_t accum_0, const uint16x8_t accum_1,
64                              uint8x8_t* res_0, uint8x8_t* res_1) {
65   uint8x16x2_t unzipped =
66       vuzpq_u8(vreinterpretq_u8_u16(accum_0), vreinterpretq_u8_u16(accum_1));
67   *res_0 = vget_low_u8(unzipped.val[1]);
68   *res_1 = vget_high_u8(unzipped.val[1]);
69 }
70 
71 // This is an exceptional definition.
72 //
73 // Modify int16x8_t, adding operators.
74 //
75 // There are exceptional circumstances that make it reasonable to write code
76 // on vector types for quantized resize bilinear in *some cases*.
77 //
78 // (a) In exact quant resize bilinear, it should be possible to guarantee that
79 //     arithmetic never overflows.
80 // (b) When the resize scaling is 2 or 4 or 8 it is possible to guarantee
81 //     exact accumulation and exact incrementation.
82 // (c) In quant resize bilinear the choice of unsigned vs signed accumulation
83 //     and saturated vs unsaturated arithmetic is often unimportant.
84 //
85 // This pattern simplifies the code considerably. This pattern should not be
86 // used more widely in code since it can hide important numerical detail.
87 //
88 // DO NOT add to this any "class-like" methods: only those that do no more than
89 // redirecting operators to specific intrinsics functions.
90 struct op_int16x8_t {
91   inline op_int16x8_t() = default;
op_int16x8_top_int16x8_t92   inline explicit op_int16x8_t(const int16x8_t& initial_val) {
93     val = initial_val;
94   }
95   inline op_int16x8_t& operator=(const int16x8_t& new_val) {
96     val = new_val;
97     return *this;
98   }
99   inline op_int16x8_t operator+=(const op_int16x8_t& add_val) {
100     val = vaddq_s16(val, add_val.val);
101     return *this;
102   }
103   inline op_int16x8_t operator-=(const op_int16x8_t& sub_val) {
104     val = vsubq_s16(val, sub_val.val);
105     return *this;
106   }
107   // This really selects vshlq_n_s16, but requires a longer implementation to
108   // convert the shift argument back to a constant. In some compiles are macros
109   // requiring constant args.
110   inline op_int16x8_t operator<<=(int32 left_shift) {
111     switch (left_shift) {
112       case 1:
113         val = vshlq_n_s16(val, 1);
114         break;
115       case 4:
116         val = vshlq_n_s16(val, 4);
117         break;
118       case 8:
119         val = vshlq_n_s16(val, 8);
120         break;
121       default:
122         TFLITE_CHECK(false);
123         break;
124     }
125     return *this;
126   }
127   // This really selects vshrq_n_u16, but requires a longer implementation to
128   // convert the shift argument back to a constant. In some compiles are macros
129   // requiring constant args.
130   inline op_int16x8_t operator>>=(int32 right_shift) {
131     switch (right_shift) {
132       case 1:
133         val = vshrq_n_s16(val, 1);
134         break;
135       case 4:
136         val = vshrq_n_s16(val, 4);
137         break;
138       case 8:
139         val = vshrq_n_s16(val, 8);
140         break;
141       default:
142         TFLITE_CHECK(false);
143         break;
144     }
145     return *this;
146   }
147   friend inline op_int16x8_t operator+(op_int16x8_t lhs,
148                                        const op_int16x8_t& rhs) {
149     lhs += rhs;
150     return lhs;
151   }
152   friend inline op_int16x8_t operator-(op_int16x8_t lhs,
153                                        const op_int16x8_t& rhs) {
154     lhs -= rhs;
155     return lhs;
156   }
157   friend inline op_int16x8_t operator<<(op_int16x8_t lhs, int32 left_shift) {
158     lhs <<= left_shift;
159     return lhs;
160   }
161   friend inline op_int16x8_t operator>>(op_int16x8_t lhs, int32 right_shift) {
162     lhs >>= right_shift;
163     return lhs;
164   }
165 
166   int16x8_t val;
167 };
168 
169 // This is an exceptional definition.
170 //
171 // Modify uint16x8_t, adding operators.
172 //
173 // Important: See above notes on op_int16x8_t.
174 struct op_uint16x8_t {
175   inline op_uint16x8_t() = default;
op_uint16x8_top_uint16x8_t176   inline explicit op_uint16x8_t(const uint16x8_t initial_val) {
177     val = initial_val;
178   }
179   inline op_uint16x8_t& operator=(const uint16x8_t& new_val) {
180     val = new_val;
181     return *this;
182   }
183   inline op_uint16x8_t operator+=(const op_int16x8_t& add_val) {
184     val = vaddq_u16(val, vreinterpretq_u16_s16(add_val.val));
185     return *this;
186   }
187   inline op_uint16x8_t operator-=(const op_int16x8_t& sub_val) {
188     val = vsubq_u16(val, vreinterpretq_u16_s16(sub_val.val));
189     return *this;
190   }
191   // This really selects vshlq_n_s16, but requires a longer implementation to
192   // convert the shift argument back to a constant. In some compiles are macros
193   // requiring constant args.
194   inline op_uint16x8_t operator<<=(int32 left_shift) {
195     switch (left_shift) {
196       case 1:
197         val = vshlq_n_u16(val, 1);
198         break;
199       case 4:
200         val = vshlq_n_u16(val, 4);
201         break;
202       case 8:
203         val = vshlq_n_u16(val, 8);
204         break;
205       default:
206         TFLITE_CHECK(false);
207         break;
208     }
209     return *this;
210   }
211   // This really selects vshrq_n_u16, but requires a longer implementation to
212   // convert the shift argument back to a constant. In some compiles are macros
213   // requiring constant args.
214   inline op_uint16x8_t operator>>=(int32 right_shift) {
215     switch (right_shift) {
216       case 1:
217         val = vshrq_n_u16(val, 1);
218         break;
219       case 4:
220         val = vshrq_n_u16(val, 4);
221         break;
222       case 8:
223         val = vshrq_n_u16(val, 8);
224         break;
225       default:
226         TFLITE_CHECK(false);
227         break;
228     }
229     return *this;
230   }
231   friend inline op_uint16x8_t operator+(op_uint16x8_t lhs,
232                                         const op_int16x8_t& rhs) {
233     lhs += rhs;
234     return lhs;
235   }
236   friend inline op_uint16x8_t operator-(op_uint16x8_t lhs,
237                                         const op_int16x8_t& rhs) {
238     lhs -= rhs;
239     return lhs;
240   }
241   friend inline op_uint16x8_t operator<<(op_uint16x8_t lhs, int32 left_shift) {
242     lhs <<= left_shift;
243     return lhs;
244   }
245   friend inline op_uint16x8_t operator>>(op_uint16x8_t lhs, int32 right_shift) {
246     lhs >>= right_shift;
247     return lhs;
248   }
249 
250   uint16x8_t val;
251 };
252 
VReinterpretQU16S16(const op_int16x8_t & other)253 inline op_uint16x8_t VReinterpretQU16S16(const op_int16x8_t& other) {
254   op_uint16x8_t ret_val(vreinterpretq_u16_s16(other.val));
255   return ret_val;
256 }
257 #endif  // USE_NEON
258 
259 // Optimized resize-bilinear for the special case where the scaling is x8 in
260 // width and height, and where we can operate on depth-8 blocks at a time. So
261 // the output blocks are 8x8x8 in width-height-depth.
262 //
263 // This optimization is for the half_pixel_centers == true version, for uint8.
264 // There are versions for NEON and non-NEON compilation.
ResizeBilinear888Uint8(int32 batches,int32 input_height,int32 input_width,int32 depth,const uint8 * input_data,uint8 * output_data)265 inline void ResizeBilinear888Uint8(int32 batches, int32 input_height,
266                                    int32 input_width, int32 depth,
267                                    const uint8* input_data,
268                                    uint8* output_data) {
269   TFLITE_DCHECK_GE(input_height, 1);
270   TFLITE_DCHECK_GE(input_width, 1);
271   TFLITE_DCHECK_EQ(depth % 8, 0);
272 
273   const int32 input_row_stride = input_width * depth;
274   const int32 output_row_stride = input_row_stride * 8;
275   for (int b = 0; b < batches; ++b) {
276     const uint8* input_base_ptr =
277         input_data + b * input_row_stride * input_height;
278     uint8* output_base_ptr =
279         output_data + b * output_row_stride * input_height * 8;
280 
281 #ifdef USE_NEON
282     for (int c_block = 0; c_block < depth; c_block += 8) {
283       op_uint16x8_t accum_c_v;
284       // Top-left margin corner.
285       {
286         uint8x8_t output_data = vld1_u8(&input_base_ptr[c_block]);
287         vst1_u8(&output_base_ptr[c_block], output_data);
288         vst1_u8(&output_base_ptr[c_block + depth], output_data);
289         vst1_u8(&output_base_ptr[c_block + depth * 2], output_data);
290         vst1_u8(&output_base_ptr[c_block + depth * 3], output_data);
291 
292         // Accumulate in 8.8 representation, pre-adding 0.5 for later rounding.
293         accum_c_v = vaddq_u16(Move8IntoUpperU16(output_data), vdupq_n_u16(128));
294       }
295 
296       // Top-centre margin.
297       op_int16x8_t wdelta_c_v;
298       op_int16x8_t wdelta_twice_c_v;
299       for (int j = 0; j < (input_width - 1); ++j) {
300         {
301           uint8x8_t output_data_alt;
302           uint8x8_t output_data;
303 
304           const op_int16x8_t tl_val(
305               Load8IntoLowerS16(&input_base_ptr[c_block + depth * j]));
306           const op_int16x8_t tr_val(
307               Load8IntoLowerS16(&input_base_ptr[c_block + depth * (j + 1)]));
308           wdelta_c_v = (tr_val - tl_val) << 4;
309           wdelta_twice_c_v = wdelta_c_v << 1;
310 
311           op_uint16x8_t accum_c_v_alt = accum_c_v + wdelta_c_v;
312           accum_c_v = accum_c_v_alt + wdelta_twice_c_v;
313           PairExtractUpper(accum_c_v_alt.val, accum_c_v.val, &output_data_alt,
314                            &output_data);
315 
316           vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth * 4],
317                   output_data_alt);
318           vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth + depth * 4],
319                   output_data);
320 
321           for (int p = 2; p < 8; p += 2) {
322             accum_c_v_alt = accum_c_v + wdelta_twice_c_v;
323             accum_c_v = accum_c_v_alt + wdelta_twice_c_v;
324             PairExtractUpper(accum_c_v_alt.val, accum_c_v.val, &output_data_alt,
325                              &output_data);
326 
327             vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth * p +
328                                      depth * 4],
329                     output_data_alt);
330             vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth * (p + 1) +
331                                      depth * 4],
332                     output_data);
333           }
334           accum_c_v += wdelta_c_v;
335         }
336       }
337 
338       // Top-right margin corner.
339       {
340         uint8x8_t output_data_discard;
341         uint8x8_t output_data;
342 
343         // Accumulations have pre-added 0.5 for rounding, but that is just
344         // discarded and this just avoids re-loading.
345         PairExtractUpper(accum_c_v.val, accum_c_v.val, &output_data,
346                          &output_data_discard);
347 
348         vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
349                                  depth * 4],
350                 output_data);
351         vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
352                                  depth * 4 + depth],
353                 output_data);
354         vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
355                                  depth * 4 + depth * 2],
356                 output_data);
357         vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
358                                  depth * 4 + depth * 3],
359                 output_data);
360       }
361     }
362     // Fill out remainder of top margin.
363     std::memcpy(output_base_ptr + output_row_stride, output_base_ptr,
364                 output_row_stride * sizeof(uint8));
365     std::memcpy(output_base_ptr + output_row_stride * 2, output_base_ptr,
366                 output_row_stride * sizeof(uint8));
367     std::memcpy(output_base_ptr + output_row_stride * 3, output_base_ptr,
368                 output_row_stride * sizeof(uint8));
369     output_base_ptr += output_row_stride * 4;
370 
371     // Main rows.
372     for (int k = 0; k < (input_height - 1); ++k) {
373       for (int c_block = 0; c_block < depth; c_block += 8) {
374         uint8* output_base_ptr_0 = output_base_ptr;
375         uint8* output_base_ptr_1;
376         uint8* output_base_ptr_2;
377         uint8* output_base_ptr_3;
378         uint8* output_base_ptr_4;
379         uint8* output_base_ptr_5;
380         uint8* output_base_ptr_6;
381         uint8* output_base_ptr_7;
382 
383         op_uint16x8_t accum_0_c_v;
384         op_uint16x8_t accum_1_c_v;
385         op_uint16x8_t accum_2_c_v;
386         op_uint16x8_t accum_3_c_v;
387         op_uint16x8_t accum_4_c_v;
388         op_uint16x8_t accum_5_c_v;
389         op_uint16x8_t accum_6_c_v;
390         op_uint16x8_t accum_7_c_v;
391 
392         op_int16x8_t hdelta_c_v;
393         op_int16x8_t hdelta_twice_c_v;
394 
395         // Left margin for 8 rows.
396         {
397           uint8x8_t output_data_0_c;
398           uint8x8_t output_data_1_c;
399           uint8x8_t output_data_2_c;
400           uint8x8_t output_data_3_c;
401           uint8x8_t output_data_4_c;
402           uint8x8_t output_data_5_c;
403           uint8x8_t output_data_6_c;
404           uint8x8_t output_data_7_c;
405 
406           const op_int16x8_t tl_val(
407               Load8IntoLowerS16(&input_base_ptr[c_block]));
408           const op_int16x8_t bl_val(
409               Load8IntoLowerS16(&input_base_ptr[c_block + input_row_stride]));
410           hdelta_c_v = (bl_val - tl_val) << 4;
411 
412           // Accumulate in 8.8 representation, pre-adding 0.5 for later
413           // rounding.
414           accum_0_c_v = VReinterpretQU16S16(tl_val << 8);
415           accum_0_c_v = vaddq_u16(accum_0_c_v.val, vdupq_n_u16(128));
416 
417           hdelta_twice_c_v = hdelta_c_v << 1;
418 
419           accum_0_c_v += hdelta_c_v;
420           accum_1_c_v = accum_0_c_v + hdelta_twice_c_v;
421           PairExtractUpper(accum_0_c_v.val, accum_1_c_v.val, &output_data_0_c,
422                            &output_data_1_c);
423 
424           vst1_u8(&output_base_ptr_0[c_block], output_data_0_c);
425           vst1_u8(&output_base_ptr_0[c_block + depth], output_data_0_c);
426           vst1_u8(&output_base_ptr_0[c_block + depth * 2], output_data_0_c);
427           vst1_u8(&output_base_ptr_0[c_block + depth * 3], output_data_0_c);
428 
429           output_base_ptr_1 = output_base_ptr_0 + output_row_stride;
430           vst1_u8(&output_base_ptr_1[c_block], output_data_1_c);
431           vst1_u8(&output_base_ptr_1[c_block + depth], output_data_1_c);
432           vst1_u8(&output_base_ptr_1[c_block + depth * 2], output_data_1_c);
433           vst1_u8(&output_base_ptr_1[c_block + depth * 3], output_data_1_c);
434 
435           //
436 
437           output_base_ptr_2 = output_base_ptr_1 + output_row_stride;
438           accum_2_c_v = accum_1_c_v + hdelta_twice_c_v;
439           accum_3_c_v = accum_2_c_v + hdelta_twice_c_v;
440           PairExtractUpper(accum_2_c_v.val, accum_3_c_v.val, &output_data_2_c,
441                            &output_data_3_c);
442 
443           vst1_u8(&output_base_ptr_2[c_block], output_data_2_c);
444           vst1_u8(&output_base_ptr_2[c_block + depth], output_data_2_c);
445           vst1_u8(&output_base_ptr_2[c_block + depth * 2], output_data_2_c);
446           vst1_u8(&output_base_ptr_2[c_block + depth * 3], output_data_2_c);
447 
448           output_base_ptr_3 = output_base_ptr_2 + output_row_stride;
449           vst1_u8(&output_base_ptr_3[c_block], output_data_3_c);
450           vst1_u8(&output_base_ptr_3[c_block + depth], output_data_3_c);
451           vst1_u8(&output_base_ptr_3[c_block + depth * 2], output_data_3_c);
452           vst1_u8(&output_base_ptr_3[c_block + depth * 3], output_data_3_c);
453 
454           //
455 
456           output_base_ptr_4 = output_base_ptr_3 + output_row_stride;
457           accum_4_c_v = accum_3_c_v + hdelta_twice_c_v;
458           accum_5_c_v = accum_4_c_v + hdelta_twice_c_v;
459           PairExtractUpper(accum_4_c_v.val, accum_5_c_v.val, &output_data_4_c,
460                            &output_data_5_c);
461 
462           vst1_u8(&output_base_ptr_4[c_block], output_data_4_c);
463           vst1_u8(&output_base_ptr_4[c_block + depth], output_data_4_c);
464           vst1_u8(&output_base_ptr_4[c_block + depth * 2], output_data_4_c);
465           vst1_u8(&output_base_ptr_4[c_block + depth * 3], output_data_4_c);
466 
467           output_base_ptr_5 = output_base_ptr_4 + output_row_stride;
468           vst1_u8(&output_base_ptr_5[c_block], output_data_5_c);
469           vst1_u8(&output_base_ptr_5[c_block + depth], output_data_5_c);
470           vst1_u8(&output_base_ptr_5[c_block + depth * 2], output_data_5_c);
471           vst1_u8(&output_base_ptr_5[c_block + depth * 3], output_data_5_c);
472 
473           //
474 
475           output_base_ptr_6 = output_base_ptr_5 + output_row_stride;
476           accum_6_c_v = accum_5_c_v + hdelta_twice_c_v;
477           accum_7_c_v = accum_6_c_v + hdelta_twice_c_v;
478           PairExtractUpper(accum_6_c_v.val, accum_7_c_v.val, &output_data_6_c,
479                            &output_data_7_c);
480 
481           vst1_u8(&output_base_ptr_6[c_block], output_data_6_c);
482           vst1_u8(&output_base_ptr_6[c_block + depth], output_data_6_c);
483           vst1_u8(&output_base_ptr_6[c_block + depth * 2], output_data_6_c);
484           vst1_u8(&output_base_ptr_6[c_block + depth * 3], output_data_6_c);
485 
486           output_base_ptr_7 = output_base_ptr_6 + output_row_stride;
487           vst1_u8(&output_base_ptr_7[c_block], output_data_7_c);
488           vst1_u8(&output_base_ptr_7[c_block + depth], output_data_7_c);
489           vst1_u8(&output_base_ptr_7[c_block + depth * 2], output_data_7_c);
490           vst1_u8(&output_base_ptr_7[c_block + depth * 3], output_data_7_c);
491         }
492 
493         // Main central body.
494         op_int16x8_t wdelta_c;
495         op_int16x8_t wdelta_twice_c;
496         op_int16x8_t hwdelta_c;
497         op_int16x8_t hwdelta_twice_c;
498 
499         op_int16x8_t incr_0_c;
500         op_int16x8_t incr_1_c;
501         op_int16x8_t incr_2_c;
502         op_int16x8_t incr_3_c;
503         op_int16x8_t incr_4_c;
504         op_int16x8_t incr_5_c;
505         op_int16x8_t incr_6_c;
506         op_int16x8_t incr_7_c;
507 
508         uint8x8_t output_data_0_c;
509         uint8x8_t output_data_1_c;
510         uint8x8_t output_data_2_c;
511         uint8x8_t output_data_3_c;
512         uint8x8_t output_data_4_c;
513         uint8x8_t output_data_5_c;
514         uint8x8_t output_data_6_c;
515         uint8x8_t output_data_7_c;
516         for (int j = 0; j < (input_width - 1); ++j) {
517           // output_base_ptr_0 = output_base_ptr;
518           // output_base_ptr_1 = output_base_ptr_0 + output_row_stride; ETC
519           {
520             const op_int16x8_t tl_val(
521                 Load8IntoLowerS16(&input_base_ptr[c_block + depth * j]));
522             const op_int16x8_t bl_val(Load8IntoLowerS16(
523                 &input_base_ptr[c_block + depth * j + input_row_stride]));
524             const op_int16x8_t tr_val(
525                 Load8IntoLowerS16(&input_base_ptr[c_block + depth * (j + 1)]));
526             const op_int16x8_t br_val(Load8IntoLowerS16(
527                 &input_base_ptr[c_block + depth * (j + 1) + input_row_stride]));
528 
529             const op_int16x8_t tmp_diff = tr_val - tl_val;
530             wdelta_c = tmp_diff << 4;
531             wdelta_twice_c = wdelta_c << 1;
532             hwdelta_c = (br_val - bl_val) - tmp_diff;
533             hwdelta_twice_c = hwdelta_c << 1;
534 
535             op_int16x8_t incr_base = wdelta_c + hwdelta_c;
536             accum_0_c_v += incr_base;
537             incr_0_c = incr_base << 1;
538             incr_base += hwdelta_twice_c;
539             accum_1_c_v += incr_base;
540             incr_1_c = incr_base << 1;
541 
542             PairExtractUpper(accum_0_c_v.val, accum_1_c_v.val, &output_data_0_c,
543                              &output_data_1_c);
544             vst1_u8(&output_base_ptr_0[c_block + depth * j * 8 + depth * 4],
545                     output_data_0_c);
546             vst1_u8(&output_base_ptr_1[c_block + depth * j * 8 + depth * 4],
547                     output_data_1_c);
548 
549             incr_base += hwdelta_twice_c;
550             accum_2_c_v += incr_base;
551             incr_2_c = incr_base << 1;
552             incr_base += hwdelta_twice_c;
553             accum_3_c_v += incr_base;
554             incr_3_c = incr_base << 1;
555 
556             PairExtractUpper(accum_2_c_v.val, accum_3_c_v.val, &output_data_2_c,
557                              &output_data_3_c);
558             vst1_u8(&output_base_ptr_2[c_block + depth * j * 8 + depth * 4],
559                     output_data_2_c);
560             vst1_u8(&output_base_ptr_3[c_block + depth * j * 8 + depth * 4],
561                     output_data_3_c);
562 
563             incr_base += hwdelta_twice_c;
564             accum_4_c_v += incr_base;
565             incr_4_c = incr_base << 1;
566             incr_base += hwdelta_twice_c;
567             accum_5_c_v += incr_base;
568             incr_5_c = incr_base << 1;
569 
570             PairExtractUpper(accum_4_c_v.val, accum_5_c_v.val, &output_data_4_c,
571                              &output_data_5_c);
572             vst1_u8(&output_base_ptr_4[c_block + depth * j * 8 + depth * 4],
573                     output_data_4_c);
574             vst1_u8(&output_base_ptr_5[c_block + depth * j * 8 + depth * 4],
575                     output_data_5_c);
576 
577             incr_base += hwdelta_twice_c;
578             accum_6_c_v += incr_base;
579             incr_6_c = incr_base << 1;
580             incr_base += hwdelta_twice_c;
581             accum_7_c_v += incr_base;
582             incr_7_c = incr_base << 1;
583 
584             PairExtractUpper(accum_6_c_v.val, accum_7_c_v.val, &output_data_6_c,
585                              &output_data_7_c);
586             vst1_u8(&output_base_ptr_6[c_block + depth * j * 8 + depth * 4],
587                     output_data_6_c);
588             vst1_u8(&output_base_ptr_7[c_block + depth * j * 8 + depth * 4],
589                     output_data_7_c);
590 
591             for (int p = 1; p < 8; ++p) {
592               accum_0_c_v += incr_0_c;
593               accum_1_c_v += incr_1_c;
594               PairExtractUpper(accum_0_c_v.val, accum_1_c_v.val,
595                                &output_data_0_c, &output_data_1_c);
596               vst1_u8(&output_base_ptr_0[c_block + depth * j * 8 + depth * p +
597                                          depth * 4],
598                       output_data_0_c);
599               vst1_u8(&output_base_ptr_1[c_block + depth * j * 8 + depth * p +
600                                          depth * 4],
601                       output_data_1_c);
602 
603               accum_2_c_v += incr_2_c;
604               accum_3_c_v += incr_3_c;
605               PairExtractUpper(accum_2_c_v.val, accum_3_c_v.val,
606                                &output_data_2_c, &output_data_3_c);
607               vst1_u8(&output_base_ptr_2[c_block + depth * j * 8 + depth * p +
608                                          depth * 4],
609                       output_data_2_c);
610               vst1_u8(&output_base_ptr_3[c_block + depth * j * 8 + depth * p +
611                                          depth * 4],
612                       output_data_3_c);
613 
614               accum_4_c_v += incr_4_c;
615               accum_5_c_v += incr_5_c;
616               PairExtractUpper(accum_4_c_v.val, accum_5_c_v.val,
617                                &output_data_4_c, &output_data_5_c);
618               vst1_u8(&output_base_ptr_4[c_block + depth * j * 8 + depth * p +
619                                          depth * 4],
620                       output_data_4_c);
621               vst1_u8(&output_base_ptr_5[c_block + depth * j * 8 + depth * p +
622                                          depth * 4],
623                       output_data_5_c);
624 
625               accum_6_c_v += incr_6_c;
626               accum_7_c_v += incr_7_c;
627               PairExtractUpper(accum_6_c_v.val, accum_7_c_v.val,
628                                &output_data_6_c, &output_data_7_c);
629               vst1_u8(&output_base_ptr_6[c_block + depth * j * 8 + depth * p +
630                                          depth * 4],
631                       output_data_6_c);
632               vst1_u8(&output_base_ptr_7[c_block + depth * j * 8 + depth * p +
633                                          depth * 4],
634                       output_data_7_c);
635             }
636 
637             accum_0_c_v += (incr_0_c >> 1);
638             accum_1_c_v += (incr_1_c >> 1);
639             accum_2_c_v += (incr_2_c >> 1);
640             accum_3_c_v += (incr_3_c >> 1);
641             accum_4_c_v += (incr_4_c >> 1);
642             accum_5_c_v += (incr_5_c >> 1);
643             accum_6_c_v += (incr_6_c >> 1);
644             accum_7_c_v += (incr_7_c >> 1);
645           }
646         }
647 
648         // Right margin.
649         {
650           // Accumulations have pre-added 0.5 for rounding, but that is just
651           // discarded and this just avoids re-loading.
652           PairExtractUpper(accum_0_c_v.val, accum_1_c_v.val, &output_data_0_c,
653                            &output_data_1_c);
654           PairExtractUpper(accum_2_c_v.val, accum_3_c_v.val, &output_data_2_c,
655                            &output_data_3_c);
656           PairExtractUpper(accum_4_c_v.val, accum_5_c_v.val, &output_data_4_c,
657                            &output_data_5_c);
658           PairExtractUpper(accum_6_c_v.val, accum_7_c_v.val, &output_data_6_c,
659                            &output_data_7_c);
660           for (int p = 0; p < 4; ++p) {
661             vst1_u8(&output_base_ptr_0[c_block + depth * (input_width - 1) * 8 +
662                                        depth * 4 + depth * p],
663                     output_data_0_c);
664             vst1_u8(&output_base_ptr_1[c_block + depth * (input_width - 1) * 8 +
665                                        depth * 4 + depth * p],
666                     output_data_1_c);
667             vst1_u8(&output_base_ptr_2[c_block + depth * (input_width - 1) * 8 +
668                                        depth * 4 + depth * p],
669                     output_data_2_c);
670             vst1_u8(&output_base_ptr_3[c_block + depth * (input_width - 1) * 8 +
671                                        depth * 4 + depth * p],
672                     output_data_3_c);
673             vst1_u8(&output_base_ptr_4[c_block + depth * (input_width - 1) * 8 +
674                                        depth * 4 + depth * p],
675                     output_data_4_c);
676             vst1_u8(&output_base_ptr_5[c_block + depth * (input_width - 1) * 8 +
677                                        depth * 4 + depth * p],
678                     output_data_5_c);
679             vst1_u8(&output_base_ptr_6[c_block + depth * (input_width - 1) * 8 +
680                                        depth * 4 + depth * p],
681                     output_data_6_c);
682             vst1_u8(&output_base_ptr_7[c_block + depth * (input_width - 1) * 8 +
683                                        depth * 4 + depth * p],
684                     output_data_7_c);
685           }
686         }
687       }
688 
689       output_base_ptr += output_row_stride * 8;
690       input_base_ptr += input_row_stride;
691     }
692 
693     //
694 
695     for (int c_block = 0; c_block < depth; c_block += 8) {
696       op_uint16x8_t accum_c_v;
697       // Bottom-left margin corner.
698       {
699         uint8x8_t output_data = vld1_u8(&input_base_ptr[c_block]);
700         vst1_u8(&output_base_ptr[c_block], output_data);
701         vst1_u8(&output_base_ptr[c_block + depth], output_data);
702         vst1_u8(&output_base_ptr[c_block + depth * 2], output_data);
703         vst1_u8(&output_base_ptr[c_block + depth * 3], output_data);
704 
705         // Accumulate in 8.8 representation, pre-adding 0.5 for later rounding.
706         accum_c_v = vaddq_u16(Move8IntoUpperU16(output_data), vdupq_n_u16(128));
707       }
708 
709       // Bottom-centre margin.
710       op_int16x8_t wdelta_c_v;
711       op_int16x8_t wdelta_twice_c_v;
712       for (int j = 0; j < (input_width - 1); ++j) {
713         {
714           uint8x8_t output_data_alt;
715           uint8x8_t output_data;
716 
717           const op_int16x8_t tl_val(
718               Load8IntoLowerS16(&input_base_ptr[c_block + depth * j]));
719           const op_int16x8_t tr_val(
720               Load8IntoLowerS16(&input_base_ptr[c_block + depth * (j + 1)]));
721           wdelta_c_v = (tr_val - tl_val) << 4;
722           wdelta_twice_c_v = wdelta_c_v << 1;
723 
724           op_uint16x8_t accum_c_v_alt = accum_c_v + wdelta_c_v;
725           accum_c_v = accum_c_v_alt + wdelta_twice_c_v;
726           PairExtractUpper(accum_c_v_alt.val, accum_c_v.val, &output_data_alt,
727                            &output_data);
728 
729           vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth * 4],
730                   output_data_alt);
731           vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth + depth * 4],
732                   output_data);
733 
734           for (int p = 2; p < 8; p += 2) {
735             accum_c_v_alt = accum_c_v + wdelta_twice_c_v;
736             accum_c_v = accum_c_v_alt + wdelta_twice_c_v;
737             PairExtractUpper(accum_c_v_alt.val, accum_c_v.val, &output_data_alt,
738                              &output_data);
739 
740             vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth * p +
741                                      depth * 4],
742                     output_data_alt);
743             vst1_u8(&output_base_ptr[c_block + depth * j * 8 + depth * (p + 1) +
744                                      depth * 4],
745                     output_data);
746           }
747           accum_c_v += wdelta_c_v;
748         }
749       }
750 
751       // Bottom-right margin corner.
752       {
753         uint8x8_t output_data_discard;
754         uint8x8_t output_data;
755 
756         // Accumulations have pre-added 0.5 for rounding, but that is just
757         // discarded and this just avoids re-loading.
758         PairExtractUpper(accum_c_v.val, accum_c_v.val, &output_data,
759                          &output_data_discard);
760 
761         vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
762                                  depth * 4],
763                 output_data);
764         vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
765                                  depth * 4 + depth],
766                 output_data);
767         vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
768                                  depth * 4 + depth * 2],
769                 output_data);
770         vst1_u8(&output_base_ptr[c_block + depth * (input_width - 1) * 8 +
771                                  depth * 4 + depth * 3],
772                 output_data);
773       }
774     }
775     // Fill out remainder of bottom margin.
776     std::memcpy(output_base_ptr + output_row_stride, output_base_ptr,
777                 output_row_stride * sizeof(uint8));
778     std::memcpy(output_base_ptr + output_row_stride * 2, output_base_ptr,
779                 output_row_stride * sizeof(uint8));
780     std::memcpy(output_base_ptr + output_row_stride * 3, output_base_ptr,
781                 output_row_stride * sizeof(uint8));
782 
783 #else  // USE_NEON
784     for (int c_block = 0; c_block < depth; c_block += 8) {
785       uint8 output_data[8];
786       uint16 accum[8];
787       // Top-left margin corner.
788       for (int c = 0; c < 8; ++c) {
789         output_data[c] = input_base_ptr[c_block + c];
790         output_base_ptr[c_block + c] = output_data[c];
791         output_base_ptr[c_block + c + depth] = output_data[c];
792         output_base_ptr[c_block + c + depth * 2] = output_data[c];
793         output_base_ptr[c_block + c + depth * 3] = output_data[c];
794 
795         // Accumulate in 8.8 representation, pre-adding 0.5 for later rounding.
796         accum[c] =
797             (output_data[c] << 8) + 128;  // 128 = 0.5 in 8.8 representation.
798       }
799 
800       // Top-centre margin.
801       uint16 wdelta[8];
802       uint16 wdelta_twice[8];
803       for (int j = 0; j < (input_width - 1); ++j) {
804         for (int c = 0; c < 8; ++c) {
805           wdelta[c] = static_cast<uint16>(
806                           input_base_ptr[c_block + c + depth * (j + 1)] -
807                           input_base_ptr[c_block + c + depth * j])
808                       << 4;
809           wdelta_twice[c] = wdelta[c] << 1;
810 
811           accum[c] += wdelta[c];
812           output_base_ptr[c_block + c + depth * j * 8 + depth * 4] =
813               accum[c] >> 8;
814           for (int p = 1; p < 8; ++p) {
815             accum[c] += wdelta_twice[c];
816             output_base_ptr[c_block + c + depth * j * 8 + depth * p +
817                             depth * 4] = accum[c] >> 8;
818           }
819           accum[c] += wdelta[c];
820         }
821       }
822 
823       // Top-right margin corner.
824       for (int c = 0; c < 8; ++c) {
825         // Accumulations have pre-added 0.5 for rounding, but that is just
826         // discarded and this just avoids re-loading.
827         output_data[c] = accum[c] >> 8;
828         TFLITE_DCHECK_EQ(
829             output_data[c],
830             input_base_ptr[c_block + c + depth * (input_width - 1)]);
831         output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
832                         depth * 4] = output_data[c];
833         output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
834                         depth * 4 + depth] = output_data[c];
835         output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
836                         depth * 4 + depth * 2] = output_data[c];
837         output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
838                         depth * 4 + depth * 3] = output_data[c];
839       }
840     }
841     // Fill out remainder of top margin.
842     std::memcpy(output_base_ptr + output_row_stride, output_base_ptr,
843                 output_row_stride * sizeof(uint8));
844     std::memcpy(output_base_ptr + output_row_stride * 2, output_base_ptr,
845                 output_row_stride * sizeof(uint8));
846     std::memcpy(output_base_ptr + output_row_stride * 3, output_base_ptr,
847                 output_row_stride * sizeof(uint8));
848     output_base_ptr += output_row_stride * 4;
849 
850     // Main rows.
851     for (int k = 0; k < (input_height - 1); ++k) {
852       for (int c_block = 0; c_block < depth; c_block += 8) {
853         uint8* output_base_ptr_0 = output_base_ptr;
854         uint8* output_base_ptr_1;
855         uint8* output_base_ptr_2;
856         uint8* output_base_ptr_3;
857         uint8* output_base_ptr_4;
858         uint8* output_base_ptr_5;
859         uint8* output_base_ptr_6;
860         uint8* output_base_ptr_7;
861         uint16 accum_0[8];
862         uint16 accum_1[8];
863         uint16 accum_2[8];
864         uint16 accum_3[8];
865         uint16 accum_4[8];
866         uint16 accum_5[8];
867         uint16 accum_6[8];
868         uint16 accum_7[8];
869 
870         // We prefer accum_0[c], etc, in sense of packed-data array for
871         // register. However the compiler will not reliably optimize for an
872         // array, and so we do most of the work in pure scalar variables.
873         uint16 accum_0_c;
874         uint16 accum_1_c;
875         uint16 accum_2_c;
876         uint16 accum_3_c;
877         uint16 accum_4_c;
878         uint16 accum_5_c;
879         uint16 accum_6_c;
880         uint16 accum_7_c;
881 
882         int16 hdelta_c;
883         int16 hdelta_twice_c;
884 
885         // Left margin for 8 rows.
886         for (int c = 0; c < 8; ++c) {
887           hdelta_c = static_cast<uint16>(
888                          input_base_ptr[c_block + c + input_row_stride] -
889                          input_base_ptr[c_block + c])
890                      << 4;
891 
892           // Accumulate in 8.8 representation, pre-adding 0.5 for later
893           // rounding.
894           accum_0_c = (input_base_ptr[c_block + c] << 8) + 128;
895 
896           accum_0_c += hdelta_c;
897           output_base_ptr_0[c_block + c] = accum_0_c >> 8;
898           output_base_ptr_0[c_block + c + depth] = accum_0_c >> 8;
899           output_base_ptr_0[c_block + c + depth * 2] = accum_0_c >> 8;
900           output_base_ptr_0[c_block + c + depth * 3] = accum_0_c >> 8;
901 
902           hdelta_twice_c = hdelta_c << 1;
903 
904           output_base_ptr_1 = output_base_ptr_0 + output_row_stride;
905           accum_1_c = accum_0_c + hdelta_twice_c;
906           output_base_ptr_1[c_block + c] = accum_1_c >> 8;
907           output_base_ptr_1[c_block + c + depth] = accum_1_c >> 8;
908           output_base_ptr_1[c_block + c + depth * 2] = accum_1_c >> 8;
909           output_base_ptr_1[c_block + c + depth * 3] = accum_1_c >> 8;
910 
911           output_base_ptr_2 = output_base_ptr_1 + output_row_stride;
912           accum_2_c = accum_1_c + hdelta_twice_c;
913           output_base_ptr_2[c_block + c] = accum_2_c >> 8;
914           output_base_ptr_2[c_block + c + depth] = accum_2_c >> 8;
915           output_base_ptr_2[c_block + c + depth * 2] = accum_2_c >> 8;
916           output_base_ptr_2[c_block + c + depth * 3] = accum_2_c >> 8;
917 
918           output_base_ptr_3 = output_base_ptr_2 + output_row_stride;
919           accum_3_c = accum_2_c + hdelta_twice_c;
920           output_base_ptr_3[c_block + c] = accum_3_c >> 8;
921           output_base_ptr_3[c_block + c + depth] = accum_3_c >> 8;
922           output_base_ptr_3[c_block + c + depth * 2] = accum_3_c >> 8;
923           output_base_ptr_3[c_block + c + depth * 3] = accum_3_c >> 8;
924 
925           output_base_ptr_4 = output_base_ptr_3 + output_row_stride;
926           accum_4_c = accum_3_c + hdelta_twice_c;
927           output_base_ptr_4[c_block + c] = accum_4_c >> 8;
928           output_base_ptr_4[c_block + c + depth] = accum_4_c >> 8;
929           output_base_ptr_4[c_block + c + depth * 2] = accum_4_c >> 8;
930           output_base_ptr_4[c_block + c + depth * 3] = accum_4_c >> 8;
931 
932           output_base_ptr_5 = output_base_ptr_4 + output_row_stride;
933           accum_5_c = accum_4_c + hdelta_twice_c;
934           output_base_ptr_5[c_block + c] = accum_5_c >> 8;
935           output_base_ptr_5[c_block + c + depth] = accum_5_c >> 8;
936           output_base_ptr_5[c_block + c + depth * 2] = accum_5_c >> 8;
937           output_base_ptr_5[c_block + c + depth * 3] = accum_5_c >> 8;
938 
939           output_base_ptr_6 = output_base_ptr_5 + output_row_stride;
940           accum_6_c = accum_5_c + hdelta_twice_c;
941           output_base_ptr_6[c_block + c] = accum_6_c >> 8;
942           output_base_ptr_6[c_block + c + depth] = accum_6_c >> 8;
943           output_base_ptr_6[c_block + c + depth * 2] = accum_6_c >> 8;
944           output_base_ptr_6[c_block + c + depth * 3] = accum_6_c >> 8;
945 
946           output_base_ptr_7 = output_base_ptr_6 + output_row_stride;
947           accum_7_c = accum_6_c + hdelta_twice_c;
948           output_base_ptr_7[c_block + c] = accum_7_c >> 8;
949           output_base_ptr_7[c_block + c + depth] = accum_7_c >> 8;
950           output_base_ptr_7[c_block + c + depth * 2] = accum_7_c >> 8;
951           output_base_ptr_7[c_block + c + depth * 3] = accum_7_c >> 8;
952 
953           accum_0[c] = accum_0_c;
954           accum_1[c] = accum_1_c;
955           accum_2[c] = accum_2_c;
956           accum_3[c] = accum_3_c;
957           accum_4[c] = accum_4_c;
958           accum_5[c] = accum_5_c;
959           accum_6[c] = accum_6_c;
960           accum_7[c] = accum_7_c;
961         }
962 
963         // Main central body.
964         int16 wdelta_c;
965         int16 wdelta_twice_c;
966         int16 hwdelta_c;
967         int16 hwdelta_twice_c;
968 
969         int16 incr_0_c;
970         int16 incr_1_c;
971         int16 incr_2_c;
972         int16 incr_3_c;
973         int16 incr_4_c;
974         int16 incr_5_c;
975         int16 incr_6_c;
976         int16 incr_7_c;
977         for (int j = 0; j < (input_width - 1); ++j) {
978           for (int c = 0; c < 8; ++c) {
979             accum_0_c = accum_0[c];
980             accum_1_c = accum_1[c];
981             accum_2_c = accum_2[c];
982             accum_3_c = accum_3[c];
983             accum_4_c = accum_4[c];
984             accum_5_c = accum_5[c];
985             accum_6_c = accum_6[c];
986             accum_7_c = accum_7[c];
987 
988             wdelta_c = static_cast<uint16>(
989                            input_base_ptr[c_block + c + depth * (j + 1)] -
990                            input_base_ptr[c_block + c + depth * j])
991                        << 4;
992             wdelta_twice_c = wdelta_c << 1;
993             hwdelta_c = static_cast<uint16>(
994                 input_base_ptr[c_block + c + depth * (j + 1) +
995                                input_row_stride] -
996                 input_base_ptr[c_block + c + depth * (j + 1)] -
997                 input_base_ptr[c_block + c + depth * j + input_row_stride] +
998                 input_base_ptr[c_block + c + depth * j]);
999             hwdelta_twice_c = hwdelta_c << 1;
1000 
1001             uint16 incr_base = wdelta_c + hwdelta_c;
1002             accum_0_c += incr_base;
1003             output_base_ptr_0[c_block + c + depth * j * 8 + depth * 4] =
1004                 accum_0_c >> 8;
1005             incr_0_c = incr_base << 1;
1006 
1007             incr_base += hwdelta_twice_c;
1008             accum_1_c += incr_base;
1009             output_base_ptr_1[c_block + c + depth * j * 8 + depth * 4] =
1010                 accum_1_c >> 8;
1011             incr_1_c = incr_base << 1;
1012 
1013             incr_base += hwdelta_twice_c;
1014             accum_2_c += incr_base;
1015             output_base_ptr_2[c_block + c + depth * j * 8 + depth * 4] =
1016                 accum_2_c >> 8;
1017             incr_2_c = incr_base << 1;
1018 
1019             incr_base += hwdelta_twice_c;
1020             accum_3_c += incr_base;
1021             output_base_ptr_3[c_block + c + depth * j * 8 + depth * 4] =
1022                 accum_3_c >> 8;
1023             incr_3_c = incr_base << 1;
1024 
1025             incr_base += hwdelta_twice_c;
1026             accum_4_c += incr_base;
1027             output_base_ptr_4[c_block + c + depth * j * 8 + depth * 4] =
1028                 accum_4_c >> 8;
1029             incr_4_c = incr_base << 1;
1030 
1031             incr_base += hwdelta_twice_c;
1032             accum_5_c += incr_base;
1033             output_base_ptr_5[c_block + c + depth * j * 8 + depth * 4] =
1034                 accum_5_c >> 8;
1035             incr_5_c = incr_base << 1;
1036 
1037             incr_base += hwdelta_twice_c;
1038             accum_6_c += incr_base;
1039             output_base_ptr_6[c_block + c + depth * j * 8 + depth * 4] =
1040                 accum_6_c >> 8;
1041             incr_6_c = incr_base << 1;
1042 
1043             incr_base += hwdelta_twice_c;
1044             accum_7_c += incr_base;
1045             output_base_ptr_7[c_block + c + depth * j * 8 + depth * 4] =
1046                 accum_7_c >> 8;
1047             incr_7_c = incr_base << 1;
1048 
1049             for (int p = 1; p < 8; ++p) {
1050               accum_0_c += incr_0_c;
1051               output_base_ptr_0[c_block + c + depth * j * 8 + depth * p +
1052                                 depth * 4] = accum_0_c >> 8;
1053               accum_1_c += incr_1_c;
1054               output_base_ptr_1[c_block + c + depth * j * 8 + depth * p +
1055                                 depth * 4] = accum_1_c >> 8;
1056               accum_2_c += incr_2_c;
1057               output_base_ptr_2[c_block + c + depth * j * 8 + depth * p +
1058                                 depth * 4] = accum_2_c >> 8;
1059               accum_3_c += incr_3_c;
1060               output_base_ptr_3[c_block + c + depth * j * 8 + depth * p +
1061                                 depth * 4] = accum_3_c >> 8;
1062               accum_4_c += incr_4_c;
1063               output_base_ptr_4[c_block + c + depth * j * 8 + depth * p +
1064                                 depth * 4] = accum_4_c >> 8;
1065               accum_5_c += incr_5_c;
1066               output_base_ptr_5[c_block + c + depth * j * 8 + depth * p +
1067                                 depth * 4] = accum_5_c >> 8;
1068               accum_6_c += incr_6_c;
1069               output_base_ptr_6[c_block + c + depth * j * 8 + depth * p +
1070                                 depth * 4] = accum_6_c >> 8;
1071               accum_7_c += incr_7_c;
1072               output_base_ptr_7[c_block + c + depth * j * 8 + depth * p +
1073                                 depth * 4] = accum_7_c >> 8;
1074             }
1075             accum_0_c += incr_0_c / 2;
1076             accum_1_c += incr_1_c / 2;
1077             accum_2_c += incr_2_c / 2;
1078             accum_3_c += incr_3_c / 2;
1079             accum_4_c += incr_4_c / 2;
1080             accum_5_c += incr_5_c / 2;
1081             accum_6_c += incr_6_c / 2;
1082             accum_7_c += incr_7_c / 2;
1083 
1084             accum_0[c] = accum_0_c;
1085             accum_1[c] = accum_1_c;
1086             accum_2[c] = accum_2_c;
1087             accum_3[c] = accum_3_c;
1088             accum_4[c] = accum_4_c;
1089             accum_5[c] = accum_5_c;
1090             accum_6[c] = accum_6_c;
1091             accum_7[c] = accum_7_c;
1092           }
1093         }
1094 
1095         // Right margin.
1096         uint8 output_data_0_c;
1097         uint8 output_data_1_c;
1098         uint8 output_data_2_c;
1099         uint8 output_data_3_c;
1100         uint8 output_data_4_c;
1101         uint8 output_data_5_c;
1102         uint8 output_data_6_c;
1103         uint8 output_data_7_c;
1104         for (int c = 0; c < 8; ++c) {
1105           accum_0_c = accum_0[c];
1106           accum_1_c = accum_1[c];
1107           accum_2_c = accum_2[c];
1108           accum_3_c = accum_3[c];
1109           accum_4_c = accum_4[c];
1110           accum_5_c = accum_5[c];
1111           accum_6_c = accum_6[c];
1112           accum_7_c = accum_7[c];
1113 
1114           // Accumulations have pre-added 0.5 for rounding, but that is just
1115           // discarded and this just avoids re-loading.
1116           output_data_0_c = accum_0_c >> 8;
1117           output_data_1_c = accum_1_c >> 8;
1118           output_data_2_c = accum_2_c >> 8;
1119           output_data_3_c = accum_3_c >> 8;
1120           output_data_4_c = accum_4_c >> 8;
1121           output_data_5_c = accum_5_c >> 8;
1122           output_data_6_c = accum_6_c >> 8;
1123           output_data_7_c = accum_7_c >> 8;
1124           for (int p = 0; p < 4; ++p) {
1125             output_base_ptr_0[c_block + c + depth * (input_width - 1) * 8 +
1126                               depth * 4 + depth * p] = output_data_0_c;
1127             output_base_ptr_1[c_block + c + depth * (input_width - 1) * 8 +
1128                               depth * 4 + depth * p] = output_data_1_c;
1129             output_base_ptr_2[c_block + c + depth * (input_width - 1) * 8 +
1130                               depth * 4 + depth * p] = output_data_2_c;
1131             output_base_ptr_3[c_block + c + depth * (input_width - 1) * 8 +
1132                               depth * 4 + depth * p] = output_data_3_c;
1133             output_base_ptr_4[c_block + c + depth * (input_width - 1) * 8 +
1134                               depth * 4 + depth * p] = output_data_4_c;
1135             output_base_ptr_5[c_block + c + depth * (input_width - 1) * 8 +
1136                               depth * 4 + depth * p] = output_data_5_c;
1137             output_base_ptr_6[c_block + c + depth * (input_width - 1) * 8 +
1138                               depth * 4 + depth * p] = output_data_6_c;
1139             output_base_ptr_7[c_block + c + depth * (input_width - 1) * 8 +
1140                               depth * 4 + depth * p] = output_data_7_c;
1141           }
1142 
1143           accum_0[c] = accum_0_c;
1144           accum_1[c] = accum_1_c;
1145           accum_2[c] = accum_2_c;
1146           accum_3[c] = accum_3_c;
1147           accum_4[c] = accum_4_c;
1148           accum_5[c] = accum_5_c;
1149           accum_6[c] = accum_6_c;
1150           accum_7[c] = accum_7_c;
1151         }
1152       }
1153 
1154       output_base_ptr += output_row_stride * 8;
1155       input_base_ptr += input_row_stride;
1156     }
1157 
1158     for (int c_block = 0; c_block < depth; c_block += 8) {
1159       uint8 output_data[8];
1160       uint16 accum[8];
1161       // Bottom-left margin corner.
1162       for (int c = 0; c < 8; ++c) {
1163         output_data[c] = input_base_ptr[c_block + c];
1164         output_base_ptr[c_block + c] = output_data[c];
1165         output_base_ptr[c_block + c + depth] = output_data[c];
1166         output_base_ptr[c_block + c + depth * 2] = output_data[c];
1167         output_base_ptr[c_block + c + depth * 3] = output_data[c];
1168 
1169         // Accumulate in 8.8 representation, pre-adding 0.5 for later rounding.
1170         accum[c] =
1171             (output_data[c] << 8) + 128;  // 128 = 0.5 in 8.8 representation.
1172       }
1173 
1174       // Bottom-centre margin.
1175       uint16 wdelta[8];
1176       uint16 wdelta_twice[8];
1177       for (int j = 0; j < (input_width - 1); ++j) {
1178         for (int c = 0; c < 8; ++c) {
1179           wdelta[c] = static_cast<uint16>(
1180                           input_base_ptr[c_block + c + depth * (j + 1)] -
1181                           input_base_ptr[c_block + c + depth * j])
1182                       << 4;
1183           wdelta_twice[c] = wdelta[c] << 1;
1184 
1185           accum[c] += wdelta[c];
1186           output_base_ptr[c_block + c + depth * j * 8 + depth * 4] =
1187               accum[c] >> 8;
1188           for (int p = 1; p < 8; ++p) {
1189             accum[c] += wdelta_twice[c];
1190             output_base_ptr[c_block + c + depth * j * 8 + depth * p +
1191                             depth * 4] = accum[c] >> 8;
1192           }
1193           accum[c] += wdelta[c];
1194         }
1195       }
1196 
1197       // Bottom-right margin corner.
1198       for (int c = 0; c < 8; ++c) {
1199         // Accumulations have pre-added 0.5 for rounding, but that is just
1200         // discarded and this just avoids re-loading.
1201         output_data[c] = accum[c] >> 8;
1202         TFLITE_DCHECK_EQ(
1203             output_data[c],
1204             input_base_ptr[c_block + c + depth * (input_width - 1)]);
1205         output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
1206                         depth * 4] = output_data[c];
1207         output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
1208                         depth * 4 + depth] = output_data[c];
1209         output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
1210                         depth * 4 + depth * 2] = output_data[c];
1211         output_base_ptr[c_block + c + depth * (input_width - 1) * 8 +
1212                         depth * 4 + depth * 3] = output_data[c];
1213       }
1214     }
1215     // Fill out remainder of bottom margin.
1216     std::memcpy(output_base_ptr + output_row_stride, output_base_ptr,
1217                 output_row_stride * sizeof(uint8));
1218     std::memcpy(output_base_ptr + output_row_stride * 2, output_base_ptr,
1219                 output_row_stride * sizeof(uint8));
1220     std::memcpy(output_base_ptr + output_row_stride * 3, output_base_ptr,
1221                 output_row_stride * sizeof(uint8));
1222 
1223 #endif  // USE_NEON
1224   }
1225 }  // NOLINT(readability/fn_size)
1226 
1227 }  // namespace resize_bilinear
1228 
1229 #ifdef USE_NEON
ResizeBilinearKernel(const float * input_ptr,int32 depth,float scale,float * output_ptr)1230 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
1231                                  float scale, float* output_ptr) {
1232   int ic = 0;
1233   // Handle 32 input channels at a time.
1234   for (; ic <= depth - 32; ic += 32) {
1235     float32x4x2_t input[4];
1236     for (int i = 0; i < 4; i++) {
1237       input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
1238       input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
1239     }
1240     float32x4x2_t acc[4];
1241     for (int i = 0; i < 4; i++) {
1242       acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
1243       acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
1244     }
1245     for (int i = 0; i < 4; i++) {
1246       acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
1247       acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
1248     }
1249     for (int i = 0; i < 4; i++) {
1250       vst1q_f32(output_ptr, acc[i].val[0]);
1251       vst1q_f32(output_ptr + 4, acc[i].val[1]);
1252       output_ptr += 8;
1253     }
1254     input_ptr += 32;
1255   }
1256   // Handle 16 input channels at a time.
1257   for (; ic <= depth - 16; ic += 16) {
1258     float32x4x2_t input[2];
1259     for (int i = 0; i < 2; i++) {
1260       input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
1261       input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
1262     }
1263     float32x4x2_t acc[2];
1264     for (int i = 0; i < 2; i++) {
1265       acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
1266       acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
1267     }
1268     for (int i = 0; i < 2; i++) {
1269       acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
1270       acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
1271     }
1272     for (int i = 0; i < 2; i++) {
1273       vst1q_f32(output_ptr, acc[i].val[0]);
1274       vst1q_f32(output_ptr + 4, acc[i].val[1]);
1275       output_ptr += 8;
1276     }
1277     input_ptr += 16;
1278   }
1279   // Handle 8 input channels at a time.
1280   for (; ic <= depth - 8; ic += 8) {
1281     float32x4x2_t input;
1282     input.val[0] = vld1q_f32(input_ptr);
1283     input.val[1] = vld1q_f32(input_ptr + 4);
1284 
1285     float32x4x2_t acc;
1286     acc.val[0] = vld1q_f32(output_ptr);
1287     acc.val[1] = vld1q_f32(output_ptr + 4);
1288     acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale);
1289     acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale);
1290 
1291     vst1q_f32(output_ptr, acc.val[0]);
1292     vst1q_f32(output_ptr + 4, acc.val[1]);
1293 
1294     input_ptr += 8;
1295     output_ptr += 8;
1296   }
1297   // Handle 4 input channels at a time.
1298   for (; ic <= depth - 4; ic += 4) {
1299     float32x4_t input = vld1q_f32(input_ptr);
1300     float32x4_t acc = vld1q_f32(output_ptr);
1301 
1302     acc = vmlaq_n_f32(acc, input, scale);
1303     vst1q_f32(output_ptr, acc);
1304 
1305     input_ptr += 4;
1306     output_ptr += 4;
1307   }
1308   // Handle 1 input channel at a time.
1309   for (; ic < depth; ic++) {
1310     *output_ptr += *input_ptr * scale;
1311     output_ptr++;
1312     input_ptr++;
1313   }
1314 }
1315 #else
ResizeBilinearKernel(const float * input_ptr,int32 depth,float scale,float * output_ptr)1316 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
1317                                  float scale, float* output_ptr) {
1318   for (int32 i = 0; i < depth; i++) {
1319     *output_ptr += *input_ptr * scale;
1320     output_ptr++;
1321     input_ptr++;
1322   }
1323 }
1324 #endif
1325 
ResizeBilinearKernel2x2(int32 x0,int32 x1,int32 y0,int32 y1,int32 x,int32 y,int32 depth,int32 batch,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)1326 inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
1327                                     int32 x, int32 y, int32 depth, int32 batch,
1328                                     const RuntimeShape& input_shape,
1329                                     const float* input_data,
1330                                     const RuntimeShape& output_shape,
1331                                     float* output_data) {
1332   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1333   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1334   const int32 input_width = input_shape.Dims(2);
1335   const int32 output_width = output_shape.Dims(2);
1336 
1337   const int32 input_x_offset = (x1 - x0) * depth;
1338   const int32 input_y_offset = (y1 - y0) * depth * input_width;
1339   const int32 output_x_offset = depth;
1340   const int32 output_y_offset = depth * output_width;
1341 
1342 #ifdef USE_NEON
1343   TFLITE_DCHECK(x1 >= x0);
1344   TFLITE_DCHECK(y1 >= y0);
1345 
1346   int ic = 0;
1347   // Handle 8 input channels at a time.
1348   for (; ic <= depth - 8; ic += 8) {
1349     const float* input_ptr = nullptr;
1350 
1351     float32x4x2_t x0y0;
1352     input_ptr = &input_data[Offset(input_shape, batch, y0, x0, ic)];
1353     x0y0.val[0] = vld1q_f32(input_ptr);
1354     x0y0.val[1] = vld1q_f32(input_ptr + 4);
1355 
1356     float32x4x2_t x1y0;
1357     input_ptr += input_x_offset;
1358     x1y0.val[0] = vld1q_f32(input_ptr);
1359     x1y0.val[1] = vld1q_f32(input_ptr + 4);
1360 
1361     float32x4x2_t x0y1;
1362     input_ptr += -input_x_offset + input_y_offset;
1363     x0y1.val[0] = vld1q_f32(input_ptr);
1364     x0y1.val[1] = vld1q_f32(input_ptr + 4);
1365 
1366     float32x4x2_t x1y1;
1367     input_ptr += input_x_offset;
1368     x1y1.val[0] = vld1q_f32(input_ptr);
1369     x1y1.val[1] = vld1q_f32(input_ptr + 4);
1370 
1371     // Top left corner.
1372     float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
1373     vst1q_f32(output_ptr, x0y0.val[0]);
1374     vst1q_f32(output_ptr + 4, x0y0.val[1]);
1375 
1376     // Top right corner.
1377     output_ptr += output_x_offset;
1378     float32x4x2_t tr;
1379     tr.val[0] = vaddq_f32(x0y0.val[0], x1y0.val[0]);
1380     tr.val[1] = vaddq_f32(x0y0.val[1], x1y0.val[1]);
1381     tr.val[0] = vmulq_n_f32(tr.val[0], 0.5f);
1382     tr.val[1] = vmulq_n_f32(tr.val[1], 0.5f);
1383 
1384     vst1q_f32(output_ptr, tr.val[0]);
1385     vst1q_f32(output_ptr + 4, tr.val[1]);
1386 
1387     // Bottom left corner.
1388     output_ptr += -output_x_offset + output_y_offset;
1389     float32x4x2_t bl;
1390     bl.val[0] = vaddq_f32(x0y0.val[0], x0y1.val[0]);
1391     bl.val[1] = vaddq_f32(x0y0.val[1], x0y1.val[1]);
1392     bl.val[0] = vmulq_n_f32(bl.val[0], 0.5f);
1393     bl.val[1] = vmulq_n_f32(bl.val[1], 0.5f);
1394     vst1q_f32(output_ptr, bl.val[0]);
1395     vst1q_f32(output_ptr + 4, bl.val[1]);
1396 
1397     // Bottom right corner.
1398     output_ptr += output_x_offset;
1399     float32x4x2_t br;
1400     br.val[0] = vaddq_f32(x1y0.val[0], x1y1.val[0]);
1401     br.val[1] = vaddq_f32(x1y0.val[1], x1y1.val[1]);
1402     br.val[0] = vmlaq_n_f32(bl.val[0], br.val[0], 0.5f);
1403     br.val[1] = vmlaq_n_f32(bl.val[1], br.val[1], 0.5f);
1404     br.val[0] = vmulq_n_f32(br.val[0], 0.5f);
1405     br.val[1] = vmulq_n_f32(br.val[1], 0.5f);
1406     vst1q_f32(output_ptr, br.val[0]);
1407     vst1q_f32(output_ptr + 4, br.val[1]);
1408   }
1409   // Handle 4 input channels at a time.
1410   for (; ic <= depth - 4; ic += 4) {
1411     const float* input_ptr =
1412         &input_data[Offset(input_shape, batch, y0, x0, ic)];
1413     float32x4_t x0y0 = vld1q_f32(input_ptr);
1414     float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset);
1415     float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset);
1416     float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset);
1417 
1418     // Top left corner.
1419     float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
1420     vst1q_f32(output_ptr, x0y0);
1421 
1422     // Top right corner.
1423     output_ptr += output_x_offset;
1424     float32x4_t tr = vaddq_f32(x0y0, x1y0);
1425     tr = vmulq_n_f32(tr, 0.5f);
1426     vst1q_f32(output_ptr, tr);
1427 
1428     // Bottom left corner.
1429     output_ptr += -output_x_offset + output_y_offset;
1430     float32x4_t bl = vaddq_f32(x0y0, x0y1);
1431     bl = vmulq_n_f32(bl, 0.5f);
1432     vst1q_f32(output_ptr, bl);
1433 
1434     // Bottom right corner.
1435     output_ptr += output_x_offset;
1436     float32x4_t br = vaddq_f32(x1y0, x1y1);
1437     br = vmlaq_n_f32(bl, br, 0.5f);
1438     br = vmulq_n_f32(br, 0.5f);
1439     vst1q_f32(output_ptr, br);
1440   }
1441   // Handle one input channel at a time.
1442   for (; ic < depth; ic++) {
1443     const int32 input_offset = Offset(input_shape, batch, y0, x0, ic);
1444 
1445     float x0y0 = input_data[input_offset];
1446     float x1y0 = input_data[input_offset + input_x_offset];
1447     float x0y1 = input_data[input_offset + input_y_offset];
1448     float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
1449 
1450     // Top left corner.
1451     const int32 output_offset = Offset(output_shape, batch, y, x, ic);
1452     output_data[output_offset] = x0y0;
1453 
1454     // Top right corner.
1455     output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
1456 
1457     // Bottom left corner.
1458     float output = (x0y0 + x0y1) / 2;
1459     output_data[output_offset + output_y_offset] = output;
1460 
1461     // Bottom right corner.
1462     output_data[output_offset + output_x_offset + output_y_offset] =
1463         (output + ((x1y0 + x1y1) / 2)) / 2;
1464   }
1465 #else
1466   for (int ch = 0; ch < depth; ch++) {
1467     const int32 input_offset = Offset(input_shape, batch, y0, x0, ch);
1468 
1469     float x0y0 = input_data[input_offset];
1470     float x1y0 = input_data[input_offset + input_x_offset];
1471     float x0y1 = input_data[input_offset + input_y_offset];
1472     float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
1473 
1474     // Top left corner.
1475     const int32 output_offset = Offset(output_shape, batch, y, x, ch);
1476     output_data[output_offset] = x0y0;
1477 
1478     // Top right corner.
1479     output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
1480 
1481     // Bottom left corner.
1482     float output = (x0y0 + x0y1) / 2;
1483     output_data[output_offset + output_y_offset] = output;
1484 
1485     // Bottom right corner.
1486     output_data[output_offset + output_x_offset + output_y_offset] =
1487         (output + ((x1y0 + x1y1) / 2)) / 2;
1488   }
1489 #endif
1490 }
1491 
ResizeBilinear2x2(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)1492 inline void ResizeBilinear2x2(int32 batches, int32 input_height,
1493                               int32 input_width, int32 depth,
1494                               int32 output_height, int32 output_width,
1495                               const RuntimeShape& input_shape,
1496                               const float* input_data,
1497                               const RuntimeShape& output_shape,
1498                               float* output_data) {
1499   for (int b = 0; b < batches; b++) {
1500     for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) {
1501       for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) {
1502         int32 x1 = std::min(x0 + 1, input_width - 1);
1503         int32 y1 = std::min(y0 + 1, input_height - 1);
1504         ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_shape,
1505                                 input_data, output_shape, output_data);
1506       }
1507     }
1508   }
1509 }
1510 
ResizeBilinearGeneric(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,float height_scale,float width_scale,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data,const bool half_pixel_centers)1511 inline void ResizeBilinearGeneric(
1512     int32 batches, int32 input_height, int32 input_width, int32 depth,
1513     int32 output_height, int32 output_width, float height_scale,
1514     float width_scale, const RuntimeShape& input_shape, const float* input_data,
1515     const RuntimeShape& output_shape, float* output_data,
1516     const bool half_pixel_centers) {
1517   memset(output_data, 0,
1518          batches * output_height * output_width * depth * sizeof(float));
1519 
1520   int32 output_offset = 0;
1521   for (int b = 0; b < batches; ++b) {
1522     for (int y = 0; y < output_height; ++y) {
1523       float input_y;
1524       int32 y0, y1;
1525       reference_ops::ComputeInterpolationValues(
1526           y, height_scale, half_pixel_centers, input_height, &input_y, &y0,
1527           &y1);
1528       for (int x = 0; x < output_width; ++x) {
1529         float input_x;
1530         int32 x0, x1;
1531         reference_ops::ComputeInterpolationValues(
1532             x, width_scale, half_pixel_centers, input_width, &input_x, &x0,
1533             &x1);
1534         float* output_ptr = &output_data[output_offset];
1535 
1536         // Run kernel on the 4 corners of the bilinear resize algorithm.
1537         int32 input_offset = Offset(input_shape, b, y0, x0, 0);
1538         float scale = (1 - (input_y - y0)) * (1 - (input_x - x0));
1539         const float* input_ptr = &input_data[input_offset];
1540         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
1541 
1542         input_offset = Offset(input_shape, b, y0, x1, 0);
1543         scale = (1 - (input_y - y0)) * (input_x - x0);
1544         input_ptr = &input_data[input_offset];
1545         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
1546 
1547         input_offset = Offset(input_shape, b, y1, x0, 0);
1548         scale = (input_y - y0) * (1 - (input_x - x0));
1549         input_ptr = &input_data[input_offset];
1550         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
1551 
1552         input_offset = Offset(input_shape, b, y1, x1, 0);
1553         scale = (input_y - y0) * (input_x - x0);
1554         input_ptr = &input_data[input_offset];
1555         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
1556 
1557         output_offset += depth;
1558       }
1559     }
1560   }
1561 }
1562 
1563 template <typename T>
ResizeBilinearGenericSmallChannel(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,float height_scale,float width_scale,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data,const bool half_pixel_centers)1564 inline void ResizeBilinearGenericSmallChannel(
1565     int32 batches, int32 input_height, int32 input_width, int32 depth,
1566     int32 output_height, int32 output_width, float height_scale,
1567     float width_scale, const RuntimeShape& input_shape, const T* input_data,
1568     const RuntimeShape& output_shape, T* output_data,
1569     const bool half_pixel_centers) {
1570   T* output_ptr = &output_data[0];
1571   const float rounding_offset = std::numeric_limits<T>::is_integer ? .5f : .0f;
1572 
1573   for (int b = 0; b < batches; ++b) {
1574     for (int y = 0; y < output_height; ++y) {
1575       float input_y;
1576       int32 y0, y1;
1577       reference_ops::ComputeInterpolationValues(
1578           y, height_scale, half_pixel_centers, input_height, &input_y, &y0,
1579           &y1);
1580       for (int x = 0; x < output_width; ++x) {
1581         float input_x;
1582         int32 x0, x1;
1583         reference_ops::ComputeInterpolationValues(
1584             x, width_scale, half_pixel_centers, input_width, &input_x, &x0,
1585             &x1);
1586 
1587         int32 input_offset[4] = {Offset(input_shape, b, y0, x0, 0),
1588                                  Offset(input_shape, b, y0, x1, 0),
1589                                  Offset(input_shape, b, y1, x0, 0),
1590                                  Offset(input_shape, b, y1, x1, 0)};
1591         float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)),
1592                           (1 - (input_y - y0)) * (input_x - x0),
1593                           (input_y - y0) * (1 - (input_x - x0)),
1594                           (input_y - y0) * (input_x - x0)};
1595 
1596         for (int d = 0; d < depth; d++) {
1597           const T* input_ptr = &input_data[d];
1598           *output_ptr++ = static_cast<T>(input_ptr[input_offset[0]] * scale[0] +
1599                                          input_ptr[input_offset[1]] * scale[1] +
1600                                          input_ptr[input_offset[2]] * scale[2] +
1601                                          input_ptr[input_offset[3]] * scale[3] +
1602                                          rounding_offset);
1603         }
1604       }
1605     }
1606   }
1607 }
1608 
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,float * output_data)1609 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
1610                            const RuntimeShape& unextended_input_shape,
1611                            const float* input_data,
1612                            const RuntimeShape& output_size_shape,
1613                            const int32* output_size_data,
1614                            const RuntimeShape& unextended_output_shape,
1615                            float* output_data) {
1616   ruy::profiler::ScopeLabel label("ResizeBilinear");
1617   // If half_pixel_centers is True, align_corners must be False.
1618   TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
1619   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1620   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1621   const RuntimeShape input_shape =
1622       RuntimeShape::ExtendedShape(4, unextended_input_shape);
1623   const RuntimeShape output_shape =
1624       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1625 
1626   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
1627   int32 input_height = input_shape.Dims(1);
1628   int32 input_width = input_shape.Dims(2);
1629   int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
1630 
1631   TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
1632   int32 output_height = output_size_data[0];
1633   int32 output_width = output_size_data[1];
1634 
1635   // Specialize for 2x2 upsample.
1636   if (!op_params.align_corners && !op_params.half_pixel_centers &&
1637       output_height == 2 * input_height && output_width == 2 * input_width) {
1638     ResizeBilinear2x2(batches, input_height, input_width, depth, output_height,
1639                       output_width, input_shape, input_data, output_shape,
1640                       output_data);
1641   } else {
1642     float height_scale = static_cast<float>(input_height) / output_height;
1643     float width_scale = static_cast<float>(input_width) / output_width;
1644     if (op_params.align_corners && output_height > 1) {
1645       height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
1646     }
1647     if (op_params.align_corners && output_width > 1) {
1648       width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
1649     }
1650 
1651     ResizeBilinearGeneric(batches, input_height, input_width, depth,
1652                           output_height, output_width, height_scale,
1653                           width_scale, input_shape, input_data, output_shape,
1654                           output_data, op_params.half_pixel_centers);
1655   }
1656 }
1657 
1658 // Note: This is not a universal quantized bilinear. It does not use int8
1659 // or int16 arithmetic.
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)1660 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
1661                            const RuntimeShape& unextended_input_shape,
1662                            const uint8* input_data,
1663                            const RuntimeShape& output_size_shape,
1664                            const int32* output_size_data,
1665                            const RuntimeShape& unextended_output_shape,
1666                            uint8* output_data) {
1667   ruy::profiler::ScopeLabel label("ResizeBilinearUint8");
1668   // If half_pixel_centers is True, align_corners must be False.
1669   TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
1670   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1671   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1672   const RuntimeShape input_shape =
1673       RuntimeShape::ExtendedShape(4, unextended_input_shape);
1674   const RuntimeShape output_shape =
1675       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1676 
1677   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
1678   int32 input_height = input_shape.Dims(1);
1679   int32 input_width = input_shape.Dims(2);
1680   int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
1681 
1682   TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
1683   int32 output_height = output_size_data[0];
1684   int32 output_width = output_size_data[1];
1685 
1686   if (!op_params.align_corners && op_params.half_pixel_centers &&
1687       ((depth % 8) == 0)) {
1688     const int32 scale = output_height / input_height;
1689     // Restricting the minimum output dimensions may not be necessary, but
1690     // ensures that kernels can use unrolling with minimal code size.
1691     if ((output_height >= 8) && (output_width >= 8) &&
1692         ((input_height * scale) == output_height) &&
1693         ((input_width * scale) == output_width)) {
1694       if (scale == 8) {
1695         resize_bilinear::ResizeBilinear888Uint8(
1696             batches, input_height, input_width, depth, input_data, output_data);
1697         return;
1698       }
1699     }
1700   }
1701 
1702   float height_scale =
1703       (op_params.align_corners && output_height > 1)
1704           ? (static_cast<float>(input_height - 1) / (output_height - 1))
1705           : (static_cast<float>(input_height) / output_height);
1706 
1707   float width_scale =
1708       (op_params.align_corners && output_width > 1)
1709           ? (static_cast<float>(input_width - 1) / (output_width - 1))
1710           : (static_cast<float>(input_width) / output_width);
1711 
1712   ResizeBilinearGenericSmallChannel<uint8>(
1713       batches, input_height, input_width, depth, output_height, output_width,
1714       height_scale, width_scale, input_shape, input_data, output_shape,
1715       output_data, op_params.half_pixel_centers);
1716 }
1717 
1718 // TODO(b/180609127) Create optimized int8 version from uint8. Call from here.
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const int8 * input_data,const RuntimeShape & unextended_output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,int8 * output_data)1719 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
1720                            const RuntimeShape& unextended_input_shape,
1721                            const int8* input_data,
1722                            const RuntimeShape& unextended_output_size_shape,
1723                            const int32* output_size_data,
1724                            const RuntimeShape& unextended_output_shape,
1725                            int8* output_data) {
1726   reference_ops::ResizeBilinearInteger(op_params, unextended_input_shape,
1727                                        input_data, unextended_output_size_shape,
1728                                        output_size_data,
1729                                        unextended_output_shape, output_data);
1730 }
1731 
1732 }  // namespace optimized_ops
1733 }  // namespace tflite
1734 
1735 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_RESIZE_BILINEAR_H_
1736