xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/optimized_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
17 
18 #include <assert.h>
19 #include <stdint.h>
20 #include <sys/types.h>
21 
22 #include <algorithm>
23 #include <cmath>
24 #include <cstdint>
25 #include <limits>
26 #include <memory>
27 #include <tuple>
28 #include <type_traits>
29 
30 #include "tensorflow/lite/kernels/internal/common.h"
31 #include "tensorflow/lite/kernels/internal/compatibility.h"
32 #include "tensorflow/lite/kernels/internal/reference/add.h"
33 #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
34 
35 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
36 #include <Accelerate/Accelerate.h>
37 #endif
38 
39 #include "Eigen/Core"
40 #include "unsupported/Eigen/CXX11/Tensor"
41 #include "fixedpoint/fixedpoint.h"
42 #include "ruy/profiler/instrumentation.h"  // from @ruy
43 #include "tensorflow/lite/c/common.h"
44 #include "tensorflow/lite/kernels/cpu_backend_context.h"
45 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
46 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
47 #include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
48 #include "tensorflow/lite/kernels/internal/cppmath.h"
49 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
50 #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
51 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops_utils.h"
52 #include "tensorflow/lite/kernels/internal/quantization_util.h"
53 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
54 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
55 #include "tensorflow/lite/kernels/internal/tensor.h"
56 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
57 #include "tensorflow/lite/kernels/internal/transpose_utils.h"
58 #include "tensorflow/lite/kernels/internal/types.h"
59 
60 #if __aarch64__ && __clang__
61 #define TFLITE_SOFTMAX_USE_UINT16_LUT
62 #endif
63 
64 namespace tflite {
65 namespace optimized_ops {
66 
67 // Unoptimized reference ops:
68 using reference_ops::Broadcast4DSlowGreater;
69 using reference_ops::Broadcast4DSlowGreaterEqual;
70 using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
71 using reference_ops::Broadcast4DSlowGreaterWithScaling;
72 using reference_ops::Broadcast4DSlowLess;
73 using reference_ops::Broadcast4DSlowLessEqual;
74 using reference_ops::Broadcast4DSlowLessEqualWithScaling;
75 using reference_ops::Broadcast4DSlowLessWithScaling;
76 using reference_ops::BroadcastAdd4DSlow;
77 using reference_ops::BroadcastMul4DSlow;
78 using reference_ops::BroadcastSub16POTSlow;
79 using reference_ops::BroadcastSubSlow;
80 using reference_ops::Concatenation;
81 using reference_ops::ConcatenationWithScaling;
82 using reference_ops::DepthConcatenation;
83 using reference_ops::Div;
84 using reference_ops::Elu;
85 using reference_ops::FakeQuant;
86 using reference_ops::Fill;
87 using reference_ops::Gather;
88 using reference_ops::Greater;
89 using reference_ops::GreaterEqual;
90 using reference_ops::GreaterEqualWithScaling;
91 using reference_ops::GreaterWithScaling;
92 using reference_ops::LeakyRelu;
93 using reference_ops::Less;
94 using reference_ops::LessEqual;
95 using reference_ops::LessEqualWithScaling;
96 using reference_ops::LessWithScaling;
97 using reference_ops::ProcessBroadcastShapes;
98 using reference_ops::RankOneSelect;
99 using reference_ops::Relu0To1;  // NOLINT
100 using reference_ops::Relu1;
101 using reference_ops::Relu6;
102 using reference_ops::ReluX;
103 using reference_ops::Round;
104 using reference_ops::Select;
105 using reference_ops::SpaceToBatchND;
106 using reference_ops::Split;
107 using reference_ops::Sub16;
108 
109 // TODO(b/80247582) Remove this constant.
110 // This will be phased out as the shifts are revised with more thought. Use of a
111 // constant enables us to track progress on this work.
112 //
113 // Used to convert from old-style shifts (right) to new-style (left).
114 static constexpr int kReverseShift = -1;
115 
116 // Copied from tensorflow/core/framework/tensor_types.h
117 template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
118 struct TTypes {
119   // Rank-1 tensor (vector) of scalar type T.
120   typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
121                            Eigen::Aligned>
122       Flat;
123   typedef Eigen::TensorMap<
124       Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>>
125       UnalignedConstMatrix;
126 };
127 
128 // TODO(b/62193649): this function is only needed as long
129 // as we have the --variable_batch hack.
130 template <typename Scalar>
MapAsMatrixWithGivenNumberOfRows(Scalar * data,const RuntimeShape & shape,int rows)131 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
132                                                    const RuntimeShape& shape,
133                                                    int rows) {
134   const int flatsize = shape.FlatSize();
135   TFLITE_DCHECK_EQ(flatsize % rows, 0);
136   const int cols = flatsize / rows;
137   return MatrixMap<Scalar>(data, rows, cols);
138 }
139 
140 template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
BinaryBroadcastFiveFold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const T * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const T * unswitched_input2_data,const RuntimeShape & output_shape,T * output_data,ElementwiseF elementwise_f,ScalarBroadcastF scalar_broadcast_f)141 inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params,
142                                     const RuntimeShape& unswitched_input1_shape,
143                                     const T* unswitched_input1_data,
144                                     const RuntimeShape& unswitched_input2_shape,
145                                     const T* unswitched_input2_data,
146                                     const RuntimeShape& output_shape,
147                                     T* output_data, ElementwiseF elementwise_f,
148                                     ScalarBroadcastF scalar_broadcast_f) {
149   ArithmeticParams switched_params = unswitched_params;
150   switched_params.input1_offset = unswitched_params.input2_offset;
151   switched_params.input1_multiplier = unswitched_params.input2_multiplier;
152   switched_params.input1_shift = unswitched_params.input2_shift;
153   switched_params.input2_offset = unswitched_params.input1_offset;
154   switched_params.input2_multiplier = unswitched_params.input1_multiplier;
155   switched_params.input2_shift = unswitched_params.input1_shift;
156 
157   const bool use_unswitched =
158       unswitched_params.broadcast_category ==
159       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
160 
161   const ArithmeticParams& params =
162       use_unswitched ? unswitched_params : switched_params;
163   const T* input1_data =
164       use_unswitched ? unswitched_input1_data : unswitched_input2_data;
165   const T* input2_data =
166       use_unswitched ? unswitched_input2_data : unswitched_input1_data;
167 
168   // Fivefold nested loops. The second input resets its position for each
169   // iteration of the second loop. The first input resets its position at the
170   // beginning of the fourth loop. The innermost loop is an elementwise add of
171   // sections of the arrays.
172   T* output_data_ptr = output_data;
173   const T* input1_data_ptr = input1_data;
174   const T* input2_data_reset = input2_data;
175   // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
176   // between input shapes. y3 for input 1 is always broadcast, and so the
177   // dimension there is 1, whereas optionally y1 might be broadcast for
178   // input 2. Put another way, input1.shape.FlatSize = y0 * y1 * y2 * y4,
179   // input2.shape.FlatSize = y0 * y2 * y3 * y4.
180   int y0 = params.broadcast_shape[0];
181   int y1 = params.broadcast_shape[1];
182   int y2 = params.broadcast_shape[2];
183   int y3 = params.broadcast_shape[3];
184   int y4 = params.broadcast_shape[4];
185   if (y4 > 1) {
186     // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
187     // dimension.
188     for (int i0 = 0; i0 < y0; ++i0) {
189       const T* input2_data_ptr = nullptr;
190       for (int i1 = 0; i1 < y1; ++i1) {
191         input2_data_ptr = input2_data_reset;
192         for (int i2 = 0; i2 < y2; ++i2) {
193           for (int i3 = 0; i3 < y3; ++i3) {
194             elementwise_f(y4, params, input1_data_ptr, input2_data_ptr,
195                           output_data_ptr);
196             input2_data_ptr += y4;
197             output_data_ptr += y4;
198           }
199           // We have broadcast y4 of input1 data y3 times, and now move on.
200           input1_data_ptr += y4;
201         }
202       }
203       // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
204       input2_data_reset = input2_data_ptr;
205     }
206   } else if (input1_data_ptr != nullptr) {
207     // Special case of y4 == 1, in which the innermost loop is a single
208     // element and can be combined with the next (y3) as an inner broadcast.
209     //
210     // Note that this handles the case of pure scalar broadcast when
211     // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
212     // broadcast with batch (as y2 > 1).
213     //
214     // NOTE The process is the same as the above general case except
215     // simplified for y4 == 1 and the loop over y3 is contained within the
216     // AddScalarBroadcast function.
217     for (int i0 = 0; i0 < y0; ++i0) {
218       const T* input2_data_ptr = nullptr;
219       for (int i1 = 0; i1 < y1; ++i1) {
220         input2_data_ptr = input2_data_reset;
221         for (int i2 = 0; i2 < y2; ++i2) {
222           scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr,
223                              output_data_ptr);
224           input2_data_ptr += y3;
225           output_data_ptr += y3;
226           input1_data_ptr += 1;
227         }
228       }
229       input2_data_reset = input2_data_ptr;
230     }
231   }
232 }
233 
234 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
235 
236 // Looks up each element of <indices> in <table>, returns them in a vector.
aarch64_lookup_vector(const uint8x16x4_t table[4],uint8x16_t indices)237 inline uint8x16_t aarch64_lookup_vector(const uint8x16x4_t table[4],
238                                         uint8x16_t indices) {
239   // Look up in 1st quarter of the table: top 2 bits of indices == 00
240   uint8x16_t output1 = vqtbl4q_u8(table[0], indices);
241   // Look up in 2nd quarter of the table: top 2 bits of indices == 01
242   uint8x16_t output2 =
243       vqtbl4q_u8(table[1], veorq_u8(indices, vdupq_n_u8(0x40)));
244   // Look up in 3rd quarter of the table: top 2 bits of indices == 10
245   uint8x16_t output3 =
246       vqtbl4q_u8(table[2], veorq_u8(indices, vdupq_n_u8(0x80)));
247   // Look up in 4th quarter of the table: top 2 bits of indices == 11
248   uint8x16_t output4 =
249       vqtbl4q_u8(table[3], veorq_u8(indices, vdupq_n_u8(0xc0)));
250 
251   // Combine result of the 4 lookups.
252   return vorrq_u8(vorrq_u8(output1, output2), vorrq_u8(output3, output4));
253 }
254 
255 #endif
256 
AddBiasAndEvalActivationFunction(float output_activation_min,float output_activation_max,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & array_shape,float * array_data)257 inline void AddBiasAndEvalActivationFunction(float output_activation_min,
258                                              float output_activation_max,
259                                              const RuntimeShape& bias_shape,
260                                              const float* bias_data,
261                                              const RuntimeShape& array_shape,
262                                              float* array_data) {
263   BiasAndClamp(output_activation_min, output_activation_max,
264                bias_shape.FlatSize(), bias_data, array_shape.FlatSize(),
265                array_data);
266 }
267 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * optional_bias_data,const RuntimeShape & output_shape,float * output_data,CpuBackendContext * cpu_backend_context)268 inline void FullyConnected(
269     const FullyConnectedParams& params, const RuntimeShape& input_shape,
270     const float* input_data, const RuntimeShape& weights_shape,
271     const float* weights_data, const RuntimeShape& bias_shape,
272     const float* optional_bias_data, const RuntimeShape& output_shape,
273     float* output_data, CpuBackendContext* cpu_backend_context) {
274   ruy::profiler::ScopeLabel label("FullyConnected");
275   const int dims_count = weights_shape.DimensionsCount();
276   const int input_rows = weights_shape.Dims(dims_count - 1);
277   cpu_backend_gemm::MatrixParams<float> rhs_params;
278   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
279   rhs_params.rows = input_rows;
280   rhs_params.cols = input_shape.FlatSize() / input_rows;
281   rhs_params.cache_policy =
282       cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
283   TFLITE_DCHECK_EQ(input_shape.FlatSize(), rhs_params.rows * rhs_params.cols);
284   cpu_backend_gemm::MatrixParams<float> lhs_params;
285   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
286   lhs_params.cols = weights_shape.Dims(dims_count - 1);
287   lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
288   lhs_params.cache_policy =
289       cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
290   cpu_backend_gemm::MatrixParams<float> dst_params;
291   dst_params.order = cpu_backend_gemm::Order::kColMajor;
292   dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
293   dst_params.cols =
294       FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
295   cpu_backend_gemm::GemmParams<float, float> gemm_params;
296   gemm_params.bias = optional_bias_data;
297   gemm_params.clamp_min = params.float_activation_min;
298   gemm_params.clamp_max = params.float_activation_max;
299   cpu_backend_gemm::Gemm(lhs_params, weights_data, rhs_params, input_data,
300                          dst_params, output_data, gemm_params,
301                          cpu_backend_context);
302 }
303 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,CpuBackendContext * cpu_backend_context)304 inline void FullyConnected(
305     const FullyConnectedParams& params, const RuntimeShape& input_shape,
306     const uint8* input_data, const RuntimeShape& filter_shape,
307     const uint8* filter_data, const RuntimeShape& bias_shape,
308     const int32* bias_data, const RuntimeShape& output_shape,
309     uint8* output_data, CpuBackendContext* cpu_backend_context) {
310   ruy::profiler::ScopeLabel label("FullyConnected/8bit");
311   const int32 input_offset = params.input_offset;
312   const int32 filter_offset = params.weights_offset;
313   const int32 output_offset = params.output_offset;
314   const int32 output_multiplier = params.output_multiplier;
315   const int output_shift = params.output_shift;
316   const int32 output_activation_min = params.quantized_activation_min;
317   const int32 output_activation_max = params.quantized_activation_max;
318   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
319   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
320   // TODO(b/62193649): This really should be:
321   //     const int batches = ArraySize(output_dims, 1);
322   // but the current --variable_batch hack consists in overwriting the 3rd
323   // dimension with the runtime batch size, as we don't keep track for each
324   // array of which dimension is the batch dimension in it.
325   const int output_dim_count = output_shape.DimensionsCount();
326   const int filter_dim_count = filter_shape.DimensionsCount();
327   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
328   const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
329   const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
330   TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
331   const int output_rows = output_shape.Dims(output_dim_count - 1);
332   TFLITE_DCHECK_EQ(output_rows, filter_rows);
333   if (bias_data) {
334     TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
335   }
336 
337   cpu_backend_gemm::MatrixParams<uint8> lhs_params;
338   lhs_params.rows = filter_rows;
339   lhs_params.cols = filter_cols;
340   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
341   lhs_params.zero_point = -filter_offset;
342   lhs_params.cache_policy =
343       cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
344   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
345   rhs_params.rows = filter_cols;
346   rhs_params.cols = batches;
347   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
348   rhs_params.zero_point = -input_offset;
349   rhs_params.cache_policy =
350       cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
351   cpu_backend_gemm::MatrixParams<uint8> dst_params;
352   dst_params.rows = filter_rows;
353   dst_params.cols = batches;
354   dst_params.order = cpu_backend_gemm::Order::kColMajor;
355   dst_params.zero_point = output_offset;
356   cpu_backend_gemm::GemmParams<int32, uint8> gemm_params;
357   gemm_params.bias = bias_data;
358   gemm_params.clamp_min = output_activation_min;
359   gemm_params.clamp_max = output_activation_max;
360   gemm_params.multiplier_fixedpoint = output_multiplier;
361   gemm_params.multiplier_exponent = output_shift;
362   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
363                          dst_params, output_data, gemm_params,
364                          cpu_backend_context);
365 }
366 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data_int32,const RuntimeShape & output_shape,int16 * output_data,CpuBackendContext * cpu_backend_context)367 inline void FullyConnected(
368     const FullyConnectedParams& params, const RuntimeShape& input_shape,
369     const uint8* input_data, const RuntimeShape& filter_shape,
370     const uint8* filter_data, const RuntimeShape& bias_shape,
371     const int32* bias_data_int32, const RuntimeShape& output_shape,
372     int16* output_data, CpuBackendContext* cpu_backend_context) {
373   ruy::profiler::ScopeLabel label("FullyConnected/Uint8Int16");
374   const int32 input_offset = params.input_offset;
375   const int32 filter_offset = params.weights_offset;
376   const int32 output_offset = params.output_offset;
377   const int32 output_multiplier = params.output_multiplier;
378   const int output_shift = params.output_shift;
379   const int32 output_activation_min = params.quantized_activation_min;
380   const int32 output_activation_max = params.quantized_activation_max;
381   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
382   TFLITE_DCHECK_EQ(output_offset, 0);
383   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
384   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
385 
386   // TODO(b/62193649): This really should be:
387   //     const int batches = ArraySize(output_dims, 1);
388   // but the current --variable_batch hack consists in overwriting the 3rd
389   // dimension with the runtime batch size, as we don't keep track for each
390   // array of which dimension is the batch dimension in it.
391   const int output_dim_count = output_shape.DimensionsCount();
392   const int filter_dim_count = filter_shape.DimensionsCount();
393   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
394   const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
395                                        output_shape, output_dim_count - 1);
396   const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
397 
398   cpu_backend_gemm::MatrixParams<uint8> lhs_params;
399   lhs_params.rows = output_depth;
400   lhs_params.cols = accum_depth;
401   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
402   lhs_params.zero_point = -filter_offset;
403   lhs_params.cache_policy =
404       cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
405   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
406   rhs_params.rows = accum_depth;
407   rhs_params.cols = batches;
408   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
409   rhs_params.zero_point = -input_offset;
410   rhs_params.cache_policy =
411       cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
412   cpu_backend_gemm::MatrixParams<int16> dst_params;
413   dst_params.rows = output_depth;
414   dst_params.cols = batches;
415   dst_params.order = cpu_backend_gemm::Order::kColMajor;
416   dst_params.zero_point = 0;
417   cpu_backend_gemm::GemmParams<int32, int16> gemm_params;
418   gemm_params.bias = bias_data_int32;
419   gemm_params.clamp_min = output_activation_min;
420   gemm_params.clamp_max = output_activation_max;
421   gemm_params.multiplier_fixedpoint = output_multiplier;
422   gemm_params.multiplier_exponent = output_shift;
423   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
424                          dst_params, output_data, gemm_params,
425                          cpu_backend_context);
426 }
427 
428 // Internal function doing the actual arithmetic work for
429 // ShuffledFullyConnected.
430 // May be called either directly by it (single-threaded case) or may be used
431 // as the 'task' for worker threads to run (multi-threaded case, see
432 // ShuffledFullyConnectedWorkerTask below).
ShuffledFullyConnectedWorkerImpl(const uint8 * shuffled_input_workspace_data,const int8 * shuffled_weights_data,int batches,int output_depth,int output_stride,int accum_depth,const int32 * bias_data,int32 output_multiplier,int output_shift,int16 * output_data)433 inline void ShuffledFullyConnectedWorkerImpl(
434     const uint8* shuffled_input_workspace_data,
435     const int8* shuffled_weights_data, int batches, int output_depth,
436     int output_stride, int accum_depth, const int32* bias_data,
437     int32 output_multiplier, int output_shift, int16* output_data) {
438 #if defined USE_NEON
439   const int8* shuffled_weights_ptr = shuffled_weights_data;
440   if (batches == 1) {
441     const int right_shift = output_shift > 0 ? 0 : -output_shift;
442     const int left_shift = output_shift > 0 ? output_shift : 0;
443     for (int c = 0; c < output_depth; c += 4) {
444       // Accumulation loop.
445       int32x4_t row_accum0 = vdupq_n_s32(0);
446       int32x4_t row_accum1 = vdupq_n_s32(0);
447       int32x4_t row_accum2 = vdupq_n_s32(0);
448       int32x4_t row_accum3 = vdupq_n_s32(0);
449       for (int d = 0; d < accum_depth; d += 16) {
450         int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
451         int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
452         int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
453         int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
454         shuffled_weights_ptr += 64;
455         int8x16_t input =
456             vreinterpretq_s8_u8(vld1q_u8(shuffled_input_workspace_data + d));
457         int16x8_t local_accum0 =
458             vmull_s8(vget_low_s8(weights0), vget_low_s8(input));
459         int16x8_t local_accum1 =
460             vmull_s8(vget_low_s8(weights1), vget_low_s8(input));
461         int16x8_t local_accum2 =
462             vmull_s8(vget_low_s8(weights2), vget_low_s8(input));
463         int16x8_t local_accum3 =
464             vmull_s8(vget_low_s8(weights3), vget_low_s8(input));
465         local_accum0 =
466             vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input));
467         local_accum1 =
468             vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input));
469         local_accum2 =
470             vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input));
471         local_accum3 =
472             vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input));
473         row_accum0 = vpadalq_s16(row_accum0, local_accum0);
474         row_accum1 = vpadalq_s16(row_accum1, local_accum1);
475         row_accum2 = vpadalq_s16(row_accum2, local_accum2);
476         row_accum3 = vpadalq_s16(row_accum3, local_accum3);
477       }
478       // Horizontally reduce accumulators
479       int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
480           pairwise_reduced_acc_2, pairwise_reduced_acc_3;
481       pairwise_reduced_acc_0 =
482           vpadd_s32(vget_low_s32(row_accum0), vget_high_s32(row_accum0));
483       pairwise_reduced_acc_1 =
484           vpadd_s32(vget_low_s32(row_accum1), vget_high_s32(row_accum1));
485       pairwise_reduced_acc_2 =
486           vpadd_s32(vget_low_s32(row_accum2), vget_high_s32(row_accum2));
487       pairwise_reduced_acc_3 =
488           vpadd_s32(vget_low_s32(row_accum3), vget_high_s32(row_accum3));
489       const int32x2_t reduced_lo =
490           vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
491       const int32x2_t reduced_hi =
492           vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
493       int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
494       // Add bias values.
495       int32x4_t bias_vec = vld1q_s32(bias_data + c);
496       reduced = vaddq_s32(reduced, bias_vec);
497       reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
498       // Multiply by the fixed-point multiplier.
499       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
500       // Rounding-shift-right.
501       using gemmlowp::RoundingDivideByPOT;
502       reduced = RoundingDivideByPOT(reduced, right_shift);
503       // Narrow values down to 16 bit signed.
504       const int16x4_t res16 = vqmovn_s32(reduced);
505       vst1_s16(output_data + c, res16);
506     }
507   } else if (batches == 4) {
508     const int right_shift = output_shift > 0 ? 0 : -output_shift;
509     const int left_shift = output_shift > 0 ? output_shift : 0;
510     for (int c = 0; c < output_depth; c += 4) {
511       const int8* shuffled_input_ptr =
512           reinterpret_cast<const int8*>(shuffled_input_workspace_data);
513       // Accumulation loop.
514       int32x4_t row_accum00 = vdupq_n_s32(0);
515       int32x4_t row_accum10 = vdupq_n_s32(0);
516       int32x4_t row_accum20 = vdupq_n_s32(0);
517       int32x4_t row_accum30 = vdupq_n_s32(0);
518       int32x4_t row_accum01 = vdupq_n_s32(0);
519       int32x4_t row_accum11 = vdupq_n_s32(0);
520       int32x4_t row_accum21 = vdupq_n_s32(0);
521       int32x4_t row_accum31 = vdupq_n_s32(0);
522       int32x4_t row_accum02 = vdupq_n_s32(0);
523       int32x4_t row_accum12 = vdupq_n_s32(0);
524       int32x4_t row_accum22 = vdupq_n_s32(0);
525       int32x4_t row_accum32 = vdupq_n_s32(0);
526       int32x4_t row_accum03 = vdupq_n_s32(0);
527       int32x4_t row_accum13 = vdupq_n_s32(0);
528       int32x4_t row_accum23 = vdupq_n_s32(0);
529       int32x4_t row_accum33 = vdupq_n_s32(0);
530       for (int d = 0; d < accum_depth; d += 16) {
531         int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
532         int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
533         int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
534         int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
535         shuffled_weights_ptr += 64;
536         int8x16_t input0 = vld1q_s8(shuffled_input_ptr + 0);
537         int8x16_t input1 = vld1q_s8(shuffled_input_ptr + 16);
538         int8x16_t input2 = vld1q_s8(shuffled_input_ptr + 32);
539         int8x16_t input3 = vld1q_s8(shuffled_input_ptr + 48);
540         shuffled_input_ptr += 64;
541         int16x8_t local_accum0, local_accum1, local_accum2, local_accum3;
542 #define TFLITE_SHUFFLED_FC_ACCUM(B)                                           \
543   local_accum0 = vmull_s8(vget_low_s8(weights0), vget_low_s8(input##B));      \
544   local_accum1 = vmull_s8(vget_low_s8(weights1), vget_low_s8(input##B));      \
545   local_accum2 = vmull_s8(vget_low_s8(weights2), vget_low_s8(input##B));      \
546   local_accum3 = vmull_s8(vget_low_s8(weights3), vget_low_s8(input##B));      \
547   local_accum0 =                                                              \
548       vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input##B)); \
549   local_accum1 =                                                              \
550       vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input##B)); \
551   local_accum2 =                                                              \
552       vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input##B)); \
553   local_accum3 =                                                              \
554       vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input##B)); \
555   row_accum0##B = vpadalq_s16(row_accum0##B, local_accum0);                   \
556   row_accum1##B = vpadalq_s16(row_accum1##B, local_accum1);                   \
557   row_accum2##B = vpadalq_s16(row_accum2##B, local_accum2);                   \
558   row_accum3##B = vpadalq_s16(row_accum3##B, local_accum3);
559 
560         TFLITE_SHUFFLED_FC_ACCUM(0)
561         TFLITE_SHUFFLED_FC_ACCUM(1)
562         TFLITE_SHUFFLED_FC_ACCUM(2)
563         TFLITE_SHUFFLED_FC_ACCUM(3)
564 
565 #undef TFLITE_SHUFFLED_FC_ACCUM
566       }
567       // Horizontally reduce accumulators
568 
569 #define TFLITE_SHUFFLED_FC_STORE(B)                                           \
570   {                                                                           \
571     int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,                 \
572         pairwise_reduced_acc_2, pairwise_reduced_acc_3;                       \
573     pairwise_reduced_acc_0 =                                                  \
574         vpadd_s32(vget_low_s32(row_accum0##B), vget_high_s32(row_accum0##B)); \
575     pairwise_reduced_acc_1 =                                                  \
576         vpadd_s32(vget_low_s32(row_accum1##B), vget_high_s32(row_accum1##B)); \
577     pairwise_reduced_acc_2 =                                                  \
578         vpadd_s32(vget_low_s32(row_accum2##B), vget_high_s32(row_accum2##B)); \
579     pairwise_reduced_acc_3 =                                                  \
580         vpadd_s32(vget_low_s32(row_accum3##B), vget_high_s32(row_accum3##B)); \
581     const int32x2_t reduced_lo =                                              \
582         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);            \
583     const int32x2_t reduced_hi =                                              \
584         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);            \
585     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);                 \
586     int32x4_t bias_vec = vld1q_s32(bias_data + c);                            \
587     reduced = vaddq_s32(reduced, bias_vec);                                   \
588     reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));                    \
589     reduced = vqrdmulhq_n_s32(reduced, output_multiplier);                    \
590     using gemmlowp::RoundingDivideByPOT;                                      \
591     reduced = RoundingDivideByPOT(reduced, right_shift);                      \
592     const int16x4_t res16 = vqmovn_s32(reduced);                              \
593     vst1_s16(output_data + c + B * output_stride, res16);                     \
594   }
595 
596       TFLITE_SHUFFLED_FC_STORE(0);
597       TFLITE_SHUFFLED_FC_STORE(1);
598       TFLITE_SHUFFLED_FC_STORE(2);
599       TFLITE_SHUFFLED_FC_STORE(3);
600 
601 #undef TFLITE_SHUFFLED_FC_STORE
602     }
603   } else {
604     TFLITE_DCHECK(false);
605     return;
606   }
607 #else
608   if (batches == 1) {
609     int16* output_ptr = output_data;
610     // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
611     // so that just reinterpreting them as int8 values is equivalent to
612     // subtracting 128 from them, thus implementing for free the subtraction of
613     // the zero_point value 128.
614     const int8* shuffled_weights_ptr =
615         reinterpret_cast<const int8*>(shuffled_weights_data);
616     // Likewise, we preshuffled and pre-xored the input data above.
617     const int8* shuffled_input_data =
618         reinterpret_cast<const int8*>(shuffled_input_workspace_data);
619     for (int c = 0; c < output_depth; c += 4) {
620       // Internal accumulation.
621       // Initialize accumulator with the bias-value.
622       int32 accum[4] = {0};
623       // Accumulation loop.
624       for (int d = 0; d < accum_depth; d += 16) {
625         for (int i = 0; i < 4; i++) {
626           for (int j = 0; j < 16; j++) {
627             int8 input_val = shuffled_input_data[d + j];
628             int8 weights_val = *shuffled_weights_ptr++;
629             accum[i] += weights_val * input_val;
630           }
631         }
632       }
633       for (int i = 0; i < 4; i++) {
634         // Add bias value
635         int acc = accum[i] + bias_data[c + i];
636         // Down-scale the final int32 accumulator to the scale used by our
637         // (16-bit, typically 3 integer bits) fixed-point format. The quantized
638         // multiplier and shift here have been pre-computed offline
639         // (e.g. by toco).
640         acc =
641             MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
642         // Saturate, cast to int16, and store to output array.
643         acc = std::max(acc, -32768);
644         acc = std::min(acc, 32767);
645         output_ptr[c + i] = acc;
646       }
647     }
648   } else if (batches == 4) {
649     int16* output_ptr = output_data;
650     // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
651     // so that just reinterpreting them as int8 values is equivalent to
652     // subtracting 128 from them, thus implementing for free the subtraction of
653     // the zero_point value 128.
654     const int8* shuffled_weights_ptr =
655         reinterpret_cast<const int8*>(shuffled_weights_data);
656     // Likewise, we preshuffled and pre-xored the input data above.
657     const int8* shuffled_input_data =
658         reinterpret_cast<const int8*>(shuffled_input_workspace_data);
659     for (int c = 0; c < output_depth; c += 4) {
660       const int8* shuffled_input_ptr = shuffled_input_data;
661       // Accumulation loop.
662       // Internal accumulation.
663       // Initialize accumulator with the bias-value.
664       int32 accum[4][4];
665       for (int i = 0; i < 4; i++) {
666         for (int b = 0; b < 4; b++) {
667           accum[i][b] = 0;
668         }
669       }
670       for (int d = 0; d < accum_depth; d += 16) {
671         for (int i = 0; i < 4; i++) {
672           for (int b = 0; b < 4; b++) {
673             for (int j = 0; j < 16; j++) {
674               int8 input_val = shuffled_input_ptr[16 * b + j];
675               int8 weights_val = shuffled_weights_ptr[16 * i + j];
676               accum[i][b] += weights_val * input_val;
677             }
678           }
679         }
680         shuffled_input_ptr += 64;
681         shuffled_weights_ptr += 64;
682       }
683       for (int i = 0; i < 4; i++) {
684         for (int b = 0; b < 4; b++) {
685           // Add bias value
686           int acc = accum[i][b] + bias_data[c + i];
687           // Down-scale the final int32 accumulator to the scale used by our
688           // (16-bit, typically 3 integer bits) fixed-point format. The
689           // quantized multiplier and shift here have been pre-computed offline
690           // (e.g. by toco).
691           acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
692                                               output_shift);
693           // Saturate, cast to int16, and store to output array.
694           acc = std::max(acc, -32768);
695           acc = std::min(acc, 32767);
696           output_ptr[b * output_stride + c + i] = acc;
697         }
698       }
699     }
700   } else {
701     TFLITE_DCHECK(false);
702     return;
703   }
704 #endif
705 }
706 
707 // Wraps ShuffledFullyConnectedWorkerImpl into a Task class
708 // to allow using gemmlowp's threadpool.
709 struct ShuffledFullyConnectedWorkerTask : cpu_backend_threadpool::Task {
ShuffledFullyConnectedWorkerTaskShuffledFullyConnectedWorkerTask710   ShuffledFullyConnectedWorkerTask(const uint8* input_data,
711                                    const int8* shuffled_weights_data,
712                                    int batches, int output_depth,
713                                    int output_stride, int accum_depth,
714                                    const int32* bias_data,
715                                    int32 output_multiplier, int output_shift,
716                                    int16* output_data)
717       : input_data_(input_data),
718         shuffled_weights_data_(shuffled_weights_data),
719         batches_(batches),
720         output_depth_(output_depth),
721         output_stride_(output_stride),
722         accum_depth_(accum_depth),
723         bias_data_(bias_data),
724         output_multiplier_(output_multiplier),
725         output_shift_(output_shift),
726         output_data_(output_data) {}
727 
RunShuffledFullyConnectedWorkerTask728   void Run() override {
729     ShuffledFullyConnectedWorkerImpl(
730         input_data_, shuffled_weights_data_, batches_, output_depth_,
731         output_stride_, accum_depth_, bias_data_, output_multiplier_,
732         output_shift_, output_data_);
733   }
734 
735   const uint8* input_data_;
736   const int8* shuffled_weights_data_;
737   int batches_;
738   int output_depth_;
739   int output_stride_;
740   int accum_depth_;
741   const int32* bias_data_;
742   int32 output_multiplier_;
743   int output_shift_;
744   int16* output_data_;
745 };
746 
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,CpuBackendContext * cpu_backend_context)747 inline void ShuffledFullyConnected(
748     const FullyConnectedParams& params, const RuntimeShape& input_shape,
749     const uint8* input_data, const RuntimeShape& weights_shape,
750     const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
751     const int32* bias_data, const RuntimeShape& output_shape,
752     int16* output_data, uint8* shuffled_input_workspace_data,
753     CpuBackendContext* cpu_backend_context) {
754   ruy::profiler::ScopeLabel label("ShuffledFullyConnected/8bit");
755   const int32 output_multiplier = params.output_multiplier;
756   const int output_shift = params.output_shift;
757   const int32 output_activation_min = params.quantized_activation_min;
758   const int32 output_activation_max = params.quantized_activation_max;
759   TFLITE_DCHECK_EQ(output_activation_min, -32768);
760   TFLITE_DCHECK_EQ(output_activation_max, 32767);
761   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
762   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
763   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
764   // TODO(b/62193649): This really should be:
765   //     const int batches = ArraySize(output_dims, 1);
766   // but the current --variable_batch hack consists in overwriting the 3rd
767   // dimension with the runtime batch size, as we don't keep track for each
768   // array of which dimension is the batch dimension in it.
769   const int output_dim_count = output_shape.DimensionsCount();
770   const int weights_dim_count = weights_shape.DimensionsCount();
771   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
772   const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
773                                        output_shape, output_dim_count - 1);
774   const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
775   TFLITE_DCHECK((accum_depth % 16) == 0);
776   TFLITE_DCHECK((output_depth % 4) == 0);
777   // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
778   // so that just reinterpreting them as int8 values is equivalent to
779   // subtracting 128 from them, thus implementing for free the subtraction of
780   // the zero_point value 128.
781   const int8* int8_shuffled_weights_data =
782       reinterpret_cast<const int8*>(shuffled_weights_data);
783 
784   // Shuffling and xoring of input activations into the workspace buffer
785   if (batches == 1) {
786 #ifdef USE_NEON
787     const uint8x16_t signbit = vdupq_n_u8(0x80);
788     for (int i = 0; i < accum_depth; i += 16) {
789       uint8x16_t val = vld1q_u8(input_data + i);
790       val = veorq_u8(val, signbit);
791       vst1q_u8(shuffled_input_workspace_data + i, val);
792     }
793 #else
794     for (int i = 0; i < accum_depth; i++) {
795       shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
796     }
797 #endif
798   } else if (batches == 4) {
799     uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
800     int c = 0;
801 #ifdef USE_NEON
802     const uint8x16_t signbit = vdupq_n_u8(0x80);
803     for (c = 0; c < accum_depth; c += 16) {
804       const uint8* src_data_ptr = input_data + c;
805       uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth);
806       uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth);
807       uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth);
808       uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth);
809       val0 = veorq_u8(val0, signbit);
810       val1 = veorq_u8(val1, signbit);
811       val2 = veorq_u8(val2, signbit);
812       val3 = veorq_u8(val3, signbit);
813       vst1q_u8(shuffled_input_workspace_ptr + 0, val0);
814       vst1q_u8(shuffled_input_workspace_ptr + 16, val1);
815       vst1q_u8(shuffled_input_workspace_ptr + 32, val2);
816       vst1q_u8(shuffled_input_workspace_ptr + 48, val3);
817       shuffled_input_workspace_ptr += 64;
818     }
819 #else
820     for (c = 0; c < accum_depth; c += 16) {
821       for (int b = 0; b < 4; b++) {
822         const uint8* src_data_ptr = input_data + b * accum_depth + c;
823         for (int j = 0; j < 16; j++) {
824           uint8 src_val = *src_data_ptr++;
825           // Flip the sign bit, so that the kernel will only need to
826           // reinterpret these uint8 values as int8, getting for free the
827           // subtraction of the zero_point value 128.
828           uint8 dst_val = src_val ^ 0x80;
829           *shuffled_input_workspace_ptr++ = dst_val;
830         }
831       }
832     }
833 #endif
834   } else {
835     TFLITE_DCHECK(false);
836     return;
837   }
838 
839   static constexpr int kKernelRows = 4;
840   const int thread_count =
841       LegacyHowManyThreads<kKernelRows>(cpu_backend_context->max_num_threads(),
842                                         output_depth, batches, accum_depth);
843   if (thread_count == 1) {
844     // Single-thread case: do the computation on the current thread, don't
845     // use a threadpool
846     ShuffledFullyConnectedWorkerImpl(
847         shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
848         output_depth, output_depth, accum_depth, bias_data, output_multiplier,
849         output_shift, output_data);
850     return;
851   }
852 
853   // Multi-threaded case: use the gemmlowp context's threadpool.
854   TFLITE_DCHECK_GT(thread_count, 1);
855   std::vector<ShuffledFullyConnectedWorkerTask> tasks;
856   // TODO(b/131746020) don't create new heap allocations every time.
857   // At least we make it a single heap allocation by using reserve().
858   tasks.reserve(thread_count);
859   const int kRowsPerWorker =
860       RoundUp<kKernelRows>(CeilQuotient(output_depth, thread_count));
861   int row_start = 0;
862   for (int i = 0; i < thread_count; i++) {
863     int row_end = std::min(output_depth, row_start + kRowsPerWorker);
864     tasks.emplace_back(shuffled_input_workspace_data,
865                        int8_shuffled_weights_data + row_start * accum_depth,
866                        batches, row_end - row_start, output_depth, accum_depth,
867                        bias_data + row_start, output_multiplier, output_shift,
868                        output_data + row_start);
869     row_start = row_end;
870   }
871   TFLITE_DCHECK_EQ(row_start, output_depth);
872   cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
873                                   cpu_backend_context);
874 }
875 
876 #ifdef USE_NEON
877 
RoundToNearest(const float32x4_t input)878 inline int32x4_t RoundToNearest(const float32x4_t input) {
879 #if defined(__aarch64__) || defined(__SSSE3__)
880   // Note: vcvtnq_s32_f32 is not available in ARMv7
881   return vcvtnq_s32_f32(input);
882 #else
883   static const float32x4_t zero_val_dup = vdupq_n_f32(0.0f);
884   static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
885   static const float32x4_t minus_point5_val_dup = vdupq_n_f32(-0.5f);
886 
887   const uint32x4_t mask = vcltq_f32(input, zero_val_dup);
888   const float32x4_t round =
889       vbslq_f32(mask, minus_point5_val_dup, point5_val_dup);
890   return vcvtq_s32_f32(vaddq_f32(input, round));
891 #endif  // defined(__aarch64__) || defined(__SSSE3__)
892 }
893 
RoundToNearestUnsigned(const float32x4_t input)894 inline uint32x4_t RoundToNearestUnsigned(const float32x4_t input) {
895 #if defined(__aarch64__)
896   // Note that vcvtnq_u32_f32 is not available in ARMv7 or in arm_neon_sse.h.
897   return vcvtnq_u32_f32(input);
898 #else
899   static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
900 
901   return vcvtq_u32_f32(vaddq_f32(input, point5_val_dup));
902 #endif  // defined(__aarch64__)
903 }
904 
905 #endif  // USE_NEON
906 
Conv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data,CpuBackendContext * cpu_backend_context)907 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
908                  const float* input_data, const RuntimeShape& filter_shape,
909                  const float* filter_data, const RuntimeShape& bias_shape,
910                  const float* bias_data, const RuntimeShape& output_shape,
911                  float* output_data, const RuntimeShape& im2col_shape,
912                  float* im2col_data, CpuBackendContext* cpu_backend_context) {
913   const int stride_width = params.stride_width;
914   const int stride_height = params.stride_height;
915   const int dilation_width_factor = params.dilation_width_factor;
916   const int dilation_height_factor = params.dilation_height_factor;
917   const float output_activation_min = params.float_activation_min;
918   const float output_activation_max = params.float_activation_max;
919   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
920   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
921   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
922 
923   ruy::profiler::ScopeLabel label("Conv");
924 
925   // NB: the float 0.0f value is represented by all zero bytes.
926   const uint8 float_zero_byte = 0x00;
927   const float* gemm_input_data = nullptr;
928   const RuntimeShape* gemm_input_shape = nullptr;
929   const int filter_width = filter_shape.Dims(2);
930   const int filter_height = filter_shape.Dims(1);
931   const bool need_dilated_im2col =
932       dilation_width_factor != 1 || dilation_height_factor != 1;
933   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
934                            filter_width != 1 || filter_height != 1;
935   if (need_dilated_im2col) {
936     DilatedIm2col(params, float_zero_byte, input_shape, input_data,
937                   filter_shape, output_shape, im2col_data);
938     gemm_input_data = im2col_data;
939     gemm_input_shape = &im2col_shape;
940   } else if (need_im2col) {
941     TFLITE_DCHECK(im2col_data);
942     Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
943            input_data, im2col_shape, im2col_data);
944     gemm_input_data = im2col_data;
945     gemm_input_shape = &im2col_shape;
946   } else {
947     TFLITE_DCHECK(!im2col_data);
948     gemm_input_data = input_data;
949     gemm_input_shape = &input_shape;
950   }
951 
952   const int gemm_input_dims = gemm_input_shape->DimensionsCount();
953   int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
954   int n = output_shape.Dims(3);
955   int k = gemm_input_shape->Dims(gemm_input_dims - 1);
956 
957 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
958   // The following code computes matrix multiplication c = a * transponse(b)
959   // with CBLAS, where:
960   // * `a` is a matrix with dimensions (m, k).
961   // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
962   // * `c` is a matrix with dimensions (m, n).
963   // The naming of variables are aligned with CBLAS specification here.
964   const float* a = gemm_input_data;
965   const float* b = filter_data;
966   float* c = output_data;
967   // The stride of matrix a, b and c respectively.
968   int stride_a = k;
969   int stride_b = k;
970   int stride_c = n;
971 
972   cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
973               stride_a, b, stride_b, 0.0f, c, stride_c);
974   optimized_ops::AddBiasAndEvalActivationFunction(
975       output_activation_min, output_activation_max, bias_shape, bias_data,
976       output_shape, output_data);
977 #else
978   // When an optimized CBLAS implementation is not available, fall back
979   // to using cpu_backend_gemm.
980   cpu_backend_gemm::MatrixParams<float> lhs_params;
981   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
982   lhs_params.rows = n;
983   lhs_params.cols = k;
984   cpu_backend_gemm::MatrixParams<float> rhs_params;
985   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
986   rhs_params.rows = k;
987   rhs_params.cols = m;
988   cpu_backend_gemm::MatrixParams<float> dst_params;
989   dst_params.order = cpu_backend_gemm::Order::kColMajor;
990   dst_params.rows = n;
991   dst_params.cols = m;
992   cpu_backend_gemm::GemmParams<float, float> gemm_params;
993   gemm_params.bias = bias_data;
994   gemm_params.clamp_min = output_activation_min;
995   gemm_params.clamp_max = output_activation_max;
996   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
997                          dst_params, output_data, gemm_params,
998                          cpu_backend_context);
999 #endif  //  defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
1000 }
1001 
HybridConv(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & accum_scratch_shape,int32_t * accum_scratch,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data,CpuBackendContext * context)1002 inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
1003                        const RuntimeShape& input_shape,
1004                        const int8_t* input_data,
1005                        const RuntimeShape& filter_shape,
1006                        const int8_t* filter_data,
1007                        const RuntimeShape& bias_shape, const float* bias_data,
1008                        const RuntimeShape& accum_scratch_shape,
1009                        int32_t* accum_scratch, const RuntimeShape& output_shape,
1010                        float* output_data, const RuntimeShape& im2col_shape,
1011                        int8_t* im2col_data, CpuBackendContext* context) {
1012   const int stride_width = params.stride_width;
1013   const int stride_height = params.stride_height;
1014   const int dilation_width_factor = params.dilation_width_factor;
1015   const int dilation_height_factor = params.dilation_height_factor;
1016   const float output_activation_min = params.float_activation_min;
1017   const float output_activation_max = params.float_activation_max;
1018   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1019   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1020   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1021 
1022   const int batch_size = input_shape.Dims(0);
1023   const int filter_width = filter_shape.Dims(2);
1024   const int filter_height = filter_shape.Dims(1);
1025 
1026   const int input_zero_point = 0;
1027   const int8_t* gemm_input_data = nullptr;
1028   int num_input;
1029   const bool need_dilated_im2col =
1030       dilation_width_factor != 1 || dilation_height_factor != 1;
1031   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1032                            filter_width != 1 || filter_height != 1;
1033 
1034   if (need_dilated_im2col) {
1035     DilatedIm2col(params, input_zero_point, input_shape, input_data,
1036                   filter_shape, output_shape, im2col_data);
1037     gemm_input_data = im2col_data;
1038     num_input = im2col_shape.FlatSize();
1039   } else if (need_im2col) {
1040     TFLITE_DCHECK(im2col_data);
1041     // symmetric quantization assumes zero point of 0.
1042 
1043     Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
1044            input_data, im2col_shape, im2col_data);
1045     gemm_input_data = im2col_data;
1046     num_input = im2col_shape.FlatSize();
1047   } else {
1048     TFLITE_DCHECK(!im2col_data);
1049     gemm_input_data = input_data;
1050     num_input = input_shape.FlatSize();
1051   }
1052 
1053   // Flatten 4D matrices into 2D matrices for matrix multiplication.
1054 
1055   // Flatten so that each filter has its own row.
1056   const int filter_rows = filter_shape.Dims(0);
1057   const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1058 
1059   // In MatrixBatchVectorMultiplyAccumulate, each output value is the
1060   // dot product of one row of the first matrix with one row of the second
1061   // matrix. Therefore, the number of cols in each matrix are equivalent.
1062   //
1063   // After Im2Col, each input patch becomes a row.
1064   const int gemm_input_cols = filter_cols;
1065   const int gemm_input_rows = num_input / gemm_input_cols;
1066 
1067   const int output_cols = output_shape.Dims(3);
1068   const int output_rows = FlatSizeSkipDim(output_shape, 3);
1069   TFLITE_DCHECK_EQ(output_cols, filter_rows);
1070   TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
1071   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_cols);
1072 
1073   // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
1074   // input matrix has its own scale factor. This code duplicates the scale
1075   // factors for each row in the same batch.
1076   const int rows_per_batch = gemm_input_rows / batch_size;
1077   for (int i = gemm_input_rows - 1; i >= 0; --i) {
1078     scaling_factors_ptr[i] = scaling_factors_ptr[i / rows_per_batch];
1079   }
1080 
1081   std::fill_n(output_data, output_rows * output_cols, 0.0f);
1082 
1083   // The scratch buffer must have the same size as the output.
1084   TFLITE_DCHECK_EQ(accum_scratch_shape.FlatSize(), output_shape.FlatSize());
1085   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
1086       filter_data, filter_rows, filter_cols, gemm_input_data,
1087       scaling_factors_ptr, /*n_batch=*/gemm_input_rows, accum_scratch,
1088       output_data, context);
1089   AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
1090                                    bias_shape, bias_data, output_shape,
1091                                    output_data);
1092 }
1093 
HybridConvPerChannel(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data,const float * per_channel_scale,int32_t * input_offset,const RuntimeShape & scratch_shape,int32_t * scratch,int32_t * row_sums,bool * compute_row_sums,CpuBackendContext * cpu_backend_context)1094 inline void HybridConvPerChannel(
1095     const ConvParams& params, float* scaling_factors_ptr,
1096     const RuntimeShape& input_shape, const int8_t* input_data,
1097     const RuntimeShape& filter_shape, const int8_t* filter_data,
1098     const RuntimeShape& bias_shape, const float* bias_data,
1099     const RuntimeShape& output_shape, float* output_data,
1100     const RuntimeShape& im2col_shape, int8_t* im2col_data,
1101     const float* per_channel_scale, int32_t* input_offset,
1102     const RuntimeShape& scratch_shape, int32_t* scratch, int32_t* row_sums,
1103     bool* compute_row_sums, CpuBackendContext* cpu_backend_context) {
1104   ruy::profiler::ScopeLabel label("ConvHybridPerChannel");
1105   const int stride_width = params.stride_width;
1106   const int stride_height = params.stride_height;
1107   const int dilation_width_factor = params.dilation_width_factor;
1108   const int dilation_height_factor = params.dilation_height_factor;
1109   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1110   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1111   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1112 
1113   const int8* gemm_input_data = nullptr;
1114   const RuntimeShape* gemm_input_shape = nullptr;
1115   const int filter_width = filter_shape.Dims(2);
1116   const int filter_height = filter_shape.Dims(1);
1117   const bool need_dilated_im2col =
1118       dilation_width_factor != 1 || dilation_height_factor != 1;
1119   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1120                            filter_width != 1 || filter_height != 1;
1121 
1122   const int batch_size = input_shape.Dims(0);
1123 
1124   if (need_dilated_im2col) {
1125     TFLITE_DCHECK(im2col_data);
1126     optimized_ops::DilatedIm2col(params, input_shape, input_data, filter_shape,
1127                                  output_shape, im2col_data, input_offset,
1128                                  batch_size);
1129     gemm_input_data = im2col_data;
1130     gemm_input_shape = &im2col_shape;
1131   } else if (need_im2col) {
1132     Im2col(params, filter_height, filter_width, input_offset, batch_size,
1133            input_shape, input_data, im2col_shape, im2col_data);
1134     gemm_input_data = im2col_data;
1135     gemm_input_shape = &im2col_shape;
1136   } else {
1137     TFLITE_DCHECK(!im2col_data);
1138     gemm_input_data = input_data;
1139     gemm_input_shape = &input_shape;
1140   }
1141 
1142   const int filter_rows = filter_shape.Dims(0);
1143   const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1144 
1145   const int gemm_input_rows = gemm_input_shape->Dims(3);
1146   const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
1147   const int output_rows = output_shape.Dims(3);
1148   const int output_cols =
1149       output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
1150 
1151   TFLITE_DCHECK_EQ(output_rows, filter_rows);
1152   TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
1153   TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
1154   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1155   TFLITE_DCHECK_EQ(scratch_shape.FlatSize(), output_shape.FlatSize());
1156   if (!compute_row_sums || *compute_row_sums) {
1157     tensor_utils::ReductionSumVector(filter_data, row_sums, filter_rows,
1158                                      filter_cols);
1159     if (compute_row_sums) {
1160       *compute_row_sums = false;
1161     }
1162   }
1163 
1164   cpu_backend_gemm::MatrixParams<int8> lhs_params;
1165   lhs_params.rows = filter_rows;
1166   lhs_params.cols = filter_cols;
1167   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1168 
1169   cpu_backend_gemm::MatrixParams<int8> rhs_params;
1170   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1171   rhs_params.rows = gemm_input_rows;
1172   rhs_params.cols = gemm_input_cols;
1173 
1174   cpu_backend_gemm::MatrixParams<int32> dst_params;
1175   dst_params.order = cpu_backend_gemm::Order::kColMajor;
1176   dst_params.rows = output_rows;
1177   dst_params.cols = output_cols;
1178 
1179   // TODO(b/149003801): Use hybrid gemm once supported in Ruy.
1180   cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
1181   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1182                          dst_params, scratch, gemm_params, cpu_backend_context);
1183 
1184   MatrixMap<float> out_mat(output_data, filter_rows, output_cols);
1185   MatrixMap<int32_t> in_mat(scratch, filter_rows, output_cols);
1186   VectorMap<const float> bias_data_vec(bias_data, filter_rows, 1);
1187   VectorMap<int32_t> row_sums_vec(row_sums, filter_rows, 1);
1188   VectorMap<const float> per_channel_scale_vec(per_channel_scale, filter_rows,
1189                                                1);
1190   const int cols_per_batch = output_cols / batch_size;
1191   for (int c = 0; c < output_cols; c++) {
1192     const int b = c / cols_per_batch;
1193     const float input_scale = scaling_factors_ptr[b];
1194     const int32_t zero_point = input_offset[b];
1195     out_mat.col(c) =
1196         (((in_mat.col(c) - (row_sums_vec * zero_point))
1197               .cast<float>()
1198               .cwiseProduct((per_channel_scale_vec * input_scale))) +
1199          bias_data_vec)
1200             .cwiseMin(params.float_activation_max)
1201             .cwiseMax(params.float_activation_min);
1202   }
1203 }
1204 
Conv(const ConvParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,const RuntimeShape & im2col_shape,uint8 * im2col_data,CpuBackendContext * cpu_backend_context)1205 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
1206                  const uint8* input_data, const RuntimeShape& filter_shape,
1207                  const uint8* filter_data, const RuntimeShape& bias_shape,
1208                  const int32* bias_data, const RuntimeShape& output_shape,
1209                  uint8* output_data, const RuntimeShape& im2col_shape,
1210                  uint8* im2col_data, CpuBackendContext* cpu_backend_context) {
1211   ruy::profiler::ScopeLabel label("Conv/8bit");
1212 
1213   const int stride_width = params.stride_width;
1214   const int stride_height = params.stride_height;
1215   const int dilation_width_factor = params.dilation_width_factor;
1216   const int dilation_height_factor = params.dilation_height_factor;
1217   const int32 input_offset = params.input_offset;
1218   const int32 filter_offset = params.weights_offset;
1219   const int32 output_offset = params.output_offset;
1220   const int32 output_multiplier = params.output_multiplier;
1221   const int output_shift = params.output_shift;
1222   const int32 output_activation_min = params.quantized_activation_min;
1223   const int32 output_activation_max = params.quantized_activation_max;
1224   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1225   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1226   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1227 
1228   const uint8* gemm_input_data = nullptr;
1229   const RuntimeShape* gemm_input_shape = nullptr;
1230   const int filter_width = filter_shape.Dims(2);
1231   const int filter_height = filter_shape.Dims(1);
1232   const bool need_dilated_im2col =
1233       dilation_width_factor != 1 || dilation_height_factor != 1;
1234   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1235                            filter_width != 1 || filter_height != 1;
1236   if (need_dilated_im2col) {
1237     TFLITE_DCHECK(im2col_data);
1238     const int input_zero_point = -input_offset;
1239     TFLITE_DCHECK_GE(input_zero_point, 0);
1240     TFLITE_DCHECK_LE(input_zero_point, 255);
1241     DilatedIm2col(params, input_zero_point, input_shape, input_data,
1242                   filter_shape, output_shape, im2col_data);
1243     gemm_input_data = im2col_data;
1244     gemm_input_shape = &im2col_shape;
1245   } else if (need_im2col) {
1246     TFLITE_DCHECK(im2col_data);
1247     const int input_zero_point = -input_offset;
1248     TFLITE_DCHECK_GE(input_zero_point, 0);
1249     TFLITE_DCHECK_LE(input_zero_point, 255);
1250     Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
1251            input_data, im2col_shape, im2col_data);
1252     gemm_input_data = im2col_data;
1253     gemm_input_shape = &im2col_shape;
1254   } else {
1255     TFLITE_DCHECK(!im2col_data);
1256     gemm_input_data = input_data;
1257     gemm_input_shape = &input_shape;
1258   }
1259 
1260   const int gemm_input_rows = gemm_input_shape->Dims(3);
1261   // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
1262   // The root cause has not yet been identified though. Same applies below for
1263   // the other calls commented out. This is a partial rollback of cl/196819423.
1264   // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
1265   const int gemm_input_cols = gemm_input_shape->Dims(0) *
1266                               gemm_input_shape->Dims(1) *
1267                               gemm_input_shape->Dims(2);
1268   const int filter_rows = filter_shape.Dims(0);
1269   // See b/79927784.
1270   // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1271   const int filter_cols =
1272       filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
1273   const int output_rows = output_shape.Dims(3);
1274   // See b/79927784.
1275   // const int output_cols = FlatSizeSkipDim(output_shape, 3);
1276   const int output_cols =
1277       output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
1278   TFLITE_DCHECK_EQ(output_rows, filter_rows);
1279   TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
1280   TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
1281   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1282 
1283   cpu_backend_gemm::MatrixParams<uint8> lhs_params;
1284   lhs_params.rows = filter_rows;
1285   lhs_params.cols = filter_cols;
1286   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1287   lhs_params.zero_point = -filter_offset;
1288   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
1289   rhs_params.rows = gemm_input_rows;
1290   rhs_params.cols = gemm_input_cols;
1291   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1292   rhs_params.zero_point = -input_offset;
1293   cpu_backend_gemm::MatrixParams<uint8> dst_params;
1294   dst_params.rows = output_rows;
1295   dst_params.cols = output_cols;
1296   dst_params.order = cpu_backend_gemm::Order::kColMajor;
1297   dst_params.zero_point = output_offset;
1298   cpu_backend_gemm::GemmParams<int32, uint8> gemm_params;
1299   gemm_params.bias = bias_data;
1300   gemm_params.clamp_min = output_activation_min;
1301   gemm_params.clamp_max = output_activation_max;
1302   gemm_params.multiplier_fixedpoint = output_multiplier;
1303   gemm_params.multiplier_exponent = output_shift;
1304   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1305                          dst_params, output_data, gemm_params,
1306                          cpu_backend_context);
1307 }
1308 
1309 template <typename T>
DepthToSpace(const tflite::DepthToSpaceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)1310 inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
1311                          const RuntimeShape& unextended_input_shape,
1312                          const T* input_data,
1313                          const RuntimeShape& unextended_output_shape,
1314                          T* output_data) {
1315   ruy::profiler::ScopeLabel label("DepthToSpace");
1316 
1317   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1318   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1319   const RuntimeShape input_shape =
1320       RuntimeShape::ExtendedShape(4, unextended_input_shape);
1321   const RuntimeShape output_shape =
1322       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1323 
1324   const int input_depth = input_shape.Dims(3);
1325   const int input_width = input_shape.Dims(2);
1326   const int input_height = input_shape.Dims(1);
1327 
1328   const int output_depth = output_shape.Dims(3);
1329   const int batch_size = output_shape.Dims(0);
1330 
1331   // Number of continuous values that we can copy in one interation.
1332   const int stride = op_params.block_size * output_depth;
1333 
1334   for (int batch = 0; batch < batch_size; ++batch) {
1335     for (int in_h = 0; in_h < input_height; ++in_h) {
1336       const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0);
1337       for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
1338         const T* src = input_ptr;
1339         for (int in_w = 0; in_w < input_width; ++in_w) {
1340           memcpy(output_data, src, stride * sizeof(T));
1341           output_data += stride;
1342           src += input_depth;
1343         }
1344         input_ptr += stride;
1345       }
1346     }
1347   }
1348 }
1349 
1350 template <typename T>
SpaceToDepth(const tflite::SpaceToDepthParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)1351 inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
1352                          const RuntimeShape& unextended_input_shape,
1353                          const T* input_data,
1354                          const RuntimeShape& unextended_output_shape,
1355                          T* output_data) {
1356   ruy::profiler::ScopeLabel label("SpaceToDepth");
1357 
1358   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1359   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1360   const RuntimeShape input_shape =
1361       RuntimeShape::ExtendedShape(4, unextended_input_shape);
1362   const RuntimeShape output_shape =
1363       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1364 
1365   const int output_depth = output_shape.Dims(3);
1366   const int output_width = output_shape.Dims(2);
1367   const int output_height = output_shape.Dims(1);
1368 
1369   const int input_depth = input_shape.Dims(3);
1370   const int batch_size = input_shape.Dims(0);
1371 
1372   // Number of continuous values that we can copy in one interation.
1373   const int stride = op_params.block_size * input_depth;
1374 
1375   for (int batch = 0; batch < batch_size; ++batch) {
1376     for (int out_h = 0; out_h < output_height; ++out_h) {
1377       T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0);
1378       for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
1379         T* dst = output_ptr;
1380         for (int out_w = 0; out_w < output_width; ++out_w) {
1381           memcpy(dst, input_data, stride * sizeof(T));
1382           input_data += stride;
1383           dst += output_depth;
1384         }
1385         output_ptr += stride;
1386       }
1387     }
1388   }
1389 }
1390 
Relu(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)1391 inline void Relu(const RuntimeShape& input_shape, const float* input_data,
1392                  const RuntimeShape& output_shape, float* output_data) {
1393   ruy::profiler::ScopeLabel label("Relu (not fused)");
1394 
1395   const auto input = MapAsVector(input_data, input_shape);
1396   auto output = MapAsVector(output_data, output_shape);
1397   output = input.cwiseMax(0.0f);
1398 }
1399 
1400 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
1401                             const RuntimeShape& input_shape,
1402                             const float* input_data,
1403                             const RuntimeShape& output_shape,
1404                             float* output_data, float epsilon = 1e-6) {
1405   ruy::profiler::ScopeLabel label("L2Normalization");
1406   const int trailing_dim = input_shape.DimensionsCount() - 1;
1407   const int outer_size =
1408       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1409   const int depth =
1410       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1411   for (int i = 0; i < outer_size; ++i) {
1412     float squared_l2_norm = 0;
1413     for (int c = 0; c < depth; ++c) {
1414       const float val = input_data[c];
1415       squared_l2_norm += val * val;
1416     }
1417     float l2_norm = std::sqrt(squared_l2_norm);
1418     l2_norm = std::max(l2_norm, epsilon);
1419     for (int c = 0; c < depth; ++c) {
1420       *output_data = *input_data / l2_norm;
1421       ++output_data;
1422       ++input_data;
1423     }
1424   }
1425 }
1426 
L2Normalization(const tflite::L2NormalizationParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)1427 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
1428                             const RuntimeShape& input_shape,
1429                             const uint8* input_data,
1430                             const RuntimeShape& output_shape,
1431                             uint8* output_data) {
1432   ruy::profiler::ScopeLabel label("L2Normalization/8bit");
1433   const int trailing_dim = input_shape.DimensionsCount() - 1;
1434   const int depth =
1435       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1436   const int outer_size =
1437       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1438   const int32 input_zero_point = op_params.input_zero_point;
1439   for (int i = 0; i < outer_size; ++i) {
1440     int32 square_l2_norm = 0;
1441     for (int c = 0; c < depth; c++) {
1442       // Note that input_data advances by depth in the second pass below.
1443       int32 diff = input_data[c] - input_zero_point;
1444       square_l2_norm += diff * diff;
1445     }
1446     // TODO(b/29395854): add clamping to TOCO and TF Lite kernel
1447     // for all zero tensors in the input_data
1448     int32 inv_l2norm_multiplier;
1449     int inv_l2norm_shift;
1450     GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift,
1451                                      &inv_l2norm_multiplier, &inv_l2norm_shift);
1452 
1453     for (int c = 0; c < depth; c++) {
1454       int32 diff = *input_data - input_zero_point;
1455       int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
1456           128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
1457       int32 unclamped_output_val = 128 + rescaled_diff;
1458       int32 output_val = std::min(255, std::max(0, unclamped_output_val));
1459       *output_data = static_cast<uint8>(output_val);
1460       ++input_data;
1461       ++output_data;
1462     }
1463   }
1464 }
1465 
AddElementwise(int size,const ArithmeticParams & params,const float * input1_data,const float * input2_data,float * output_data)1466 inline void AddElementwise(int size, const ArithmeticParams& params,
1467                            const float* input1_data, const float* input2_data,
1468                            float* output_data) {
1469   int i = 0;
1470 
1471 #ifdef USE_NEON
1472   const auto activation_min = vdupq_n_f32(params.float_activation_min);
1473   const auto activation_max = vdupq_n_f32(params.float_activation_max);
1474   for (; i <= size - 16; i += 16) {
1475     auto a10 = vld1q_f32(input1_data + i);
1476     auto a11 = vld1q_f32(input1_data + i + 4);
1477     auto a12 = vld1q_f32(input1_data + i + 8);
1478     auto a13 = vld1q_f32(input1_data + i + 12);
1479     auto a20 = vld1q_f32(input2_data + i);
1480     auto a21 = vld1q_f32(input2_data + i + 4);
1481     auto a22 = vld1q_f32(input2_data + i + 8);
1482     auto a23 = vld1q_f32(input2_data + i + 12);
1483     auto x0 = vaddq_f32(a10, a20);
1484     auto x1 = vaddq_f32(a11, a21);
1485     auto x2 = vaddq_f32(a12, a22);
1486     auto x3 = vaddq_f32(a13, a23);
1487     x0 = vmaxq_f32(activation_min, x0);
1488     x1 = vmaxq_f32(activation_min, x1);
1489     x2 = vmaxq_f32(activation_min, x2);
1490     x3 = vmaxq_f32(activation_min, x3);
1491     x0 = vminq_f32(activation_max, x0);
1492     x1 = vminq_f32(activation_max, x1);
1493     x2 = vminq_f32(activation_max, x2);
1494     x3 = vminq_f32(activation_max, x3);
1495     vst1q_f32(output_data + i, x0);
1496     vst1q_f32(output_data + i + 4, x1);
1497     vst1q_f32(output_data + i + 8, x2);
1498     vst1q_f32(output_data + i + 12, x3);
1499   }
1500   for (; i <= size - 4; i += 4) {
1501     auto a1 = vld1q_f32(input1_data + i);
1502     auto a2 = vld1q_f32(input2_data + i);
1503     auto x = vaddq_f32(a1, a2);
1504     x = vmaxq_f32(activation_min, x);
1505     x = vminq_f32(activation_max, x);
1506     vst1q_f32(output_data + i, x);
1507   }
1508 #endif  // NEON
1509 
1510   for (; i < size; i++) {
1511     auto x = input1_data[i] + input2_data[i];
1512     output_data[i] = ActivationFunctionWithMinMax(
1513         x, params.float_activation_min, params.float_activation_max);
1514   }
1515 }
1516 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)1517 inline void Add(const ArithmeticParams& params,
1518                 const RuntimeShape& input1_shape, const float* input1_data,
1519                 const RuntimeShape& input2_shape, const float* input2_data,
1520                 const RuntimeShape& output_shape, float* output_data) {
1521   ruy::profiler::ScopeLabel label("Add");
1522   const int flat_size =
1523       MatchingElementsSize(input1_shape, input2_shape, output_shape);
1524   AddElementwise(flat_size, params, input1_data, input2_data, output_data);
1525 }
1526 
1527 // Element-wise add that can often be used for inner loop of broadcast add as
1528 // well as the non-broadcast add.
AddElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)1529 inline void AddElementwise(int size, const ArithmeticParams& params,
1530                            const uint8* input1_data, const uint8* input2_data,
1531                            uint8* output_data) {
1532   ruy::profiler::ScopeLabel label("AddElementwise/8bit");
1533   int i = 0;
1534   TFLITE_DCHECK_GT(params.input1_offset, -256);
1535   TFLITE_DCHECK_GT(params.input2_offset, -256);
1536   TFLITE_DCHECK_LT(params.input1_offset, 256);
1537   TFLITE_DCHECK_LT(params.input2_offset, 256);
1538 #ifdef USE_NEON
1539   const uint8x8_t output_activation_min_vector =
1540       vdup_n_u8(params.quantized_activation_min);
1541   const uint8x8_t output_activation_max_vector =
1542       vdup_n_u8(params.quantized_activation_max);
1543   for (; i <= size - 8; i += 8) {
1544     const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
1545     const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
1546     const int16x8_t input1_val_s16 =
1547         vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
1548     const int16x8_t input2_val_s16 =
1549         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
1550     const int16x8_t input1_val =
1551         vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
1552     const int16x8_t input2_val =
1553         vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
1554     const int16x4_t input1_val_high = vget_high_s16(input1_val);
1555     const int16x4_t input1_val_low = vget_low_s16(input1_val);
1556     const int16x4_t input2_val_high = vget_high_s16(input2_val);
1557     const int16x4_t input2_val_low = vget_low_s16(input2_val);
1558     int32x4_t x11 = vmovl_s16(input1_val_low);
1559     int32x4_t x12 = vmovl_s16(input1_val_high);
1560     int32x4_t x21 = vmovl_s16(input2_val_low);
1561     int32x4_t x22 = vmovl_s16(input2_val_high);
1562     const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
1563     x11 = vshlq_s32(x11, left_shift_dup);
1564     x12 = vshlq_s32(x12, left_shift_dup);
1565     x21 = vshlq_s32(x21, left_shift_dup);
1566     x22 = vshlq_s32(x22, left_shift_dup);
1567     x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
1568     x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
1569     x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
1570     x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
1571     const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
1572     const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
1573     x11 = vshlq_s32(x11, input1_shift_dup);
1574     x12 = vshlq_s32(x12, input1_shift_dup);
1575     x21 = vshlq_s32(x21, input2_shift_dup);
1576     x22 = vshlq_s32(x22, input2_shift_dup);
1577     int32x4_t s1 = vaddq_s32(x11, x21);
1578     int32x4_t s2 = vaddq_s32(x12, x22);
1579     s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
1580     s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
1581     using gemmlowp::RoundingDivideByPOT;
1582     s1 = RoundingDivideByPOT(s1, -params.output_shift);
1583     s2 = RoundingDivideByPOT(s2, -params.output_shift);
1584     const int16x4_t s1_narrowed = vmovn_s32(s1);
1585     const int16x4_t s2_narrowed = vmovn_s32(s2);
1586     const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
1587                                   vdupq_n_s16(params.output_offset));
1588     const uint8x8_t clamped =
1589         vmax_u8(output_activation_min_vector,
1590                 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
1591     vst1_u8(output_data + i, clamped);
1592   }
1593 #endif  // NEON
1594 
1595   for (; i < size; ++i) {
1596     const int32 input1_val = params.input1_offset + input1_data[i];
1597     const int32 input2_val = params.input2_offset + input2_data[i];
1598     const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
1599     const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
1600     const int32 scaled_input1_val =
1601         MultiplyByQuantizedMultiplierSmallerThanOneExp(
1602             shifted_input1_val, params.input1_multiplier, params.input1_shift);
1603     const int32 scaled_input2_val =
1604         MultiplyByQuantizedMultiplierSmallerThanOneExp(
1605             shifted_input2_val, params.input2_multiplier, params.input2_shift);
1606     const int32 raw_sum = scaled_input1_val + scaled_input2_val;
1607     const int32 raw_output =
1608         MultiplyByQuantizedMultiplierSmallerThanOneExp(
1609             raw_sum, params.output_multiplier, params.output_shift) +
1610         params.output_offset;
1611     const int32 clamped_output =
1612         std::min(params.quantized_activation_max,
1613                  std::max(params.quantized_activation_min, raw_output));
1614     output_data[i] = static_cast<uint8>(clamped_output);
1615   }
1616 }
1617 
1618 // Scalar-broadcast add that can be used for inner loop of more general
1619 // broadcast add, so that, for example, scalar-broadcast with batch will still
1620 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,uint8 input1_data,const uint8 * input2_data,uint8 * output_data)1621 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
1622                                uint8 input1_data, const uint8* input2_data,
1623                                uint8* output_data) {
1624   using gemmlowp::RoundingDivideByPOT;
1625 
1626   ruy::profiler::ScopeLabel label("AddScalarBroadcast/8bit");
1627   TFLITE_DCHECK_GT(params.input1_offset, -256);
1628   TFLITE_DCHECK_GT(params.input2_offset, -256);
1629   TFLITE_DCHECK_LT(params.input1_offset, 256);
1630   TFLITE_DCHECK_LT(params.input2_offset, 256);
1631 
1632   int i = 0;
1633 
1634 #ifdef USE_NEON
1635   const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
1636   const uint8x8_t output_activation_min_vector =
1637       vdup_n_u8(params.quantized_activation_min);
1638   const uint8x8_t output_activation_max_vector =
1639       vdup_n_u8(params.quantized_activation_max);
1640 
1641   // Process broadcast scalar.
1642   const uint8x8_t input1_val_original = vdup_n_u8(input1_data);
1643   const int16x8_t input1_val_s16 =
1644       vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
1645   const int16x8_t input1_val =
1646       vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
1647   const int16x4_t input1_val_high = vget_high_s16(input1_val);
1648   const int16x4_t input1_val_low = vget_low_s16(input1_val);
1649   int32x4_t x11 = vmovl_s16(input1_val_low);
1650   int32x4_t x12 = vmovl_s16(input1_val_high);
1651   x11 = vshlq_s32(x11, left_shift_dup);
1652   x12 = vshlq_s32(x12, left_shift_dup);
1653   x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
1654   x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
1655   const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
1656   x11 = vshlq_s32(x11, input1_shift_dup);
1657   x12 = vshlq_s32(x12, input1_shift_dup);
1658 
1659   for (; i <= size - 8; i += 8) {
1660     const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
1661     const int16x8_t input2_val_s16 =
1662         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
1663     const int16x8_t input2_val =
1664         vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
1665     const int16x4_t input2_val_high = vget_high_s16(input2_val);
1666     const int16x4_t input2_val_low = vget_low_s16(input2_val);
1667     int32x4_t x21 = vmovl_s16(input2_val_low);
1668     int32x4_t x22 = vmovl_s16(input2_val_high);
1669     x21 = vshlq_s32(x21, left_shift_dup);
1670     x22 = vshlq_s32(x22, left_shift_dup);
1671     x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
1672     x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
1673     const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
1674     x21 = vshlq_s32(x21, input2_shift_dup);
1675     x22 = vshlq_s32(x22, input2_shift_dup);
1676     int32x4_t s1 = vaddq_s32(x11, x21);
1677     int32x4_t s2 = vaddq_s32(x12, x22);
1678     s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
1679     s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
1680     s1 = RoundingDivideByPOT(s1, -params.output_shift);
1681     s2 = RoundingDivideByPOT(s2, -params.output_shift);
1682     const int16x4_t s1_narrowed = vmovn_s32(s1);
1683     const int16x4_t s2_narrowed = vmovn_s32(s2);
1684     const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
1685                                   vdupq_n_s16(params.output_offset));
1686     const uint8x8_t clamped =
1687         vmax_u8(output_activation_min_vector,
1688                 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
1689     vst1_u8(output_data + i, clamped);
1690   }
1691 #endif  // NEON
1692 
1693   if (i < size) {
1694     // Process broadcast scalar.
1695     const int32 input1_val = params.input1_offset + input1_data;
1696     const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
1697     const int32 scaled_input1_val =
1698         MultiplyByQuantizedMultiplierSmallerThanOneExp(
1699             shifted_input1_val, params.input1_multiplier, params.input1_shift);
1700 
1701     for (; i < size; ++i) {
1702       const int32 input2_val = params.input2_offset + input2_data[i];
1703       const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
1704       const int32 scaled_input2_val =
1705           MultiplyByQuantizedMultiplierSmallerThanOneExp(
1706               shifted_input2_val, params.input2_multiplier,
1707               params.input2_shift);
1708       const int32 raw_sum = scaled_input1_val + scaled_input2_val;
1709       const int32 raw_output =
1710           MultiplyByQuantizedMultiplierSmallerThanOneExp(
1711               raw_sum, params.output_multiplier, params.output_shift) +
1712           params.output_offset;
1713       const int32 clamped_output =
1714           std::min(params.quantized_activation_max,
1715                    std::max(params.quantized_activation_min, raw_output));
1716       output_data[i] = static_cast<uint8>(clamped_output);
1717     }
1718   }
1719 }
1720 
1721 // Scalar-broadcast add that can be used for inner loop of more general
1722 // broadcast add, so that, for example, scalar-broadcast with batch will still
1723 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,float broadcast_value,const float * input2_data,float * output_data)1724 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
1725                                float broadcast_value, const float* input2_data,
1726                                float* output_data) {
1727   int i = 0;
1728 #ifdef USE_NEON
1729   const float32x4_t output_activation_min_vector =
1730       vdupq_n_f32(params.float_activation_min);
1731   const float32x4_t output_activation_max_vector =
1732       vdupq_n_f32(params.float_activation_max);
1733   const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
1734   for (; i <= size - 4; i += 4) {
1735     const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
1736 
1737     const float32x4_t output =
1738         vaddq_f32(input2_val_original, broadcast_value_dup);
1739 
1740     const float32x4_t clamped =
1741         vmaxq_f32(output_activation_min_vector,
1742                   vminq_f32(output_activation_max_vector, output));
1743     vst1q_f32(output_data + i, clamped);
1744   }
1745 #endif  // NEON
1746 
1747   for (; i < size; ++i) {
1748     auto x = broadcast_value + input2_data[i];
1749     output_data[i] = ActivationFunctionWithMinMax(
1750         x, params.float_activation_min, params.float_activation_max);
1751   }
1752 }
1753 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)1754 inline void Add(const ArithmeticParams& params,
1755                 const RuntimeShape& input1_shape, const uint8* input1_data,
1756                 const RuntimeShape& input2_shape, const uint8* input2_data,
1757                 const RuntimeShape& output_shape, uint8* output_data) {
1758   TFLITE_DCHECK_LE(params.quantized_activation_min,
1759                    params.quantized_activation_max);
1760   ruy::profiler::ScopeLabel label("Add/8bit");
1761   const int flat_size =
1762       MatchingElementsSize(input1_shape, input2_shape, output_shape);
1763 
1764   TFLITE_DCHECK_GT(params.input1_offset, -256);
1765   TFLITE_DCHECK_GT(params.input2_offset, -256);
1766   TFLITE_DCHECK_LT(params.input1_offset, 256);
1767   TFLITE_DCHECK_LT(params.input2_offset, 256);
1768   AddElementwise(flat_size, params, input1_data, input2_data, output_data);
1769 }
1770 
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)1771 inline void Add(const ArithmeticParams& params,
1772                 const RuntimeShape& input1_shape, const int16* input1_data,
1773                 const RuntimeShape& input2_shape, const int16* input2_data,
1774                 const RuntimeShape& output_shape, int16* output_data) {
1775   ruy::profiler::ScopeLabel label("Add/Int16");
1776   TFLITE_DCHECK_LE(params.quantized_activation_min,
1777                    params.quantized_activation_max);
1778 
1779   const int input1_shift = params.input1_shift;
1780   const int flat_size =
1781       MatchingElementsSize(input1_shape, input2_shape, output_shape);
1782   const int16 output_activation_min = params.quantized_activation_min;
1783   const int16 output_activation_max = params.quantized_activation_max;
1784 
1785   TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
1786   TFLITE_DCHECK_LE(input1_shift, 0);
1787   TFLITE_DCHECK_LE(params.input2_shift, 0);
1788   const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
1789   const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
1790   const int input_right_shift =
1791       input1_shift == 0 ? -params.input2_shift : -input1_shift;
1792 
1793   for (int i = 0; i < flat_size; i++) {
1794     // F0 uses 0 integer bits, range [-1, 1].
1795     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
1796 
1797     F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
1798     F0 scaled_input = F0::FromRaw(
1799         gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
1800     F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
1801     const int16 raw_output = result.raw();
1802     const int16 clamped_output = std::min(
1803         output_activation_max, std::max(output_activation_min, raw_output));
1804     output_data[i] = clamped_output;
1805   }
1806 }
1807 
1808 template <typename T>
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1809 inline typename std::enable_if<is_int32_or_int64<T>::value, void>::type Add(
1810     const ArithmeticParams& params, const RuntimeShape& input1_shape,
1811     const T* input1_data, const RuntimeShape& input2_shape,
1812     const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
1813   ruy::profiler::ScopeLabel label("Add/int32or64");
1814 
1815   T activation_min, activation_max;
1816   GetActivationParams(params, &activation_min, &activation_max);
1817 
1818   auto input1_map = MapAsVector(input1_data, input1_shape);
1819   auto input2_map = MapAsVector(input2_data, input2_shape);
1820   auto output_map = MapAsVector(output_data, output_shape);
1821   if (input1_shape == input2_shape) {
1822     output_map.array() = (input1_map.array() + input2_map.array())
1823                              .cwiseMax(activation_min)
1824                              .cwiseMin(activation_max);
1825   } else if (input2_shape.FlatSize() == 1) {
1826     auto scalar = input2_data[0];
1827     output_map.array() = (input1_map.array() + scalar)
1828                              .cwiseMax(activation_min)
1829                              .cwiseMin(activation_max);
1830   } else if (input1_shape.FlatSize() == 1) {
1831     auto scalar = input1_data[0];
1832     output_map.array() = (scalar + input2_map.array())
1833                              .cwiseMax(activation_min)
1834                              .cwiseMin(activation_max);
1835   } else {
1836     reference_ops::BroadcastAdd4DSlow<T>(params, input1_shape, input1_data,
1837                                          input2_shape, input2_data,
1838                                          output_shape, output_data);
1839   }
1840 }
1841 
1842 template <typename T>
BroadcastAddDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1843 inline void BroadcastAddDispatch(
1844     const ArithmeticParams& params, const RuntimeShape& input1_shape,
1845     const T* input1_data, const RuntimeShape& input2_shape,
1846     const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
1847   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
1848     return BroadcastAdd4DSlow(params, input1_shape, input1_data, input2_shape,
1849                               input2_data, output_shape, output_data);
1850   }
1851 
1852   BinaryBroadcastFiveFold(
1853       params, input1_shape, input1_data, input2_shape, input2_data,
1854       output_shape, output_data,
1855       static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
1856                            T*)>(AddElementwise),
1857       static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
1858           AddScalarBroadcast));
1859 }
1860 
BroadcastAddFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)1861 inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
1862                                  const RuntimeShape& unswitched_input1_shape,
1863                                  const uint8* unswitched_input1_data,
1864                                  const RuntimeShape& unswitched_input2_shape,
1865                                  const uint8* unswitched_input2_data,
1866                                  const RuntimeShape& output_shape,
1867                                  uint8* output_data) {
1868   BroadcastAddDispatch(unswitched_params, unswitched_input1_shape,
1869                        unswitched_input1_data, unswitched_input2_shape,
1870                        unswitched_input2_data, output_shape, output_data);
1871 }
1872 
BroadcastAddFivefold(const ArithmeticParams & params,const RuntimeShape & unswitched_input1_shape,const float * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const float * unswitched_input2_data,const RuntimeShape & output_shape,float * output_data)1873 inline void BroadcastAddFivefold(const ArithmeticParams& params,
1874                                  const RuntimeShape& unswitched_input1_shape,
1875                                  const float* unswitched_input1_data,
1876                                  const RuntimeShape& unswitched_input2_shape,
1877                                  const float* unswitched_input2_data,
1878                                  const RuntimeShape& output_shape,
1879                                  float* output_data) {
1880   BroadcastAddDispatch(params, unswitched_input1_shape, unswitched_input1_data,
1881                        unswitched_input2_shape, unswitched_input2_data,
1882                        output_shape, output_data);
1883 }
1884 
MulElementwise(int size,const ArithmeticParams & params,const float * input1_data,const float * input2_data,float * output_data)1885 inline void MulElementwise(int size, const ArithmeticParams& params,
1886                            const float* input1_data, const float* input2_data,
1887                            float* output_data) {
1888   const float output_activation_min = params.float_activation_min;
1889   const float output_activation_max = params.float_activation_max;
1890 
1891   int i = 0;
1892 #ifdef USE_NEON
1893   const auto activation_min = vdupq_n_f32(output_activation_min);
1894   const auto activation_max = vdupq_n_f32(output_activation_max);
1895   for (; i <= size - 16; i += 16) {
1896     auto a10 = vld1q_f32(input1_data + i);
1897     auto a11 = vld1q_f32(input1_data + i + 4);
1898     auto a12 = vld1q_f32(input1_data + i + 8);
1899     auto a13 = vld1q_f32(input1_data + i + 12);
1900     auto a20 = vld1q_f32(input2_data + i);
1901     auto a21 = vld1q_f32(input2_data + i + 4);
1902     auto a22 = vld1q_f32(input2_data + i + 8);
1903     auto a23 = vld1q_f32(input2_data + i + 12);
1904     auto x0 = vmulq_f32(a10, a20);
1905     auto x1 = vmulq_f32(a11, a21);
1906     auto x2 = vmulq_f32(a12, a22);
1907     auto x3 = vmulq_f32(a13, a23);
1908 
1909     x0 = vmaxq_f32(activation_min, x0);
1910     x1 = vmaxq_f32(activation_min, x1);
1911     x2 = vmaxq_f32(activation_min, x2);
1912     x3 = vmaxq_f32(activation_min, x3);
1913     x0 = vminq_f32(activation_max, x0);
1914     x1 = vminq_f32(activation_max, x1);
1915     x2 = vminq_f32(activation_max, x2);
1916     x3 = vminq_f32(activation_max, x3);
1917 
1918     vst1q_f32(output_data + i, x0);
1919     vst1q_f32(output_data + i + 4, x1);
1920     vst1q_f32(output_data + i + 8, x2);
1921     vst1q_f32(output_data + i + 12, x3);
1922   }
1923   for (; i <= size - 4; i += 4) {
1924     auto a1 = vld1q_f32(input1_data + i);
1925     auto a2 = vld1q_f32(input2_data + i);
1926     auto x = vmulq_f32(a1, a2);
1927 
1928     x = vmaxq_f32(activation_min, x);
1929     x = vminq_f32(activation_max, x);
1930 
1931     vst1q_f32(output_data + i, x);
1932   }
1933 #endif  // NEON
1934 
1935   for (; i < size; i++) {
1936     auto x = input1_data[i] * input2_data[i];
1937     output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
1938                                                   output_activation_max);
1939   }
1940 }
1941 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)1942 inline void Mul(const ArithmeticParams& params,
1943                 const RuntimeShape& input1_shape, const float* input1_data,
1944                 const RuntimeShape& input2_shape, const float* input2_data,
1945                 const RuntimeShape& output_shape, float* output_data) {
1946   ruy::profiler::ScopeLabel label("Mul");
1947 
1948   const int flat_size =
1949       MatchingElementsSize(input1_shape, input2_shape, output_shape);
1950   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
1951 }
1952 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)1953 inline void Mul(const ArithmeticParams& params,
1954                 const RuntimeShape& input1_shape, const int32* input1_data,
1955                 const RuntimeShape& input2_shape, const int32* input2_data,
1956                 const RuntimeShape& output_shape, int32* output_data) {
1957   ruy::profiler::ScopeLabel label("Mul/int32/activation");
1958 
1959   const int flat_size =
1960       MatchingElementsSize(input1_shape, input2_shape, output_shape);
1961   const int32 output_activation_min = params.quantized_activation_min;
1962   const int32 output_activation_max = params.quantized_activation_max;
1963   for (int i = 0; i < flat_size; ++i) {
1964     output_data[i] = ActivationFunctionWithMinMax(
1965         input1_data[i] * input2_data[i], output_activation_min,
1966         output_activation_max);
1967   }
1968 }
1969 
MulNoActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)1970 inline void MulNoActivation(const ArithmeticParams& params,
1971                             const RuntimeShape& input1_shape,
1972                             const int32* input1_data,
1973                             const RuntimeShape& input2_shape,
1974                             const int32* input2_data,
1975                             const RuntimeShape& output_shape,
1976                             int32* output_data) {
1977   ruy::profiler::ScopeLabel label("Mul/int32");
1978 
1979   auto input1_map = MapAsVector(input1_data, input1_shape);
1980   auto input2_map = MapAsVector(input2_data, input2_shape);
1981   auto output_map = MapAsVector(output_data, output_shape);
1982   if (input1_shape == input2_shape) {
1983     output_map.array() = input1_map.array() * input2_map.array();
1984   } else if (input2_shape.FlatSize() == 1) {
1985     auto scalar = input2_data[0];
1986     output_map.array() = input1_map.array() * scalar;
1987   } else if (input1_shape.FlatSize() == 1) {
1988     auto scalar = input1_data[0];
1989     output_map.array() = scalar * input2_map.array();
1990   } else {
1991     reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data,
1992                                       input2_shape, input2_data, output_shape,
1993                                       output_data);
1994   }
1995 }
1996 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)1997 inline void Mul(const ArithmeticParams& params,
1998                 const RuntimeShape& input1_shape, const int16* input1_data,
1999                 const RuntimeShape& input2_shape, const int16* input2_data,
2000                 const RuntimeShape& output_shape, int16* output_data) {
2001   ruy::profiler::ScopeLabel label("Mul/Int16/NoActivation");
2002   // This is a copy of the reference implementation. We do not currently have a
2003   // properly optimized version.
2004 
2005   const int flat_size =
2006       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2007 
2008   for (int i = 0; i < flat_size; i++) {
2009     // F0 uses 0 integer bits, range [-1, 1].
2010     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2011 
2012     F0 unclamped_result =
2013         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
2014     output_data[i] = unclamped_result.raw();
2015   }
2016 }
2017 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2018 inline void Mul(const ArithmeticParams& params,
2019                 const RuntimeShape& input1_shape, const int16* input1_data,
2020                 const RuntimeShape& input2_shape, const int16* input2_data,
2021                 const RuntimeShape& output_shape, uint8* output_data) {
2022   ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
2023   // This is a copy of the reference implementation. We do not currently have a
2024   // properly optimized version.
2025   const int32 output_activation_min = params.quantized_activation_min;
2026   const int32 output_activation_max = params.quantized_activation_max;
2027   const int32 output_offset = params.output_offset;
2028   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
2029 
2030   const int flat_size =
2031       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2032 
2033   for (int i = 0; i < flat_size; i++) {
2034     // F0 uses 0 integer bits, range [-1, 1].
2035     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2036 
2037     F0 unclamped_result =
2038         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
2039     int16 rescaled_result =
2040         gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
2041     int16 clamped_result =
2042         std::min<int16>(output_activation_max - output_offset, rescaled_result);
2043     clamped_result =
2044         std::max<int16>(output_activation_min - output_offset, clamped_result);
2045     output_data[i] = output_offset + clamped_result;
2046   }
2047 }
2048 
2049 // Element-wise mul that can often be used for inner loop of broadcast Mul as
2050 // well as the non-broadcast Mul.
MulElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)2051 inline void MulElementwise(int size, const ArithmeticParams& params,
2052                            const uint8* input1_data, const uint8* input2_data,
2053                            uint8* output_data) {
2054   int i = 0;
2055   TFLITE_DCHECK_GT(params.input1_offset, -256);
2056   TFLITE_DCHECK_LT(params.input1_offset, 256);
2057   TFLITE_DCHECK_GT(params.input2_offset, -256);
2058   TFLITE_DCHECK_LT(params.input2_offset, 256);
2059   TFLITE_DCHECK_GT(params.output_offset, -256);
2060   TFLITE_DCHECK_LT(params.output_offset, 256);
2061 #ifdef USE_NEON
2062   const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
2063   const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
2064   const auto output_offset_vector = vdupq_n_s16(params.output_offset);
2065   const auto output_activation_min_vector =
2066       vdup_n_u8(params.quantized_activation_min);
2067   const auto output_activation_max_vector =
2068       vdup_n_u8(params.quantized_activation_max);
2069   const int left_shift = std::max(0, params.output_shift);
2070   const int right_shift = std::max(0, -params.output_shift);
2071   const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
2072   for (; i <= size - 8; i += 8) {
2073     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
2074     const auto input1_val_original = vld1_u8(input1_data + i);
2075     const auto input2_val_original = vld1_u8(input2_data + i);
2076     const auto input1_val_s16 =
2077         vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
2078     const auto input2_val_s16 =
2079         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2080     const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
2081     const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
2082 
2083     const auto input1_val_low = vget_low_s16(input1_val);
2084     const auto input1_val_high = vget_high_s16(input1_val);
2085     const auto input2_val_low = vget_low_s16(input2_val);
2086     const auto input2_val_high = vget_high_s16(input2_val);
2087 
2088     auto p1 = vmull_s16(input2_val_low, input1_val_low);
2089     auto p2 = vmull_s16(input2_val_high, input1_val_high);
2090 
2091     p1 = vshlq_s32(p1, left_shift_vec);
2092     p2 = vshlq_s32(p2, left_shift_vec);
2093     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
2094     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
2095     using gemmlowp::RoundingDivideByPOT;
2096     p1 = RoundingDivideByPOT(p1, right_shift);
2097     p2 = RoundingDivideByPOT(p2, right_shift);
2098 
2099     const auto p1_narrowed = vqmovn_s32(p1);
2100     const auto p2_narrowed = vqmovn_s32(p2);
2101     const auto p =
2102         vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
2103     const auto clamped =
2104         vmax_u8(output_activation_min_vector,
2105                 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
2106     vst1_u8(output_data + i, clamped);
2107   }
2108 #endif  // NEON
2109 
2110   for (; i < size; ++i) {
2111     const int32 input1_val = params.input1_offset + input1_data[i];
2112     const int32 input2_val = params.input2_offset + input2_data[i];
2113     const int32 unclamped_result =
2114         params.output_offset +
2115         MultiplyByQuantizedMultiplier(input1_val * input2_val,
2116                                       params.output_multiplier,
2117                                       params.output_shift);
2118     const int32 clamped_output =
2119         std::min(params.quantized_activation_max,
2120                  std::max(params.quantized_activation_min, unclamped_result));
2121     output_data[i] = static_cast<uint8>(clamped_output);
2122   }
2123 }
2124 
2125 // Broadcast mul that can often be used for inner loop of broadcast Mul.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const uint8 broadcast_value,const uint8 * input2_data,uint8 * output_data)2126 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
2127                                const uint8 broadcast_value,
2128                                const uint8* input2_data, uint8* output_data) {
2129   const int16 input1_val = params.input1_offset + broadcast_value;
2130 
2131   int i = 0;
2132   TFLITE_DCHECK_GT(params.input1_offset, -256);
2133   TFLITE_DCHECK_LT(params.input1_offset, 256);
2134   TFLITE_DCHECK_GT(params.input2_offset, -256);
2135   TFLITE_DCHECK_LT(params.input2_offset, 256);
2136   TFLITE_DCHECK_GT(params.output_offset, -256);
2137   TFLITE_DCHECK_LT(params.output_offset, 256);
2138 #ifdef USE_NEON
2139   const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
2140   const auto output_offset_vector = vdupq_n_s16(params.output_offset);
2141   const auto output_activation_min_vector =
2142       vdup_n_u8(params.quantized_activation_min);
2143   const auto output_activation_max_vector =
2144       vdup_n_u8(params.quantized_activation_max);
2145   const int left_shift = std::max(0, params.output_shift);
2146   const int right_shift = std::max(0, -params.output_shift);
2147   const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
2148   for (; i <= size - 8; i += 8) {
2149     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
2150     const auto input2_val_original = vld1_u8(input2_data + i);
2151     const auto input2_val_s16 =
2152         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2153     const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
2154 
2155     const auto input2_val_low = vget_low_s16(input2_val);
2156     const auto input2_val_high = vget_high_s16(input2_val);
2157 
2158     auto p1 = vmull_n_s16(input2_val_low, input1_val);
2159     auto p2 = vmull_n_s16(input2_val_high, input1_val);
2160 
2161     p1 = vshlq_s32(p1, left_shift_vec);
2162     p2 = vshlq_s32(p2, left_shift_vec);
2163     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
2164     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
2165     using gemmlowp::RoundingDivideByPOT;
2166     p1 = RoundingDivideByPOT(p1, right_shift);
2167     p2 = RoundingDivideByPOT(p2, right_shift);
2168 
2169     const auto p1_narrowed = vmovn_s32(p1);
2170     const auto p2_narrowed = vmovn_s32(p2);
2171     const auto p =
2172         vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
2173     const auto clamped =
2174         vmax_u8(output_activation_min_vector,
2175                 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
2176     vst1_u8(output_data + i, clamped);
2177   }
2178 #endif  // NEON
2179 
2180   for (; i < size; ++i) {
2181     const int32 input2_val = params.input2_offset + input2_data[i];
2182     const int32 unclamped_result =
2183         params.output_offset +
2184         MultiplyByQuantizedMultiplier(input1_val * input2_val,
2185                                       params.output_multiplier,
2186                                       params.output_shift);
2187     const int32 clamped_output =
2188         std::min(params.quantized_activation_max,
2189                  std::max(params.quantized_activation_min, unclamped_result));
2190     output_data[i] = static_cast<uint8>(clamped_output);
2191   }
2192 }
2193 
2194 // Broadcast mul that can often be used for inner loop of broadcast Mul.
2195 // This function will handle scalar_value (LHS) * vector_values (RHS).
2196 // Since it's a float function, input params does not matter here.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const float broadcast_value,const float * input2_data,float * output_data)2197 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
2198                                const float broadcast_value,
2199                                const float* input2_data, float* output_data) {
2200   int i = 0;
2201 #ifdef USE_NEON
2202   const float32x4_t output_activation_min_vector =
2203       vdupq_n_f32(params.float_activation_min);
2204   const float32x4_t output_activation_max_vector =
2205       vdupq_n_f32(params.float_activation_max);
2206   const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
2207   for (; i <= size - 4; i += 4) {
2208     const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
2209 
2210     const float32x4_t output =
2211         vmulq_f32(input2_val_original, broadcast_value_dup);
2212 
2213     const float32x4_t clamped =
2214         vmaxq_f32(output_activation_min_vector,
2215                   vminq_f32(output_activation_max_vector, output));
2216     vst1q_f32(output_data + i, clamped);
2217   }
2218 #endif  // NEON
2219 
2220   for (; i < size; ++i) {
2221     float x = broadcast_value * input2_data[i];
2222     output_data[i] = ActivationFunctionWithMinMax(
2223         x, params.float_activation_min, params.float_activation_max);
2224   }
2225 }
2226 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2227 inline void Mul(const ArithmeticParams& params,
2228                 const RuntimeShape& input1_shape, const uint8* input1_data,
2229                 const RuntimeShape& input2_shape, const uint8* input2_data,
2230                 const RuntimeShape& output_shape, uint8* output_data) {
2231   TFLITE_DCHECK_LE(params.quantized_activation_min,
2232                    params.quantized_activation_max);
2233   ruy::profiler::ScopeLabel label("Mul/8bit");
2234   const int flat_size =
2235       MatchingElementsSize(input1_shape, input2_shape, output_shape);
2236 
2237   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
2238 }
2239 
2240 template <typename T>
BroadcastMulDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2241 inline void BroadcastMulDispatch(
2242     const ArithmeticParams& params, const RuntimeShape& input1_shape,
2243     const T* input1_data, const RuntimeShape& input2_shape,
2244     const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2245   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
2246     return BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
2247                               input2_data, output_shape, output_data);
2248   }
2249 
2250   BinaryBroadcastFiveFold(
2251       params, input1_shape, input1_data, input2_shape, input2_data,
2252       output_shape, output_data,
2253       static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
2254                            T*)>(MulElementwise),
2255       static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
2256           MulSimpleBroadcast));
2257 }
2258 
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)2259 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
2260                                  const RuntimeShape& unswitched_input1_shape,
2261                                  const uint8* unswitched_input1_data,
2262                                  const RuntimeShape& unswitched_input2_shape,
2263                                  const uint8* unswitched_input2_data,
2264                                  const RuntimeShape& output_shape,
2265                                  uint8* output_data) {
2266   BroadcastMulDispatch(unswitched_params, unswitched_input1_shape,
2267                        unswitched_input1_data, unswitched_input2_shape,
2268                        unswitched_input2_data, output_shape, output_data);
2269 }
2270 
BroadcastMulFivefold(const ArithmeticParams & params,const RuntimeShape & unswitched_input1_shape,const float * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const float * unswitched_input2_data,const RuntimeShape & output_shape,float * output_data)2271 inline void BroadcastMulFivefold(const ArithmeticParams& params,
2272                                  const RuntimeShape& unswitched_input1_shape,
2273                                  const float* unswitched_input1_data,
2274                                  const RuntimeShape& unswitched_input2_shape,
2275                                  const float* unswitched_input2_data,
2276                                  const RuntimeShape& output_shape,
2277                                  float* output_data) {
2278   BroadcastMulDispatch(params, unswitched_input1_shape, unswitched_input1_data,
2279                        unswitched_input2_shape, unswitched_input2_data,
2280                        output_shape, output_data);
2281 }
2282 
2283 // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
2284 // dimensionality if the runtime code does a single loop over one dimension
2285 // that handles broadcasting as the base case. The code generator would then
2286 // generate max(D1, D2) nested for loops.
2287 // TODO(benoitjacob): BroadcastDiv is intentionally duplicated from
2288 // reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
2289 // is no longer referenced in this file, move NdArrayDesc<T> from types.h to
2290 // reference_ops.h.
2291 template <typename T, int N = 5>
BroadcastDivSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)2292 void BroadcastDivSlow(const ArithmeticParams& params,
2293                       const RuntimeShape& unextended_input1_shape,
2294                       const T* input1_data,
2295                       const RuntimeShape& unextended_input2_shape,
2296                       const T* input2_data,
2297                       const RuntimeShape& unextended_output_shape,
2298                       T* output_data) {
2299   ruy::profiler::ScopeLabel label("BroadcastDivSlow");
2300   T output_activation_min;
2301   T output_activation_max;
2302   GetActivationParams(params, &output_activation_min, &output_activation_max);
2303 
2304   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
2305   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
2306   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
2307 
2308   NdArrayDesc<N> desc1;
2309   NdArrayDesc<N> desc2;
2310   NdArrayDesc<N> output_desc;
2311   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
2312                                       unextended_input2_shape, &desc1, &desc2);
2313   CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
2314                  &output_desc);
2315 
2316   // In Tensorflow, the dimensions are canonically named (batch_number, row,
2317   // col, channel), with extents (batches, height, width, depth), with the
2318   // trailing dimension changing most rapidly (channels has the smallest stride,
2319   // typically 1 element).
2320   //
2321   // In generated C code, we store arrays with the dimensions reversed. The
2322   // first dimension has smallest stride.
2323   //
2324   // We name our variables by their Tensorflow convention, but generate C code
2325   // nesting loops such that the innermost loop has the smallest stride for the
2326   // best cache behavior.
2327   auto div_func = [&](int indexes[N]) {
2328     output_data[SubscriptToIndex(output_desc, indexes)] =
2329         ActivationFunctionWithMinMax(
2330             input1_data[SubscriptToIndex(desc1, indexes)] /
2331                 input2_data[SubscriptToIndex(desc2, indexes)],
2332             output_activation_min, output_activation_max);
2333   };
2334   NDOpsHelper<N>(output_desc, div_func);
2335 }
2336 
2337 // BroadcastDiv is intentionally duplicated from reference_ops.h.
2338 // For more details see the comment above the generic version of
2339 // BroadcastDivSlow.
2340 template <int N = 5>
BroadcastDivSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const uint8 * input1_data,const RuntimeShape & unextended_input2_shape,const uint8 * input2_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)2341 inline void BroadcastDivSlow(const ArithmeticParams& params,
2342                              const RuntimeShape& unextended_input1_shape,
2343                              const uint8* input1_data,
2344                              const RuntimeShape& unextended_input2_shape,
2345                              const uint8* input2_data,
2346                              const RuntimeShape& unextended_output_shape,
2347                              uint8* output_data) {
2348   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
2349   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
2350   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
2351 
2352   NdArrayDesc<N> desc1;
2353   NdArrayDesc<N> desc2;
2354   NdArrayDesc<N> output_desc;
2355   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
2356                                       unextended_input2_shape, &desc1, &desc2);
2357   CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
2358                  &output_desc);
2359 
2360   TFLITE_DCHECK_GT(params.input1_offset, -256);
2361   TFLITE_DCHECK_LT(params.input1_offset, 256);
2362   TFLITE_DCHECK_GT(params.input2_offset, -256);
2363   TFLITE_DCHECK_LT(params.input2_offset, 256);
2364   TFLITE_DCHECK_GT(params.output_offset, -256);
2365   TFLITE_DCHECK_LT(params.output_offset, 256);
2366 
2367   auto div_func = [&](int indexes[N]) {
2368     int32 input1_val =
2369         params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)];
2370     int32 input2_val =
2371         params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)];
2372     TFLITE_DCHECK_NE(input2_val, 0);
2373     if (input2_val < 0) {
2374       // Invert signs to avoid a negative input2_val as input2_inv needs to be
2375       // positive to be used as multiplier of MultiplyByQuantizedMultiplier.
2376       input1_val = -input1_val;
2377       input2_val = -input2_val;
2378     }
2379     int recip_shift;
2380     const int32 input2_inv = GetReciprocal(input2_val, 31, &recip_shift);
2381     const int headroom = CountLeadingSignBits(input1_val);
2382     const int32 unscaled_quotient = MultiplyByQuantizedMultiplierGreaterThanOne(
2383         input1_val, input2_inv, headroom);
2384     const int total_shift = params.output_shift - recip_shift - headroom;
2385     const int32 unclamped_result =
2386         params.output_offset +
2387         MultiplyByQuantizedMultiplierSmallerThanOneExp(
2388             unscaled_quotient, params.output_multiplier, total_shift);
2389     const int32 clamped_output =
2390         std::min(params.quantized_activation_max,
2391                  std::max(params.quantized_activation_min, unclamped_result));
2392     output_data[SubscriptToIndex(output_desc, indexes)] =
2393         static_cast<uint8>(clamped_output);
2394   };
2395   NDOpsHelper<N>(output_desc, div_func);
2396 }
2397 
2398 template <typename T>
SubWithActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2399 inline void SubWithActivation(
2400     const ArithmeticParams& params, const RuntimeShape& input1_shape,
2401     const T* input1_data, const RuntimeShape& input2_shape,
2402     const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2403   ruy::profiler::ScopeLabel label("SubWithActivation_optimized");
2404   TFLITE_DCHECK_EQ(input1_shape.FlatSize(), input2_shape.FlatSize());
2405   auto input1_map = MapAsVector(input1_data, input1_shape);
2406   auto input2_map = MapAsVector(input2_data, input2_shape);
2407   auto output_map = MapAsVector(output_data, output_shape);
2408   T activation_min, activation_max;
2409   GetActivationParams(params, &activation_min, &activation_max);
2410   output_map.array() = (input1_map.array() - input2_map.array())
2411                            .cwiseMin(activation_max)
2412                            .cwiseMax(activation_min);
2413 }
2414 
SubNonBroadcast(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)2415 inline void SubNonBroadcast(const ArithmeticParams& params,
2416                             const RuntimeShape& input1_shape,
2417                             const float* input1_data,
2418                             const RuntimeShape& input2_shape,
2419                             const float* input2_data,
2420                             const RuntimeShape& output_shape,
2421                             float* output_data) {
2422   ruy::profiler::ScopeLabel label("SubNonBroadcast");
2423   SubWithActivation<float>(params, input1_shape, input1_data, input2_shape,
2424                            input2_data, output_shape, output_data);
2425 }
2426 
2427 template <typename T>
Sub(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2428 void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
2429          const T* input1_data, const RuntimeShape& input2_shape,
2430          const T* input2_data, const RuntimeShape& output_shape,
2431          T* output_data) {
2432   ruy::profiler::ScopeLabel label("Sub");
2433 
2434   auto input1_map = MapAsVector(input1_data, input1_shape);
2435   auto input2_map = MapAsVector(input2_data, input2_shape);
2436   auto output_map = MapAsVector(output_data, output_shape);
2437   if (input1_shape == input2_shape) {
2438     output_map.array() = input1_map.array() - input2_map.array();
2439   } else if (input1_shape.FlatSize() == 1) {
2440     auto scalar = input1_data[0];
2441     output_map.array() = scalar - input2_map.array();
2442   } else if (input2_shape.FlatSize() == 1) {
2443     auto scalar = input2_data[0];
2444     output_map.array() = input1_map.array() - scalar;
2445   } else {
2446     BroadcastSubSlow(params, input1_shape, input1_data, input2_shape,
2447                      input2_data, output_shape, output_data);
2448   }
2449 }
2450 
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data,CpuBackendContext * cpu_backend_context)2451 inline void LstmCell(
2452     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2453     const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
2454     const float* prev_activ_data, const RuntimeShape& weights_shape,
2455     const float* weights_data, const RuntimeShape& unextended_bias_shape,
2456     const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
2457     const float* prev_state_data,
2458     const RuntimeShape& unextended_output_state_shape, float* output_state_data,
2459     const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
2460     const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
2461     const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data,
2462     CpuBackendContext* cpu_backend_context) {
2463   ruy::profiler::ScopeLabel label("LstmCell");
2464   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2465   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2466   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2467   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2468   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2469   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2470   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2471   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2472   const RuntimeShape input_shape =
2473       RuntimeShape::ExtendedShape(4, unextended_input_shape);
2474   const RuntimeShape prev_activ_shape =
2475       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
2476   const RuntimeShape bias_shape =
2477       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
2478   const RuntimeShape prev_state_shape =
2479       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
2480   const RuntimeShape output_state_shape =
2481       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
2482   const RuntimeShape output_activ_shape =
2483       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
2484   const RuntimeShape concat_temp_shape =
2485       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
2486   const RuntimeShape activ_temp_shape =
2487       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
2488   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2489 
2490   const int weights_dim_count = weights_shape.DimensionsCount();
2491   MatchingDim(  // batches
2492       input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
2493       output_state_shape, 0, output_activ_shape, 0);
2494   MatchingDim(  // height
2495       input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
2496       output_state_shape, 1, output_activ_shape, 1);
2497   MatchingDim(  // width
2498       input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
2499       output_state_shape, 2, output_activ_shape, 2);
2500   const int input_depth = input_shape.Dims(3);
2501   const int prev_activ_depth = prev_activ_shape.Dims(3);
2502   const int total_input_depth = prev_activ_depth + input_depth;
2503   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
2504                    total_input_depth);
2505   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
2506   const int intern_activ_depth =
2507       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
2508   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
2509                    intern_activ_depth * total_input_depth);
2510   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
2511   const int output_depth =
2512       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
2513                   3, output_activ_shape, 3);
2514   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
2515 
2516   // Concatenate prev_activ and input data together
2517   std::vector<float const*> concat_input_arrays_data;
2518   std::vector<RuntimeShape const*> concat_input_arrays_shapes;
2519   concat_input_arrays_data.push_back(input_data);
2520   concat_input_arrays_data.push_back(prev_activ_data);
2521   concat_input_arrays_shapes.push_back(&input_shape);
2522   concat_input_arrays_shapes.push_back(&prev_activ_shape);
2523   tflite::ConcatenationParams concat_params;
2524   concat_params.axis = 3;
2525   concat_params.inputs_count = concat_input_arrays_data.size();
2526   Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
2527                 &(concat_input_arrays_data[0]), concat_temp_shape,
2528                 concat_temp_data);
2529 
2530   // Fully connected
2531   tflite::FullyConnectedParams fc_params;
2532   fc_params.float_activation_min = std::numeric_limits<float>::lowest();
2533   fc_params.float_activation_max = std::numeric_limits<float>::max();
2534   fc_params.lhs_cacheable = false;
2535   fc_params.rhs_cacheable = false;
2536   FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
2537                  weights_data, bias_shape, bias_data, activ_temp_shape,
2538                  activ_temp_data, cpu_backend_context);
2539 
2540   // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
2541   // operations.
2542   ArrayMap<float> activ_temp_map =
2543       MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
2544   auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
2545                                             activ_temp_map.cols());
2546   auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
2547                                            activ_temp_map.cols());
2548   auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
2549                                              activ_temp_map.cols());
2550   auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
2551                                              activ_temp_map.cols());
2552   ArrayMap<const float> prev_state_map =
2553       MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
2554   ArrayMap<float> output_state_map =
2555       MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
2556   ArrayMap<float> output_activ_map =
2557       MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
2558 
2559   // Combined memory state and final output calculation
2560   ruy::profiler::ScopeLabel label2("MemoryStateAndFinalOutput");
2561   output_state_map =
2562       input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2563           new_input_sm.tanh() +
2564       forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2565           prev_state_map;
2566   output_activ_map =
2567       output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2568       output_state_map.tanh();
2569 }
2570 
2571 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,CpuBackendContext * cpu_backend_context)2572 inline void LstmCell(
2573     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2574     const uint8* input_data_uint8,
2575     const RuntimeShape& unextended_prev_activ_shape,
2576     const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
2577     const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
2578     const int32* bias_data_int32,
2579     const RuntimeShape& unextended_prev_state_shape,
2580     const int16* prev_state_data_int16,
2581     const RuntimeShape& unextended_output_state_shape,
2582     int16* output_state_data_int16,
2583     const RuntimeShape& unextended_output_activ_shape,
2584     uint8* output_activ_data_uint8,
2585     const RuntimeShape& unextended_concat_temp_shape,
2586     uint8* concat_temp_data_uint8,
2587     const RuntimeShape& unextended_activ_temp_shape,
2588     int16* activ_temp_data_int16, CpuBackendContext* cpu_backend_context) {
2589   ruy::profiler::ScopeLabel label(
2590       "LstmCell/quantized (8bit external, 16bit internal)");
2591   int32 weights_zero_point = params.weights_zero_point;
2592   int32 accum_multiplier = params.accum_multiplier;
2593   int accum_shift = params.accum_shift;
2594   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2595   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2596   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2597   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2598   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2599   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2600   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2601   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2602   const RuntimeShape input_shape =
2603       RuntimeShape::ExtendedShape(4, unextended_input_shape);
2604   const RuntimeShape prev_activ_shape =
2605       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
2606   const RuntimeShape bias_shape =
2607       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
2608   const RuntimeShape prev_state_shape =
2609       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
2610   const RuntimeShape output_state_shape =
2611       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
2612   const RuntimeShape output_activ_shape =
2613       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
2614   const RuntimeShape concat_temp_shape =
2615       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
2616   const RuntimeShape activ_temp_shape =
2617       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
2618   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2619 
2620   // Gather dimensions information, and perform consistency checks.
2621   const int weights_dim_count = weights_shape.DimensionsCount();
2622   const int outer_size = MatchingFlatSizeSkipDim(
2623       input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
2624       output_activ_shape);
2625   const int input_depth = input_shape.Dims(3);
2626   const int prev_activ_depth = prev_activ_shape.Dims(3);
2627   const int total_input_depth = prev_activ_depth + input_depth;
2628   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
2629                    total_input_depth);
2630   const int intern_activ_depth =
2631       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
2632   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
2633                    intern_activ_depth * total_input_depth);
2634   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
2635   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
2636   const int output_depth =
2637       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
2638                   3, output_activ_shape, 3);
2639   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
2640   const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
2641   const int fc_output_depth =
2642       MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
2643   const int fc_accum_depth = total_input_depth;
2644   TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
2645 
2646   // Depth-concatenate prev_activ and input data together.
2647   uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
2648                                               prev_activ_data_uint8};
2649   const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
2650                                                        &prev_activ_shape};
2651   tflite::ConcatenationParams concat_params;
2652   concat_params.axis = 3;
2653   concat_params.inputs_count = 2;
2654   Concatenation(concat_params, concat_input_arrays_shapes,
2655                 concat_input_arrays_data, concat_temp_shape,
2656                 concat_temp_data_uint8);
2657 
2658   // Implementation of the fully connected node inside the LSTM cell.
2659   // The operands are 8-bit integers, the accumulators are internally 32bit
2660   // integers, and the output is 16-bit fixed-point with 3 integer bits so
2661   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
2662   // is explained in the function comment above.
2663   cpu_backend_gemm::MatrixParams<uint8> lhs_params;
2664   lhs_params.rows = fc_output_depth;
2665   lhs_params.cols = fc_accum_depth;
2666   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
2667   lhs_params.zero_point = weights_zero_point;
2668   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
2669   rhs_params.rows = fc_accum_depth;
2670   rhs_params.cols = fc_batches;
2671   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
2672   rhs_params.zero_point = 128;
2673   cpu_backend_gemm::MatrixParams<int16> dst_params;
2674   dst_params.rows = fc_output_depth;
2675   dst_params.cols = fc_batches;
2676   dst_params.order = cpu_backend_gemm::Order::kColMajor;
2677   dst_params.zero_point = 0;
2678   cpu_backend_gemm::GemmParams<int32, int16> gemm_params;
2679   gemm_params.bias = bias_data_int32;
2680   gemm_params.multiplier_fixedpoint = accum_multiplier;
2681   gemm_params.multiplier_exponent = accum_shift;
2682   cpu_backend_gemm::Gemm(
2683       lhs_params, weights_data_uint8, rhs_params, concat_temp_data_uint8,
2684       dst_params, activ_temp_data_int16, gemm_params, cpu_backend_context);
2685 
2686   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
2687   // and muls, all done in 16-bit fixed-point.
2688   const int16* input_gate_input_ptr = activ_temp_data_int16;
2689   const int16* input_modulation_gate_input_ptr =
2690       activ_temp_data_int16 + output_depth;
2691   const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth;
2692   const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth;
2693   const int16* prev_state_ptr = prev_state_data_int16;
2694   int16* output_state_data_ptr = output_state_data_int16;
2695   uint8* output_activ_data_ptr = output_activ_data_uint8;
2696 
2697   for (int b = 0; b < outer_size; ++b) {
2698     int c = 0;
2699 #ifdef GEMMLOWP_NEON
2700     for (; c <= output_depth - 8; c += 8) {
2701       // Define the fixed-point data types that we will use here. All use
2702       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
2703       // They only differ by the number of integral vs. fractional bits,
2704       // determining the range of values that they can represent.
2705       //
2706       // F0 uses 0 integer bits, range [-1, 1].
2707       // This is the return type of math functions such as tanh, logistic,
2708       // whose range is in [-1, 1].
2709       using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
2710       // F3 uses 3 integer bits, range [-8, 8].
2711       // This is the range of the previous fully-connected node's output,
2712       // which is our input here.
2713       using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
2714       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
2715       // 2^StateIntegerBits]. It's used to represent the internal state, whose
2716       // number of integer bits is currently dictated by the model. See comment
2717       // on the StateIntegerBits template parameter above.
2718       using FS = gemmlowp::FixedPoint<int16x8_t, StateIntegerBits>;
2719       // Implementation of input gate, using fixed-point logistic function.
2720       F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr));
2721       input_gate_input_ptr += 8;
2722       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
2723       // Implementation of input modulation gate, using fixed-point tanh
2724       // function.
2725       F3 input_modulation_gate_input =
2726           F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr));
2727       input_modulation_gate_input_ptr += 8;
2728       F0 input_modulation_gate_output =
2729           gemmlowp::tanh(input_modulation_gate_input);
2730       // Implementation of forget gate, using fixed-point logistic function.
2731       F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr));
2732       forget_gate_input_ptr += 8;
2733       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
2734       // Implementation of output gate, using fixed-point logistic function.
2735       F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr));
2736       output_gate_input_ptr += 8;
2737       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
2738       // Implementation of internal multiplication nodes, still in fixed-point.
2739       F0 input_times_input_modulation =
2740           input_gate_output * input_modulation_gate_output;
2741       FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr));
2742       prev_state_ptr += 8;
2743       FS prev_state_times_forget_state = forget_gate_output * prev_state;
2744       // Implementation of internal addition node, saturating.
2745       FS new_state = gemmlowp::SaturatingAdd(
2746           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
2747           prev_state_times_forget_state);
2748       // Implementation of last internal Tanh node, still in fixed-point.
2749       // Since a Tanh fixed-point implementation is specialized for a given
2750       // number or integer bits, and each specialization can have a substantial
2751       // code size, and we already used above a Tanh on an input with 3 integer
2752       // bits, and per the table in the above function comment there is no
2753       // significant accuracy to be lost by clamping to [-8, +8] for a
2754       // 3-integer-bits representation, let us just do that. This helps people
2755       // porting this to targets where code footprint must be minimized.
2756       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
2757       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
2758       // Store the new internal state back to memory, as 16-bit integers.
2759       // Note: here we store the original value with StateIntegerBits, not
2760       // the rescaled 3-integer-bits value fed to tanh.
2761       vst1q_s16(output_state_data_ptr, new_state.raw());
2762       output_state_data_ptr += 8;
2763       // Down-scale the output activations to 8-bit integers, saturating,
2764       // and store back to memory.
2765       int16x8_t rescaled_output_activ =
2766           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
2767       int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ);
2768       uint8x8_t uint8_output_activ =
2769           vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ));
2770       vst1_u8(output_activ_data_ptr, uint8_output_activ);
2771       output_activ_data_ptr += 8;
2772     }
2773 #endif
2774     for (; c < output_depth; ++c) {
2775       // Define the fixed-point data types that we will use here. All use
2776       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
2777       // They only differ by the number of integral vs. fractional bits,
2778       // determining the range of values that they can represent.
2779       //
2780       // F0 uses 0 integer bits, range [-1, 1].
2781       // This is the return type of math functions such as tanh, logistic,
2782       // whose range is in [-1, 1].
2783       using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2784       // F3 uses 3 integer bits, range [-8, 8].
2785       // This is the range of the previous fully-connected node's output,
2786       // which is our input here.
2787       using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
2788       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
2789       // 2^StateIntegerBits]. It's used to represent the internal state, whose
2790       // number of integer bits is currently dictated by the model. See comment
2791       // on the StateIntegerBits template parameter above.
2792       using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
2793       // Implementation of input gate, using fixed-point logistic function.
2794       F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++);
2795       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
2796       // Implementation of input modulation gate, using fixed-point tanh
2797       // function.
2798       F3 input_modulation_gate_input =
2799           F3::FromRaw(*input_modulation_gate_input_ptr++);
2800       F0 input_modulation_gate_output =
2801           gemmlowp::tanh(input_modulation_gate_input);
2802       // Implementation of forget gate, using fixed-point logistic function.
2803       F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++);
2804       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
2805       // Implementation of output gate, using fixed-point logistic function.
2806       F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++);
2807       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
2808       // Implementation of internal multiplication nodes, still in fixed-point.
2809       F0 input_times_input_modulation =
2810           input_gate_output * input_modulation_gate_output;
2811       FS prev_state = FS::FromRaw(*prev_state_ptr++);
2812       FS prev_state_times_forget_state = forget_gate_output * prev_state;
2813       // Implementation of internal addition node, saturating.
2814       FS new_state = gemmlowp::SaturatingAdd(
2815           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
2816           prev_state_times_forget_state);
2817       // Implementation of last internal Tanh node, still in fixed-point.
2818       // Since a Tanh fixed-point implementation is specialized for a given
2819       // number or integer bits, and each specialization can have a substantial
2820       // code size, and we already used above a Tanh on an input with 3 integer
2821       // bits, and per the table in the above function comment there is no
2822       // significant accuracy to be lost by clamping to [-8, +8] for a
2823       // 3-integer-bits representation, let us just do that. This helps people
2824       // porting this to targets where code footprint must be minimized.
2825       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
2826       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
2827       // Store the new internal state back to memory, as 16-bit integers.
2828       // Note: here we store the original value with StateIntegerBits, not
2829       // the rescaled 3-integer-bits value fed to tanh.
2830       *output_state_data_ptr++ = new_state.raw();
2831       // Down-scale the output activations to 8-bit integers, saturating,
2832       // and store back to memory.
2833       int16 rescaled_output_activ =
2834           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
2835       int16 clamped_output_activ =
2836           std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
2837       *output_activ_data_ptr++ = 128 + clamped_output_activ;
2838     }
2839     input_gate_input_ptr += 3 * output_depth;
2840     input_modulation_gate_input_ptr += 3 * output_depth;
2841     forget_gate_input_ptr += 3 * output_depth;
2842     output_gate_input_ptr += 3 * output_depth;
2843   }
2844 }
2845 
NodeOffset(int b,int h,int w,int height,int width)2846 inline int NodeOffset(int b, int h, int w, int height, int width) {
2847   return (b * height + h) * width + w;
2848 }
2849 
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)2850 inline bool AveragePool(const PoolParams& params,
2851                         const RuntimeShape& input_shape,
2852                         const float* input_data,
2853                         const RuntimeShape& output_shape, float* output_data) {
2854   ruy::profiler::ScopeLabel label("AveragePool");
2855   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2856   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2857   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
2858   const int input_height = input_shape.Dims(1);
2859   const int input_width = input_shape.Dims(2);
2860   const int output_height = output_shape.Dims(1);
2861   const int output_width = output_shape.Dims(2);
2862   const int stride_height = params.stride_height;
2863   const int stride_width = params.stride_width;
2864 
2865   if (stride_height == 0) return false;
2866   if (stride_width == 0) return false;
2867 
2868   // TODO(benoitjacob) make this a proper reference impl without Eigen!
2869   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
2870   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
2871   // TODO(benoitjacob) get rid of the dynamic memory allocation here!
2872   Eigen::VectorXf out_count(out_mat.cols());
2873   out_count.setZero();
2874   // Prefill the output to 0.
2875   out_mat.setZero();
2876   for (int b = 0; b < batches; ++b) {
2877     for (int h = 0; h < input_height; ++h) {
2878       for (int w = 0; w < input_width; ++w) {
2879         // (h_start, h_end) * (w_start, w_end) is the range that the input
2880         // vector projects to.
2881         int hpad = h + params.padding_values.height;
2882         int wpad = w + params.padding_values.width;
2883         int h_start = (hpad < params.filter_height)
2884                           ? 0
2885                           : (hpad - params.filter_height) / stride_height + 1;
2886         int h_end = std::min(hpad / stride_height + 1, output_height);
2887         int w_start = (wpad < params.filter_width)
2888                           ? 0
2889                           : (wpad - params.filter_width) / stride_width + 1;
2890         int w_end = std::min(wpad / stride_width + 1, output_width);
2891         // compute elementwise sum
2892         for (int ph = h_start; ph < h_end; ++ph) {
2893           for (int pw = w_start; pw < w_end; ++pw) {
2894             int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
2895             out_mat.col(out_offset) +=
2896                 in_mat.col(NodeOffset(b, h, w, input_height, input_width));
2897             out_count(out_offset)++;
2898           }
2899         }
2900       }
2901     }
2902   }
2903   // Divide the output by the actual number of elements being averaged over
2904   TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
2905   out_mat.array().rowwise() /= out_count.transpose().array();
2906 
2907   const int flat_size = output_shape.FlatSize();
2908   for (int i = 0; i < flat_size; ++i) {
2909     output_data[i] = ActivationFunctionWithMinMax(output_data[i],
2910                                                   params.float_activation_min,
2911                                                   params.float_activation_max);
2912   }
2913 
2914   return true;
2915 }
2916 
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)2917 inline bool AveragePool(const PoolParams& params,
2918                         const RuntimeShape& input_shape,
2919                         const uint8* input_data,
2920                         const RuntimeShape& output_shape, uint8* output_data) {
2921   ruy::profiler::ScopeLabel label("AveragePool/8bit");
2922 
2923   // Here, and in other pooling ops, in order to maintain locality of reference,
2924   // to minimize some recalculations, and to load into NEON vector registers, we
2925   // use an inner loop down the depth. Since depths can be large and hence we
2926   // would need arbitrarily large temporary storage, we divide the work up into
2927   // depth tranches just within the batch loop.
2928   static constexpr int kPoolingAccTrancheSize = 256;
2929 
2930   TFLITE_DCHECK_LE(params.quantized_activation_min,
2931                    params.quantized_activation_max);
2932   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2933   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2934   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
2935   const int depth = MatchingDim(input_shape, 3, output_shape, 3);
2936   const int input_height = input_shape.Dims(1);
2937   const int input_width = input_shape.Dims(2);
2938   const int output_height = output_shape.Dims(1);
2939   const int output_width = output_shape.Dims(2);
2940   const int stride_height = params.stride_height;
2941   const int stride_width = params.stride_width;
2942 
2943   uint32 acc[kPoolingAccTrancheSize];
2944   for (int batch = 0; batch < batches; ++batch) {
2945     // We proceed through the depth in tranches (see comment above). The
2946     // depth_base is the depth at the beginning of the tranche. The
2947     // tranche_depth is the depth dimension of the tranche.
2948     for (int depth_base = 0; depth_base < depth;
2949          depth_base += kPoolingAccTrancheSize) {
2950       const int tranche_depth =
2951           std::min(depth - depth_base, kPoolingAccTrancheSize);
2952       for (int out_y = 0; out_y < output_height; ++out_y) {
2953         for (int out_x = 0; out_x < output_width; ++out_x) {
2954           const int in_x_origin =
2955               (out_x * stride_width) - params.padding_values.width;
2956           const int in_y_origin =
2957               (out_y * stride_height) - params.padding_values.height;
2958           const int filter_x_start = std::max(0, -in_x_origin);
2959           const int filter_x_end =
2960               std::min(params.filter_width, input_width - in_x_origin);
2961           const int filter_y_start = std::max(0, -in_y_origin);
2962           const int filter_y_end =
2963               std::min(params.filter_height, input_height - in_y_origin);
2964           const int filter_count =
2965               (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
2966           if (filter_count == 0) return false;
2967           memset(acc, 0, tranche_depth * sizeof(acc[0]));
2968           const uint8* input_ptr =
2969               input_data + depth_base +
2970               depth * (in_x_origin +
2971                        input_width * (in_y_origin + input_height * batch));
2972           for (int fy = filter_y_start; fy < filter_y_end; fy++) {
2973             const uint8* input_row_ptr =
2974                 input_ptr + depth * (fy * input_width + filter_x_start);
2975             for (int fx = filter_x_start; fx < filter_x_end; fx++) {
2976               const uint8* input_channel_ptr = input_row_ptr;
2977               int channel = 0;
2978 #ifdef USE_NEON
2979               for (; channel <= tranche_depth - 16; channel += 16) {
2980                 uint16x4_t acc_reg[4];
2981                 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
2982                 input_channel_ptr += 16;
2983                 acc_reg[0] = vget_low_u16(vmovl_u8(vget_low_u8(input_reg)));
2984                 acc_reg[1] = vget_high_u16(vmovl_u8(vget_low_u8(input_reg)));
2985                 acc_reg[2] = vget_low_u16(vmovl_u8(vget_high_u8(input_reg)));
2986                 acc_reg[3] = vget_high_u16(vmovl_u8(vget_high_u8(input_reg)));
2987                 for (int i = 0; i < 4; i++) {
2988                   vst1q_u32(
2989                       acc + channel + 4 * i,
2990                       vaddw_u16(vld1q_u32(acc + channel + 4 * i), acc_reg[i]));
2991                 }
2992               }
2993               for (; channel <= tranche_depth - 8; channel += 8) {
2994                 uint16x4_t acc_reg[2];
2995                 uint16x8_t input_reg = vmovl_u8(vld1_u8(input_channel_ptr));
2996                 input_channel_ptr += 8;
2997                 acc_reg[0] = vget_low_u16(input_reg);
2998                 acc_reg[1] = vget_high_u16(input_reg);
2999                 for (int i = 0; i < 2; i++) {
3000                   vst1q_u32(
3001                       acc + channel + 4 * i,
3002                       vaddw_u16(vld1q_u32(acc + channel + 4 * i), acc_reg[i]));
3003                 }
3004               }
3005 #endif
3006               for (; channel < tranche_depth; ++channel) {
3007                 acc[channel] += *input_channel_ptr++;
3008               }
3009               input_row_ptr += depth;
3010             }
3011           }
3012           uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
3013                                                    out_x, depth_base);
3014           int channel = 0;
3015 #ifdef USE_NEON
3016 #define AVGPOOL_DIVIDING_BY(FILTER_COUNT)                               \
3017   if (filter_count == FILTER_COUNT) {                                   \
3018     for (; channel <= tranche_depth - 8; channel += 8) {                \
3019       uint16 buf[8];                                                    \
3020       for (int i = 0; i < 8; i++) {                                     \
3021         buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT;  \
3022       }                                                                 \
3023       uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));                      \
3024       buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); \
3025       buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); \
3026       vst1_u8(output_ptr + channel, buf8);                              \
3027     }                                                                   \
3028   }
3029           AVGPOOL_DIVIDING_BY(9)
3030           AVGPOOL_DIVIDING_BY(15)
3031 #undef AVGPOOL_DIVIDING_BY
3032           for (; channel <= tranche_depth - 8; channel += 8) {
3033             uint16 buf[8];
3034             for (int i = 0; i < 8; i++) {
3035               buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
3036             }
3037             uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
3038             buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max));
3039             buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min));
3040             vst1_u8(output_ptr + channel, buf8);
3041           }
3042 #endif
3043           for (; channel < tranche_depth; ++channel) {
3044             uint16 a = (acc[channel] + filter_count / 2) / filter_count;
3045             a = std::max<uint16>(a, params.quantized_activation_min);
3046             a = std::min<uint16>(a, params.quantized_activation_max);
3047             output_ptr[channel] = static_cast<uint8>(a);
3048           }
3049         }
3050       }
3051     }
3052   }
3053   return true;
3054 }
3055 
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3056 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
3057                     const float* input_data, const RuntimeShape& output_shape,
3058                     float* output_data) {
3059   ruy::profiler::ScopeLabel label("MaxPool");
3060   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3061   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3062   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3063   const int input_height = input_shape.Dims(1);
3064   const int input_width = input_shape.Dims(2);
3065   const int output_height = output_shape.Dims(1);
3066   const int output_width = output_shape.Dims(2);
3067   const int stride_height = params.stride_height;
3068   const int stride_width = params.stride_width;
3069 
3070   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3071   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3072   // Prefill the output to minimum representable float value
3073   out_mat.setConstant(std::numeric_limits<float>::lowest());
3074   for (int b = 0; b < batches; ++b) {
3075     for (int h = 0; h < input_height; ++h) {
3076       for (int w = 0; w < input_width; ++w) {
3077         // (h_start, h_end) * (w_start, w_end) is the range that the input
3078         // vector projects to.
3079         int hpad = h + params.padding_values.height;
3080         int wpad = w + params.padding_values.width;
3081         int h_start = (hpad < params.filter_height)
3082                           ? 0
3083                           : (hpad - params.filter_height) / stride_height + 1;
3084         int h_end = std::min(hpad / stride_height + 1, output_height);
3085         int w_start = (wpad < params.filter_width)
3086                           ? 0
3087                           : (wpad - params.filter_width) / stride_width + 1;
3088         int w_end = std::min(wpad / stride_width + 1, output_width);
3089         // compute elementwise sum
3090         for (int ph = h_start; ph < h_end; ++ph) {
3091           for (int pw = w_start; pw < w_end; ++pw) {
3092             int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
3093             out_mat.col(out_offset) =
3094                 out_mat.col(out_offset)
3095                     .cwiseMax(in_mat.col(
3096                         NodeOffset(b, h, w, input_height, input_width)));
3097           }
3098         }
3099       }
3100     }
3101   }
3102   const int flat_size = output_shape.FlatSize();
3103   for (int i = 0; i < flat_size; ++i) {
3104     output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3105                                                   params.float_activation_min,
3106                                                   params.float_activation_max);
3107   }
3108 }
3109 
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3110 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
3111                     const uint8* input_data, const RuntimeShape& output_shape,
3112                     uint8* output_data) {
3113   ruy::profiler::ScopeLabel label("MaxPool/8bit");
3114 
3115   // Here, and in other pooling ops, in order to maintain locality of reference,
3116   // to minimize some recalculations, and to load into NEON vector registers, we
3117   // use an inner loop down the depth. Since depths can be large and hence we
3118   // would need arbitrarily large temporary storage, we divide the work up into
3119   // depth tranches just within the batch loop.
3120   static constexpr int kPoolingAccTrancheSize = 256;
3121 
3122   TFLITE_DCHECK_LE(params.quantized_activation_min,
3123                    params.quantized_activation_max);
3124   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3125   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3126   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3127   const int depth = MatchingDim(input_shape, 3, output_shape, 3);
3128   const int input_height = input_shape.Dims(1);
3129   const int input_width = input_shape.Dims(2);
3130   const int output_height = output_shape.Dims(1);
3131   const int output_width = output_shape.Dims(2);
3132   const int stride_height = params.stride_height;
3133   const int stride_width = params.stride_width;
3134 
3135   uint8 acc[kPoolingAccTrancheSize];
3136   for (int batch = 0; batch < batches; ++batch) {
3137     // We proceed through the depth in tranches (see comment above). The
3138     // depth_base is the depth at the beginning of the tranche. The
3139     // tranche_depth is the depth dimension of the tranche.
3140     for (int depth_base = 0; depth_base < depth;
3141          depth_base += kPoolingAccTrancheSize) {
3142       const int tranche_depth =
3143           std::min(depth - depth_base, kPoolingAccTrancheSize);
3144       for (int out_y = 0; out_y < output_height; ++out_y) {
3145         for (int out_x = 0; out_x < output_width; ++out_x) {
3146           const int in_x_origin =
3147               (out_x * stride_width) - params.padding_values.width;
3148           const int in_y_origin =
3149               (out_y * stride_height) - params.padding_values.height;
3150           const int filter_x_start = std::max(0, -in_x_origin);
3151           const int filter_x_end =
3152               std::min(params.filter_width, input_width - in_x_origin);
3153           const int filter_y_start = std::max(0, -in_y_origin);
3154           const int filter_y_end =
3155               std::min(params.filter_height, input_height - in_y_origin);
3156           memset(acc, 0, tranche_depth * sizeof(acc[0]));
3157           const uint8* input_ptr =
3158               input_data + depth_base +
3159               depth * (in_x_origin +
3160                        input_width * (in_y_origin + input_height * batch));
3161           for (int fy = filter_y_start; fy < filter_y_end; fy++) {
3162             const uint8* input_row_ptr =
3163                 input_ptr + depth * (fy * input_width + filter_x_start);
3164             for (int fx = filter_x_start; fx < filter_x_end; fx++) {
3165               const uint8* input_channel_ptr = input_row_ptr;
3166               int channel = 0;
3167 #ifdef USE_NEON
3168               for (; channel <= tranche_depth - 16; channel += 16) {
3169                 uint8x16_t acc_reg = vld1q_u8(acc + channel);
3170                 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
3171                 input_channel_ptr += 16;
3172                 acc_reg = vmaxq_u8(acc_reg, input_reg);
3173                 vst1q_u8(acc + channel, acc_reg);
3174               }
3175 
3176               for (; channel <= tranche_depth - 8; channel += 8) {
3177                 uint8x8_t acc_reg = vld1_u8(acc + channel);
3178                 uint8x8_t input_reg = vld1_u8(input_channel_ptr);
3179                 input_channel_ptr += 8;
3180                 acc_reg = vmax_u8(acc_reg, input_reg);
3181                 vst1_u8(acc + channel, acc_reg);
3182               }
3183 #endif
3184               for (; channel < tranche_depth; ++channel) {
3185                 acc[channel] = std::max(acc[channel], *input_channel_ptr++);
3186               }
3187               input_row_ptr += depth;
3188             }
3189           }
3190           uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
3191                                                    out_x, depth_base);
3192           int channel = 0;
3193 #ifdef USE_NEON
3194           for (; channel <= tranche_depth - 16; channel += 16) {
3195             uint8x16_t a = vld1q_u8(acc + channel);
3196             a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max));
3197             a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min));
3198             vst1q_u8(output_ptr + channel, a);
3199           }
3200           for (; channel <= tranche_depth - 8; channel += 8) {
3201             uint8x8_t a = vld1_u8(acc + channel);
3202             a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max));
3203             a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min));
3204             vst1_u8(output_ptr + channel, a);
3205           }
3206 #endif
3207           for (; channel < tranche_depth; ++channel) {
3208             uint8 a = acc[channel];
3209             a = std::max<uint8>(a, params.quantized_activation_min);
3210             a = std::min<uint8>(a, params.quantized_activation_max);
3211             output_ptr[channel] = static_cast<uint8>(a);
3212           }
3213         }
3214       }
3215     }
3216   }
3217 }
3218 
L2Pool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3219 inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
3220                    const float* input_data, const RuntimeShape& output_shape,
3221                    float* output_data) {
3222   ruy::profiler::ScopeLabel label("L2Pool");
3223   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3224   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3225   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3226   const int input_height = input_shape.Dims(1);
3227   const int input_width = input_shape.Dims(2);
3228   const int output_height = output_shape.Dims(1);
3229   const int output_width = output_shape.Dims(2);
3230   const int stride_height = params.stride_height;
3231   const int stride_width = params.stride_width;
3232   // Actually carry out L2 Pool. Code is written in forward mode: we go through
3233   // the input values once, and write to all the pooled regions that it maps to.
3234   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3235   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3236   Eigen::VectorXf in_square(in_mat.rows());
3237   Eigen::VectorXf out_count(out_mat.cols());
3238   out_count.setZero();
3239   // Prefill the output to 0.
3240   out_mat.setZero();
3241   for (int b = 0; b < batches; ++b) {
3242     for (int h = 0; h < input_height; ++h) {
3243       for (int w = 0; w < input_width; ++w) {
3244         // (h_start, h_end) * (w_start, w_end) is the range that the input
3245         // vector projects to.
3246         const int hpad = h + params.padding_values.height;
3247         const int wpad = w + params.padding_values.width;
3248         const int h_start =
3249             (hpad < params.filter_height)
3250                 ? 0
3251                 : (hpad - params.filter_height) / stride_height + 1;
3252         const int h_end = std::min(hpad / stride_height + 1, output_height);
3253         const int w_start =
3254             (wpad < params.filter_width)
3255                 ? 0
3256                 : (wpad - params.filter_width) / stride_width + 1;
3257         const int w_end = std::min(wpad / stride_width + 1, output_width);
3258         // pre-compute square
3259         const int in_offset = w + input_width * (h + input_height * b);
3260         in_square =
3261             in_mat.col(in_offset).array() * in_mat.col(in_offset).array();
3262         // compute elementwise sum of squares
3263         for (int ph = h_start; ph < h_end; ++ph) {
3264           for (int pw = w_start; pw < w_end; ++pw) {
3265             const int out_offset = pw + output_width * (ph + output_height * b);
3266             out_mat.col(out_offset) += in_square;
3267             out_count(out_offset)++;
3268           }
3269         }
3270       }
3271     }
3272   }
3273 
3274   out_count = out_count.array().inverse();
3275   out_mat =
3276       (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
3277 
3278   const int flat_size = output_shape.FlatSize();
3279   for (int i = 0; i < flat_size; ++i) {
3280     output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3281                                                   params.float_activation_min,
3282                                                   params.float_activation_max);
3283   }
3284 }
3285 
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3286 inline void LocalResponseNormalization(
3287     const tflite::LocalResponseNormalizationParams& op_params,
3288     const RuntimeShape& input_shape, const float* input_data,
3289     const RuntimeShape& output_shape, float* output_data) {
3290   ruy::profiler::ScopeLabel label("LocalResponseNormalization");
3291   MatchingFlatSize(input_shape, output_shape);
3292 
3293   const auto data_in = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3294   auto data_out = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3295 
3296   // Carry out local response normalization, vector by vector.
3297   // Since the data are stored column major, making row-wise operation
3298   // probably not memory efficient anyway, we do an explicit for loop over
3299   // the columns.
3300   const int double_range = op_params.range * 2;
3301   Eigen::VectorXf padded_square(data_in.rows() + double_range);
3302   padded_square.setZero();
3303   const float bias = op_params.bias;
3304   for (int r = 0; r < data_in.cols(); ++r) {
3305     // Do local response normalization for data_in(:, r)
3306     // first, compute the square and store them in buffer for repeated use
3307     padded_square.block(op_params.range, 0, data_in.rows(), 1) =
3308         data_in.col(r).cwiseProduct(data_in.col(r)) * op_params.alpha;
3309     // Then, compute the scale and writes them to data_out
3310     float accumulated_scale = 0;
3311     for (int i = 0; i < double_range; ++i) {
3312       accumulated_scale += padded_square(i);
3313     }
3314     for (int i = 0; i < data_in.rows(); ++i) {
3315       accumulated_scale += padded_square(i + double_range);
3316       data_out(i, r) = bias + accumulated_scale;
3317       accumulated_scale -= padded_square(i);
3318     }
3319   }
3320 
3321   // In a few cases, the pow computation could benefit from speedups.
3322   if (op_params.beta == 1) {
3323     data_out.array() = data_in.array() * data_out.array().inverse();
3324   } else if (op_params.beta == 0.5f) {
3325     data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
3326   } else {
3327     data_out.array() = data_in.array() * data_out.array().pow(-op_params.beta);
3328   }
3329 }
3330 
SoftmaxImpl(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data,int start_batch,int end_batch)3331 inline void SoftmaxImpl(const SoftmaxParams& params,
3332                         const RuntimeShape& input_shape,
3333                         const float* input_data,
3334                         const RuntimeShape& output_shape, float* output_data,
3335                         int start_batch, int end_batch) {
3336   ruy::profiler::ScopeLabel label("Softmax/Impl");
3337   MatchingFlatSize(input_shape, output_shape);
3338 
3339   const int logit_size = input_shape.Dims(input_shape.DimensionsCount() - 1);
3340   const MatrixMap<const float> in_mat(input_data + logit_size * start_batch,
3341                                       logit_size, end_batch - start_batch);
3342   MatrixMap<float> out_mat(output_data + logit_size * start_batch, logit_size,
3343                            end_batch - start_batch);
3344   // Compute the exponential first, removing the max coefficient for numerical
3345   // stability.
3346   out_mat =
3347       (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
3348   // We are separating out the exp function so that exp can be vectorized.
3349   out_mat = out_mat.array().exp();
3350   // Normalize to get the activations.
3351   Eigen::Array<float, 1, Eigen::Dynamic> scale =
3352       out_mat.array().colwise().sum().inverse();
3353   out_mat.array().rowwise() *= scale;
3354 }
3355 
3356 struct SoftmaxWorkerTask : cpu_backend_threadpool::Task {
SoftmaxWorkerTaskSoftmaxWorkerTask3357   SoftmaxWorkerTask(const SoftmaxParams& params,
3358                     const RuntimeShape& input_shape, const float* input_data,
3359                     const RuntimeShape& output_shape, float* output_data,
3360                     int start_batch, int end_batch)
3361       : params(params),
3362         input_shape(input_shape),
3363         input_data(input_data),
3364         output_shape(output_shape),
3365         output_data(output_data),
3366         start_batch(start_batch),
3367         end_batch(end_batch) {}
RunSoftmaxWorkerTask3368   void Run() override {
3369     SoftmaxImpl(params, input_shape, input_data, output_shape, output_data,
3370                 start_batch, end_batch);
3371   }
3372 
3373  private:
3374   const tflite::SoftmaxParams& params;
3375   const RuntimeShape& input_shape;
3376   const float* input_data;
3377   const RuntimeShape& output_shape;
3378   float* output_data;
3379   int start_batch;
3380   int end_batch;
3381 };
3382 
3383 inline void Softmax(const SoftmaxParams& params,
3384                     const RuntimeShape& input_shape, const float* input_data,
3385                     const RuntimeShape& output_shape, float* output_data,
3386                     CpuBackendContext* cpu_backend_context = nullptr) {
3387   ruy::profiler::ScopeLabel label("Softmax");
3388 
3389   // We picture softmax input as a 2-D matrix while the last dim is the logit
3390   // dim, and the rest dims will be the batch dim for the 2-D matrix.
3391   const int batch_size =
3392       FlatSizeSkipDim(input_shape, input_shape.DimensionsCount() - 1);
3393   constexpr int kMinBatchPerThread = 8;
3394   int thread_count = batch_size / kMinBatchPerThread;
3395   thread_count = thread_count > 0 ? thread_count : 1;
3396   const int capped_thread_count =
3397       cpu_backend_context == nullptr
3398           ? 1
3399           : std::min(thread_count, cpu_backend_context->max_num_threads());
3400   if (capped_thread_count == 1) {
3401     SoftmaxImpl(params, input_shape, input_data, output_shape, output_data, 0,
3402                 batch_size);
3403   } else {
3404     std::vector<SoftmaxWorkerTask> tasks;
3405     // TODO(b/131746020) don't create new heap allocations every time.
3406     // At least we make it a single heap allocation by using reserve().
3407     tasks.reserve(capped_thread_count);
3408     int batch_start = 0;
3409     for (int i = 0; i < capped_thread_count; ++i) {
3410       // Try to distribute the tasks as even as possible.
3411       int batch_end =
3412           batch_start + (batch_size - batch_start) / (capped_thread_count - i);
3413       tasks.emplace_back(params, input_shape, input_data, output_shape,
3414                          output_data, batch_start, batch_end);
3415       batch_start = batch_end;
3416     }
3417     cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
3418                                     cpu_backend_context);
3419   }
3420 }
3421 
3422 template <typename T>
QuantizeSoftmaxOutput(float prob_rescaled,int32_t zero_point)3423 inline int32_t QuantizeSoftmaxOutput(float prob_rescaled, int32_t zero_point) {
3424   const int32_t prob_rnd = static_cast<int32_t>(std::round(prob_rescaled));
3425   return prob_rnd + zero_point;
3426 }
3427 
3428 #if !__aarch64__
3429 // With ARM64, rounding is faster than add + truncation.
3430 template <>
3431 inline int32_t QuantizeSoftmaxOutput<uint8_t>(float prob_rescaled,
3432                                               int32_t zero_point) {
3433   return static_cast<int32_t>(prob_rescaled + 0.5f);
3434 }
3435 #endif
3436 
PopulateSoftmaxLookupTable(SoftmaxParams * data,float input_scale,float beta)3437 inline void PopulateSoftmaxLookupTable(SoftmaxParams* data, float input_scale,
3438                                        float beta) {
3439   const float scale = -input_scale * beta;
3440   const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3441   for (int32_t val = 0; val <= max_uint8; ++val) {
3442     data->table[max_uint8 - val] = expf(scale * val);
3443   }
3444 }
3445 
3446 template <typename In, typename Out>
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const In * input_data,const RuntimeShape & output_shape,Out * output_data)3447 inline void Softmax(const SoftmaxParams& params,
3448                     const RuntimeShape& input_shape, const In* input_data,
3449                     const RuntimeShape& output_shape, Out* output_data) {
3450   const int trailing_dim = input_shape.DimensionsCount() - 1;
3451   const int excluding_last_dim =
3452       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3453   const int last_dim =
3454       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3455 
3456   const int32_t clamp_max = std::numeric_limits<Out>::max();
3457   const int32_t clamp_min = std::numeric_limits<Out>::min();
3458   for (int i = 0; i < excluding_last_dim; ++i) {
3459     int32_t max_val = std::numeric_limits<In>::min();
3460     // Find max quantized value.
3461     for (int j = 0; j < last_dim; ++j) {
3462       max_val = std::max(max_val, static_cast<int32_t>(input_data[j]));
3463     }
3464 
3465     float sum_exp = 0.0f;
3466     const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3467     const float* table_offset = &params.table[max_uint8 - max_val];
3468     // Calculate normalizer sum(exp(x)).
3469     for (int j = 0; j < last_dim; ++j) {
3470       sum_exp += table_offset[input_data[j]];
3471     }
3472 
3473     const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
3474     // Normalize and quantize probabilities.
3475     for (int j = 0; j < last_dim; ++j) {
3476       const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
3477       const int32_t prob_quantized =
3478           QuantizeSoftmaxOutput<Out>(prob_rescaled, params.zero_point);
3479       output_data[j] = static_cast<Out>(
3480           std::max(std::min(clamp_max, prob_quantized), clamp_min));
3481     }
3482     input_data += last_dim;
3483     output_data += last_dim;
3484   }
3485 }
3486 
3487 // Here's the softmax LUT optimization strategy:
3488 // For softmax, we can do some mathmetically equivalent transformation:
3489 //
3490 // softmax(x) = e^x / sum(e^x, 0...n)  ===> equals to
3491 // softmax(x) = e^(x - CONST) / sum(e^(x - CONST), 0...n)
3492 //
3493 // For quantization, `x` in our case is (input_q - input_zp) * input_s
3494 // For uint8 case (int8 can be handled similarly), the range is [0, 255]
3495 //
3496 // so if we let
3497 // CONST = (255 - input_zp) * input_s
3498 // then we will have:
3499 // softmax(x) = e^((input_q - 255) * input_s) --------- (1)
3500 //         /
3501 // sum(e^(input_q - 255) * input_s, 0...n)   -------- (2)
3502 //
3503 // the good thing about (1) is it's within the range of (0, 1), so we can
3504 // approximate its result with uint16.
3505 //  (1) = uint8_out * 1 / 2^16.
3506 //
3507 // so (1) is lookup_uint8_table(input_zp) * 1 / 2^16.
3508 // then (2) is essentially the following:
3509 // sum(lookup_uint8_table(input_zp), 0...n) / 2^16.
3510 //
3511 // since (output_q - output_zp) * output_s = softmax(x)
3512 // output_q = lookup_uint8_table(input_zp)
3513 //            /
3514 // (sum(lookup_uint8_table(input_zp), 0...n) * output_s)
3515 //             +
3516 //   output_zp
3517 //
3518 // We can actually further improve the performance by using uint8 instead of
3519 // uint16. But that we may lose some accuracy, so we need to pay attention
3520 // to that.
PopulateSoftmaxUInt8LookupTable(SoftmaxParams * data,float input_scale,float beta)3521 inline void PopulateSoftmaxUInt8LookupTable(SoftmaxParams* data,
3522                                             float input_scale, float beta) {
3523   const float scale = input_scale * beta;
3524   const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3525   const int32_t max_uint16 = std::numeric_limits<uint16_t>::max();
3526 
3527   for (int32_t val = 0; val <= max_uint8; ++val) {
3528     float input_to_exp = scale * (val - max_uint8);
3529     int32_t temp = static_cast<int>(expf(input_to_exp) * max_uint16 + 0.5);
3530     temp = std::min(max_uint16, temp);
3531     uint8_t part1 = temp >> 8;
3532     uint8_t part2 = temp & 0xff;
3533     data->uint8_table1[val] = static_cast<uint8_t>(part1);
3534     data->uint8_table2[val] = static_cast<uint8_t>(part2);
3535   }
3536 }
3537 
FindMaxValue(int size,const uint8_t * input_data,uint8_t offset)3538 inline int FindMaxValue(int size, const uint8_t* input_data, uint8_t offset) {
3539   int32_t max_val = std::numeric_limits<uint8_t>::min();
3540   int j = 0;
3541 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3542   uint8x16_t max_val_dup = vdupq_n_u8(max_val);
3543   uint8x16_t offset_dup = vdupq_n_u8(offset);
3544   for (; j <= size - 16; j += 16) {
3545     uint8x16_t input_value = vld1q_u8(input_data + j);
3546     input_value = veorq_u8(input_value, offset_dup);
3547     max_val_dup = vmaxq_u8(input_value, max_val_dup);
3548   }
3549   max_val = std::max(max_val, static_cast<int32>(vmaxvq_u8(max_val_dup)));
3550 #endif
3551 
3552   for (; j < size; ++j) {
3553     max_val = std::max(max_val, static_cast<int32_t>(input_data[j] ^ offset));
3554   }
3555   return max_val;
3556 }
3557 
3558 #ifdef USE_NEON
3559 // Value_to_store layout:
3560 // [high_high, high_low, low_high, low_low].
StoreValue(int32x4x4_t value_to_store,int8_t * output)3561 inline void StoreValue(int32x4x4_t value_to_store, int8_t* output) {
3562   const int16x8_t result_1 = vcombine_s16(vqmovn_s32(value_to_store.val[1]),
3563                                           vqmovn_s32(value_to_store.val[0]));
3564   const int16x8_t result_2 = vcombine_s16(vqmovn_s32(value_to_store.val[3]),
3565                                           vqmovn_s32(value_to_store.val[2]));
3566   const int8x16_t result =
3567       vcombine_s8(vqmovn_s16(result_2), vqmovn_s16(result_1));
3568   vst1q_s8(output, result);
3569 }
3570 
3571 // Value_to_store layout:
3572 // [high_high, high_low, low_high, low_low].
StoreValue(int32x4x4_t value_to_store,uint8_t * output)3573 inline void StoreValue(int32x4x4_t value_to_store, uint8_t* output) {
3574   const uint16x8_t result_1 =
3575       vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[1])),
3576                    vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[0])));
3577   const uint16x8_t result_2 =
3578       vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[3])),
3579                    vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[2])));
3580   const uint8x16_t result =
3581       vcombine_u8(vqmovn_u16(result_2), vqmovn_u16(result_1));
3582   vst1q_u8(output, result);
3583 }
3584 
3585 #endif
3586 
3587 template <typename In, typename Out>
SoftmaxInt8LUT(const SoftmaxParams & params,const RuntimeShape & input_shape,const In * input_data,const RuntimeShape & output_shape,Out * output_data)3588 inline void SoftmaxInt8LUT(const SoftmaxParams& params,
3589                            const RuntimeShape& input_shape,
3590                            const In* input_data,
3591                            const RuntimeShape& output_shape, Out* output_data) {
3592   ruy::profiler::ScopeLabel label("SoftmaxInt8LUT");
3593 
3594   const int trailing_dim = input_shape.DimensionsCount() - 1;
3595   const int excluding_last_dim =
3596       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3597   const int last_dim =
3598       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3599 
3600   const int32_t clamp_max = std::numeric_limits<Out>::max();
3601   const int32_t clamp_min = std::numeric_limits<Out>::min();
3602 
3603   // Offset is used to interpret the input data "correctly".
3604   // If the input is uint8, the data will be unchanged.
3605   // If the input is int8, since it will be reinterpret as uint8.
3606   // e.g.,
3607   // int8 127 will be applied "offset" to become 255 in uint8.
3608   uint8_t offset = 0;
3609   if (std::is_same<In, int8>::value) {
3610     offset = 0x80;
3611   }
3612 
3613   const uint8_t* input_data_uint = reinterpret_cast<const uint8_t*>(input_data);
3614 
3615 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3616   // This code uses ARM64-only instructions.
3617   // TODO(b/143709993): Port to ARMv7
3618 
3619   // Load the tables into registers. (4*4 128-bit registers)
3620   uint8x16x4_t table1[4];
3621   table1[0] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 0);
3622   table1[1] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 1);
3623   table1[2] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 2);
3624   table1[3] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 3);
3625 
3626   uint8x16x4_t table2[4];
3627   table2[0] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 0);
3628   table2[1] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 1);
3629   table2[2] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 2);
3630   table2[3] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 3);
3631 #endif
3632 
3633   for (int i = 0; i < excluding_last_dim; ++i) {
3634     // Find max quantized value.
3635     int32_t max_val = FindMaxValue(last_dim, input_data_uint, offset);
3636 
3637     int32 sum_exp = 0;
3638     const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3639     const uint8_t table_offset = max_uint8 - max_val;
3640 
3641     // Calculate normalizer sum(exp(x)).
3642     int sum_j = 0;
3643 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3644     uint8x16_t table_offset_dup = vdupq_n_u8(table_offset);
3645     uint8x16_t offset_dup = vdupq_n_u8(offset);
3646     uint32x4_t sum_4 = vdupq_n_u32(0);
3647     const int multiplier_shift = 8;
3648     for (; sum_j <= last_dim - 16; sum_j += 16) {
3649       uint8x16_t input_value = vld1q_u8(input_data_uint + sum_j);
3650       input_value = veorq_u8(input_value, offset_dup);
3651       input_value = vaddq_u8(input_value, table_offset_dup);
3652 
3653       const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
3654       const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
3655 
3656       uint16x8_t exp_value1 =
3657           vshll_n_u8(vget_high_u8(output1), multiplier_shift);
3658       uint16x8_t exp_value2 =
3659           vshll_n_u8(vget_low_u8(output1), multiplier_shift);
3660 
3661       exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
3662       exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
3663 
3664       sum_4 = vpadalq_u16(sum_4, exp_value1);
3665       sum_4 = vpadalq_u16(sum_4, exp_value2);
3666     }
3667     int temp = vgetq_lane_u32(sum_4, 0) + vgetq_lane_u32(sum_4, 1) +
3668                vgetq_lane_u32(sum_4, 2) + vgetq_lane_u32(sum_4, 3);
3669     sum_exp += temp;
3670 
3671 #endif
3672     for (; sum_j < last_dim; ++sum_j) {
3673       const uint8_t index = (input_data_uint[sum_j] ^ offset) + table_offset;
3674 
3675       uint8_t part1 = params.uint8_table1[index];
3676       uint8_t part2 = params.uint8_table2[index];
3677       sum_exp += ((part1 << 8) + part2);
3678     }
3679 
3680     const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
3681 
3682     int32 multiplier, shift;
3683     QuantizeMultiplier(inv_sum_exp, &multiplier, &shift);
3684 
3685     // Normalize and quantize probabilities.
3686     int j = 0;
3687 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3688     const int32x4_t output_zp_dup = vdupq_n_s32(params.zero_point);
3689     const int32x4_t max_val_dup = vdupq_n_s32(clamp_max);
3690     const int32x4_t min_val_dup = vdupq_n_s32(clamp_min);
3691 
3692     for (; j <= last_dim - 16; j += 16) {
3693       uint8x16_t input_value = vld1q_u8(input_data_uint + j);
3694       input_value = veorq_u8(input_value, offset_dup);
3695       input_value = vaddq_u8(input_value, table_offset_dup);
3696 
3697       const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
3698       const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
3699 
3700       uint16x8_t exp_value1 =
3701           vshll_n_u8(vget_high_u8(output1), multiplier_shift);
3702       uint16x8_t exp_value2 =
3703           vshll_n_u8(vget_low_u8(output1), multiplier_shift);
3704 
3705       exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
3706       exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
3707 
3708       int32x4x4_t output_value;
3709       output_value.val[0] =
3710           vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value1)));
3711       output_value.val[1] =
3712           vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value1)));
3713       output_value.val[2] =
3714           vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value2)));
3715       output_value.val[3] =
3716           vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value2)));
3717 
3718       int32x4x4_t temp_val =
3719           MultiplyByQuantizedMultiplier4Rows(output_value, multiplier, shift);
3720 
3721       temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
3722       temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
3723       temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
3724       temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
3725 
3726       temp_val.val[0] =
3727           vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
3728       temp_val.val[1] =
3729           vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
3730       temp_val.val[2] =
3731           vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
3732       temp_val.val[3] =
3733           vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
3734 
3735       StoreValue(temp_val, output_data + j);
3736     }
3737 #endif
3738     for (; j < last_dim; ++j) {
3739       const uint8_t index = (input_data_uint[j] ^ offset) + table_offset;
3740       const uint8_t part1 = params.uint8_table1[index];
3741       const uint8_t part2 = params.uint8_table2[index];
3742       const int32_t exp_value = (part1 << 8) + part2;
3743       const int32_t output_value =
3744           MultiplyByQuantizedMultiplier(exp_value, multiplier, shift);
3745 
3746       output_data[j] = static_cast<Out>(std::max(
3747           std::min(clamp_max, output_value + params.zero_point), clamp_min));
3748     }
3749     input_data_uint += last_dim;
3750     output_data += last_dim;
3751   }
3752 }
3753 
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3754 inline void LogSoftmax(const SoftmaxParams& params,
3755                        const RuntimeShape& input_shape, const float* input_data,
3756                        const RuntimeShape& output_shape, float* output_data) {
3757   ruy::profiler::ScopeLabel label("LogSoftmax");
3758   const int trailing_dim = input_shape.DimensionsCount() - 1;
3759   const int outer_size =
3760       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3761   const int depth =
3762       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3763 
3764   for (int i = 0; i < outer_size; ++i) {
3765     VectorMap<const float> block_input(input_data + i * depth, depth, 1);
3766     VectorMap<float> block_output(output_data + i * depth, depth, 1);
3767     // Find max element value which we'll use to ensure numerical stability
3768     // taking advantage of the following equality:
3769     // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
3770     const float max = block_input.maxCoeff();
3771     const float log_sum = std::log((block_input.array() - max).exp().sum());
3772     block_output = block_input.array() - max - log_sum;
3773   }
3774 }
3775 
3776 // Backwards compatibility. Less optimized than below version.
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3777 inline void LogSoftmax(const SoftmaxParams& params,
3778                        const RuntimeShape& input_shape, const uint8* input_data,
3779                        const RuntimeShape& output_shape, uint8* output_data) {
3780   reference_ops::LogSoftmax(params, input_shape, input_data, output_shape,
3781                             output_data);
3782 }
3783 
3784 // Compute LogSoftmax as (x - x_max) - ln(sum(e^(x_i - x_max)...)
3785 // as done in tf.nn.log_softmax to prevent underflow and overflow.
3786 // This is in contrast to just log(softmax(x))
3787 //
3788 // To handle quantization, first dequantize the inputs (from doing
3789 // e^(input scale * val) where we ignore the zero point since it cancels
3790 // out during subtraction due to the ln) and do a rescale at the end to int8.
3791 //
3792 // Notably this makes use of float and is intended as the optimized
3793 // form for quantized execution on CPU. For a fully integer version,
3794 // see the reference op.
3795 //
3796 // TODO(tflite): notes for optimization:
3797 // 1) See if e^ is also bottleneck in the reference fully-integer
3798 // version and apply lookup there and compare.
3799 template <typename T>
LogSoftmax(const SoftmaxParams & params,float input_scale,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)3800 inline void LogSoftmax(const SoftmaxParams& params, float input_scale,
3801                        const RuntimeShape& input_shape, const T* input_data,
3802                        const RuntimeShape& output_shape, T* output_data) {
3803   ruy::profiler::ScopeLabel label("LogSoftmax");
3804   const int trailing_dim = input_shape.DimensionsCount() - 1;
3805   const int excluding_last_dim =
3806       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3807   const int last_dim =
3808       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3809 
3810   const int32_t clamp_max = std::numeric_limits<T>::max();
3811   const int32_t clamp_min = std::numeric_limits<T>::min();
3812 
3813   for (int i = 0; i < excluding_last_dim; ++i) {
3814     T max_val = std::numeric_limits<T>::min();
3815     // Find max quantized value.
3816     for (int j = 0; j < last_dim; ++j) {
3817       max_val = std::max(max_val, input_data[j]);
3818     }
3819 
3820     float sum_exp = 0.0f;
3821     const int32_t max_uint8 = std::numeric_limits<uint8>::max();
3822     // Offset into table to compute exp(scale*(x - xmax)) instead of
3823     // exp(scale*(x)) to prevent overflow.
3824     const float* table_offset = &params.table[max_uint8 - max_val];
3825     // Calculate sum(exp(scale*(x - x_max))).
3826     for (int j = 0; j < last_dim; ++j) {
3827       sum_exp += table_offset[input_data[j]];
3828     }
3829     const float log_sum_exp = std::log(sum_exp);
3830 
3831     // params.scale is the output scale.
3832     const float scale = input_scale / params.scale;
3833     const float precomputed =
3834         (input_scale * max_val + log_sum_exp) / params.scale;
3835     for (int j = 0; j < last_dim; ++j) {
3836       // Equivalent to (input_scale * (input_data[j] - max_val) - log_sum_exp) /
3837       // output_scale.
3838       const float log_prob = scale * input_data[j] - precomputed;
3839 
3840       // TODO(tflite): look into better solution.
3841       // Use std::rint over std::round (which is used in
3842       // FakeQuant) since it's multiple times faster on tested arm32.
3843       const int32_t prob_quantized = std::rint(log_prob) + params.zero_point;
3844       output_data[j] = static_cast<T>(
3845           std::max(std::min(clamp_max, prob_quantized), clamp_min));
3846     }
3847     input_data += last_dim;
3848     output_data += last_dim;
3849   }
3850 }
3851 
Logistic(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3852 inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
3853                      const RuntimeShape& output_shape, float* output_data) {
3854   ruy::profiler::ScopeLabel label("Logistic");
3855   auto input_map = MapAsVector(input_data, input_shape);
3856   auto output_map = MapAsVector(output_data, output_shape);
3857   output_map.array() =
3858       input_map.array().unaryExpr(Eigen::internal::scalar_logistic_op<float>());
3859 }
3860 
3861 // Convenience version that allows, for example, generated-code calls to be
3862 // uniform between data types.
Logistic(const LogisticParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3863 inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
3864                      const float* input_data, const RuntimeShape& output_shape,
3865                      float* output_data) {
3866   // Drop params: not needed.
3867   Logistic(input_shape, input_data, output_shape, output_data);
3868 }
3869 
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)3870 inline void Logistic(const LogisticParams& params,
3871                      const RuntimeShape& input_shape, const int16* input_data,
3872                      const RuntimeShape& output_shape, int16* output_data) {
3873   ruy::profiler::ScopeLabel label("Logistic/Int16");
3874   const int flat_size = MatchingFlatSize(input_shape, output_shape);
3875 
3876   for (int i = 0; i < flat_size; i++) {
3877   }
3878 
3879   int c = 0;
3880   const int16* input_data_ptr = input_data;
3881   int16* output_data_ptr = output_data;
3882 #ifdef GEMMLOWP_NEON
3883   {
3884     // F0 uses 0 integer bits, range [-1, 1].
3885     // This is the return type of math functions such as tanh, logistic,
3886     // whose range is in [-1, 1].
3887     using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
3888     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
3889     using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
3890 
3891     for (; c <= flat_size - 16; c += 16) {
3892       F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
3893       F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
3894       F0 output0 = gemmlowp::logistic(input0);
3895       F0 output1 = gemmlowp::logistic(input1);
3896       vst1q_s16(output_data_ptr, output0.raw());
3897       vst1q_s16(output_data_ptr + 8, output1.raw());
3898 
3899       input_data_ptr += 16;
3900       output_data_ptr += 16;
3901     }
3902     for (; c <= flat_size - 8; c += 8) {
3903       F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
3904       F0 output = gemmlowp::logistic(input);
3905       vst1q_s16(output_data_ptr, output.raw());
3906 
3907       input_data_ptr += 8;
3908       output_data_ptr += 8;
3909     }
3910   }
3911 #endif
3912 #ifdef GEMMLOWP_SSE4
3913   {
3914     // F0 uses 0 integer bits, range [-1, 1].
3915     // This is the return type of math functions such as tanh, logistic,
3916     // whose range is in [-1, 1].
3917     using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
3918     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
3919     using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
3920 
3921     for (; c <= flat_size - 16; c += 16) {
3922       F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
3923           _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
3924       F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
3925           reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
3926       F0 output0 = gemmlowp::logistic(input0);
3927       F0 output1 = gemmlowp::logistic(input1);
3928       _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
3929                        output0.raw().v);
3930       _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
3931                        output1.raw().v);
3932       input_data_ptr += 16;
3933       output_data_ptr += 16;
3934     }
3935     for (; c <= flat_size - 8; c += 8) {
3936       F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
3937           _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
3938       F0 output = gemmlowp::logistic(input);
3939       _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
3940                        output.raw().v);
3941       input_data_ptr += 8;
3942       output_data_ptr += 8;
3943     }
3944   }
3945 #endif
3946 
3947   {
3948     // F0 uses 0 integer bits, range [-1, 1].
3949     // This is the return type of math functions such as tanh, logistic,
3950     // whose range is in [-1, 1].
3951     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3952     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
3953     using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
3954 
3955     for (; c < flat_size; ++c) {
3956       F3 input = F3::FromRaw(*input_data_ptr);
3957       F0 output = gemmlowp::logistic(input);
3958       *output_data_ptr = output.raw();
3959 
3960       ++input_data_ptr;
3961       ++output_data_ptr;
3962     }
3963   }
3964 }
3965 
Tanh(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3966 inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
3967                  const RuntimeShape& output_shape, float* output_data) {
3968   ruy::profiler::ScopeLabel label("Tanh");
3969   auto input_map = MapAsVector(input_data, input_shape);
3970   auto output_map = MapAsVector(output_data, output_shape);
3971   output_map.array() = input_map.array().tanh();
3972 }
3973 
3974 // Convenience version that allows, for example, generated-code calls to be
3975 // uniform between data types.
Tanh(const TanhParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3976 inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
3977                  const float* input_data, const RuntimeShape& output_shape,
3978                  float* output_data) {
3979   // Drop params: not needed.
3980   Tanh(input_shape, input_data, output_shape, output_data);
3981 }
3982 
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)3983 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
3984                  const int16* input_data, const RuntimeShape& output_shape,
3985                  int16* output_data) {
3986   ruy::profiler::ScopeLabel label("Tanh/Int16");
3987   const int input_left_shift = params.input_left_shift;
3988   // Support for shifts is limited until we have a parameterized version of
3989   // SaturatingRoundingMultiplyByPOT().
3990   TFLITE_DCHECK_GE(input_left_shift, 0);
3991   TFLITE_DCHECK_LE(input_left_shift, 1);
3992 
3993   const int flat_size = MatchingFlatSize(input_shape, output_shape);
3994 
3995   int c = 0;
3996   const int16* input_data_ptr = input_data;
3997   int16* output_data_ptr = output_data;
3998 #ifdef GEMMLOWP_NEON
3999   {
4000     // F0 uses 0 integer bits, range [-1, 1].
4001     // This is the return type of math functions such as tanh, logistic,
4002     // whose range is in [-1, 1].
4003     using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
4004     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4005     using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
4006 
4007     if (input_left_shift == 0) {
4008       for (; c <= flat_size - 16; c += 16) {
4009         F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
4010         F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
4011         F0 output0 = gemmlowp::tanh(input0);
4012         F0 output1 = gemmlowp::tanh(input1);
4013         vst1q_s16(output_data_ptr, output0.raw());
4014         vst1q_s16(output_data_ptr + 8, output1.raw());
4015 
4016         input_data_ptr += 16;
4017         output_data_ptr += 16;
4018       }
4019       for (; c <= flat_size - 8; c += 8) {
4020         F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
4021         F0 output = gemmlowp::tanh(input);
4022         vst1q_s16(output_data_ptr, output.raw());
4023 
4024         input_data_ptr += 8;
4025         output_data_ptr += 8;
4026       }
4027     } else {
4028       for (; c <= flat_size - 16; c += 16) {
4029         F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4030             vld1q_s16(input_data_ptr)));
4031         F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4032             vld1q_s16(input_data_ptr + 8)));
4033         F0 output0 = gemmlowp::tanh(input0);
4034         F0 output1 = gemmlowp::tanh(input1);
4035         vst1q_s16(output_data_ptr, output0.raw());
4036         vst1q_s16(output_data_ptr + 8, output1.raw());
4037 
4038         input_data_ptr += 16;
4039         output_data_ptr += 16;
4040       }
4041       for (; c <= flat_size - 8; c += 8) {
4042         F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4043             vld1q_s16(input_data_ptr)));
4044         F0 output = gemmlowp::tanh(input);
4045         vst1q_s16(output_data_ptr, output.raw());
4046 
4047         input_data_ptr += 8;
4048         output_data_ptr += 8;
4049       }
4050     }
4051   }
4052 #endif
4053 #ifdef GEMMLOWP_SSE4
4054   {
4055     // F0 uses 0 integer bits, range [-1, 1].
4056     // This is the return type of math functions such as tanh, logistic,
4057     // whose range is in [-1, 1].
4058     using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
4059     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4060     using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
4061 
4062     if (input_left_shift == 0) {
4063       for (; c <= flat_size - 16; c += 16) {
4064         F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4065             _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4066         F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4067             reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
4068         F0 output0 = gemmlowp::tanh(input0);
4069         F0 output1 = gemmlowp::tanh(input1);
4070         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4071                          output0.raw().v);
4072         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4073                          output1.raw().v);
4074 
4075         input_data_ptr += 16;
4076         output_data_ptr += 16;
4077       }
4078       for (; c <= flat_size - 8; c += 8) {
4079         F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4080             _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4081         F0 output = gemmlowp::tanh(input);
4082         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4083                          output.raw().v);
4084         input_data_ptr += 8;
4085         output_data_ptr += 8;
4086       }
4087     } else {
4088       for (; c <= flat_size - 16; c += 16) {
4089         F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4090             gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4091                 reinterpret_cast<const __m128i*>(input_data_ptr)))));
4092         F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4093             gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4094                 reinterpret_cast<const __m128i*>(input_data_ptr + 8)))));
4095         F0 output0 = gemmlowp::tanh(input0);
4096         F0 output1 = gemmlowp::tanh(input1);
4097         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4098                          output0.raw().v);
4099         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4100                          output1.raw().v);
4101 
4102         input_data_ptr += 16;
4103         output_data_ptr += 16;
4104       }
4105       for (; c <= flat_size - 8; c += 8) {
4106         F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4107             gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4108                 reinterpret_cast<const __m128i*>(input_data_ptr)))));
4109         F0 output = gemmlowp::tanh(input);
4110         _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4111                          output.raw().v);
4112         input_data_ptr += 8;
4113         output_data_ptr += 8;
4114       }
4115     }
4116   }
4117 #endif
4118 
4119   {
4120     // F0 uses 0 integer bits, range [-1, 1].
4121     // This is the return type of math functions such as tanh, logistic,
4122     // whose range is in [-1, 1].
4123     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
4124     // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4125     using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
4126 
4127     if (input_left_shift == 0) {
4128       for (; c < flat_size; ++c) {
4129         F3 input = F3::FromRaw(*input_data_ptr);
4130         F0 output = gemmlowp::tanh(input);
4131         *output_data_ptr = output.raw();
4132 
4133         ++input_data_ptr;
4134         ++output_data_ptr;
4135       }
4136     } else {
4137       for (; c < flat_size; ++c) {
4138         F3 input = F3::FromRaw(
4139             gemmlowp::SaturatingRoundingMultiplyByPOT<1>(*input_data_ptr));
4140         F0 output = gemmlowp::tanh(input);
4141         *output_data_ptr = output.raw();
4142 
4143         ++input_data_ptr;
4144         ++output_data_ptr;
4145       }
4146     }
4147   }
4148 }
4149 
4150 template <typename SrcT, typename DstT>
Cast(const RuntimeShape & input_shape,const SrcT * input_data,const RuntimeShape & output_shape,DstT * output_data)4151 inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
4152                  const RuntimeShape& output_shape, DstT* output_data) {
4153   ruy::profiler::ScopeLabel label("Cast");
4154   auto input_map = MapAsVector(input_data, input_shape);
4155   auto output_map = MapAsVector(output_data, output_shape);
4156   output_map.array() = input_map.array().template cast<DstT>();
4157 }
4158 
Floor(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4159 inline void Floor(const RuntimeShape& input_shape, const float* input_data,
4160                   const RuntimeShape& output_shape, float* output_data) {
4161   ruy::profiler::ScopeLabel label("Floor");
4162   auto input_map = MapAsVector(input_data, input_shape);
4163   auto output_map = MapAsVector(output_data, output_shape);
4164   output_map.array() = Eigen::floor(input_map.array());
4165 }
4166 
Ceil(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4167 inline void Ceil(const RuntimeShape& input_shape, const float* input_data,
4168                  const RuntimeShape& output_shape, float* output_data) {
4169   ruy::profiler::ScopeLabel label("Ceil");
4170   auto input_map = MapAsVector(input_data, input_shape);
4171   auto output_map = MapAsVector(output_data, output_shape);
4172   output_map.array() = Eigen::ceil(input_map.array());
4173 }
4174 
4175 // Helper methods for BatchToSpaceND.
4176 // `spatial_index_dim` specifies post-crop offset index in this spatial
4177 // dimension, i.e. spatial offset introduced by flattening batch to spatial
4178 // dimension minus the crop size at beginning. `block_shape_dim` is the block
4179 // size in current dimension. `input_dim` and `output_dim` are input and output
4180 // size of BatchToSpaceND operation in current dimension.
4181 // Output start index is inclusive and end index is exclusive.
GetIndexRange(int spatial_index_dim,int block_shape_dim,int input_dim,int output_dim,int * start_index,int * end_index)4182 inline void GetIndexRange(int spatial_index_dim, int block_shape_dim,
4183                           int input_dim, int output_dim, int* start_index,
4184                           int* end_index) {
4185   // (*start_index) * block_shape_dim is effectively rounded up to the next
4186   // multiple of block_shape_dim by the integer division.
4187   *start_index =
4188       std::max(0, (-spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
4189   // Similarly, (*end_index) * block_shape_dim is rounded up too (note that
4190   // end_index is exclusive).
4191   *end_index = std::min(
4192       input_dim,
4193       (output_dim - spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
4194 }
4195 
4196 template <typename T>
BatchToSpaceND(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const int32 * block_shape_data,const RuntimeShape & unextended_input3_shape,const int32 * crops_data,const RuntimeShape & unextended_output_shape,T * output_data)4197 inline void BatchToSpaceND(
4198     const RuntimeShape& unextended_input1_shape, const T* input1_data,
4199     const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
4200     const RuntimeShape& unextended_input3_shape, const int32* crops_data,
4201     const RuntimeShape& unextended_output_shape, T* output_data) {
4202   ruy::profiler::ScopeLabel label("BatchToSpaceND");
4203 
4204   TFLITE_DCHECK_GE(unextended_input1_shape.DimensionsCount(), 3);
4205   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
4206   TFLITE_DCHECK_EQ(unextended_input1_shape.DimensionsCount(),
4207                    unextended_output_shape.DimensionsCount());
4208 
4209   // Extends the input/output shape from 3D to 4D if needed, NHC -> NH1C.
4210   auto extend_shape = [](const RuntimeShape& shape) {
4211     if (shape.DimensionsCount() == 4) {
4212       return shape;
4213     }
4214     RuntimeShape new_shape(4, 1);
4215     new_shape.SetDim(0, shape.Dims(0));
4216     new_shape.SetDim(1, shape.Dims(1));
4217     new_shape.SetDim(3, shape.Dims(2));
4218     return new_shape;
4219   };
4220   const RuntimeShape input1_shape = extend_shape(unextended_input1_shape);
4221   const RuntimeShape output_shape = extend_shape(unextended_output_shape);
4222 
4223   const int output_width = output_shape.Dims(2);
4224   const int output_height = output_shape.Dims(1);
4225   const int output_batch_size = output_shape.Dims(0);
4226 
4227   const int depth = input1_shape.Dims(3);
4228   const int input_width = input1_shape.Dims(2);
4229   const int input_height = input1_shape.Dims(1);
4230   const int input_batch_size = input1_shape.Dims(0);
4231 
4232   const int block_shape_height = block_shape_data[0];
4233   const int block_shape_width =
4234       unextended_input1_shape.DimensionsCount() == 4 ? block_shape_data[1] : 1;
4235   const int crops_top = crops_data[0];
4236   const int crops_left =
4237       unextended_input1_shape.DimensionsCount() == 4 ? crops_data[2] : 0;
4238 
4239   for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
4240     const int out_batch = in_batch % output_batch_size;
4241     const int spatial_offset = in_batch / output_batch_size;
4242 
4243     int in_h_start = 0;
4244     int in_h_end = 0;
4245     // GetIndexRange ensures start and end indices are in [0, output_height).
4246     GetIndexRange(spatial_offset / block_shape_width - crops_top,
4247                   block_shape_height, input_height, output_height, &in_h_start,
4248                   &in_h_end);
4249 
4250     for (int in_h = in_h_start; in_h < in_h_end; ++in_h) {
4251       const int out_h = in_h * block_shape_height +
4252                         spatial_offset / block_shape_width - crops_top;
4253       TFLITE_DCHECK_GE(out_h, 0);
4254       TFLITE_DCHECK_LT(out_h, output_height);
4255 
4256       int in_w_start = 0;
4257       int in_w_end = 0;
4258       // GetIndexRange ensures start and end indices are in [0, output_width).
4259       GetIndexRange(spatial_offset % block_shape_width - crops_left,
4260                     block_shape_width, input_width, output_width, &in_w_start,
4261                     &in_w_end);
4262 
4263       for (int in_w = in_w_start; in_w < in_w_end; ++in_w) {
4264         const int out_w = in_w * block_shape_width +
4265                           spatial_offset % block_shape_width - crops_left;
4266         TFLITE_DCHECK_GE(out_w, 0);
4267         TFLITE_DCHECK_LT(out_w, output_width);
4268         T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
4269         const T* in =
4270             input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
4271         memcpy(out, in, depth * sizeof(T));
4272       }
4273     }
4274   }
4275 }
4276 
4277 template <typename T>
TypedMemset(void * ptr,T value,size_t num)4278 void TypedMemset(void* ptr, T value, size_t num) {
4279   // Optimization for common cases where memset() will suffice.
4280   if (value == 0 || std::is_same<T, uint8_t>::value) {
4281     memset(ptr, value, num * sizeof(T));
4282   } else {
4283     // Default implementation for cases where memset() will not preserve the
4284     // bytes, e.g., typically when sizeof(T) > sizeof(uint8_t).
4285     char* pos = static_cast<char*>(ptr);
4286     for (size_t i = 0; i < num; ++i) {
4287       memcpy(pos, &value, sizeof(T));
4288       pos = pos + sizeof(T);
4289     }
4290   }
4291 }
4292 
4293 // This makes heavy use of Offset, along with conditional branches. There may be
4294 // opportunities for improvement.
4295 //
4296 // There are two versions of pad: Pad and PadV2.  In PadV2 there is a second
4297 // scalar input that provides the padding value.  Therefore pad_value_ptr can be
4298 // equivalent to a simple input1_data.  For Pad, it should point to a zero
4299 // value.
4300 //
4301 // Note that two typenames are required, so that T=P=int32 is considered a
4302 // specialization distinct from P=int32.
4303 template <typename T, typename P>
PadImpl(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4304 inline void PadImpl(const tflite::PadParams& op_params,
4305                     const RuntimeShape& input_shape, const T* input_data,
4306                     const P* pad_value_ptr, const RuntimeShape& output_shape,
4307                     T* output_data) {
4308   ruy::profiler::ScopeLabel label("PadImpl");
4309   const int max_supported_dims = 5;
4310   const RuntimeShape ext_input_shape =
4311       RuntimeShape::ExtendedShape(max_supported_dims, input_shape);
4312   const RuntimeShape ext_output_shape =
4313       RuntimeShape::ExtendedShape(max_supported_dims, output_shape);
4314   TFLITE_DCHECK_LE(op_params.left_padding_count, max_supported_dims);
4315   TFLITE_DCHECK_LE(op_params.right_padding_count, max_supported_dims);
4316 
4317   // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
4318   // to 4 dims (yes, we are "padding the padding").
4319   std::vector<int> left_padding_copy(max_supported_dims, 0);
4320   const int left_padding_extend =
4321       max_supported_dims - op_params.left_padding_count;
4322   for (int i = 0; i < op_params.left_padding_count; ++i) {
4323     left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
4324   }
4325   std::vector<int> right_padding_copy(max_supported_dims, 0);
4326   const int right_padding_extend =
4327       max_supported_dims - op_params.right_padding_count;
4328   for (int i = 0; i < op_params.right_padding_count; ++i) {
4329     right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
4330   }
4331 
4332   const int output_batch = ext_output_shape.Dims(0);
4333   const int output_spatial_dim1 = ext_output_shape.Dims(1);
4334   const int output_spatial_dim2 = ext_output_shape.Dims(2);
4335   const int output_spatial_dim3 = ext_output_shape.Dims(3);
4336   const int output_channel = ext_output_shape.Dims(4);
4337 
4338   const int left_b_padding = left_padding_copy[0];
4339   const int left_s1_padding = left_padding_copy[1];
4340   const int left_s2_padding = left_padding_copy[2];
4341   const int left_s3_padding = left_padding_copy[3];
4342   const int left_c_padding = left_padding_copy[4];
4343 
4344   const int right_b_padding = right_padding_copy[0];
4345   const int right_s1_padding = right_padding_copy[1];
4346   const int right_s2_padding = right_padding_copy[2];
4347   const int right_s3_padding = right_padding_copy[3];
4348   const int right_c_padding = right_padding_copy[4];
4349 
4350   const int input_depth = ext_input_shape.Dims(4);
4351   const T pad_value = *pad_value_ptr;
4352 
4353   if (left_b_padding != 0) {
4354     TypedMemset<T>(output_data, pad_value,
4355                    left_b_padding * output_spatial_dim1 * output_spatial_dim2 *
4356                        output_spatial_dim3 * output_channel);
4357   }
4358   for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
4359        ++out_b) {
4360     if (left_s1_padding != 0) {
4361       TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, 0, 0, 0, 0),
4362                      pad_value,
4363                      left_s1_padding * output_spatial_dim2 *
4364                          output_spatial_dim3 * output_channel);
4365     }
4366     for (int out_p = left_s1_padding;
4367          out_p < output_spatial_dim1 - right_s1_padding; ++out_p) {
4368       if (left_s2_padding != 0) {
4369         TypedMemset<T>(
4370             output_data + Offset(ext_output_shape, out_b, out_p, 0, 0, 0),
4371             pad_value, left_s2_padding * output_spatial_dim3 * output_channel);
4372       }
4373       for (int out_h = left_s2_padding;
4374            out_h < output_spatial_dim2 - right_s2_padding; ++out_h) {
4375         if (left_s3_padding != 0) {
4376           TypedMemset<T>(
4377               output_data + Offset(ext_output_shape, out_b, out_p, out_h, 0, 0),
4378               pad_value, left_s3_padding * output_channel);
4379         }
4380         for (int out_w = left_s3_padding;
4381              out_w < output_spatial_dim3 - right_s3_padding; ++out_w) {
4382           if (left_c_padding != 0) {
4383             TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, out_p,
4384                                                 out_h, out_w, 0),
4385                            pad_value, left_c_padding);
4386           }
4387 
4388           T* out = output_data + Offset(ext_output_shape, out_b, out_p, out_h,
4389                                         out_w, left_c_padding);
4390           const T* in = input_data +
4391                         Offset(ext_input_shape, out_b - left_b_padding,
4392                                out_p - left_s1_padding, out_h - left_s2_padding,
4393                                out_w - left_s3_padding, 0);
4394           memcpy(out, in, input_depth * sizeof(T));
4395 
4396           if (right_c_padding != 0) {
4397             TypedMemset<T>(
4398                 output_data + Offset(ext_output_shape, out_b, out_p, out_h,
4399                                      out_w, output_channel - right_c_padding),
4400                 pad_value, right_c_padding);
4401           }
4402         }
4403         if (right_s3_padding != 0) {
4404           TypedMemset<T>(
4405               output_data + Offset(ext_output_shape, out_b, out_p, out_h,
4406                                    output_spatial_dim3 - right_s3_padding, 0),
4407               pad_value, right_s3_padding * output_channel);
4408         }
4409       }
4410       if (right_s2_padding != 0) {
4411         TypedMemset<T>(
4412             output_data + Offset(ext_output_shape, out_b, out_p,
4413                                  output_spatial_dim2 - right_s2_padding, 0, 0),
4414             pad_value, right_s2_padding * output_spatial_dim3 * output_channel);
4415       }
4416     }
4417     if (right_s1_padding != 0) {
4418       TypedMemset<T>(
4419           output_data + Offset(ext_output_shape, out_b,
4420                                output_spatial_dim1 - right_s1_padding, 0, 0, 0),
4421           pad_value,
4422           right_s1_padding * output_spatial_dim2 * output_spatial_dim3 *
4423               output_channel);
4424     }
4425   }
4426   if (right_b_padding != 0) {
4427     TypedMemset<T>(
4428         output_data + Offset(ext_output_shape, output_batch - right_b_padding,
4429                              0, 0, 0, 0),
4430         pad_value,
4431         right_b_padding * output_spatial_dim1 * output_spatial_dim2 *
4432             output_spatial_dim3 * output_channel);
4433   }
4434 }
4435 
4436 template <typename T, typename P>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4437 inline void Pad(const tflite::PadParams& op_params,
4438                 const RuntimeShape& input_shape, const T* input_data,
4439                 const P* pad_value_ptr, const RuntimeShape& output_shape,
4440                 T* output_data) {
4441   PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
4442           output_data);
4443 }
4444 
4445 // The second (pad-value) input can be int32 when, say, the first is uint8.
4446 template <typename T>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4447 inline void Pad(const tflite::PadParams& op_params,
4448                 const RuntimeShape& input_shape, const T* input_data,
4449                 const int32* pad_value_ptr, const RuntimeShape& output_shape,
4450                 T* output_data) {
4451   const T converted_pad_value = static_cast<T>(*pad_value_ptr);
4452   PadImpl(op_params, input_shape, input_data, &converted_pad_value,
4453           output_shape, output_data);
4454 }
4455 
4456 // This version avoids conflicting template matching.
4457 template <>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const int32 * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,int32 * output_data)4458 inline void Pad(const tflite::PadParams& op_params,
4459                 const RuntimeShape& input_shape, const int32* input_data,
4460                 const int32* pad_value_ptr, const RuntimeShape& output_shape,
4461                 int32* output_data) {
4462   PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
4463           output_data);
4464 }
4465 
4466 // TODO(b/117643175): Optimize. (This is an introductory copy of standard Pad.)
4467 //
4468 // This pad requires that (a) left and right paddings are in the 4D patterns
4469 // {0, h_pad, w_pad, 0}, and (b) memset can be used: *pad_value_ptr == 0 and/or
4470 // T is uint8.
4471 //
4472 // There are two versions of pad: Pad and PadV2.  In PadV2 there is a second
4473 // scalar input that provides the padding value.  Therefore pad_value_ptr can be
4474 // equivalent to a simple input1_data.  For Pad, it should point to a zero
4475 // value.
4476 //
4477 // Note that two typenames are required, so that T=P=int32 is considered a
4478 // specialization distinct from P=int32.
4479 template <typename T, typename P>
PadImageStyleMemset(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4480 inline void PadImageStyleMemset(const tflite::PadParams& op_params,
4481                                 const RuntimeShape& input_shape,
4482                                 const T* input_data, const P* pad_value_ptr,
4483                                 const RuntimeShape& output_shape,
4484                                 T* output_data) {
4485   ruy::profiler::ScopeLabel label("PadImageStyle");
4486   const RuntimeShape ext_input_shape =
4487       RuntimeShape::ExtendedShape(4, input_shape);
4488   const RuntimeShape ext_output_shape =
4489       RuntimeShape::ExtendedShape(4, output_shape);
4490   TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
4491   TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
4492 
4493   // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
4494   // to 4 dims (yes, we are "padding the padding").
4495   std::vector<int> left_padding_copy(4, 0);
4496   const int left_padding_extend = 4 - op_params.left_padding_count;
4497   for (int i = 0; i < op_params.left_padding_count; ++i) {
4498     left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
4499   }
4500   std::vector<int> right_padding_copy(4, 0);
4501   const int right_padding_extend = 4 - op_params.right_padding_count;
4502   for (int i = 0; i < op_params.right_padding_count; ++i) {
4503     right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
4504   }
4505   // The following padding restrictions are contractual requirements, and
4506   // embody what it means for a padding op to be "image-style".
4507   TFLITE_DCHECK_EQ(left_padding_copy[0], 0);
4508   TFLITE_DCHECK_EQ(left_padding_copy[3], 0);
4509   TFLITE_DCHECK_EQ(right_padding_copy[0], 0);
4510   TFLITE_DCHECK_EQ(right_padding_copy[3], 0);
4511 
4512   const int batch = MatchingDim(ext_input_shape, 0, ext_output_shape, 0);
4513   const int output_height = ext_output_shape.Dims(1);
4514   const int output_width = ext_output_shape.Dims(2);
4515   const int input_height = ext_input_shape.Dims(1);
4516   const int input_width = ext_input_shape.Dims(2);
4517   const int depth = MatchingDim(ext_input_shape, 3, ext_output_shape, 3);
4518 
4519   const int left_h_padding = left_padding_copy[1];
4520   const int left_w_padding = left_padding_copy[2];
4521   const int right_h_padding = right_padding_copy[1];
4522   const int right_w_padding = right_padding_copy[2];
4523 
4524   TFLITE_DCHECK_EQ(output_height,
4525                    input_height + left_h_padding + right_h_padding);
4526   TFLITE_DCHECK_EQ(output_width,
4527                    input_width + left_w_padding + right_w_padding);
4528 
4529   const T pad_value = *pad_value_ptr;
4530   const int top_block_size = left_h_padding * output_width * depth;
4531   const size_t num_top_block_bytes = top_block_size * sizeof(T);
4532   const int bottom_block_size = right_h_padding * output_width * depth;
4533   const size_t num_bottom_block_bytes = bottom_block_size * sizeof(T);
4534   const int left_blocks_size = left_w_padding * depth;
4535   const size_t num_left_block_bytes = left_blocks_size * sizeof(T);
4536   const int right_blocks_size = right_w_padding * depth;
4537   const size_t num_right_block_bytes = right_blocks_size * sizeof(T);
4538   const int inner_line_size = input_width * depth;
4539   const size_t num_inner_line_bytes = inner_line_size * sizeof(T);
4540 
4541   if (input_height == 0) {
4542     memset(output_data, pad_value,
4543            num_top_block_bytes + num_bottom_block_bytes);
4544   } else {
4545     for (int i = 0; i < batch; ++i) {
4546       // For each image in the batch, apply the top padding, then iterate
4547       // through rows, then apply the bottom padding.
4548       //
4549       // By unwinding one iteration, we can combine the first left-margin
4550       // padding with the top padding, and the last right-margin padding with
4551       // the bottom padding.
4552       memset(output_data, pad_value,
4553              num_top_block_bytes + num_left_block_bytes);
4554       output_data += top_block_size + left_blocks_size;
4555       memcpy(output_data, input_data, num_inner_line_bytes);
4556       input_data += inner_line_size;
4557       output_data += inner_line_size;
4558       // One iteration unwound.
4559       // Unwinding this loop affords the opportunity to reorder the loop work
4560       // and hence combine memset() calls.
4561       //
4562       // Before unwinding:
4563       // for (int j = 0; j < input_height; ++j) {
4564       //   // Pad on left, copy central data, pad on right.
4565       //   memset(output_data, pad_value, num_left_block_bytes);
4566       //   output_data += left_blocks_size;
4567       //   memcpy(output_data, input_data, num_inner_line_bytes);
4568       //   input_data += inner_line_size;
4569       //   output_data += inner_line_size;
4570       //   memset(output_data, pad_value, num_right_block_bytes);
4571       //   output_data += right_blocks_size;
4572       // }
4573       for (int j = 1; j < input_height; ++j) {
4574         memset(output_data, pad_value,
4575                num_right_block_bytes + num_left_block_bytes);
4576         output_data += right_blocks_size + left_blocks_size;
4577         memcpy(output_data, input_data, num_inner_line_bytes);
4578         input_data += inner_line_size;
4579         output_data += inner_line_size;
4580       }
4581       memset(output_data, pad_value,
4582              num_right_block_bytes + num_bottom_block_bytes);
4583       output_data += right_blocks_size + bottom_block_size;
4584     }
4585   }
4586 }
4587 
4588 template <typename T, typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4589 inline void PadImageStyle(const tflite::PadParams& op_params,
4590                           const RuntimeShape& input_shape, const T* input_data,
4591                           const P* pad_value_ptr,
4592                           const RuntimeShape& output_shape, T* output_data) {
4593   reference_ops::PadImageStyle(op_params, input_shape, input_data,
4594                                pad_value_ptr, output_shape, output_data);
4595 }
4596 
4597 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,uint8 * output_data)4598 inline void PadImageStyle(const tflite::PadParams& op_params,
4599                           const RuntimeShape& input_shape,
4600                           const uint8* input_data, const P* pad_value_ptr,
4601                           const RuntimeShape& output_shape,
4602                           uint8* output_data) {
4603   PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
4604                       output_shape, output_data);
4605 }
4606 
4607 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const float * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,float * output_data)4608 inline void PadImageStyle(const tflite::PadParams& op_params,
4609                           const RuntimeShape& input_shape,
4610                           const float* input_data, const P* pad_value_ptr,
4611                           const RuntimeShape& output_shape,
4612                           float* output_data) {
4613   const float converted_pad_value = static_cast<float>(*pad_value_ptr);
4614   if (converted_pad_value == 0.0f) {
4615     PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
4616                         output_shape, output_data);
4617   } else {
4618     PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
4619             output_data);
4620   }
4621 }
4622 
4623 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,SequentialTensorWriter<T> * writer)4624 inline void Slice(const tflite::SliceParams& op_params,
4625                   const RuntimeShape& input_shape,
4626                   const RuntimeShape& output_shape,
4627                   SequentialTensorWriter<T>* writer) {
4628   ruy::profiler::ScopeLabel label("Slice");
4629   const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
4630   TFLITE_DCHECK_LE(op_params.begin_count, 5);
4631   TFLITE_DCHECK_LE(op_params.size_count, 5);
4632   const int begin_count = op_params.begin_count;
4633   const int size_count = op_params.size_count;
4634   // We front-pad the begin and size vectors.
4635   std::array<int, 5> start;
4636   std::array<int, 5> stop;
4637   for (int i = 0; i < 5; ++i) {
4638     int padded_i = 5 - i;
4639     start[i] =
4640         begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
4641     stop[i] =
4642         (size_count < padded_i || op_params.size[size_count - padded_i] == -1)
4643             ? ext_shape.Dims(i)
4644             : start[i] + op_params.size[size_count - padded_i];
4645   }
4646 
4647   for (int i0 = start[0]; i0 < stop[0]; ++i0) {
4648     for (int i1 = start[1]; i1 < stop[1]; ++i1) {
4649       for (int i2 = start[2]; i2 < stop[2]; ++i2) {
4650         for (int i3 = start[3]; i3 < stop[3]; ++i3) {
4651           const int len = stop[4] - start[4];
4652           if (len > 0)
4653             writer->WriteN(Offset(ext_shape, i0, i1, i2, i3, start[4]), len);
4654         }
4655       }
4656     }
4657   }
4658 }
4659 
4660 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)4661 inline void Slice(const tflite::SliceParams& op_params,
4662                   const RuntimeShape& input_shape, const T* input_data,
4663                   const RuntimeShape& output_shape, T* output_data) {
4664   SequentialTensorWriter<T> writer(input_data, output_data);
4665   return Slice(op_params, input_shape, output_shape, &writer);
4666 }
4667 
4668 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const TfLiteTensor * input,const RuntimeShape & output_shape,TfLiteTensor * output)4669 inline void Slice(const tflite::SliceParams& op_params,
4670                   const RuntimeShape& input_shape, const TfLiteTensor* input,
4671                   const RuntimeShape& output_shape, TfLiteTensor* output) {
4672   SequentialTensorWriter<T> writer(input, output);
4673   return Slice(op_params, input_shape, output_shape, &writer);
4674 }
4675 
4676 // Note: This implementation is only optimized for the case where the inner
4677 // stride == 1.
4678 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const RuntimeShape & unextended_output_shape,SequentialTensorWriter<T> * writer)4679 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
4680                          const RuntimeShape& unextended_input_shape,
4681                          const RuntimeShape& unextended_output_shape,
4682                          SequentialTensorWriter<T>* writer) {
4683   using strided_slice::LoopCondition;
4684   using strided_slice::StartForAxis;
4685   using strided_slice::StopForAxis;
4686 
4687   ruy::profiler::ScopeLabel label("StridedSlice");
4688 
4689   // Note that the output_shape is not used herein.
4690   tflite::StridedSliceParams params_copy = op_params;
4691 
4692   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
4693   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
4694   const RuntimeShape input_shape =
4695       RuntimeShape::ExtendedShape(5, unextended_input_shape);
4696   const RuntimeShape output_shape =
4697       RuntimeShape::ExtendedShape(5, unextended_output_shape);
4698 
4699   // Reverse and pad to 5 dimensions because that is what the runtime code
4700   // requires (ie. all shapes must be 5D and are given backwards).
4701   strided_slice::StridedSlicePadIndices(&params_copy, 5);
4702 
4703   const int start_0 = StartForAxis(params_copy, input_shape, 0);
4704   const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
4705   const int start_1 = StartForAxis(params_copy, input_shape, 1);
4706   const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
4707   const int start_2 = StartForAxis(params_copy, input_shape, 2);
4708   const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
4709   const int start_3 = StartForAxis(params_copy, input_shape, 3);
4710   const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
4711   const int start_4 = StartForAxis(params_copy, input_shape, 4);
4712   const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
4713   const bool inner_stride_is_1 = params_copy.strides[4] == 1;
4714 
4715   for (int offset_0 = start_0 * input_shape.Dims(1),
4716            end_0 = stop_0 * input_shape.Dims(1),
4717            step_0 = params_copy.strides[0] * input_shape.Dims(1);
4718        !LoopCondition(offset_0, end_0, params_copy.strides[0]);
4719        offset_0 += step_0) {
4720     for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
4721              end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
4722              step_1 = params_copy.strides[1] * input_shape.Dims(2);
4723          !LoopCondition(offset_1, end_1, params_copy.strides[1]);
4724          offset_1 += step_1) {
4725       for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
4726                end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
4727                step_2 = params_copy.strides[2] * input_shape.Dims(3);
4728            !LoopCondition(offset_2, end_2, params_copy.strides[2]);
4729            offset_2 += step_2) {
4730         for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
4731                  end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
4732                  step_3 = params_copy.strides[3] * input_shape.Dims(4);
4733              !LoopCondition(offset_3, end_3, params_copy.strides[3]);
4734              offset_3 += step_3) {
4735           // When the stride is 1, the inner loop is equivalent to the
4736           // optimized slice inner loop. Otherwise, it is identical to the
4737           // strided_slice reference implementation inner loop.
4738           if (inner_stride_is_1) {
4739             const int len = stop_4 - start_4;
4740             if (len > 0) {
4741               writer->WriteN(offset_3 + start_4, len);
4742             }
4743           } else {
4744             for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
4745                  !LoopCondition(offset_4, end_4, params_copy.strides[4]);
4746                  offset_4 += params_copy.strides[4]) {
4747               writer->Write(offset_4);
4748             }
4749           }
4750         }
4751       }
4752     }
4753   }
4754 }
4755 
4756 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)4757 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
4758                          const RuntimeShape& unextended_input_shape,
4759                          const T* input_data,
4760                          const RuntimeShape& unextended_output_shape,
4761                          T* output_data) {
4762   SequentialTensorWriter<T> writer(input_data, output_data);
4763   StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
4764                   &writer);
4765 }
4766 
4767 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const TfLiteTensor * input,const RuntimeShape & unextended_output_shape,TfLiteTensor * output)4768 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
4769                          const RuntimeShape& unextended_input_shape,
4770                          const TfLiteTensor* input,
4771                          const RuntimeShape& unextended_output_shape,
4772                          TfLiteTensor* output) {
4773   SequentialTensorWriter<T> writer(input, output);
4774   StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
4775                   &writer);
4776 }
4777 
4778 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)4779 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
4780              const T* input2_data, const RuntimeShape& output_shape,
4781              T* output_data) {
4782   ruy::profiler::ScopeLabel label("TensorFlowMinimum");
4783   auto input1_map = MapAsVector(input1_data, input1_shape);
4784   auto output_map = MapAsVector(output_data, output_shape);
4785   auto min_value = input2_data[0];
4786   output_map.array() = input1_map.array().min(min_value);
4787 }
4788 
4789 // Convenience version that allows, for example, generated-code calls to be
4790 // the same as other binary ops.
4791 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)4792 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
4793                     const RuntimeShape&, const T* input2_data,
4794                     const RuntimeShape& output_shape, T* output_data) {
4795   // Drop shape of second input: not needed.
4796   Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
4797 }
4798 
4799 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)4800 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
4801              const T* input2_data, const RuntimeShape& output_shape,
4802              T* output_data) {
4803   ruy::profiler::ScopeLabel label("TensorFlowMaximum");
4804   auto input1_map = MapAsVector(input1_data, input1_shape);
4805   auto output_map = MapAsVector(output_data, output_shape);
4806   auto max_value = input2_data[0];
4807   output_map.array() = input1_map.array().max(max_value);
4808 }
4809 
4810 // Convenience version that allows, for example, generated-code calls to be
4811 // the same as other binary ops.
4812 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)4813 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
4814                     const RuntimeShape&, const T* input2_data,
4815                     const RuntimeShape& output_shape, T* output_data) {
4816   // Drop shape of second input: not needed.
4817   Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
4818 }
4819 
4820 template <typename T>
TransposeIm2col(const ConvParams & params,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data)4821 void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
4822                      const RuntimeShape& input_shape, const T* input_data,
4823                      const RuntimeShape& filter_shape,
4824                      const RuntimeShape& output_shape, T* im2col_data) {
4825   ruy::profiler::ScopeLabel label("TransposeIm2col");
4826   const int stride_width = params.stride_width;
4827   const int stride_height = params.stride_height;
4828   const int pad_width = params.padding_values.width;
4829   const int pad_height = params.padding_values.height;
4830   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4831   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
4832   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
4833   TFLITE_DCHECK(im2col_data);
4834 
4835   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
4836   const int input_height = input_shape.Dims(1);
4837   const int input_width = input_shape.Dims(2);
4838   const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
4839   const int filter_height = filter_shape.Dims(1);
4840   const int filter_width = filter_shape.Dims(2);
4841   const int output_height = output_shape.Dims(1);
4842   const int output_width = output_shape.Dims(2);
4843   MatchingDim(output_shape, 3, filter_shape, 0);  // output_depth
4844 
4845   // Construct the MxN sized im2col matrix.
4846   // The rows M, are sub-ordered B x H x W
4847   const RuntimeShape row_shape({1, batches, output_height, output_width});
4848   // The columns, N, are sub-ordered Kh x Kw x Din
4849   const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
4850   // Use dimensions M and N to construct dims for indexing directly into im2col
4851   const RuntimeShape im2col_shape(
4852       {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
4853 
4854   // Build the im2col matrix by looping through all the input pixels,
4855   // computing their influence on the output, rather than looping through all
4856   // the output pixels. We therefore must initialize the im2col array to zero.
4857   // This is potentially inefficient because we subsequently overwrite bytes
4858   // set here. However, in practice memset is very fast and costs negligible.
4859   memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
4860 
4861   // Loop through the output batches
4862   for (int batch = 0; batch < batches; ++batch) {
4863     // Loop through input pixels one at a time.
4864     for (int in_y = 0; in_y < input_height; ++in_y) {
4865       for (int in_x = 0; in_x < input_width; ++in_x) {
4866         // Loop through the output pixels it will influence
4867         const int out_x_origin = (in_x * stride_width) - pad_width;
4868         const int out_y_origin = (in_y * stride_height) - pad_height;
4869         for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
4870           const int out_y = out_y_origin + filter_y;
4871           // Is output pixel within height bounds?
4872           if ((out_y >= 0) && (out_y < output_height)) {
4873             for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
4874               const int out_x = out_x_origin + filter_x;
4875               // Is output pixel within width bounds?
4876               if ((out_x >= 0) && (out_x < output_width)) {
4877                 // Copy the input elements of this pixel
4878                 T const* src =
4879                     input_data + Offset(input_shape, batch, in_y, in_x, 0);
4880                 int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
4881                 int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
4882                 T* dst = im2col_data +
4883                          Offset(im2col_shape, 0, 0, row_offset, col_offset);
4884                 memcpy(dst, src, input_depth * sizeof(T));
4885               }
4886             }
4887           }
4888         }
4889       }
4890     }
4891   }
4892 }
4893 
4894 // Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
4895 // order (height, width, depth), constructed from patches in 'col_data', which
4896 // is required to be in storage order (out_height * out_width, filter_height,
4897 // filter_width, in_depth).  Implementation by Yangqing Jia (jiayq).
4898 // Copied from //tensorflow/core/kernels/conv_grad_input_ops.cc
4899 template <typename T>
Col2im(const T * col_data,const int depth,const int height,const int width,const int filter_h,const int filter_w,const int pad_t,const int pad_l,const int pad_b,const int pad_r,const int stride_h,const int stride_w,T * im_data)4900 void Col2im(const T* col_data, const int depth, const int height,
4901             const int width, const int filter_h, const int filter_w,
4902             const int pad_t, const int pad_l, const int pad_b, const int pad_r,
4903             const int stride_h, const int stride_w, T* im_data) {
4904   ruy::profiler::ScopeLabel label("Col2im");
4905   int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
4906   int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
4907   int h_pad = -pad_t;
4908   for (int h = 0; h < height_col; ++h) {
4909     int w_pad = -pad_l;
4910     for (int w = 0; w < width_col; ++w) {
4911       T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
4912       for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
4913         for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
4914           if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
4915             // TODO(andydavis) Vectorize this loop (if compiler does not).
4916             for (int i = 0; i < depth; ++i) {
4917               im_patch_data[i] += col_data[i];
4918             }
4919           }
4920           im_patch_data += depth;
4921           col_data += depth;
4922         }
4923         // Jump over remaining number of depth.
4924         im_patch_data += depth * (width - filter_w);
4925       }
4926       w_pad += stride_w;
4927     }
4928     h_pad += stride_h;
4929   }
4930 }
4931 
4932 // TODO(b/188008864) Optimize this function by combining outer loops.
4933 template <typename T>
BiasAdd(T * im_data,const T * bias_data,const int batch_size,const int height,const int width,const int depth)4934 void BiasAdd(T* im_data, const T* bias_data, const int batch_size,
4935              const int height, const int width, const int depth) {
4936   if (bias_data) {
4937     for (int n = 0; n < batch_size; ++n) {
4938       for (int h = 0; h < height; ++h) {
4939         for (int w = 0; w < width; ++w) {
4940           for (int d = 0; d < depth; ++d) {
4941             im_data[d] += bias_data[d];
4942           }
4943           im_data += depth;
4944         }
4945       }
4946     }
4947   }
4948 }
4949 
4950 // TransposeConvV2 expect the weights in HWOI order.
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const float * hwoi_ordered_filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * const output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)4951 inline void TransposeConvV2(
4952     const ConvParams& params, const RuntimeShape& input_shape,
4953     const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
4954     const float* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
4955     const float* bias_data, const RuntimeShape& output_shape,
4956     float* const output_data, const RuntimeShape& col2im_shape,
4957     float* col2im_data, CpuBackendContext* cpu_backend_context) {
4958   ruy::profiler::ScopeLabel label("TransposeConvV2/float");
4959   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4960   TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
4961   TFLITE_DCHECK(col2im_data);
4962   TFLITE_DCHECK(hwoi_ordered_filter_data);
4963 
4964   const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
4965   const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
4966   const int output_height = output_shape.Dims(1);
4967   const int output_width = output_shape.Dims(2);
4968   const int output_image_size = output_height * output_width;
4969   const int input_depth =
4970       MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
4971   const int output_depth =
4972       MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
4973   const int input_offset = input_image_size * input_depth;
4974   const int output_offset = output_image_size * output_depth;
4975 
4976   const int filter_height = hwoi_ordered_filter_shape.Dims(0);
4977   const int filter_width = hwoi_ordered_filter_shape.Dims(1);
4978   const int padding_top = params.padding_values.height;
4979   const int padding_bottom =
4980       params.padding_values.height + params.padding_values.height_offset;
4981   const int padding_left = params.padding_values.width;
4982   const int padding_right =
4983       params.padding_values.width + params.padding_values.width_offset;
4984   const int stride_height = params.stride_height;
4985   const int stride_width = params.stride_width;
4986 
4987   const int hwoi_ordered_filter_total_size =
4988       filter_height * filter_width * output_depth;
4989 
4990   cpu_backend_gemm::MatrixParams<float> lhs_params;
4991   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
4992   lhs_params.rows = hwoi_ordered_filter_total_size;
4993   lhs_params.cols = input_depth;
4994   float* output_data_p = output_data;
4995   std::fill_n(output_data, output_offset * batch_size, 0.0f);
4996   for (int i = 0; i < batch_size; ++i) {
4997     cpu_backend_gemm::MatrixParams<float> rhs_params;
4998     rhs_params.order = cpu_backend_gemm::Order::kColMajor;
4999     rhs_params.rows = input_depth;
5000     rhs_params.cols = input_image_size;
5001     cpu_backend_gemm::MatrixParams<float> dst_params;
5002     dst_params.order = cpu_backend_gemm::Order::kColMajor;
5003     dst_params.rows = hwoi_ordered_filter_total_size;
5004     dst_params.cols = input_image_size;
5005     cpu_backend_gemm::GemmParams<float, float> gemm_params;
5006     cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
5007                            input_data + input_offset * i, dst_params,
5008                            col2im_data, gemm_params, cpu_backend_context);
5009 
5010     Col2im(col2im_data, output_depth, output_height, output_width,
5011            filter_height, filter_width, padding_top, padding_left,
5012            padding_bottom, padding_right, stride_height, stride_width,
5013            output_data_p);
5014     output_data_p += output_offset;
5015   }
5016   output_data_p = output_data;
5017   BiasAdd(output_data_p, bias_data, batch_size, output_height, output_width,
5018           output_depth);
5019 }
5020 
Quantize(int32_t multiplier,int32_t shift,int32_t total_size,int32_t output_zp,int32_t * scratch,uint8_t * output)5021 inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size,
5022                      int32_t output_zp, int32_t* scratch, uint8_t* output) {
5023   ruy::profiler::ScopeLabel label("Quantize/uint8");
5024   int i = 0;
5025   const int32_t output_min = std::numeric_limits<uint8_t>::min();
5026   const int32_t output_max = std::numeric_limits<uint8_t>::max();
5027 
5028 #ifdef USE_NEON
5029   const int32x4_t output_zp_dup = vdupq_n_s32(output_zp);
5030   const int32x4_t max_val_dup = vdupq_n_s32(output_max);
5031   const int32x4_t min_val_dup = vdupq_n_s32(output_min);
5032 
5033   using gemmlowp::RoundingDivideByPOT;
5034   using gemmlowp::SaturatingRoundingDoublingHighMul;
5035 
5036   for (; i <= total_size - 16; i += 16) {
5037     int32x4x4_t scratch_val;
5038     scratch_val.val[0] = vld1q_s32(scratch + i);
5039     scratch_val.val[1] = vld1q_s32(scratch + i + 4);
5040     scratch_val.val[2] = vld1q_s32(scratch + i + 8);
5041     scratch_val.val[3] = vld1q_s32(scratch + i + 12);
5042 
5043     int32x4x4_t temp_val =
5044         MultiplyByQuantizedMultiplier4Rows(scratch_val, multiplier, shift);
5045 
5046     temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
5047     temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
5048     temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
5049     temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
5050 
5051     temp_val.val[0] =
5052         vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
5053     temp_val.val[1] =
5054         vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
5055     temp_val.val[2] =
5056         vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
5057     temp_val.val[3] =
5058         vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
5059 
5060     const uint16x8_t result_1 =
5061         vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[0])),
5062                      vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[1])));
5063     const uint16x8_t result_2 =
5064         vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[2])),
5065                      vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[3])));
5066     const uint8x16_t result =
5067         vcombine_u8(vqmovn_u16(result_1), vqmovn_u16(result_2));
5068     vst1q_u8(output + i, result);
5069   }
5070 #endif
5071   for (; i < total_size; ++i) {
5072     int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift);
5073     temp += output_zp;
5074     if (temp > output_max) {
5075       temp = output_max;
5076     }
5077     if (temp < output_min) {
5078       temp = output_min;
5079     }
5080     output[i] = static_cast<uint8_t>(temp);
5081   }
5082 }
5083 
5084 // Single-rounding MultiplyByQuantizedMultiplier
5085 #if TFLITE_SINGLE_ROUNDING
Quantize(const int32_t * multiplier,const int32_t * shift,int32_t channel_size,int32_t total_size,int32_t output_zp,int32_t output_min,int32_t output_max,int32_t * scratch,int8_t * output)5086 inline void Quantize(const int32_t* multiplier, const int32_t* shift,
5087                      int32_t channel_size, int32_t total_size,
5088                      int32_t output_zp, int32_t output_min, int32_t output_max,
5089                      int32_t* scratch, int8_t* output) {
5090   ruy::profiler::ScopeLabel label("Quantize/int8");
5091 
5092   // Here we're trying to quantize the raw accumulators:
5093   //        output_channels
5094   //       data data data data data
5095   // rows  data data data data data
5096   //       data data data data data
5097   //          ....
5098   //
5099   // In order to minimize the reload of the multipliers & shifts, once we load
5100   // the multipliers & shifts, we load & quantize the raw accumulators for every
5101   // row.
5102 #ifdef USE_NEON
5103   const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
5104   const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
5105   const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
5106   const int32x4_t minus_ones = vdupq_n_s32(-1);
5107 #endif
5108 
5109   TFLITE_DCHECK_EQ(total_size % channel_size, 0);
5110   const int32_t rows = total_size / channel_size;
5111 
5112   int c = 0;
5113 
5114 #ifdef USE_NEON
5115   for (; c <= channel_size - 8; c += 8) {
5116     int32x4_t out_shift_1 = vld1q_s32(shift + c);
5117     int32x4_t out_shift_2 = vld1q_s32(shift + c + 4);
5118 
5119     int32x4_t right_shift_1 = vminq_s32(out_shift_1, minus_ones);
5120     int32x4_t right_shift_2 = vminq_s32(out_shift_2, minus_ones);
5121 
5122     int32x4_t left_shift_1 = vsubq_s32(out_shift_1, right_shift_1);
5123     int32x4_t left_shift_2 = vsubq_s32(out_shift_2, right_shift_2);
5124 
5125     int32x4_t out_mul_1 = vld1q_s32(multiplier + c);
5126     int32x4_t out_mul_2 = vld1q_s32(multiplier + c + 4);
5127     for (int n = 0; n < rows; ++n) {
5128       int loc = n * channel_size + c;
5129       int32x4_t acc_1 = vld1q_s32(scratch + loc);
5130       int32x4_t acc_2 = vld1q_s32(scratch + loc + 4);
5131 
5132       // Saturating Doubling High Mul.
5133       acc_1 = vshlq_s32(acc_1, left_shift_1);
5134       acc_1 = vqdmulhq_s32(acc_1, out_mul_1);
5135       acc_2 = vshlq_s32(acc_2, left_shift_2);
5136       acc_2 = vqdmulhq_s32(acc_2, out_mul_2);
5137 
5138       // Rounding Dividing By POT.
5139       acc_1 = vrshlq_s32(acc_1, right_shift_1);
5140       acc_2 = vrshlq_s32(acc_2, right_shift_2);
5141 
5142       // Add the output offset.
5143       acc_1 = vaddq_s32(acc_1, output_offset_vec);
5144       acc_2 = vaddq_s32(acc_2, output_offset_vec);
5145 
5146       // Apply the activation function.
5147       acc_1 = vmaxq_s32(acc_1, output_activation_min_vec);
5148       acc_1 = vminq_s32(acc_1, output_activation_max_vec);
5149       acc_2 = vmaxq_s32(acc_2, output_activation_min_vec);
5150       acc_2 = vminq_s32(acc_2, output_activation_max_vec);
5151 
5152       // Saturating cast to int8 and store to destination.
5153       const int16x4_t acc_s16_1 = vqmovn_s32(acc_1);
5154       const int16x4_t acc_s16_2 = vqmovn_s32(acc_2);
5155       const int16x8_t res_s16 = vcombine_s16(acc_s16_1, acc_s16_2);
5156       const int8x8_t res_s8 = vqmovn_s16(res_s16);
5157       vst1_s8(output + loc, res_s8);
5158     }
5159   }
5160 
5161 #endif  // USE_NEON
5162   // Handle leftover values, one by one. This is very slow.
5163   for (; c < channel_size; c++) {
5164     for (int n = 0; n < rows; ++n) {
5165       int loc = n * channel_size + c;
5166       int32 acc = scratch[loc];
5167       acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
5168       acc += output_zp;
5169       acc = std::max(acc, output_min);
5170       acc = std::min(acc, output_max);
5171       output[loc] = static_cast<int8>(acc);
5172     }
5173   }
5174 }
5175 
Quantize(const int32_t * multiplier,const int32_t * shift,int32_t channel_size,int32_t total_size,int32_t output_zp,int32_t output_min,int32_t output_max,int32_t * scratch,int16_t * output)5176 inline void Quantize(const int32_t* multiplier, const int32_t* shift,
5177                      int32_t channel_size, int32_t total_size,
5178                      int32_t output_zp, int32_t output_min, int32_t output_max,
5179                      int32_t* scratch, int16_t* output) {
5180   ruy::profiler::ScopeLabel label("Quantize(Single-rounding)/int16");
5181 
5182   // Here we're trying to quantize the raw accumulators:
5183   //        output_channels
5184   //       data data data data data
5185   // rows  data data data data data
5186   //       data data data data data
5187   //          ....
5188   //
5189   // In order to minimize the reload of the multipliers & shifts, once we load
5190   // the multipliers & shifts, we load & quantize the raw accumulators for every
5191   // row.
5192 #ifdef USE_NEON
5193   const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
5194   const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
5195   const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
5196   const int32x4_t minus_ones = vdupq_n_s32(-1);
5197 #endif
5198 
5199   TFLITE_DCHECK_EQ(total_size % channel_size, 0);
5200   const int32_t rows = total_size / channel_size;
5201 
5202   int c = 0;
5203 
5204 #ifdef USE_NEON
5205   for (; c <= channel_size - 8; c += 8) {
5206     int32x4_t out_shift_1 = vld1q_s32(shift + c);
5207     int32x4_t out_shift_2 = vld1q_s32(shift + c + 4);
5208 
5209     int32x4_t right_shift_1 = vminq_s32(out_shift_1, minus_ones);
5210     int32x4_t right_shift_2 = vminq_s32(out_shift_2, minus_ones);
5211 
5212     int32x4_t left_shift_1 = vsubq_s32(out_shift_1, right_shift_1);
5213     int32x4_t left_shift_2 = vsubq_s32(out_shift_2, right_shift_2);
5214 
5215     int32x4_t out_mul_1 = vld1q_s32(multiplier + c);
5216     int32x4_t out_mul_2 = vld1q_s32(multiplier + c + 4);
5217     for (int n = 0; n < rows; ++n) {
5218       int loc = n * channel_size + c;
5219       int32x4_t acc_1 = vld1q_s32(scratch + loc);
5220       int32x4_t acc_2 = vld1q_s32(scratch + loc + 4);
5221 
5222       // Saturating Doubling High Mul.
5223       acc_1 = vshlq_s32(acc_1, left_shift_1);
5224       acc_1 = vqdmulhq_s32(acc_1, out_mul_1);
5225       acc_2 = vshlq_s32(acc_2, left_shift_2);
5226       acc_2 = vqdmulhq_s32(acc_2, out_mul_2);
5227 
5228       // Rounding Dividing By POT.
5229       acc_1 = vrshlq_s32(acc_1, right_shift_1);
5230       acc_2 = vrshlq_s32(acc_2, right_shift_2);
5231 
5232       // Add the output offset.
5233       acc_1 = vaddq_s32(acc_1, output_offset_vec);
5234       acc_2 = vaddq_s32(acc_2, output_offset_vec);
5235 
5236       // Apply the activation function.
5237       acc_1 = vmaxq_s32(acc_1, output_activation_min_vec);
5238       acc_1 = vminq_s32(acc_1, output_activation_max_vec);
5239       acc_2 = vmaxq_s32(acc_2, output_activation_min_vec);
5240       acc_2 = vminq_s32(acc_2, output_activation_max_vec);
5241 
5242       // Saturating cast to int16 and store to destination.
5243       const int16x4_t acc_s16_1 = vqmovn_s32(acc_1);
5244       const int16x4_t acc_s16_2 = vqmovn_s32(acc_2);
5245       vst1_s16(reinterpret_cast<int16_t*>(output) + loc, acc_s16_1);
5246       vst1_s16(reinterpret_cast<int16_t*>(output) + loc + 4, acc_s16_2);
5247     }
5248   }
5249 
5250 #endif  // USE_NEON
5251   // Handle leftover values, one by one. This is very slow.
5252   for (; c < channel_size; c++) {
5253     for (int n = 0; n < rows; ++n) {
5254       int loc = n * channel_size + c;
5255       int32 acc = scratch[loc];
5256       acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
5257       acc += output_zp;
5258       acc = std::max(acc, output_min);
5259       acc = std::min(acc, output_max);
5260       output[loc] = static_cast<int16>(acc);
5261     }
5262   }
5263 }
5264 // Double-rounding MultiplyByQuantizedMultiplier
5265 #else
Quantize(const int32_t * multiplier,const int32_t * shift,int32_t channel_size,int32_t total_size,int32_t output_zp,int32_t output_min,int32_t output_max,int32_t * scratch,int8_t * output)5266 inline void Quantize(const int32_t* multiplier, const int32_t* shift,
5267                      int32_t channel_size, int32_t total_size,
5268                      int32_t output_zp, int32_t output_min, int32_t output_max,
5269                      int32_t* scratch, int8_t* output) {
5270   ruy::profiler::ScopeLabel label("Quantize/int8");
5271 
5272   // Here we're trying to quantize the raw accumulators:
5273   //        output_channels
5274   //       data data data data data
5275   // rows  data data data data data
5276   //       data data data data data
5277   //          ....
5278   //
5279   // In order to minimize the reload of the multipliers & shifts, once we load
5280   // the multipliers & shifts, we load & quantize the raw accumulators for every
5281   // row.
5282 #ifdef USE_NEON
5283   const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
5284   const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
5285   const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
5286   const int32x4_t zeros = vdupq_n_s32(0);
5287 #endif
5288 
5289   TFLITE_DCHECK_EQ(total_size % channel_size, 0);
5290   const int32_t rows = total_size / channel_size;
5291 
5292   int c = 0;
5293 
5294 #ifdef USE_NEON
5295   using gemmlowp::RoundingDivideByPOT;
5296   for (; c <= channel_size - 8; c += 8) {
5297     int32x4_t out_shift_1 = vld1q_s32(shift + c);
5298     int32x4_t out_shift_2 = vld1q_s32(shift + c + 4);
5299     int32x4_t left_shift_1 = vmaxq_s32(out_shift_1, zeros);
5300     int32x4_t left_shift_2 = vmaxq_s32(out_shift_2, zeros);
5301 
5302     // Right shift will be performed as left shift with negative values.
5303     int32x4_t right_shift_1 = vminq_s32(out_shift_1, zeros);
5304     int32x4_t right_shift_2 = vminq_s32(out_shift_2, zeros);
5305 
5306     int32x4_t out_mul_1 = vld1q_s32(multiplier + c);
5307     int32x4_t out_mul_2 = vld1q_s32(multiplier + c + 4);
5308     for (int n = 0; n < rows; ++n) {
5309       int loc = n * channel_size + c;
5310       int32x4_t acc_1 = vld1q_s32(scratch + loc);
5311       int32x4_t acc_2 = vld1q_s32(scratch + loc + 4);
5312 
5313       // Saturating Rounding Doubling High Mul.
5314       acc_1 = vshlq_s32(acc_1, left_shift_1);
5315       acc_1 = vqrdmulhq_s32(acc_1, out_mul_1);
5316       acc_2 = vshlq_s32(acc_2, left_shift_2);
5317       acc_2 = vqrdmulhq_s32(acc_2, out_mul_2);
5318 
5319       // Rounding Dividing By POT.
5320       acc_1 = vrshlq_s32(acc_1, right_shift_1);
5321       acc_2 = vrshlq_s32(acc_2, right_shift_2);
5322 
5323       // Add the output offset.
5324       acc_1 = vaddq_s32(acc_1, output_offset_vec);
5325       acc_2 = vaddq_s32(acc_2, output_offset_vec);
5326 
5327       // Apply the activation function.
5328       acc_1 = vmaxq_s32(acc_1, output_activation_min_vec);
5329       acc_1 = vminq_s32(acc_1, output_activation_max_vec);
5330       acc_2 = vmaxq_s32(acc_2, output_activation_min_vec);
5331       acc_2 = vminq_s32(acc_2, output_activation_max_vec);
5332 
5333       // Saturating cast to int8 and store to destination.
5334       const int16x4_t acc_s16_1 = vqmovn_s32(acc_1);
5335       const int16x4_t acc_s16_2 = vqmovn_s32(acc_2);
5336       const int16x8_t res_s16 = vcombine_s16(acc_s16_1, acc_s16_2);
5337       const int8x8_t res_s8 = vqmovn_s16(res_s16);
5338       vst1_s8(output + loc, res_s8);
5339     }
5340   }
5341 
5342 #endif  // USE_NEON
5343   // Handle leftover values, one by one. This is very slow.
5344   for (; c < channel_size; c++) {
5345     for (int n = 0; n < rows; ++n) {
5346       int loc = n * channel_size + c;
5347       int32 acc = scratch[loc];
5348       acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
5349       acc += output_zp;
5350       acc = std::max(acc, output_min);
5351       acc = std::min(acc, output_max);
5352       output[loc] = static_cast<int8>(acc);
5353     }
5354   }
5355 }
5356 
Quantize(const int32_t * multiplier,const int32_t * shift,int32_t channel_size,int32_t total_size,int32_t output_zp,int32_t output_min,int32_t output_max,int32_t * scratch,int16_t * output)5357 inline void Quantize(const int32_t* multiplier, const int32_t* shift,
5358                      int32_t channel_size, int32_t total_size,
5359                      int32_t output_zp, int32_t output_min, int32_t output_max,
5360                      int32_t* scratch, int16_t* output) {
5361   ruy::profiler::ScopeLabel label("Quantize(Double-rounding)/int16");
5362 
5363   // Here we're trying to quantize the raw accumulators:
5364   //        output_channels
5365   //       data data data data data
5366   // rows  data data data data data
5367   //       data data data data data
5368   //          ....
5369   //
5370   // In order to minimize the reload of the multipliers & shifts, once we load
5371   // the multipliers & shifts, we load & quantize the raw accumulators for every
5372   // row.
5373 #ifdef USE_NEON
5374   const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
5375   const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
5376   const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
5377   const int32x4_t zeros = vdupq_n_s32(0);
5378 #endif
5379 
5380   TFLITE_DCHECK_EQ(total_size % channel_size, 0);
5381   const int32_t rows = total_size / channel_size;
5382 
5383   int c = 0;
5384 
5385 #ifdef USE_NEON
5386   using gemmlowp::RoundingDivideByPOT;
5387   for (; c <= channel_size - 8; c += 8) {
5388     int32x4_t out_shift_1 = vld1q_s32(shift + c);
5389     int32x4_t out_shift_2 = vld1q_s32(shift + c + 4);
5390     int32x4_t left_shift_1 = vmaxq_s32(out_shift_1, zeros);
5391     int32x4_t left_shift_2 = vmaxq_s32(out_shift_2, zeros);
5392 
5393     // Right shift will be performed as left shift with negative values.
5394     int32x4_t right_shift_1 = vminq_s32(out_shift_1, zeros);
5395     int32x4_t right_shift_2 = vminq_s32(out_shift_2, zeros);
5396 
5397     int32x4_t out_mul_1 = vld1q_s32(multiplier + c);
5398     int32x4_t out_mul_2 = vld1q_s32(multiplier + c + 4);
5399     for (int n = 0; n < rows; ++n) {
5400       int loc = n * channel_size + c;
5401       int32x4_t acc_1 = vld1q_s32(scratch + loc);
5402       int32x4_t acc_2 = vld1q_s32(scratch + loc + 4);
5403 
5404       // Saturating Rounding Doubling High Mul.
5405       acc_1 = vshlq_s32(acc_1, left_shift_1);
5406       acc_1 = vqrdmulhq_s32(acc_1, out_mul_1);
5407       acc_2 = vshlq_s32(acc_2, left_shift_2);
5408       acc_2 = vqrdmulhq_s32(acc_2, out_mul_2);
5409 
5410       // Rounding Dividing By POT.
5411       acc_1 = vrshlq_s32(acc_1, right_shift_1);
5412       acc_2 = vrshlq_s32(acc_2, right_shift_2);
5413 
5414       // Add the output offset.
5415       acc_1 = vaddq_s32(acc_1, output_offset_vec);
5416       acc_2 = vaddq_s32(acc_2, output_offset_vec);
5417 
5418       // Apply the activation function.
5419       acc_1 = vmaxq_s32(acc_1, output_activation_min_vec);
5420       acc_1 = vminq_s32(acc_1, output_activation_max_vec);
5421       acc_2 = vmaxq_s32(acc_2, output_activation_min_vec);
5422       acc_2 = vminq_s32(acc_2, output_activation_max_vec);
5423 
5424       // Saturating cast to int16 and store to destination.
5425       const int16x4_t acc_s16_1 = vqmovn_s32(acc_1);
5426       const int16x4_t acc_s16_2 = vqmovn_s32(acc_2);
5427       vst1_s16(reinterpret_cast<int16_t*>(output) + loc, acc_s16_1);
5428       vst1_s16(reinterpret_cast<int16_t*>(output) + loc + 4, acc_s16_2);
5429     }
5430   }
5431 
5432 #endif  // USE_NEON
5433   // Handle leftover values, one by one. This is very slow.
5434   for (; c < channel_size; c++) {
5435     for (int n = 0; n < rows; ++n) {
5436       int loc = n * channel_size + c;
5437       int32 acc = scratch[loc];
5438       acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
5439       acc += output_zp;
5440       acc = std::max(acc, output_min);
5441       acc = std::min(acc, output_max);
5442       output[loc] = static_cast<int16>(acc);
5443     }
5444   }
5445 }
5446 #endif  // TFLITE_SINGLE_ROUNDING
5447 
5448 // TransposeConvV2 expect the weights in HWOI order.
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const uint8_t * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const uint8_t * hwoi_ordered_filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8_t * output_data,const RuntimeShape & col2im_shape,int32_t * col2im_data,int32_t * scratch_data,CpuBackendContext * cpu_backend_context)5449 inline void TransposeConvV2(
5450     const ConvParams& params, const RuntimeShape& input_shape,
5451     const uint8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
5452     const uint8_t* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
5453     const int32* bias_data, const RuntimeShape& output_shape,
5454     uint8_t* output_data, const RuntimeShape& col2im_shape,
5455     int32_t* col2im_data, int32_t* scratch_data,
5456     CpuBackendContext* cpu_backend_context) {
5457   ruy::profiler::ScopeLabel label("TransposeConvV2/uint8");
5458   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5459   TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
5460   TFLITE_DCHECK(col2im_data);
5461   TFLITE_DCHECK(hwoi_ordered_filter_data);
5462 
5463   const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
5464   const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
5465   const int output_height = output_shape.Dims(1);
5466   const int output_width = output_shape.Dims(2);
5467   const int output_image_size = output_height * output_width;
5468   const int input_depth =
5469       MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
5470   const int output_depth =
5471       MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
5472   const int input_offset = input_image_size * input_depth;
5473   const int output_offset = output_image_size * output_depth;
5474 
5475   const int filter_height = hwoi_ordered_filter_shape.Dims(0);
5476   const int filter_width = hwoi_ordered_filter_shape.Dims(1);
5477   const int padding_top = params.padding_values.height;
5478   const int padding_bottom =
5479       params.padding_values.height + params.padding_values.height_offset;
5480   const int padding_left = params.padding_values.width;
5481   const int padding_right =
5482       params.padding_values.width + params.padding_values.width_offset;
5483   const int stride_height = params.stride_height;
5484   const int stride_width = params.stride_width;
5485 
5486   const int hwoi_ordered_filter_total_size =
5487       filter_height * filter_width * output_depth;
5488 
5489   cpu_backend_gemm::MatrixParams<uint8_t> lhs_params;
5490   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
5491   lhs_params.rows = hwoi_ordered_filter_total_size;
5492   lhs_params.cols = input_depth;
5493   lhs_params.zero_point = -params.weights_offset;
5494 
5495   int32_t* scratch_data_p = scratch_data;
5496   std::fill_n(scratch_data, output_offset * batch_size, static_cast<int32>(0));
5497   for (int i = 0; i < batch_size; ++i) {
5498     cpu_backend_gemm::MatrixParams<uint8_t> rhs_params;
5499     rhs_params.order = cpu_backend_gemm::Order::kColMajor;
5500     rhs_params.rows = input_depth;
5501     rhs_params.cols = input_image_size;
5502     rhs_params.zero_point = -params.input_offset;
5503 
5504     cpu_backend_gemm::MatrixParams<int32_t> dst_params;
5505     dst_params.order = cpu_backend_gemm::Order::kColMajor;
5506     dst_params.rows = hwoi_ordered_filter_total_size;
5507     dst_params.cols = input_image_size;
5508 
5509     cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
5510     cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
5511                            input_data + input_offset * i, dst_params,
5512                            col2im_data, gemm_params, cpu_backend_context);
5513 
5514     Col2im(col2im_data, output_depth, output_height, output_width,
5515            filter_height, filter_width, padding_top, padding_left,
5516            padding_bottom, padding_right, stride_height, stride_width,
5517            scratch_data_p);
5518 
5519     scratch_data_p += output_offset;
5520   }
5521   scratch_data_p = scratch_data;
5522   BiasAdd(scratch_data_p, bias_data, batch_size, output_height, output_width,
5523           output_depth);
5524 
5525   Quantize(params.output_multiplier, params.output_shift,
5526            output_shape.FlatSize(), params.output_offset, scratch_data,
5527            output_data);
5528 }
5529 
5530 // Integer-only version of ResizeNearestNeighbor. Since scales are represented
5531 // in fixed-point and thus approximated, |in_x| or |in_y| may differ from the
5532 // reference version. Debug checks are in place to test if this occurs.
5533 // NOTE: If align_corners or half_pixel_centers is true, we use the reference
5534 // version.
ResizeNearestNeighbor(const tflite::ResizeNearestNeighborParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)5535 inline void ResizeNearestNeighbor(
5536     const tflite::ResizeNearestNeighborParams& op_params,
5537     const RuntimeShape& unextended_input_shape, const uint8* input_data,
5538     const RuntimeShape& output_size_shape, const int32* output_size_data,
5539     const RuntimeShape& unextended_output_shape, uint8* output_data) {
5540   if (op_params.align_corners || op_params.half_pixel_centers) {
5541     // TODO(b/149823713): Add support for align_corners & half_pixel_centers in
5542     // this kernel.
5543     reference_ops::ResizeNearestNeighbor(
5544         op_params, unextended_input_shape, input_data, output_size_shape,
5545         output_size_data, unextended_output_shape, output_data);
5546     return;
5547   }
5548   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
5549   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
5550 
5551   const RuntimeShape input_shape =
5552       RuntimeShape::ExtendedShape(4, unextended_input_shape);
5553   const RuntimeShape output_shape =
5554       RuntimeShape::ExtendedShape(4, unextended_output_shape);
5555 
5556   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
5557   int32 input_height = input_shape.Dims(1);
5558   int32 input_width = input_shape.Dims(2);
5559   int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
5560 
5561   // The Tensorflow version of this op allows resize on the width and height
5562   // axis only.
5563   TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
5564   int32 output_height = output_size_data[0];
5565   int32 output_width = output_size_data[1];
5566 
5567   // Convert scales to fixed-point with 16 fractional bits. We add 1 as an
5568   // error factor and to avoid zero scales. For example, with input_height = 1,
5569   // output_height = 3, the float scaling factor would be non-zero at 1/3.
5570   // With fixed-point, this is zero.
5571   int32 height_scale = (input_height << 16) / output_height + 1;
5572   int32 width_scale = (input_width << 16) / output_width + 1;
5573 
5574   const int col_offset = input_shape.Dims(3);
5575   const int row_offset = input_shape.Dims(2) * col_offset;
5576   const int batch_offset = input_shape.Dims(1) * row_offset;
5577 
5578   const uint8* input_ptr = input_data;
5579   uint8* output_ptr = output_data;
5580   for (int b = 0; b < batches; ++b) {
5581     for (int y = 0; y < output_height; ++y) {
5582       int32 in_y = std::min((y * height_scale) >> 16, input_height - 1);
5583       // Check offset calculation is the same as the reference version. See
5584       // function comment for details. We check using a non-float version of:
5585       // TFLITE_DCHECK_EQ(in_y, std::floor(y * (static_cast<float>(input_height)
5586       //                                            / output_height)));
5587       TFLITE_DCHECK_LT(y * input_height, output_height + in_y * output_height);
5588       TFLITE_DCHECK_GE(y * input_height, in_y * output_height);
5589       const uint8* y_input_ptr = input_ptr + in_y * row_offset;
5590       for (int x = 0; x < output_width; ++x) {
5591         int32 in_x = std::min((x * width_scale) >> 16, input_width - 1);
5592         // Check offset calculation is the same as the reference version. See
5593         // function comment for details. We check using a non-float version of:
5594         // TFLITE_DCHECK_EQ(in_y,
5595         //                  std::floor(y * (static_cast<float>(input_width)
5596         //                                      / output_width)));
5597         TFLITE_DCHECK_LT(x * input_width, output_width + in_x * output_width);
5598         TFLITE_DCHECK_GE(x * input_width, in_x * output_width);
5599         const uint8* x_input_ptr = y_input_ptr + in_x * col_offset;
5600         memcpy(output_ptr, x_input_ptr, depth);
5601         output_ptr += depth;
5602       }
5603     }
5604     input_ptr += batch_offset;
5605   }
5606 }
5607 
5608 template <typename input_type, typename output_type>
Requantize(const input_type * input_data,int32_t size,int32_t effective_scale_multiplier,int32_t effective_scale_shift,int32_t input_zeropoint,int32_t output_zeropoint,output_type * output_data)5609 inline void Requantize(const input_type* input_data, int32_t size,
5610                        int32_t effective_scale_multiplier,
5611                        int32_t effective_scale_shift, int32_t input_zeropoint,
5612                        int32_t output_zeropoint, output_type* output_data) {
5613   reference_ops::Requantize(input_data, size, effective_scale_multiplier,
5614                             effective_scale_shift, input_zeropoint,
5615                             output_zeropoint, output_data);
5616 }
5617 
5618 template <>
5619 inline void Requantize<int8_t, uint8_t>(const int8_t* input_data, int32_t size,
5620                                         int32_t effective_scale_multiplier,
5621                                         int32_t effective_scale_shift,
5622                                         int32_t input_zeropoint,
5623                                         int32_t output_zeropoint,
5624                                         uint8_t* output_data) {
5625   ruy::profiler::ScopeLabel label("Requantize/Int8ToUint8");
5626 
5627   static constexpr int32_t kMinOutput = std::numeric_limits<uint8_t>::min();
5628   static constexpr int32_t kMaxOutput = std::numeric_limits<uint8_t>::max();
5629 
5630   int i = 0;
5631 #ifdef USE_NEON
5632   // Constants.
5633   const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5634   const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5635   const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5636   const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5637 
5638   for (; i <= size - 16; i += 16) {
5639     const int8x16_t input_vec = vld1q_s8(input_data + i);
5640     const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
5641     const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
5642     int32x4x4_t input;
5643     input.val[0] = vmovl_s16(vget_low_s16(first_half));
5644     input.val[1] = vmovl_s16(vget_high_s16(first_half));
5645     input.val[2] = vmovl_s16(vget_low_s16(second_half));
5646     input.val[3] = vmovl_s16(vget_high_s16(second_half));
5647     input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5648     input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5649     input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5650     input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5651 
5652     int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5653         input, effective_scale_multiplier, effective_scale_shift);
5654 
5655     result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5656     result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5657     result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5658     result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5659     result.val[0] =
5660         vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5661     result.val[1] =
5662         vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5663     result.val[2] =
5664         vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5665     result.val[3] =
5666         vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5667 
5668     const uint32x4_t result_val_1_unsigned =
5669         vreinterpretq_u32_s32(result.val[0]);
5670     const uint32x4_t result_val_2_unsigned =
5671         vreinterpretq_u32_s32(result.val[1]);
5672     const uint32x4_t result_val_3_unsigned =
5673         vreinterpretq_u32_s32(result.val[2]);
5674     const uint32x4_t result_val_4_unsigned =
5675         vreinterpretq_u32_s32(result.val[3]);
5676 
5677     const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
5678     const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
5679     const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
5680     const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
5681     const uint16x8_t output_first_half =
5682         vcombine_u16(narrowed_val_1, narrowed_val_2);
5683     const uint16x8_t output_second_half =
5684         vcombine_u16(narrowed_val_3, narrowed_val_4);
5685     const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
5686     const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
5687     const uint8x16_t narrowed_result =
5688         vcombine_u8(narrowed_first_half, narrowed_second_half);
5689     vst1q_u8(output_data + i, narrowed_result);
5690   }
5691 
5692 #endif
5693   for (; i < size; ++i) {
5694     const int32_t input = input_data[i] - input_zeropoint;
5695     const int32_t output =
5696         MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5697                                       effective_scale_shift) +
5698         output_zeropoint;
5699     const int32_t clamped_output =
5700         std::max(std::min(output, kMaxOutput), kMinOutput);
5701     output_data[i] = static_cast<uint8_t>(clamped_output);
5702   }
5703 }
5704 
5705 template <>
5706 inline void Requantize<uint8_t, int8_t>(const uint8_t* input_data, int32_t size,
5707                                         int32_t effective_scale_multiplier,
5708                                         int32_t effective_scale_shift,
5709                                         int32_t input_zeropoint,
5710                                         int32_t output_zeropoint,
5711                                         int8_t* output_data) {
5712   ruy::profiler::ScopeLabel label("Requantize/Uint8ToInt8");
5713 
5714   static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
5715   static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
5716 
5717   int i = 0;
5718 #ifdef USE_NEON
5719   // Constants.
5720   const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5721   const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5722   const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5723   const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5724 
5725   for (; i <= size - 16; i += 16) {
5726     const uint8x16_t input_vec = vld1q_u8(input_data + i);
5727     const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
5728     const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
5729     int32x4x4_t input;
5730     input.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
5731     input.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(first_half)));
5732     input.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(second_half)));
5733     input.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(second_half)));
5734     input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5735     input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5736     input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5737     input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5738 
5739     int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5740         input, effective_scale_multiplier, effective_scale_shift);
5741 
5742     result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5743     result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5744     result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5745     result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5746     result.val[0] =
5747         vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5748     result.val[1] =
5749         vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5750     result.val[2] =
5751         vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5752     result.val[3] =
5753         vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5754 
5755     const int16x4_t narrowed_val_1 = vqmovn_s32(result.val[0]);
5756     const int16x4_t narrowed_val_2 = vqmovn_s32(result.val[1]);
5757     const int16x4_t narrowed_val_3 = vqmovn_s32(result.val[2]);
5758     const int16x4_t narrowed_val_4 = vqmovn_s32(result.val[3]);
5759     const int16x8_t output_first_half =
5760         vcombine_s16(narrowed_val_1, narrowed_val_2);
5761     const int16x8_t output_second_half =
5762         vcombine_s16(narrowed_val_3, narrowed_val_4);
5763     const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
5764     const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
5765     const int8x16_t narrowed_result =
5766         vcombine_s8(narrowed_first_half, narrowed_second_half);
5767     vst1q_s8(output_data + i, narrowed_result);
5768   }
5769 
5770 #endif
5771   for (; i < size; ++i) {
5772     const int32_t input = input_data[i] - input_zeropoint;
5773     const int32_t output =
5774         MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5775                                       effective_scale_shift) +
5776         output_zeropoint;
5777     const int32_t clamped_output =
5778         std::max(std::min(output, kMaxOutput), kMinOutput);
5779     output_data[i] = static_cast<int8_t>(clamped_output);
5780   }
5781 }
5782 
5783 template <>
5784 inline void Requantize<int8_t, int8_t>(const int8_t* input_data, int32_t size,
5785                                        int32_t effective_scale_multiplier,
5786                                        int32_t effective_scale_shift,
5787                                        int32_t input_zeropoint,
5788                                        int32_t output_zeropoint,
5789                                        int8_t* output_data) {
5790   ruy::profiler::ScopeLabel label("Requantize/Int8ToInt8");
5791 
5792   static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
5793   static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
5794 
5795   int i = 0;
5796 #ifdef USE_NEON
5797   // Constants.
5798   const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5799   const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5800   const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5801   const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5802 
5803   for (; i <= size - 16; i += 16) {
5804     const int8x16_t input_vec = vld1q_s8(input_data + i);
5805     const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
5806     const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
5807     int32x4x4_t input;
5808     input.val[0] = vmovl_s16(vget_low_s16(first_half));
5809     input.val[1] = vmovl_s16(vget_high_s16(first_half));
5810     input.val[2] = vmovl_s16(vget_low_s16(second_half));
5811     input.val[3] = vmovl_s16(vget_high_s16(second_half));
5812 
5813     input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5814     input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5815     input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5816     input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5817 
5818     int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5819         input, effective_scale_multiplier, effective_scale_shift);
5820 
5821     result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5822     result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5823     result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5824     result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5825     result.val[0] =
5826         vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5827     result.val[1] =
5828         vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5829     result.val[2] =
5830         vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5831     result.val[3] =
5832         vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5833 
5834     const int16x4_t narrowed_val_1 = vqmovn_s32(result.val[0]);
5835     const int16x4_t narrowed_val_2 = vqmovn_s32(result.val[1]);
5836     const int16x4_t narrowed_val_3 = vqmovn_s32(result.val[2]);
5837     const int16x4_t narrowed_val_4 = vqmovn_s32(result.val[3]);
5838     const int16x8_t output_first_half =
5839         vcombine_s16(narrowed_val_1, narrowed_val_2);
5840     const int16x8_t output_second_half =
5841         vcombine_s16(narrowed_val_3, narrowed_val_4);
5842     const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
5843     const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
5844     const int8x16_t narrowed_result =
5845         vcombine_s8(narrowed_first_half, narrowed_second_half);
5846     vst1q_s8(output_data + i, narrowed_result);
5847   }
5848 
5849 #endif
5850   for (; i < size; ++i) {
5851     const int32_t input = input_data[i] - input_zeropoint;
5852     const int32_t output =
5853         MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5854                                       effective_scale_shift) +
5855         output_zeropoint;
5856     const int32_t clamped_output =
5857         std::max(std::min(output, kMaxOutput), kMinOutput);
5858     output_data[i] = static_cast<int8_t>(clamped_output);
5859   }
5860 }
5861 
5862 template <>
5863 inline void Requantize<uint8_t, uint8_t>(
5864     const uint8_t* input_data, int32_t size, int32_t effective_scale_multiplier,
5865     int32_t effective_scale_shift, int32_t input_zeropoint,
5866     int32_t output_zeropoint, uint8_t* output_data) {
5867   ruy::profiler::ScopeLabel label("Requantize/Uint8ToUint8");
5868 
5869   static constexpr int32_t kMinOutput = std::numeric_limits<uint8_t>::min();
5870   static constexpr int32_t kMaxOutput = std::numeric_limits<uint8_t>::max();
5871 
5872   int i = 0;
5873 #ifdef USE_NEON
5874   // Constants.
5875   const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5876   const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5877   const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5878   const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5879 
5880   for (; i <= size - 16; i += 16) {
5881     const uint8x16_t input_vec = vld1q_u8(input_data + i);
5882     const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
5883     const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
5884     int32x4x4_t input;
5885     input.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
5886     input.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(first_half)));
5887     input.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(second_half)));
5888     input.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(second_half)));
5889     input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5890     input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5891     input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5892     input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5893 
5894     int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5895         input, effective_scale_multiplier, effective_scale_shift);
5896 
5897     result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5898     result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5899     result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5900     result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5901     result.val[0] =
5902         vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5903     result.val[1] =
5904         vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5905     result.val[2] =
5906         vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5907     result.val[3] =
5908         vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5909 
5910     const uint32x4_t result_val_1_unsigned =
5911         vreinterpretq_u32_s32(result.val[0]);
5912     const uint32x4_t result_val_2_unsigned =
5913         vreinterpretq_u32_s32(result.val[1]);
5914     const uint32x4_t result_val_3_unsigned =
5915         vreinterpretq_u32_s32(result.val[2]);
5916     const uint32x4_t result_val_4_unsigned =
5917         vreinterpretq_u32_s32(result.val[3]);
5918 
5919     const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
5920     const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
5921     const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
5922     const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
5923     const uint16x8_t output_first_half =
5924         vcombine_u16(narrowed_val_1, narrowed_val_2);
5925     const uint16x8_t output_second_half =
5926         vcombine_u16(narrowed_val_3, narrowed_val_4);
5927     const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
5928     const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
5929     const uint8x16_t narrowed_result =
5930         vcombine_u8(narrowed_first_half, narrowed_second_half);
5931     vst1q_u8(output_data + i, narrowed_result);
5932   }
5933 
5934 #endif
5935   for (; i < size; ++i) {
5936     const int32_t input = input_data[i] - input_zeropoint;
5937     const int32_t output =
5938         MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5939                                       effective_scale_shift) +
5940         output_zeropoint;
5941     const int32_t clamped_output =
5942         std::max(std::min(output, kMaxOutput), kMinOutput);
5943     output_data[i] = static_cast<uint8_t>(clamped_output);
5944   }
5945 }
5946 
HardSwish(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5947 inline void HardSwish(const RuntimeShape& input_shape, const float* input_data,
5948                       const RuntimeShape& output_shape, float* output_data) {
5949   ruy::profiler::ScopeLabel label("HardSwish/Float");
5950   auto size = MatchingFlatSize(input_shape, output_shape);
5951   int i = 0;
5952 #ifdef USE_NEON
5953   const float32x4_t zero = vdupq_n_f32(0.0f);
5954   const float32x4_t three = vdupq_n_f32(3.0f);
5955   const float32x4_t six = vdupq_n_f32(6.0f);
5956   const float32x4_t one_sixth = vdupq_n_f32(1.0f / 6.0f);
5957 
5958   for (; i <= size - 16; i += 16) {
5959     // 4x partially unrolled version of the loop below. Refer to its comments.
5960     const float32x4_t in_0 = vld1q_f32(input_data + i + 0);
5961     const float32x4_t in_1 = vld1q_f32(input_data + i + 4);
5962     const float32x4_t in_2 = vld1q_f32(input_data + i + 8);
5963     const float32x4_t in_3 = vld1q_f32(input_data + i + 12);
5964     const float32x4_t in_scaled_0 = vmulq_f32(in_0, one_sixth);
5965     const float32x4_t in_scaled_1 = vmulq_f32(in_1, one_sixth);
5966     const float32x4_t in_scaled_2 = vmulq_f32(in_2, one_sixth);
5967     const float32x4_t in_scaled_3 = vmulq_f32(in_3, one_sixth);
5968     const float32x4_t in_reluish_0 =
5969         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_0, three)));
5970     const float32x4_t in_reluish_1 =
5971         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_1, three)));
5972     const float32x4_t in_reluish_2 =
5973         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_2, three)));
5974     const float32x4_t in_reluish_3 =
5975         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_3, three)));
5976     const float32x4_t product_0 = vmulq_f32(in_scaled_0, in_reluish_0);
5977     const float32x4_t product_1 = vmulq_f32(in_scaled_1, in_reluish_1);
5978     const float32x4_t product_2 = vmulq_f32(in_scaled_2, in_reluish_2);
5979     const float32x4_t product_3 = vmulq_f32(in_scaled_3, in_reluish_3);
5980     vst1q_f32(output_data + i + 0, product_0);
5981     vst1q_f32(output_data + i + 4, product_1);
5982     vst1q_f32(output_data + i + 8, product_2);
5983     vst1q_f32(output_data + i + 12, product_3);
5984   }
5985   for (; i <= size - 4; i += 4) {
5986     // The expression to be computed is:
5987     //   out = one_sixth * in * min(six, max(zero, (in + three)))
5988     // We structure the AST to have two roughly balanced, independent branches:
5989     //  - Multiplication: in_scaled = one_sixth * in.
5990     //  - Addition and clamping: in_reluish = min(six, max(zero, (in + three))).
5991     // Then the remaining multiplication at the root of the tree.
5992     const float32x4_t in = vld1q_f32(input_data + i);
5993     const float32x4_t in_scaled = vmulq_f32(in, one_sixth);
5994     const float32x4_t in_reluish =
5995         vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in, three)));
5996     const float32x4_t product = vmulq_f32(in_scaled, in_reluish);
5997     vst1q_f32(output_data + i, product);
5998   }
5999 #endif
6000   for (; i < size; i++) {
6001     const float in = input_data[i];
6002     output_data[i] =
6003         in * std::min(6.0f, std::max(0.0f, in + 3.0f)) * (1.0f / 6.0f);
6004   }
6005 }
6006 
6007 #ifdef USE_NEON
SaturateAndStore(int16x8_t src,std::uint8_t * dst)6008 inline void SaturateAndStore(int16x8_t src, std::uint8_t* dst) {
6009   // Narrow values down to 8 bit unsigned, saturating.
6010   uint8x8_t res8 = vqmovun_s16(src);
6011   // Store results to destination.
6012   vst1_u8(dst, res8);
6013 }
6014 
SaturateAndStore(int16x8_t src,std::int8_t * dst)6015 inline void SaturateAndStore(int16x8_t src, std::int8_t* dst) {
6016   // Narrow values down to 8 bit unsigned, saturating.
6017   int8x8_t res8 = vqmovn_s16(src);
6018   // Store results to destination.
6019   vst1_s8(dst, res8);
6020 }
6021 #endif
6022 
6023 template <typename T>
HardSwish(const HardSwishParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)6024 inline void HardSwish(const HardSwishParams& params,
6025                       const RuntimeShape& input_shape, const T* input_data,
6026                       const RuntimeShape& output_shape, T* output_data) {
6027   ruy::profiler::ScopeLabel label("HardSwish/Quantized");
6028 
6029   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6030 
6031   int i = 0;
6032   // This code heavily uses NEON saturating left shifts (vqshl*) with shift
6033   // amounts that can be zero, in which case we rely on the correct behavior
6034   // of a left shift by zero returning just its first operand unmodified.
6035   // Unfortunately, the Intel arm_neon_sse.h implementation of vqshl* is
6036   // buggy in the case of zero shift amounts, see b/137199585. That is why
6037   // this NEON code path is restricted to true ARM NEON, excluding
6038   // arm_neon_sse.h. Anyway, the arm_neon_sse.h implementation of saturating
6039   // left shifts is slow scalar code, so there may not be much benefit in
6040   // running that over just plain reference code.
6041   //
6042   // TODO(b/137199585): revisit when this is fixed.
6043 #ifdef __ARM_NEON
6044   const int16x8_t positive_reluish_multiplier_exponent_minus_one =
6045       vdupq_n_s16(std::max(0, params.reluish_multiplier_exponent - 1));
6046   const int16x8_t positive_reluish_multiplier_exponent_last_bit =
6047       vdupq_n_s16(params.reluish_multiplier_exponent > 0 ? 1 : 0);
6048   const int16x8_t negative_reluish_multiplier_exponent =
6049       vdupq_n_s16(std::min(0, params.reluish_multiplier_exponent));
6050   const int16x8_t constant_32767 = vdupq_n_s16(32767);
6051   const int16x8_t output_multiplier_exponent =
6052       vdupq_n_s16(params.output_multiplier_exponent);
6053   const int16x8_t output_zero_point = vdupq_n_s16(params.output_zero_point);
6054   // 4x unrolled version of the below NEON loop. Read that first.
6055   for (; i <= flat_size - 32; i += 32) {
6056     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6057     const int16x8x2_t input_value_0_1 =
6058         Load16AndSubtractZeroPoint(input_data + i, params.input_zero_point);
6059     const int16x8x2_t input_value_2_3 = Load16AndSubtractZeroPoint(
6060         input_data + i + 16, params.input_zero_point);
6061     const int16x8_t input_value_on_hires_input_scale_0 =
6062         vshlq_n_s16(input_value_0_1.val[0], 7);
6063     const int16x8_t input_value_on_hires_input_scale_1 =
6064         vshlq_n_s16(input_value_0_1.val[1], 7);
6065     const int16x8_t input_value_on_hires_input_scale_2 =
6066         vshlq_n_s16(input_value_2_3.val[0], 7);
6067     const int16x8_t input_value_on_hires_input_scale_3 =
6068         vshlq_n_s16(input_value_2_3.val[1], 7);
6069     const int16x8_t input_value_on_preshift_output_scale_0 =
6070         vqrdmulhq_n_s16(input_value_on_hires_input_scale_0,
6071                         params.output_multiplier_fixedpoint_int16);
6072     const int16x8_t input_value_on_preshift_output_scale_1 =
6073         vqrdmulhq_n_s16(input_value_on_hires_input_scale_1,
6074                         params.output_multiplier_fixedpoint_int16);
6075     const int16x8_t input_value_on_preshift_output_scale_2 =
6076         vqrdmulhq_n_s16(input_value_on_hires_input_scale_2,
6077                         params.output_multiplier_fixedpoint_int16);
6078     const int16x8_t input_value_on_preshift_output_scale_3 =
6079         vqrdmulhq_n_s16(input_value_on_hires_input_scale_3,
6080                         params.output_multiplier_fixedpoint_int16);
6081     int16x8_t reluish_value_0 = input_value_on_hires_input_scale_0;
6082     int16x8_t reluish_value_1 = input_value_on_hires_input_scale_1;
6083     int16x8_t reluish_value_2 = input_value_on_hires_input_scale_2;
6084     int16x8_t reluish_value_3 = input_value_on_hires_input_scale_3;
6085     reluish_value_0 = vqshlq_s16(
6086         reluish_value_0, positive_reluish_multiplier_exponent_minus_one);
6087     reluish_value_1 = vqshlq_s16(
6088         reluish_value_1, positive_reluish_multiplier_exponent_minus_one);
6089     reluish_value_2 = vqshlq_s16(
6090         reluish_value_2, positive_reluish_multiplier_exponent_minus_one);
6091     reluish_value_3 = vqshlq_s16(
6092         reluish_value_3, positive_reluish_multiplier_exponent_minus_one);
6093     reluish_value_0 = vqrdmulhq_n_s16(
6094         reluish_value_0, params.reluish_multiplier_fixedpoint_int16);
6095     reluish_value_1 = vqrdmulhq_n_s16(
6096         reluish_value_1, params.reluish_multiplier_fixedpoint_int16);
6097     reluish_value_2 = vqrdmulhq_n_s16(
6098         reluish_value_2, params.reluish_multiplier_fixedpoint_int16);
6099     reluish_value_3 = vqrdmulhq_n_s16(
6100         reluish_value_3, params.reluish_multiplier_fixedpoint_int16);
6101     reluish_value_0 = vqshlq_s16(reluish_value_0,
6102                                  positive_reluish_multiplier_exponent_last_bit);
6103     reluish_value_1 = vqshlq_s16(reluish_value_1,
6104                                  positive_reluish_multiplier_exponent_last_bit);
6105     reluish_value_2 = vqshlq_s16(reluish_value_2,
6106                                  positive_reluish_multiplier_exponent_last_bit);
6107     reluish_value_3 = vqshlq_s16(reluish_value_3,
6108                                  positive_reluish_multiplier_exponent_last_bit);
6109     reluish_value_0 =
6110         vrshlq_s16(reluish_value_0, negative_reluish_multiplier_exponent);
6111     reluish_value_1 =
6112         vrshlq_s16(reluish_value_1, negative_reluish_multiplier_exponent);
6113     reluish_value_2 =
6114         vrshlq_s16(reluish_value_2, negative_reluish_multiplier_exponent);
6115     reluish_value_3 =
6116         vrshlq_s16(reluish_value_3, negative_reluish_multiplier_exponent);
6117     reluish_value_0 = vrhaddq_s16(reluish_value_0, constant_32767);
6118     reluish_value_1 = vrhaddq_s16(reluish_value_1, constant_32767);
6119     reluish_value_2 = vrhaddq_s16(reluish_value_2, constant_32767);
6120     reluish_value_3 = vrhaddq_s16(reluish_value_3, constant_32767);
6121     const int16x8_t preshift_output_value_0 =
6122         vqdmulhq_s16(reluish_value_0, input_value_on_preshift_output_scale_0);
6123     const int16x8_t preshift_output_value_1 =
6124         vqdmulhq_s16(reluish_value_1, input_value_on_preshift_output_scale_1);
6125     const int16x8_t preshift_output_value_2 =
6126         vqdmulhq_s16(reluish_value_2, input_value_on_preshift_output_scale_2);
6127     const int16x8_t preshift_output_value_3 =
6128         vqdmulhq_s16(reluish_value_3, input_value_on_preshift_output_scale_3);
6129     int16x8_t output_value_0 =
6130         vrshlq_s16(preshift_output_value_0, output_multiplier_exponent);
6131     int16x8_t output_value_1 =
6132         vrshlq_s16(preshift_output_value_1, output_multiplier_exponent);
6133     int16x8_t output_value_2 =
6134         vrshlq_s16(preshift_output_value_2, output_multiplier_exponent);
6135     int16x8_t output_value_3 =
6136         vrshlq_s16(preshift_output_value_3, output_multiplier_exponent);
6137     output_value_0 = vaddq_s16(output_value_0, output_zero_point);
6138     output_value_1 = vaddq_s16(output_value_1, output_zero_point);
6139     output_value_2 = vaddq_s16(output_value_2, output_zero_point);
6140     output_value_3 = vaddq_s16(output_value_3, output_zero_point);
6141     SaturateAndStore(output_value_0, output_data + i);
6142     SaturateAndStore(output_value_1, output_data + i + 8);
6143     SaturateAndStore(output_value_2, output_data + i + 16);
6144     SaturateAndStore(output_value_3, output_data + i + 24);
6145   }
6146   // NEON version of reference_ops::HardSwish. Read that first.
6147   for (; i <= flat_size - 8; i += 8) {
6148     using cpu_backend_gemm::detail::Load8AndSubtractZeroPoint;
6149     const int16x8_t input_value =
6150         Load8AndSubtractZeroPoint(input_data + i, params.input_zero_point);
6151     const int16x8_t input_value_on_hires_input_scale =
6152         vshlq_n_s16(input_value, 7);
6153     const int16x8_t input_value_on_preshift_output_scale =
6154         vqrdmulhq_n_s16(input_value_on_hires_input_scale,
6155                         params.output_multiplier_fixedpoint_int16);
6156     int16x8_t reluish_value = input_value_on_hires_input_scale;
6157     reluish_value = vqshlq_s16(reluish_value,
6158                                positive_reluish_multiplier_exponent_minus_one);
6159     reluish_value = vqrdmulhq_n_s16(reluish_value,
6160                                     params.reluish_multiplier_fixedpoint_int16);
6161     reluish_value = vqshlq_s16(reluish_value,
6162                                positive_reluish_multiplier_exponent_last_bit);
6163     reluish_value =
6164         vrshlq_s16(reluish_value, negative_reluish_multiplier_exponent);
6165     reluish_value = vrhaddq_s16(reluish_value, constant_32767);
6166     const int16x8_t preshift_output_value =
6167         vqdmulhq_s16(reluish_value, input_value_on_preshift_output_scale);
6168     int16x8_t output_value =
6169         vrshlq_s16(preshift_output_value, output_multiplier_exponent);
6170     output_value = vaddq_s16(output_value, output_zero_point);
6171     SaturateAndStore(output_value, output_data + i);
6172   }
6173 #endif
6174   // TODO(b/137208495): revisit when unit tests cover reference code.
6175   // Fall back to reference_ops::HardSwish. In general we have preferred
6176   // to duplicate such scalar code rather than call reference code to handle
6177   // leftovers, thinking that code duplication was not a big concern.
6178   // However, most of our unit tests happen to test only optimized code,
6179   // and the quantized HardSwish implementation is nontrivial enough that
6180   // I really want test coverage for the reference code.
6181   if (i < flat_size) {
6182     const RuntimeShape leftover_shape{flat_size - i};
6183     reference_ops::HardSwish(params, leftover_shape, input_data + i,
6184                              leftover_shape, output_data + i);
6185   }
6186 }
6187 
6188 template <typename T>
IntegerExponentPow(const ArithmeticParams & params,const RuntimeShape & unextended_base_shape,const T * base_data,const int exponent,const RuntimeShape & unextended_output_shape,T * output_data)6189 inline void IntegerExponentPow(const ArithmeticParams& params,
6190                                const RuntimeShape& unextended_base_shape,
6191                                const T* base_data, const int exponent,
6192                                const RuntimeShape& unextended_output_shape,
6193                                T* output_data) {
6194   TFLITE_DCHECK_GE(exponent, 1);
6195   if (exponent == 1) {
6196     // copy data over.
6197     std::memcpy(output_data, base_data,
6198                 unextended_base_shape.FlatSize() * sizeof(T));
6199   } else {
6200     IntegerExponentPow(params, unextended_base_shape, base_data, exponent / 2,
6201                        unextended_output_shape, output_data);
6202     Mul(params, unextended_base_shape, output_data, unextended_base_shape,
6203         output_data, unextended_output_shape, output_data);
6204     if (exponent % 2 == 1) {
6205       Mul(params, unextended_base_shape, base_data, unextended_base_shape,
6206           output_data, unextended_output_shape, output_data);
6207     }
6208   }
6209 }
6210 
6211 template <typename T>
BroadcastPow4D(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)6212 inline void BroadcastPow4D(const RuntimeShape& unextended_input1_shape,
6213                            const T* input1_data,
6214                            const RuntimeShape& unextended_input2_shape,
6215                            const T* input2_data,
6216                            const RuntimeShape& unextended_output_shape,
6217                            T* output_data) {
6218   ruy::profiler::ScopeLabel label("PowBroadcast");
6219 
6220   if (unextended_input2_shape.FlatSize() == 1) {
6221     static const float epsilon = 1e-5;
6222     const T exponent = input2_data[0];
6223     const int int_exponent = static_cast<int>(std::round(exponent));
6224     if ((std::abs(input2_data[0] - int_exponent) < epsilon) &&
6225         (int_exponent >= 1)) {
6226       ArithmeticParams params;
6227       if (std::is_same<T, float>::value) {
6228         params.float_activation_max = std::numeric_limits<float>::max();
6229         params.float_activation_min = std::numeric_limits<float>::lowest();
6230       } else if (std::is_same<T, int>::value) {
6231         params.quantized_activation_max = std::numeric_limits<int>::max();
6232         params.quantized_activation_min = std::numeric_limits<int>::lowest();
6233       }
6234       IntegerExponentPow(params, unextended_input1_shape, input1_data,
6235                          int_exponent, unextended_output_shape, output_data);
6236       return;
6237     }
6238   }
6239   reference_ops::BroadcastPow4DSlow(unextended_input1_shape, input1_data,
6240                                     unextended_input2_shape, input2_data,
6241                                     unextended_output_shape, output_data);
6242 }
6243 
6244 #ifdef USE_NEON
6245 
ScaleWithNewZeroPoint(const int32x4_t input,const float32x4_t scale_dup,const float32x4_t zero_times_scale_dup,float32x4_t * output)6246 inline void ScaleWithNewZeroPoint(const int32x4_t input,
6247                                   const float32x4_t scale_dup,
6248                                   const float32x4_t zero_times_scale_dup,
6249                                   float32x4_t* output) {
6250 #ifdef __ARM_FEATURE_FMA
6251   *output = vfmaq_f32(zero_times_scale_dup, vcvtq_f32_s32(input), scale_dup);
6252 #else
6253   *output = vaddq_f32(vmulq_f32(vcvtq_f32_s32(input), scale_dup),
6254                       zero_times_scale_dup);
6255 #endif
6256 }
6257 
6258 #endif  // USE_NEON
6259 
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const uint8_t * input_data,const RuntimeShape & output_shape,float * output_data)6260 inline void Dequantize(const tflite::DequantizationParams& op_params,
6261                        const RuntimeShape& input_shape,
6262                        const uint8_t* input_data,
6263                        const RuntimeShape& output_shape, float* output_data) {
6264   ruy::profiler::ScopeLabel label("Dequantize/Uint8");
6265   const int32 zero_point = op_params.zero_point;
6266   const double scale = op_params.scale;
6267   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6268 
6269   int i = 0;
6270 #ifdef USE_NEON
6271   const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6272   const float32x4_t zero_times_scale_dup =
6273       vdupq_n_f32(static_cast<float>(-zero_point * scale));
6274   for (; i <= flat_size - 8; i += 8) {
6275     const uint8x8_t input_u8 = vld1_u8(input_data + i);
6276     const uint16x8_t input_u16 = vmovl_u8(input_u8);
6277     const int16x8_t input_s16 = vreinterpretq_s16_u16(input_u16);
6278     const int16x4_t input_s16_low = vget_low_s16(input_s16);
6279     const int16x4_t input_s16_high = vget_high_s16(input_s16);
6280     const int32x4_t val_low = vmovl_s16(input_s16_low);
6281     const int32x4_t val_high = vmovl_s16(input_s16_high);
6282 
6283     float32x4_t result_low, result_high;
6284     ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6285                           &result_low);
6286     ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6287                           &result_high);
6288 
6289     vst1q_f32(output_data + i, result_low);
6290     vst1q_f32(output_data + i + 4, result_high);
6291   }
6292 #endif  // NEON
6293   for (; i < flat_size; ++i) {
6294     const int32 val = input_data[i];
6295     const float result = static_cast<float>(scale * (val - zero_point));
6296     output_data[i] = result;
6297   }
6298 }
6299 
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & output_shape,float * output_data)6300 inline void Dequantize(const tflite::DequantizationParams& op_params,
6301                        const RuntimeShape& input_shape,
6302                        const int8_t* input_data,
6303                        const RuntimeShape& output_shape, float* output_data) {
6304   ruy::profiler::ScopeLabel label("Dequantize/Int8");
6305   const int32 zero_point = op_params.zero_point;
6306   const double scale = op_params.scale;
6307   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6308 
6309   int i = 0;
6310 #ifdef USE_NEON
6311   const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6312   const float32x4_t zero_times_scale_dup =
6313       vdupq_n_f32(static_cast<float>(-zero_point * scale));
6314   for (; i <= flat_size - 8; i += 8) {
6315     const int8x8_t input_s8 = vld1_s8(input_data + i);
6316     const int16x8_t input_s16 = vmovl_s8(input_s8);
6317     const int16x4_t input_s16_low = vget_low_s16(input_s16);
6318     const int16x4_t input_s16_high = vget_high_s16(input_s16);
6319     const int32x4_t val_low = vmovl_s16(input_s16_low);
6320     const int32x4_t val_high = vmovl_s16(input_s16_high);
6321 
6322     float32x4_t result_low, result_high;
6323     ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6324                           &result_low);
6325     ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6326                           &result_high);
6327 
6328     vst1q_f32(output_data + i, result_low);
6329     vst1q_f32(output_data + i + 4, result_high);
6330   }
6331 #endif  // NEON
6332   for (; i < flat_size; ++i) {
6333     const int32 val = input_data[i];
6334     const float result = static_cast<float>(scale * (val - zero_point));
6335     output_data[i] = result;
6336   }
6337 }
6338 
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const int16_t * input_data,const RuntimeShape & output_shape,float * output_data)6339 inline void Dequantize(const tflite::DequantizationParams& op_params,
6340                        const RuntimeShape& input_shape,
6341                        const int16_t* input_data,
6342                        const RuntimeShape& output_shape, float* output_data) {
6343   ruy::profiler::ScopeLabel label("Dequantize/Int16");
6344   const int32 zero_point = op_params.zero_point;
6345   const double scale = op_params.scale;
6346   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6347 
6348   int i = 0;
6349 #ifdef USE_NEON
6350   const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6351   const float32x4_t zero_times_scale_dup =
6352       vdupq_n_f32(static_cast<float>(-zero_point * scale));
6353   for (; i <= flat_size - 8; i += 8) {
6354     const int16x4_t input_s16_low = vld1_s16(input_data + i);
6355     const int16x4_t input_s16_high = vld1_s16(input_data + i + 4);
6356     const int32x4_t val_low = vmovl_s16(input_s16_low);
6357     const int32x4_t val_high = vmovl_s16(input_s16_high);
6358 
6359     float32x4_t result_low, result_high;
6360     ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6361                           &result_low);
6362     ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6363                           &result_high);
6364 
6365     vst1q_f32(output_data + i, result_low);
6366     vst1q_f32(output_data + i + 4, result_high);
6367   }
6368 #endif  // NEON
6369   for (; i < flat_size; ++i) {
6370     const int32 val = input_data[i];
6371     const float result = static_cast<float>(scale * (val - zero_point));
6372     output_data[i] = result;
6373   }
6374 }
6375 
Dequantize(const RuntimeShape & input_shape,const Eigen::half * input_data,const RuntimeShape & output_shape,float * output_data)6376 inline void Dequantize(const RuntimeShape& input_shape,
6377                        const Eigen::half* input_data,
6378                        const RuntimeShape& output_shape, float* output_data) {
6379   reference_ops::Dequantize(input_shape, input_data, output_shape, output_data);
6380 }
6381 
6382 template <typename T>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,T * output_data)6383 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6384                            const RuntimeShape& input_shape,
6385                            const float* input_data,
6386                            const RuntimeShape& output_shape, T* output_data) {
6387   reference_ops::AffineQuantize(op_params, input_shape, input_data,
6388                                 output_shape, output_data);
6389 }
6390 
6391 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,int8_t * output_data)6392 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6393                            const RuntimeShape& input_shape,
6394                            const float* input_data,
6395                            const RuntimeShape& output_shape,
6396                            int8_t* output_data) {
6397   ruy::profiler::ScopeLabel label("Quantize/Int8");
6398   const int32 zero_point = op_params.zero_point;
6399   const double scale = static_cast<double>(op_params.scale);
6400   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6401   static constexpr int32 min_val = std::numeric_limits<int8_t>::min();
6402   static constexpr int32 max_val = std::numeric_limits<int8_t>::max();
6403 
6404   int i = 0;
6405 #ifdef USE_NEON
6406   const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6407   const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6408   const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6409   const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6410 
6411   for (; i <= flat_size - 8; i += 8) {
6412     const float* src_data_ptr = input_data + i;
6413     float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6414     float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6415 
6416     input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6417     input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6418 
6419     int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6420     int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6421 
6422     casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6423     casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6424 
6425     // Clamp the values to fit the target type's range.
6426     casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6427     casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6428     casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6429     casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6430 
6431     const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
6432     const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
6433     const int16x8_t combined_val = vcombine_s16(narrowed_val_0, narrowed_val_1);
6434     const int8x8_t combined_val_narrowed = vmovn_s16(combined_val);
6435     vst1_s8(output_data + i, combined_val_narrowed);
6436   }
6437 #endif  // NEON
6438 
6439   for (; i < flat_size; ++i) {
6440     const float val = input_data[i];
6441     const int32 unclamped =
6442         static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6443     const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6444     output_data[i] = clamped;
6445   }
6446 }
6447 
6448 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,uint8_t * output_data)6449 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6450                            const RuntimeShape& input_shape,
6451                            const float* input_data,
6452                            const RuntimeShape& output_shape,
6453                            uint8_t* output_data) {
6454   ruy::profiler::ScopeLabel label("Quantize/Uint8");
6455   const int32 zero_point = op_params.zero_point;
6456   const double scale = static_cast<double>(op_params.scale);
6457   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6458   static constexpr int32 min_val = std::numeric_limits<uint8_t>::min();
6459   static constexpr int32 max_val = std::numeric_limits<uint8_t>::max();
6460 
6461   int i = 0;
6462 #ifdef USE_NEON
6463   const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6464   const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6465   const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6466   const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6467 
6468   for (; i <= flat_size - 8; i += 8) {
6469     const float* src_data_ptr = input_data + i;
6470     float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6471     float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6472 
6473     input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6474     input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6475 
6476     int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6477     int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6478 
6479     casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6480     casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6481 
6482     // Clamp the values to fit the target type's range.
6483     casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6484     casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6485     casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6486     casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6487 
6488     const uint16x4_t narrowed_val_0 = vqmovun_s32(casted_val_0);
6489     const uint16x4_t narrowed_val_1 = vqmovun_s32(casted_val_1);
6490     const uint16x8_t combined_val =
6491         vcombine_u16(narrowed_val_0, narrowed_val_1);
6492     const uint8x8_t combined_val_narrowed = vmovn_u16(combined_val);
6493     vst1_u8(output_data + i, combined_val_narrowed);
6494   }
6495 #endif  // NEON
6496 
6497   for (; i < flat_size; ++i) {
6498     const float val = input_data[i];
6499     const int32 unclamped =
6500         static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6501     const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6502     output_data[i] = clamped;
6503   }
6504 }
6505 
6506 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,int16_t * output_data)6507 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6508                            const RuntimeShape& input_shape,
6509                            const float* input_data,
6510                            const RuntimeShape& output_shape,
6511                            int16_t* output_data) {
6512   ruy::profiler::ScopeLabel label("Quantize/Int16");
6513   const int32 zero_point = op_params.zero_point;
6514   const double scale = static_cast<double>(op_params.scale);
6515   const int flat_size = MatchingFlatSize(input_shape, output_shape);
6516   static constexpr int32 min_val = std::numeric_limits<int16_t>::min();
6517   static constexpr int32 max_val = std::numeric_limits<int16_t>::max();
6518 
6519   int i = 0;
6520 #ifdef USE_NEON
6521   const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6522   const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6523   const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6524   const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6525 
6526   for (; i <= flat_size - 8; i += 8) {
6527     const float* src_data_ptr = input_data + i;
6528     float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6529     float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6530 
6531     input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6532     input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6533 
6534     int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6535     int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6536 
6537     casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6538     casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6539 
6540     // Clamp the values to fit the target type's range.
6541     casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6542     casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6543     casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6544     casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6545 
6546     const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
6547     const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
6548     vst1_s16(output_data + i, narrowed_val_0);
6549     vst1_s16(output_data + i + 4, narrowed_val_1);
6550   }
6551 #endif  // NEON
6552 
6553   for (; i < flat_size; ++i) {
6554     const float val = input_data[i];
6555     const int32 unclamped =
6556         static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6557     const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6558     output_data[i] = clamped;
6559   }
6560 }
6561 
6562 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6563 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6564 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6565 #ifdef GEMMLOWP_NEON
6566 
SaturatingRounding(int16x8_t input_val_0,int16x8_t input_val_1,int16x8_t input_val_2,int16x8_t input_val_3,int input_left_shift,int input_multiplier)6567 inline int16x8x4_t SaturatingRounding(
6568     int16x8_t input_val_0, int16x8_t input_val_1, int16x8_t input_val_2,
6569     int16x8_t input_val_3, int input_left_shift, int input_multiplier) {
6570   // This performs what is expressed in the scalar code as
6571   // const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6572   //      static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6573   //      static_cast<int16>(input_multiplier));
6574   const int16x8_t left_shift_dup = vdupq_n_s16(input_left_shift);
6575   const int16x8_t input_val_shifted_0 = vshlq_s16(input_val_0, left_shift_dup);
6576   const int16x8_t input_val_shifted_1 = vshlq_s16(input_val_1, left_shift_dup);
6577   const int16x8_t input_val_shifted_2 = vshlq_s16(input_val_2, left_shift_dup);
6578   const int16x8_t input_val_shifted_3 = vshlq_s16(input_val_3, left_shift_dup);
6579   int16x8x4_t result;
6580   result.val[0] = vqrdmulhq_n_s16(input_val_shifted_0, input_multiplier);
6581   result.val[1] = vqrdmulhq_n_s16(input_val_shifted_1, input_multiplier);
6582   result.val[2] = vqrdmulhq_n_s16(input_val_shifted_2, input_multiplier);
6583   result.val[3] = vqrdmulhq_n_s16(input_val_shifted_3, input_multiplier);
6584   return result;
6585 }
6586 
6587 // 4-bit fixed point is enough for tanh since tanh(16) is almost same with one,
6588 // considering 7 digits under zero.
FixedPoint4Logistic(int16x8x4_t input_val)6589 inline int16x8x4_t FixedPoint4Logistic(int16x8x4_t input_val) {
6590   // Invoke gemmlowp::logistic on FixedPoint wrapping int16x8_t
6591   using FixedPoint4 = gemmlowp::FixedPoint<int16x8_t, 4>;
6592   using FixedPoint0 = gemmlowp::FixedPoint<int16x8_t, 0>;
6593   const FixedPoint4 input_val_f4_0 = FixedPoint4::FromRaw(input_val.val[0]);
6594   const FixedPoint4 input_val_f4_1 = FixedPoint4::FromRaw(input_val.val[1]);
6595   const FixedPoint4 input_val_f4_2 = FixedPoint4::FromRaw(input_val.val[2]);
6596   const FixedPoint4 input_val_f4_3 = FixedPoint4::FromRaw(input_val.val[3]);
6597 
6598   // TODO(b/134622898) Implement a low accuracy version of logistic. In this
6599   // method, gemmlowp::tanh spends about 80% of the execution times. The
6600   // current implementation is rougly 12-bit accurate in the 16-bit fixed
6601   // point case. Until reaching to error bounds, there are rooms for
6602   // improvements.
6603   const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
6604   const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
6605   const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
6606   const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
6607 
6608   // Divide by 2^7 as in the scalar code
6609   int16x8x4_t result;
6610   result.val[0] = vrshrq_n_s16(output_val_f0_0.raw(), 7);
6611   result.val[1] = vrshrq_n_s16(output_val_f0_1.raw(), 7);
6612   result.val[2] = vrshrq_n_s16(output_val_f0_2.raw(), 7);
6613   result.val[3] = vrshrq_n_s16(output_val_f0_3.raw(), 7);
6614   return result;
6615 }
6616 
6617 // 4-bit fixed point is enough for tanh since tanh(16) is almost same with one,
6618 // considering 11 digits under zero at least.
FixedPoint4Tanh(int16x8x4_t input_val)6619 inline int16x8x4_t FixedPoint4Tanh(int16x8x4_t input_val) {
6620   // Invoke gemmlowp::logistic on FixedPoint wrapping int16x8_t
6621   using FixedPoint4 = gemmlowp::FixedPoint<int16x8_t, 4>;
6622   using FixedPoint0 = gemmlowp::FixedPoint<int16x8_t, 0>;
6623   const FixedPoint4 input_val_f4_0 = FixedPoint4::FromRaw(input_val.val[0]);
6624   const FixedPoint4 input_val_f4_1 = FixedPoint4::FromRaw(input_val.val[1]);
6625   const FixedPoint4 input_val_f4_2 = FixedPoint4::FromRaw(input_val.val[2]);
6626   const FixedPoint4 input_val_f4_3 = FixedPoint4::FromRaw(input_val.val[3]);
6627 
6628   // TODO(b/134622898) Implement a low accuracy version of logistic. In this
6629   // method, gemmlowp::tanh spends about 80% of the execution times. The
6630   // current implementation is rougly 12-bit accurate in the 16-bit fixed
6631   // point case. Until reaching to error bounds, there are rooms for
6632   // improvements.
6633   const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
6634   const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
6635   const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
6636   const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
6637 
6638   // Divide by 2^7 as in the scalar code
6639   int16x8x4_t result;
6640   result.val[0] = vrshrq_n_s16(output_val_f0_0.raw(), 8);
6641   result.val[1] = vrshrq_n_s16(output_val_f0_1.raw(), 8);
6642   result.val[2] = vrshrq_n_s16(output_val_f0_2.raw(), 8);
6643   result.val[3] = vrshrq_n_s16(output_val_f0_3.raw(), 8);
6644   return result;
6645 }
6646 
CalculateUnsignedClampingWithRangeBitMasks(int16x8x2_t input_val,int16x8_t range_radius_dup,int16x8_t neg_range_radius_dup)6647 inline uint8x16x2_t CalculateUnsignedClampingWithRangeBitMasks(
6648     int16x8x2_t input_val, int16x8_t range_radius_dup,
6649     int16x8_t neg_range_radius_dup) {
6650   const uint16x8_t mask_rightclamp_0 =
6651       vcgtq_s16(input_val.val[0], range_radius_dup);
6652   const uint16x8_t mask_rightclamp_1 =
6653       vcgtq_s16(input_val.val[1], range_radius_dup);
6654 
6655   const uint16x8_t mask_leftclamp_0 =
6656       vcgeq_s16(input_val.val[0], neg_range_radius_dup);
6657   const uint16x8_t mask_leftclamp_1 =
6658       vcgeq_s16(input_val.val[1], neg_range_radius_dup);
6659 
6660   uint8x16x2_t result;
6661   result.val[0] = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
6662                               vshrn_n_u16(mask_leftclamp_1, 8));
6663   result.val[1] = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
6664                               vshrn_n_u16(mask_rightclamp_1, 8));
6665   return result;
6666 }
6667 
CalculateSignedClampingWithRangeBitMasks(int16x8x2_t input_val,int16x8_t range_radius_dup,int16x8_t neg_range_radius_dup)6668 inline uint8x16x2_t CalculateSignedClampingWithRangeBitMasks(
6669     int16x8x2_t input_val, int16x8_t range_radius_dup,
6670     int16x8_t neg_range_radius_dup) {
6671   const uint16x8_t mask_rightclamp_0 =
6672       vcgtq_s16(input_val.val[0], range_radius_dup);
6673   const uint16x8_t mask_rightclamp_1 =
6674       vcgtq_s16(input_val.val[1], range_radius_dup);
6675 
6676   const uint16x8_t mask_leftclamp_0 =
6677       vcltq_s16(input_val.val[0], neg_range_radius_dup);
6678   const uint16x8_t mask_leftclamp_1 =
6679       vcltq_s16(input_val.val[1], neg_range_radius_dup);
6680 
6681   uint8x16x2_t result;
6682   result.val[0] = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
6683                               vshrn_n_u16(mask_leftclamp_1, 8));
6684   result.val[1] = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
6685                               vshrn_n_u16(mask_rightclamp_1, 8));
6686   return result;
6687 }
6688 
ClampWithRangeAndStore(uint8_t * output_dst,uint8x16_t input_val,uint8x16x2_t masks_clamp)6689 inline void ClampWithRangeAndStore(uint8_t* output_dst, uint8x16_t input_val,
6690                                    uint8x16x2_t masks_clamp) {
6691   // Store back to memory
6692   vst1q_u8(output_dst, vandq_u8(vorrq_u8(input_val, masks_clamp.val[1]),
6693                                 masks_clamp.val[0]));
6694 }
6695 
ClampWithRangeAndStore(int8_t * output_dst,int8x16_t input_val,uint8x16x2_t masks_clamp)6696 inline void ClampWithRangeAndStore(int8_t* output_dst, int8x16_t input_val,
6697                                    uint8x16x2_t masks_clamp) {
6698   static const int8x16_t max_dup = vdupq_n_s8(127);
6699   static const int8x16_t min_dup = vdupq_n_s8(-128);
6700   // Store back to memory
6701   vst1q_s8(output_dst,
6702            vbslq_s8(masks_clamp.val[1], max_dup,
6703                     vbslq_s8(masks_clamp.val[0], min_dup, input_val)));
6704 }
6705 
6706 #endif  // GEMMLOWP_NEON
6707 
Tanh16bitPrecision(const TanhParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)6708 inline void Tanh16bitPrecision(const TanhParams& params,
6709                                const RuntimeShape& input_shape,
6710                                const uint8* input_data,
6711                                const RuntimeShape& output_shape,
6712                                uint8* output_data) {
6713   // Note that this is almost the exact same code as in Logistic().
6714   ruy::profiler::ScopeLabel label("Tanh/Uint8");
6715   const int32 input_zero_point = params.input_zero_point;
6716   const int32 input_range_radius = params.input_range_radius;
6717   const int16 input_multiplier = static_cast<int16>(params.input_multiplier);
6718   const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
6719   const int size = MatchingFlatSize(input_shape, output_shape);
6720 
6721   int c = 0;
6722   int16_t output_zero_point = 128;
6723 
6724 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6725 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6726 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6727 #ifdef GEMMLOWP_NEON
6728   const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
6729   const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
6730   const int16x8_t output_zero_point_s16 = vdupq_n_s16(output_zero_point);
6731 
6732   // Handle 32 values at a time
6733   for (; c <= size - 32; c += 32) {
6734     // Read input uint8 values, cast to int16 and subtract input_zero_point
6735     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6736     const int16x8x2_t input_val_centered_0_1 =
6737         Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
6738     const int16x8x2_t input_val_centered_2_3 =
6739         Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
6740 
6741     // Prepare the bit masks that we will use at the end to implement the logic
6742     // that was expressed in the scalar code with branching:
6743     //   if (input_val_centered < -input_range_radius) {
6744     //     output_val = 0;
6745     //   } else if (input_val_centered > input_range_radius) {
6746     //     output_val = 255;
6747     //   } else {
6748     //     ...
6749     uint8x16x2_t masks_clamp_0_1 = CalculateUnsignedClampingWithRangeBitMasks(
6750         input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
6751     uint8x16x2_t masks_clamp_2_3 = CalculateUnsignedClampingWithRangeBitMasks(
6752         input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
6753 
6754     int16x8x4_t input_val_rescaled = SaturatingRounding(
6755         input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
6756         input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
6757         input_left_shift, input_multiplier);
6758 
6759     int16x8x4_t output_val_s16 = FixedPoint4Tanh(input_val_rescaled);
6760 
6761     // Add the output zero point
6762     output_val_s16.val[0] =
6763         vaddq_s16(output_val_s16.val[0], output_zero_point_s16);
6764     output_val_s16.val[1] =
6765         vaddq_s16(output_val_s16.val[1], output_zero_point_s16);
6766     output_val_s16.val[2] =
6767         vaddq_s16(output_val_s16.val[2], output_zero_point_s16);
6768     output_val_s16.val[3] =
6769         vaddq_s16(output_val_s16.val[3], output_zero_point_s16);
6770 
6771     // Cast output values to uint8, saturating
6772     uint8x16_t output_val_u8_0_1 = vcombine_u8(
6773         vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1]));
6774     uint8x16_t output_val_u8_2_3 = vcombine_u8(
6775         vqmovun_s16(output_val_s16.val[2]), vqmovun_s16(output_val_s16.val[3]));
6776 
6777     ClampWithRangeAndStore(output_data + c, output_val_u8_0_1, masks_clamp_0_1);
6778     ClampWithRangeAndStore(output_data + c + 16, output_val_u8_2_3,
6779                            masks_clamp_2_3);
6780   }
6781 #endif  // GEMMLOWP_NEON
6782   // Leftover loop: handle one value at a time with scalar code.
6783   for (; c < size; ++c) {
6784     const uint8 input_val_u8 = input_data[c];
6785     const int16 input_val_centered =
6786         static_cast<int16>(input_val_u8) - input_zero_point;
6787     uint8 output_val;
6788     if (input_val_centered < -input_range_radius) {
6789       output_val = 0;
6790     } else if (input_val_centered > input_range_radius) {
6791       output_val = 255;
6792     } else {
6793       using gemmlowp::SaturatingRoundingDoublingHighMul;
6794       const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6795           static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6796           static_cast<int16>(input_multiplier));
6797       using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
6798       using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
6799       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
6800       const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
6801       using gemmlowp::RoundingDivideByPOT;
6802       int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8);
6803       output_val_s16 += output_zero_point;
6804       if (output_val_s16 == 256) {
6805         output_val_s16 = 255;
6806       }
6807       TFLITE_DCHECK_GE(output_val_s16, 0);
6808       TFLITE_DCHECK_LE(output_val_s16, 255);
6809       output_val = static_cast<uint8>(output_val_s16);
6810     }
6811     output_data[c] = output_val;
6812   }
6813 }
6814 
Tanh16bitPrecision(const TanhParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & output_shape,int8 * output_data)6815 inline void Tanh16bitPrecision(const TanhParams& params,
6816                                const RuntimeShape& input_shape,
6817                                const int8* input_data,
6818                                const RuntimeShape& output_shape,
6819                                int8* output_data) {
6820   // Note that this is almost the exact same code as in Logistic().
6821   ruy::profiler::ScopeLabel label("Tanh/Int8");
6822   const int32 input_zero_point = params.input_zero_point;
6823   const int32 input_range_radius = params.input_range_radius;
6824   const int16 input_multiplier = static_cast<int16>(params.input_multiplier);
6825   const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
6826   const int size = MatchingFlatSize(input_shape, output_shape);
6827 
6828   int c = 0;
6829 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6830 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6831 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6832 #ifdef GEMMLOWP_NEON
6833   const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
6834   const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
6835 
6836   // Handle 32 values at a time
6837   for (; c <= size - 32; c += 32) {
6838     // Read input int8 values, cast to int16 and subtract input_zero_point
6839     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6840     const int16x8x2_t input_val_centered_0_1 =
6841         Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
6842     const int16x8x2_t input_val_centered_2_3 =
6843         Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
6844 
6845     // Prepare the bit masks that we will use at the end to implement the logic
6846     // that was expressed in the scalar code with branching:
6847     //   if (input_val_centered < -input_range_radius) {
6848     //     output_val = -128;
6849     //   } else if (input_val_centered > input_range_radius) {
6850     //     output_val = 127;
6851     //   } else {
6852     //     ...
6853     uint8x16x2_t masks_clamp_0_1 = CalculateSignedClampingWithRangeBitMasks(
6854         input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
6855     uint8x16x2_t masks_clamp_2_3 = CalculateSignedClampingWithRangeBitMasks(
6856         input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
6857 
6858     int16x8x4_t input_val_rescaled = SaturatingRounding(
6859         input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
6860         input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
6861         input_left_shift, input_multiplier);
6862 
6863     int16x8x4_t output_val_s16 = FixedPoint4Tanh(input_val_rescaled);
6864 
6865     // Cast output values to uint8, saturating
6866     int8x16_t output_val_s8_0_1 = vcombine_s8(
6867         vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1]));
6868     int8x16_t output_val_s8_2_3 = vcombine_s8(
6869         vqmovn_s16(output_val_s16.val[2]), vqmovn_s16(output_val_s16.val[3]));
6870 
6871     ClampWithRangeAndStore(output_data + c, output_val_s8_0_1, masks_clamp_0_1);
6872     ClampWithRangeAndStore(output_data + c + 16, output_val_s8_2_3,
6873                            masks_clamp_2_3);
6874   }
6875 #endif  // GEMMLOWP_NEON
6876   // Leftover loop: handle one value at a time with scalar code.
6877   for (; c < size; ++c) {
6878     const int8 input_val_s8 = input_data[c];
6879     const int16 input_val_centered =
6880         static_cast<int16>(input_val_s8) - input_zero_point;
6881     int8 output_val;
6882     if (input_val_centered <= -input_range_radius) {
6883       output_val = -128;
6884     } else if (input_val_centered >= input_range_radius) {
6885       output_val = 127;
6886     } else {
6887       using gemmlowp::SaturatingRoundingDoublingHighMul;
6888       const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6889           static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6890           static_cast<int16>(input_multiplier));
6891       using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
6892       using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
6893       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
6894       const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
6895       using gemmlowp::RoundingDivideByPOT;
6896       int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8);
6897       if (output_val_s16 == 128) {
6898         output_val_s16 = 127;
6899       }
6900       TFLITE_DCHECK_GE(output_val_s16, -128);
6901       TFLITE_DCHECK_LE(output_val_s16, 127);
6902       output_val = static_cast<int8>(output_val_s16);
6903     }
6904     output_data[c] = output_val;
6905   }
6906 }
6907 
Logistic16bitPrecision(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)6908 inline void Logistic16bitPrecision(const LogisticParams& params,
6909                                    const RuntimeShape& input_shape,
6910                                    const uint8* input_data,
6911                                    const RuntimeShape& output_shape,
6912                                    uint8* output_data) {
6913   ruy::profiler::ScopeLabel label("Logistic/Uint8");
6914   const int32 input_zero_point = params.input_zero_point;
6915   const int32 input_range_radius = params.input_range_radius;
6916   const int32 input_multiplier = params.input_multiplier;
6917   const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
6918   const int size = MatchingFlatSize(input_shape, output_shape);
6919 
6920   int c = 0;
6921 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6922 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6923 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6924 #ifdef GEMMLOWP_NEON
6925   const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
6926   const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
6927 
6928   // Handle 32 values at a time
6929   for (; c <= size - 32; c += 32) {
6930     // Read input uint8 values, cast to int16 and subtract input_zero_point
6931     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6932     const int16x8x2_t input_val_centered_0_1 =
6933         Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
6934     const int16x8x2_t input_val_centered_2_3 =
6935         Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
6936 
6937     // Prepare the bit masks that we will use at the end to implement the logic
6938     // that was expressed in the scalar code with branching:
6939     //   if (input_val_centered < -input_range_radius) {
6940     //     output_val = 0;
6941     //   } else if (input_val_centered > input_range_radius) {
6942     //     output_val = 255;
6943     //   } else {
6944     //     ...
6945     uint8x16x2_t masks_clamp_0_1 = CalculateUnsignedClampingWithRangeBitMasks(
6946         input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
6947     uint8x16x2_t masks_clamp_2_3 = CalculateUnsignedClampingWithRangeBitMasks(
6948         input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
6949 
6950     int16x8x4_t input_val_rescaled = SaturatingRounding(
6951         input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
6952         input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
6953         input_left_shift, input_multiplier);
6954 
6955     int16x8x4_t output_val_s16 = FixedPoint4Logistic(input_val_rescaled);
6956 
6957     // Cast output values to uint8, saturating
6958     uint8x16_t output_val_u8_0_1 = vcombine_u8(
6959         vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1]));
6960     uint8x16_t output_val_u8_2_3 = vcombine_u8(
6961         vqmovun_s16(output_val_s16.val[2]), vqmovun_s16(output_val_s16.val[3]));
6962 
6963     ClampWithRangeAndStore(output_data + c, output_val_u8_0_1, masks_clamp_0_1);
6964     ClampWithRangeAndStore(output_data + c + 16, output_val_u8_2_3,
6965                            masks_clamp_2_3);
6966   }
6967 #endif  // GEMMLOWP_NEON
6968   // Leftover loop: handle one value at a time with scalar code.
6969   for (; c < size; ++c) {
6970     const uint8 input_val_u8 = input_data[c];
6971     const int16 input_val_centered =
6972         static_cast<int16>(input_val_u8) - input_zero_point;
6973     uint8 output_val;
6974     if (input_val_centered < -input_range_radius) {
6975       output_val = 0;
6976     } else if (input_val_centered > input_range_radius) {
6977       output_val = 255;
6978     } else {
6979       using gemmlowp::SaturatingRoundingDoublingHighMul;
6980       const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6981           static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6982           static_cast<int16>(input_multiplier));
6983       using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
6984       using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
6985       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
6986       const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
6987       using gemmlowp::RoundingDivideByPOT;
6988       int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7);
6989       if (output_val_s16 == 256) {
6990         output_val_s16 = 255;
6991       }
6992       TFLITE_DCHECK_GE(output_val_s16, 0);
6993       TFLITE_DCHECK_LE(output_val_s16, 255);
6994       output_val = static_cast<uint8>(output_val_s16);
6995     }
6996     output_data[c] = output_val;
6997   }
6998 }
6999 
Logistic16bitPrecision(const LogisticParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & output_shape,int8 * output_data)7000 inline void Logistic16bitPrecision(const LogisticParams& params,
7001                                    const RuntimeShape& input_shape,
7002                                    const int8* input_data,
7003                                    const RuntimeShape& output_shape,
7004                                    int8* output_data) {
7005   ruy::profiler::ScopeLabel label("Logistic/Int8");
7006   const int32 input_zero_point = params.input_zero_point;
7007   const int32 input_range_radius = params.input_range_radius;
7008   const int32 input_multiplier = params.input_multiplier;
7009   const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
7010   const int size = MatchingFlatSize(input_shape, output_shape);
7011 
7012   int c = 0;
7013   const int16 output_zero_point = 128;
7014 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7015 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7016 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7017 #ifdef GEMMLOWP_NEON
7018   const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
7019   const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
7020   const int16x8_t output_zero_point_dup = vdupq_n_s16(output_zero_point);
7021 
7022   // Handle 32 values at a time
7023   for (; c <= size - 32; c += 32) {
7024     // Read input int8 values, cast to int16 and subtract input_zero_point
7025     using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
7026     const int16x8x2_t input_val_centered_0_1 =
7027         Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
7028     const int16x8x2_t input_val_centered_2_3 =
7029         Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
7030 
7031     // Prepare the bit masks that we will use at the end to implement the logic
7032     // that was expressed in the scalar code with branching:
7033     //   if (input_val_centered < -input_range_radius) {
7034     //     output_val = -128;
7035     //   } else if (input_val_centered > input_range_radius) {
7036     //     output_val = 127;
7037     //   } else {
7038     //     ...
7039     uint8x16x2_t masks_clamp_0_1 = CalculateSignedClampingWithRangeBitMasks(
7040         input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
7041     uint8x16x2_t masks_clamp_2_3 = CalculateSignedClampingWithRangeBitMasks(
7042         input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
7043 
7044     int16x8x4_t input_val_rescaled = SaturatingRounding(
7045         input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7046         input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7047         input_left_shift, input_multiplier);
7048 
7049     int16x8x4_t output_val_s16 = FixedPoint4Logistic(input_val_rescaled);
7050 
7051     // Substract output zero point.
7052     output_val_s16.val[0] =
7053         vsubq_s16(output_val_s16.val[0], output_zero_point_dup);
7054     output_val_s16.val[1] =
7055         vsubq_s16(output_val_s16.val[1], output_zero_point_dup);
7056     output_val_s16.val[2] =
7057         vsubq_s16(output_val_s16.val[2], output_zero_point_dup);
7058     output_val_s16.val[3] =
7059         vsubq_s16(output_val_s16.val[3], output_zero_point_dup);
7060 
7061     // Cast output values to int8, saturating
7062     int8x16_t output_val_s8_0_1 = vcombine_s8(
7063         vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1]));
7064     int8x16_t output_val_s8_2_3 = vcombine_s8(
7065         vqmovn_s16(output_val_s16.val[2]), vqmovn_s16(output_val_s16.val[3]));
7066 
7067     ClampWithRangeAndStore(output_data + c, output_val_s8_0_1, masks_clamp_0_1);
7068     ClampWithRangeAndStore(output_data + c + 16, output_val_s8_2_3,
7069                            masks_clamp_2_3);
7070   }
7071 #endif  // GEMMLOWP_NEON
7072   // Leftover loop: handle one value at a time with scalar code.
7073   for (; c < size; ++c) {
7074     const int8 input_val_s8 = input_data[c];
7075     const int16 input_val_centered =
7076         static_cast<int16>(input_val_s8) - input_zero_point;
7077     int8 output_val;
7078     if (input_val_centered < -input_range_radius) {
7079       output_val = -128;
7080     } else if (input_val_centered > input_range_radius) {
7081       output_val = 127;
7082     } else {
7083       using gemmlowp::SaturatingRoundingDoublingHighMul;
7084       const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7085           static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7086           static_cast<int16>(input_multiplier));
7087       using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7088       using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7089       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7090       const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
7091       using gemmlowp::RoundingDivideByPOT;
7092       int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7);
7093       output_val_s16 -= output_zero_point;
7094       if (output_val_s16 == 128) {
7095         output_val_s16 = 127;
7096       }
7097       TFLITE_DCHECK_GE(output_val_s16, -128);
7098       TFLITE_DCHECK_LE(output_val_s16, 127);
7099       output_val = static_cast<int8>(output_val_s16);
7100     }
7101     output_data[c] = output_val;
7102   }
7103 }
7104 
7105 // Transpose2D only deals with typical 2D matrix transpose ops.
7106 // Perform transpose by transposing 4x4 blocks of the input, proceeding from
7107 // left to right (down the rows) of the input, and then from top to bottom.
7108 template <typename T>
Transpose2D(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7109 inline void Transpose2D(const RuntimeShape& input_shape, const T* input_data,
7110                         const RuntimeShape& output_shape, T* output_data) {
7111   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7112   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
7113 
7114   const int d0 = input_shape.DimsData()[0];
7115   const int d1 = input_shape.DimsData()[1];
7116   const int kLines = 4;
7117   const int kSkipSize = (kLines - 1) * d1;
7118 
7119   const T* input = input_data;
7120 
7121   int i = 0;
7122   for (; i <= d0 - kLines; i += kLines) {
7123     T* output = output_data + i;
7124 
7125     const T* input_ptr = input;
7126     optimized_ops_preload_l1_keep(input_ptr);
7127     input_ptr += d1;
7128     optimized_ops_preload_l1_keep(input_ptr);
7129     input_ptr += d1;
7130     optimized_ops_preload_l1_keep(input_ptr);
7131     input_ptr += d1;
7132     optimized_ops_preload_l1_keep(input_ptr);
7133 
7134     int j = 0;
7135     for (; j <= d1 - kLines; j += kLines) {
7136       input_ptr = input;
7137       const T a00 = input_ptr[0];
7138       const T a01 = input_ptr[1];
7139       const T a02 = input_ptr[2];
7140       const T a03 = input_ptr[3];
7141       input_ptr += d1;
7142       const T a10 = input_ptr[0];
7143       const T a11 = input_ptr[1];
7144       const T a12 = input_ptr[2];
7145       const T a13 = input_ptr[3];
7146       input_ptr += d1;
7147       const T a20 = input_ptr[0];
7148       const T a21 = input_ptr[1];
7149       const T a22 = input_ptr[2];
7150       const T a23 = input_ptr[3];
7151       input_ptr += d1;
7152       const T a30 = input_ptr[0];
7153       const T a31 = input_ptr[1];
7154       const T a32 = input_ptr[2];
7155       const T a33 = input_ptr[3];
7156 
7157       output[0] = a00;
7158       output[1] = a10;
7159       output[2] = a20;
7160       output[3] = a30;
7161       output += d0;
7162 
7163       output[0] = a01;
7164       output[1] = a11;
7165       output[2] = a21;
7166       output[3] = a31;
7167       output += d0;
7168 
7169       output[0] = a02;
7170       output[1] = a12;
7171       output[2] = a22;
7172       output[3] = a32;
7173       output += d0;
7174 
7175       output[0] = a03;
7176       output[1] = a13;
7177       output[2] = a23;
7178       output[3] = a33;
7179       output += d0;
7180 
7181       input += kLines;
7182     }
7183     if (j == d1) {
7184       input += kSkipSize;
7185     } else {
7186       for (int p = 0; p < kLines; ++p) {
7187         for (int q = 0; q < d1 - j; ++q) {
7188           *(output + q * d0 + p) = *(input + p * d1 + q);
7189         }
7190       }
7191       input += (d1 - j) + kSkipSize;
7192     }
7193   }
7194   for (; i < d0; ++i) {
7195     T* output = output_data + i;
7196     for (int j = 0; j < d1; ++j) {
7197       *output = *input;
7198       output += d0;
7199       ++input;
7200     }
7201   }
7202 }
7203 
7204 template <>
Transpose2D(const RuntimeShape & input_shape,const int32_t * input_data,const RuntimeShape & output_shape,int32_t * output_data)7205 inline void Transpose2D(const RuntimeShape& input_shape,
7206                         const int32_t* input_data,
7207                         const RuntimeShape& output_shape,
7208                         int32_t* output_data) {
7209   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7210   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
7211 
7212   const int d0 = input_shape.DimsData()[0];
7213   const int d1 = input_shape.DimsData()[1];
7214 #ifdef USE_NEON
7215   const int kLines = 4;
7216   const int kSkipSize = (kLines - 1) * d1;
7217 #endif
7218 
7219   const int32_t* input = input_data;
7220 
7221   int i = 0;
7222 #ifdef USE_NEON
7223   for (; i <= d0 - kLines; i += kLines) {
7224     int32_t* output = output_data + i;
7225 
7226     const int32_t* input_ptr = input;
7227     optimized_ops_preload_l1_keep(input_ptr);
7228     input_ptr += d1;
7229     optimized_ops_preload_l1_keep(input_ptr);
7230     input_ptr += d1;
7231     optimized_ops_preload_l1_keep(input_ptr);
7232     input_ptr += d1;
7233     optimized_ops_preload_l1_keep(input_ptr);
7234 
7235     int j = 0;
7236     for (; j <= d1 - kLines; j += kLines) {
7237       input_ptr = input;
7238       int32x4_t a0 = vld1q_s32(input);
7239       input_ptr += d1;
7240       int32x4_t a1 = vld1q_s32(input_ptr);
7241       input_ptr += d1;
7242       int32x4_t a2 = vld1q_s32(input_ptr);
7243       input_ptr += d1;
7244       int32x4_t a3 = vld1q_s32(input_ptr);
7245 
7246       int32x4x2_t tmp1 = vuzpq_s32(a0, a2);
7247       int32x4x2_t tmp2 = vuzpq_s32(a1, a3);
7248       int32x4x2_t tmp3 = vtrnq_s32(tmp1.val[0], tmp2.val[0]);
7249       int32x4x2_t tmp4 = vtrnq_s32(tmp1.val[1], tmp2.val[1]);
7250 
7251       vst1q_s32(output, tmp3.val[0]);
7252       output += d0;
7253       vst1q_s32(output, tmp4.val[0]);
7254       output += d0;
7255       vst1q_s32(output, tmp3.val[1]);
7256       output += d0;
7257       vst1q_s32(output, tmp4.val[1]);
7258       output += d0;
7259       input += kLines;
7260     }
7261     if (j == d1) {
7262       input += kSkipSize;
7263     } else {
7264       for (int p = 0; p < kLines; ++p) {
7265         for (int q = 0; q < d1 - j; ++q) {
7266           *(output + q * d0 + p) = *(input + p * d1 + q);
7267         }
7268       }
7269       input += (d1 - j) + kSkipSize;
7270     }
7271   }
7272 #endif
7273   for (; i < d0; ++i) {
7274     int32_t* output = output_data + i;
7275     for (int j = 0; j < d1; ++j) {
7276       *output = *input;
7277       output += d0;
7278       ++input;
7279     }
7280   }
7281 }
7282 
7283 // TODO(b/173718660): see if we can reduce the number
7284 // of lines of code in branching without affecting latency.
7285 template <typename T>
Transpose3D(const TransposeParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7286 inline void Transpose3D(const TransposeParams& params,
7287                         const RuntimeShape& input_shape, const T* input_data,
7288                         const RuntimeShape& output_shape, T* output_data) {
7289   int s1, s2, s3;
7290   s1 = input_shape.Dims(0);
7291   s2 = input_shape.Dims(1);
7292   s3 = input_shape.Dims(2);
7293 
7294   int p1, p2, p3;
7295   if (params.perm[0] == 2) {
7296     p1 = 1;
7297   } else if (params.perm[1] == 2) {
7298     p2 = 1;
7299   } else {
7300     p3 = 1;
7301   }
7302 
7303   if (params.perm[0] == 1) {
7304     p1 = s3;
7305   } else if (params.perm[1] == 1) {
7306     p2 = s3;
7307   } else {
7308     p3 = s3;
7309   }
7310 
7311   if (params.perm[0] == 0) {
7312     p1 = s2 * s3;
7313   } else if (params.perm[1] == 0) {
7314     p2 = s2 * s3;
7315   } else {
7316     p3 = s2 * s3;
7317   }
7318 
7319   int o_s[3];
7320   o_s[0] = input_shape.Dims(params.perm[0]);
7321   o_s[1] = input_shape.Dims(params.perm[1]);
7322   o_s[2] = input_shape.Dims(params.perm[2]);
7323 
7324   for (int i1 = 0; i1 < o_s[0]; ++i1) {
7325     for (int i2 = 0; i2 < o_s[1]; ++i2) {
7326       for (int i3 = 0; i3 < o_s[2]; ++i3) {
7327         const int i = i1 * p1 + i2 * p2 + i3 * p3;
7328         const int o = i1 * o_s[1] * o_s[2] + i2 * o_s[2] + i3;
7329         output_data[o] = input_data[i];
7330       }
7331     }
7332   }
7333 }
7334 
7335 template <typename T, int N>
TransposeImpl(const TransposeParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7336 void TransposeImpl(const TransposeParams& params,
7337                    const RuntimeShape& input_shape, const T* input_data,
7338                    const RuntimeShape& output_shape, T* output_data) {
7339   const int dims_cnt = input_shape.DimensionsCount();
7340 
7341   int dim0, dim1;
7342   if (transpose_utils::IsTranspose2DApplicable(params, input_shape, &dim0,
7343                                                &dim1)) {
7344     Transpose2D(RuntimeShape({dim0, dim1}), input_data,
7345                 RuntimeShape({dim1, dim0}), output_data);
7346     return;
7347   }
7348 
7349   // TODO(b/141217325): notably Eigen is better suited for
7350   // larger inputs whereas Transpose3D is generally
7351   // better for smaller ones.
7352   //
7353   // E.g. on Nexus 5, Eigen is better for size 96^3 and up
7354   // and Transpose3D is better for 72^3 and down.
7355   //
7356   // 96^3 is not mobile-friendly for certain usecases
7357   // (e.g. model used in beam search for seq2seq) but is in others.
7358   // Consider tradeoffs.
7359   if (dims_cnt == 3) {
7360     Transpose3D(params, input_shape, input_data, output_shape, output_data);
7361     return;
7362   }
7363 
7364   // Reroute to the reference version if an optimized method for the given data
7365   // is not available.
7366   reference_ops::Transpose<T, N>(params, input_shape, input_data, output_shape,
7367                                  output_data);
7368 }
7369 
7370 template <typename T, int N = 5>
Transpose(const TransposeParams & unshrinked_params,const RuntimeShape & unshrinked_input_shape,const T * input_data,const RuntimeShape & unshrinked_output_shape,T * output_data)7371 void Transpose(const TransposeParams& unshrinked_params,
7372                const RuntimeShape& unshrinked_input_shape, const T* input_data,
7373                const RuntimeShape& unshrinked_output_shape, T* output_data) {
7374   ruy::profiler::ScopeLabel label("Transpose");
7375 
7376   const int output_size = unshrinked_output_shape.DimensionsCount();
7377   TFLITE_DCHECK_LE(unshrinked_input_shape.DimensionsCount(), N);
7378   TFLITE_DCHECK_LE(output_size, N);
7379   TFLITE_DCHECK_EQ(output_size, unshrinked_params.perm_count);
7380 
7381   RuntimeShape shrinked_input_shape = RuntimeShape(unshrinked_input_shape);
7382   RuntimeShape shrinked_output_shape = RuntimeShape(unshrinked_output_shape);
7383   TransposeParams shrinked_params = unshrinked_params;
7384 
7385   // Reduce any dimensions that have one size. Lower transpose op usually
7386   // performs better since memory access patterns will be improved.
7387   transpose_utils::RemoveOneSizeDimensions(
7388       &shrinked_input_shape, &shrinked_output_shape, &shrinked_params);
7389 
7390   // Handle identity cases.
7391   // TODO(b/140779653): Add an optimization pass in the conversion process to
7392   // remove transpose op nodes where they do nothing like the below one.
7393   bool identical = true;
7394   for (int i = 0; i < shrinked_params.perm_count; ++i) {
7395     if (shrinked_params.perm[i] != i) {
7396       identical = false;
7397       break;
7398     }
7399   }
7400   if (identical) {
7401     memcpy(output_data, input_data,
7402            unshrinked_input_shape.FlatSize() * sizeof(T));
7403     return;
7404   }
7405 
7406   // Reduce dimensions by flattening.
7407   if (shrinked_params.perm[0] == 0 && output_size >= 3) {
7408     RuntimeShape non_flatten_input_shape;
7409     RuntimeShape non_flatten_output_shape;
7410     TransposeParams non_flatten_params;
7411     const int total_size = shrinked_input_shape.FlatSize();
7412     const int non_flatten_size = transpose_utils::Flatten(
7413         shrinked_input_shape, shrinked_output_shape, shrinked_params,
7414         &non_flatten_input_shape, &non_flatten_output_shape,
7415         &non_flatten_params);
7416     TFLITE_DCHECK_NE(non_flatten_params.perm[0], 0);
7417 
7418     for (int i = 0; i < total_size; i += non_flatten_size) {
7419       TransposeImpl<T, N>(non_flatten_params, non_flatten_input_shape,
7420                           input_data + i, non_flatten_output_shape,
7421                           output_data + i);
7422     }
7423     return;
7424   }
7425 
7426   // Call non-flattened case.
7427   TransposeImpl<T, N>(shrinked_params, shrinked_input_shape, input_data,
7428                       shrinked_output_shape, output_data);
7429 }
7430 
7431 // Assume input1 & input2 have the same scale & zero point.
MaximumElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)7432 inline void MaximumElementwise(int size, const ArithmeticParams& params,
7433                                const int8* input1_data, const int8* input2_data,
7434                                int8* output_data) {
7435   ruy::profiler::ScopeLabel label("MaximumElementwiseInt8/8bit");
7436   int i = 0;
7437 #ifdef USE_NEON
7438   for (; i <= size - 16; i += 16) {
7439     const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
7440     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7441     const int8x16_t max_data =
7442         vmaxq_s8(input1_val_original, input2_val_original);
7443     vst1q_s8(output_data + i, max_data);
7444   }
7445 #endif  // USE_NEON
7446   for (; i < size; ++i) {
7447     const int8 input1_val = input1_data[i];
7448     const int8 input2_val = input2_data[i];
7449     output_data[i] = std::max(input1_val, input2_val);
7450   }
7451 }
7452 
MaximumScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)7453 inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params,
7454                                    int8 input1_data, const int8* input2_data,
7455                                    int8* output_data) {
7456   ruy::profiler::ScopeLabel label("MaximumScalarBroadcastInt8/8bit");
7457   int i = 0;
7458 
7459 #ifdef USE_NEON
7460   const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
7461   for (; i <= size - 16; i += 16) {
7462     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7463     const int8x16_t max_data =
7464         vmaxq_s8(input1_val_original, input2_val_original);
7465     vst1q_s8(output_data + i, max_data);
7466   }
7467 #endif  // USE_NEON
7468   for (; i < size; ++i) {
7469     const int8 input2_val = input2_data[i];
7470     output_data[i] = std::max(input1_data, input2_val);
7471   }
7472 }
7473 
7474 // Assume input1 & input2 have the same scale & zero point.
MinimumElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)7475 inline void MinimumElementwise(int size, const ArithmeticParams& params,
7476                                const int8* input1_data, const int8* input2_data,
7477                                int8* output_data) {
7478   ruy::profiler::ScopeLabel label("MinimumElementwiseInt8/8bit");
7479   int i = 0;
7480 #ifdef USE_NEON
7481   for (; i <= size - 16; i += 16) {
7482     const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
7483     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7484     const int8x16_t min_data =
7485         vminq_s8(input1_val_original, input2_val_original);
7486     vst1q_s8(output_data + i, min_data);
7487   }
7488 #endif  // USE_NEON
7489   for (; i < size; ++i) {
7490     const int8 input1_val = input1_data[i];
7491     const int8 input2_val = input2_data[i];
7492     output_data[i] = std::min(input1_val, input2_val);
7493   }
7494 }
7495 
MinimumScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)7496 inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params,
7497                                    int8 input1_data, const int8* input2_data,
7498                                    int8* output_data) {
7499   ruy::profiler::ScopeLabel label("MinimumScalarBroadcastInt8/8bit");
7500   int i = 0;
7501 
7502 #ifdef USE_NEON
7503   const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
7504   for (; i <= size - 16; i += 16) {
7505     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7506     const int8x16_t min_data =
7507         vminq_s8(input1_val_original, input2_val_original);
7508     vst1q_s8(output_data + i, min_data);
7509   }
7510 #endif  // USE_NEON
7511   for (; i < size; ++i) {
7512     const int8 input2_val = input2_data[i];
7513     output_data[i] = std::min(input1_data, input2_val);
7514   }
7515 }
7516 
7517 template <typename Op>
BroadcastMaximumDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data,Op op)7518 inline void BroadcastMaximumDispatch(const ArithmeticParams& params,
7519                                      const RuntimeShape& input1_shape,
7520                                      const int8* input1_data,
7521                                      const RuntimeShape& input2_shape,
7522                                      const int8* input2_data,
7523                                      const RuntimeShape& output_shape,
7524                                      int8* output_data, Op op) {
7525   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
7526     return reference_ops::MaximumMinimumBroadcastSlow(
7527         input1_shape, input1_data, input2_shape, input2_data, output_shape,
7528         output_data, op);
7529   }
7530 
7531   BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
7532                           input2_data, output_shape, output_data,
7533                           MaximumElementwise, MaximumScalarBroadcast);
7534 }
7535 
7536 template <typename Op>
BroadcastMinimumDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data,Op op)7537 inline void BroadcastMinimumDispatch(const ArithmeticParams& params,
7538                                      const RuntimeShape& input1_shape,
7539                                      const int8* input1_data,
7540                                      const RuntimeShape& input2_shape,
7541                                      const int8* input2_data,
7542                                      const RuntimeShape& output_shape,
7543                                      int8* output_data, Op op) {
7544   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
7545     return reference_ops::MaximumMinimumBroadcastSlow(
7546         input1_shape, input1_data, input2_shape, input2_data, output_shape,
7547         output_data, op);
7548   }
7549 
7550   BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
7551                           input2_data, output_shape, output_data,
7552                           MinimumElementwise, MinimumScalarBroadcast);
7553 }
7554 
7555 template <typename T>
CumsumImpl(const T * input_data,const RuntimeShape & shape,int axis,bool exclusive,bool reverse,T * output_data)7556 void CumsumImpl(const T* input_data, const RuntimeShape& shape, int axis,
7557                 bool exclusive, bool reverse, T* output_data) {
7558   Eigen::array<Eigen::DenseIndex, 3> dims = {1, 1, 1};
7559 
7560   for (int i = 0; i < axis; ++i) {
7561     dims[0] *= shape.Dims(i);
7562   }
7563   dims[1] = shape.Dims(axis);
7564   for (int i = axis + 1; i < shape.DimensionsCount(); ++i) {
7565     dims[2] *= shape.Dims(i);
7566   }
7567 
7568   typedef Eigen::TensorMap<
7569       Eigen::Tensor<const T, 3, Eigen::RowMajor, Eigen::DenseIndex>,
7570       Eigen::Aligned>
7571       ConstTensor;
7572   typedef Eigen::TensorMap<
7573       Eigen::Tensor<T, 3, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>
7574       Tensor;
7575   ConstTensor input(input_data, dims);
7576   Tensor output(output_data, dims);
7577 
7578   if (reverse) {
7579     Eigen::array<bool, 3> reverse_idx = {false, true, false};
7580     output =
7581         input.reverse(reverse_idx).cumsum(1, exclusive).reverse(reverse_idx);
7582   } else {
7583     output = input.cumsum(1, exclusive);
7584   }
7585 }
7586 
7587 template <typename T>
CumSum(const T * input_data,const RuntimeShape & shape,int axis,bool exclusive,bool reverse,T * output_data)7588 void CumSum(const T* input_data, const RuntimeShape& shape, int axis,
7589             bool exclusive, bool reverse, T* output_data) {
7590   const int dim = shape.DimensionsCount();
7591   TFLITE_DCHECK_GE(dim, 1);
7592   CumsumImpl<T>(input_data, shape, axis, exclusive, reverse, output_data);
7593 }
7594 
PReluScalarBroadcast(int size,const ArithmeticParams & params,float alpha,const float * input_data,float * output_data)7595 inline void PReluScalarBroadcast(int size, const ArithmeticParams& params,
7596                                  float alpha, const float* input_data,
7597                                  float* output_data) {
7598   ruy::profiler::ScopeLabel label("PreluScalarBroadcast/float");
7599   int i = 0;
7600 
7601 #ifdef USE_NEON
7602   const float32x4_t zero_dup = vdupq_n_f32(0.0f);
7603   const float32x4_t alpha_dup = vdupq_n_f32(alpha);
7604   for (; i <= size - 16; i += 16) {
7605     const float32x4_t input1 = vld1q_f32(input_data + i);
7606     const float32x4_t input2 = vld1q_f32(input_data + i + 4);
7607     const float32x4_t input3 = vld1q_f32(input_data + i + 8);
7608     const float32x4_t input4 = vld1q_f32(input_data + i + 12);
7609 
7610     const float32x4_t temp1 = vmulq_f32(input1, alpha_dup);
7611     const float32x4_t temp2 = vmulq_f32(input2, alpha_dup);
7612     const float32x4_t temp3 = vmulq_f32(input3, alpha_dup);
7613     const float32x4_t temp4 = vmulq_f32(input4, alpha_dup);
7614 
7615     const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
7616     const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
7617     const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
7618     const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
7619 
7620     const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
7621     vst1q_f32(output_data + i, result1);
7622     const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
7623     vst1q_f32(output_data + i + 4, result2);
7624     const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
7625     vst1q_f32(output_data + i + 8, result3);
7626     const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
7627     vst1q_f32(output_data + i + 12, result4);
7628   }
7629 
7630   for (; i <= size - 4; i += 4) {
7631     const float32x4_t input = vld1q_f32(input_data + i);
7632     const float32x4_t temp = vmulq_f32(input, alpha_dup);
7633     const uint32x4_t mask = vcgeq_f32(input, zero_dup);
7634     const float32x4_t result = vbslq_f32(mask, input, temp);
7635     vst1q_f32(output_data + i, result);
7636   }
7637 #endif  // USE_NEON
7638   for (; i < size; ++i) {
7639     const float input = input_data[i];
7640     output_data[i] = input >= 0.f ? input : input * alpha;
7641   }
7642 }
7643 
PReluElementWise(int flat_size,const ArithmeticParams & params,const float * alpha_data,const float * input_data,float * output_data)7644 inline void PReluElementWise(int flat_size, const ArithmeticParams& params,
7645                              const float* alpha_data, const float* input_data,
7646                              float* output_data) {
7647   ruy::profiler::ScopeLabel label("PreluElementWise/float");
7648 
7649   int i = 0;
7650 #ifdef USE_NEON
7651   const float32x4_t zero_dup = vdupq_n_f32(0.0f);
7652   for (; i <= flat_size - 16; i += 16) {
7653     const float32x4_t input1 = vld1q_f32(input_data + i);
7654     const float32x4_t alpha1 = vld1q_f32(alpha_data + i);
7655     const float32x4_t input2 = vld1q_f32(input_data + i + 4);
7656     const float32x4_t alpha2 = vld1q_f32(alpha_data + i + 4);
7657     const float32x4_t input3 = vld1q_f32(input_data + i + 8);
7658     const float32x4_t alpha3 = vld1q_f32(alpha_data + i + 8);
7659     const float32x4_t input4 = vld1q_f32(input_data + i + 12);
7660     const float32x4_t alpha4 = vld1q_f32(alpha_data + i + 12);
7661 
7662     const float32x4_t temp1 = vmulq_f32(input1, alpha1);
7663     const float32x4_t temp2 = vmulq_f32(input2, alpha2);
7664     const float32x4_t temp3 = vmulq_f32(input3, alpha3);
7665     const float32x4_t temp4 = vmulq_f32(input4, alpha4);
7666 
7667     const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
7668     const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
7669     const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
7670     const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
7671 
7672     const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
7673     vst1q_f32(output_data + i, result1);
7674     const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
7675     vst1q_f32(output_data + i + 4, result2);
7676     const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
7677     vst1q_f32(output_data + i + 8, result3);
7678     const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
7679     vst1q_f32(output_data + i + 12, result4);
7680   }
7681 
7682   for (; i <= flat_size - 4; i += 4) {
7683     const float32x4_t input = vld1q_f32(input_data + i);
7684     const float32x4_t alpha = vld1q_f32(alpha_data + i);
7685 
7686     const float32x4_t temp = vmulq_f32(input, alpha);
7687     const uint32x4_t mask = vcgeq_f32(input, zero_dup);
7688     const float32x4_t result = vbslq_f32(mask, input, temp);
7689     vst1q_f32(output_data + i, result);
7690   }
7691 #endif  // USE_NEON
7692   for (; i < flat_size; ++i) {
7693     const float input = input_data[i];
7694     const float alpha = alpha_data[i];
7695     output_data[i] = input >= 0.f ? input : input * alpha;
7696   }
7697 }
7698 
BroadcastPReluDispatch(const ArithmeticParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & alpha_shape,const float * alpha_data,const RuntimeShape & output_shape,float * output_data,float (* func)(float,float))7699 inline void BroadcastPReluDispatch(
7700     const ArithmeticParams& params, const RuntimeShape& input_shape,
7701     const float* input_data, const RuntimeShape& alpha_shape,
7702     const float* alpha_data, const RuntimeShape& output_shape,
7703     float* output_data, float (*func)(float, float)) {
7704   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
7705     return reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
7706         input_shape, input_data, alpha_shape, alpha_data, output_shape,
7707         output_data, func);
7708   }
7709 
7710   BinaryBroadcastFiveFold(params, input_shape, input_data, alpha_shape,
7711                           alpha_data, output_shape, output_data,
7712                           PReluElementWise, PReluScalarBroadcast);
7713 }
7714 
7715 // Returns the index with minimum value within `input_data`.
7716 // If there is a tie, returns the smaller index.
7717 template <typename T>
ArgMinVector(const T * input_data,int size)7718 inline int ArgMinVector(const T* input_data, int size) {
7719   T min_value = input_data[0];
7720   int min_index = 0;
7721   for (int i = 1; i < size; ++i) {
7722     const T curr_value = input_data[i];
7723     if (curr_value < min_value) {
7724       min_value = curr_value;
7725       min_index = i;
7726     }
7727   }
7728   return min_index;
7729 }
7730 
7731 // Returns the index with maximum value within `input_data`.
7732 // If there is a tie, returns the smaller index.
7733 template <typename T>
ArgMaxVector(const T * input_data,int size)7734 inline int ArgMaxVector(const T* input_data, int size) {
7735   T max_value = input_data[0];
7736   int max_index = 0;
7737   for (int i = 1; i < size; ++i) {
7738     const T curr_value = input_data[i];
7739     if (curr_value > max_value) {
7740       max_value = curr_value;
7741       max_index = i;
7742     }
7743   }
7744   return max_index;
7745 }
7746 
7747 template <>
ArgMinVector(const float * input_data,int size)7748 inline int ArgMinVector(const float* input_data, int size) {
7749   int32_t min_index = 0;
7750   float min_value = input_data[0];
7751   int32_t i = 1;
7752 #ifdef USE_NEON
7753   if (size >= 4) {
7754     float32x4_t min_value_f32x4 = vld1q_f32(input_data);
7755     const int32_t index_init[4] = {0, 1, 2, 3};
7756     int32x4_t min_index_s32x4 = vld1q_s32(index_init);
7757     int32x4_t index_s32x4 = min_index_s32x4;
7758     int32x4_t inc = vdupq_n_s32(4);
7759     for (i = 4; i <= size - 4; i += 4) {
7760       // Increase indices by 4.
7761       index_s32x4 = vaddq_s32(index_s32x4, inc);
7762       float32x4_t v = vld1q_f32(&input_data[i]);
7763       uint32x4_t mask = vcltq_f32(v, min_value_f32x4);
7764       min_value_f32x4 = vminq_f32(min_value_f32x4, v);
7765       min_index_s32x4 = vbslq_s32(mask, index_s32x4, min_index_s32x4);
7766     }
7767     // Find min element within float32x4_t.
7768 #ifdef __aarch64__
7769     min_value = vminvq_f32(min_value_f32x4);
7770 #else
7771     float32x2_t min_value_f32x2 = vpmin_f32(vget_low_f32(min_value_f32x4),
7772                                             vget_high_f32(min_value_f32x4));
7773     min_value_f32x2 = vpmin_f32(min_value_f32x2, min_value_f32x2);
7774     min_value = vget_lane_f32(min_value_f32x2, 0);
7775 #endif  // __aarch64__
7776     // Mask indices of non-min values with max int32_t.
7777     float32x4_t fill_min_value_f32x4 = vdupq_n_f32(min_value);
7778     uint32x4_t mask = vceqq_f32(min_value_f32x4, fill_min_value_f32x4);
7779     int32x4_t all_set = vdupq_n_s32(std::numeric_limits<int>::max());
7780     min_index_s32x4 = vbslq_s32(mask, min_index_s32x4, all_set);
7781     // Find min index of min values.
7782 #ifdef __aarch64__
7783     min_index = vminvq_s32(min_index_s32x4);
7784 #else
7785     int32x2_t min_index_s32x2 = vpmin_s32(vget_low_s32(min_index_s32x4),
7786                                           vget_high_s32(min_index_s32x4));
7787     min_index_s32x2 = vpmin_s32(min_index_s32x2, min_index_s32x2);
7788     min_index = vget_lane_s32(min_index_s32x2, 0);
7789 #endif  // __aarch64__
7790   }
7791 #endif  // USE_NEON
7792   // Leftover loop.
7793   for (; i < size; ++i) {
7794     const float curr_value = input_data[i];
7795     if (curr_value < min_value) {
7796       min_value = curr_value;
7797       min_index = i;
7798     }
7799   }
7800   return min_index;
7801 }
7802 
7803 template <>
ArgMaxVector(const float * input_data,int size)7804 inline int ArgMaxVector(const float* input_data, int size) {
7805   int32_t max_index = 0;
7806   float max_value = input_data[0];
7807   int32_t i = 1;
7808 #ifdef USE_NEON
7809   if (size >= 4) {
7810     float32x4_t max_value_f32x4 = vld1q_f32(input_data);
7811     const int32_t index_init[4] = {0, 1, 2, 3};
7812     int32x4_t max_index_s32x4 = vld1q_s32(index_init);
7813     int32x4_t index_s32x4 = max_index_s32x4;
7814     int32x4_t inc = vdupq_n_s32(4);
7815     for (i = 4; i <= size - 4; i += 4) {
7816       // Increase indices by 4.
7817       index_s32x4 = vaddq_s32(index_s32x4, inc);
7818       float32x4_t v = vld1q_f32(&input_data[i]);
7819       uint32x4_t mask = vcgtq_f32(v, max_value_f32x4);
7820       max_value_f32x4 = vmaxq_f32(max_value_f32x4, v);
7821       max_index_s32x4 = vbslq_s32(mask, index_s32x4, max_index_s32x4);
7822     }
7823     // Find max element within float32x4_t.
7824 #ifdef __aarch64__
7825     max_value = vmaxvq_f32(max_value_f32x4);
7826 #else
7827     float32x2_t max_value_f32x2 = vpmax_f32(vget_low_f32(max_value_f32x4),
7828                                             vget_high_f32(max_value_f32x4));
7829     max_value_f32x2 = vpmax_f32(max_value_f32x2, max_value_f32x2);
7830     max_value = vget_lane_f32(max_value_f32x2, 0);
7831 #endif  // __aarch64__
7832     // Mask indices of non-max values with max int32_t.
7833     float32x4_t fill_max_value_f32x4 = vdupq_n_f32(max_value);
7834     uint32x4_t mask = vceqq_f32(max_value_f32x4, fill_max_value_f32x4);
7835     int32x4_t all_set = vdupq_n_s32(std::numeric_limits<int>::max());
7836     max_index_s32x4 = vbslq_s32(mask, max_index_s32x4, all_set);
7837     // Find min index of max values.
7838 #ifdef __aarch64__
7839     max_index = vminvq_s32(max_index_s32x4);
7840 #else
7841     int32x2_t max_index_s32x2 = vpmin_s32(vget_low_s32(max_index_s32x4),
7842                                           vget_high_s32(max_index_s32x4));
7843     max_index_s32x2 = vpmin_s32(max_index_s32x2, max_index_s32x2);
7844     max_index = vget_lane_s32(max_index_s32x2, 0);
7845 #endif  // __aarch64__
7846   }
7847 #endif  // USE_NEON
7848   // Leftover loop.
7849   for (; i < size; ++i) {
7850     const float curr_value = input_data[i];
7851     if (curr_value > max_value) {
7852       max_value = curr_value;
7853       max_index = i;
7854     }
7855   }
7856   return max_index;
7857 }
7858 
7859 template <>
ArgMaxVector(const int8_t * input_data,int size)7860 inline int ArgMaxVector(const int8_t* input_data, int size) {
7861   int32_t max_index = 0;
7862   int8_t max_value = input_data[0];
7863   int32_t i = 0;
7864 #ifdef USE_NEON
7865   constexpr int VECTOR_SIZE = 16;
7866   if (size >= VECTOR_SIZE) {
7867     int8x16_t max_value_s8x16;
7868     for (; i <= size - VECTOR_SIZE; i += VECTOR_SIZE) {
7869       max_value_s8x16 = vld1q_s8(input_data + i);
7870       int8_t max_from_vec;
7871 #ifdef __aarch64__
7872       max_from_vec = vmaxvq_s8(max_value_s8x16);
7873 #else   // 32 bit
7874       int8x8_t max_val_s8x8 =
7875           vpmax_s8(vget_low_s8(max_value_s8x16), vget_high_s8(max_value_s8x16));
7876       max_val_s8x8 = vpmax_s8(max_val_s8x8, max_val_s8x8);
7877       max_val_s8x8 = vpmax_s8(max_val_s8x8, max_val_s8x8);
7878       max_val_s8x8 = vpmax_s8(max_val_s8x8, max_val_s8x8);
7879       max_from_vec = vget_lane_s8(max_val_s8x8, 0);
7880 #endif  // __aarch64__
7881       if (max_from_vec > max_value) {
7882         max_value = max_from_vec;
7883         max_index = i;
7884       }
7885     }
7886   }
7887   for (int start_idx = max_index; start_idx < max_index + VECTOR_SIZE;
7888        start_idx++) {
7889     if (input_data[start_idx] == max_value) {
7890       max_index = start_idx;
7891       break;
7892     }
7893   }
7894 
7895 #endif  // USE_NEON
7896   // Leftover loop.
7897   for (; i < size; ++i) {
7898     const int8_t curr_value = input_data[i];
7899     if (curr_value > max_value) {
7900       max_value = curr_value;
7901       max_index = i;
7902     }
7903   }
7904 
7905   return max_index;
7906 }
7907 
7908 template <>
ArgMaxVector(const uint8_t * input_data,int size)7909 inline int ArgMaxVector(const uint8_t* input_data, int size) {
7910   int32_t max_index = 0;
7911   uint8_t max_value = input_data[0];
7912   int32_t i = 0;
7913 #ifdef USE_NEON
7914   constexpr int VECTOR_SIZE = 16;
7915   if (size >= VECTOR_SIZE) {
7916     uint8x16_t max_value_u8x16;
7917     for (; i <= size - VECTOR_SIZE; i += VECTOR_SIZE) {
7918       max_value_u8x16 = vld1q_u8(input_data + i);
7919       uint8_t max_from_vec;
7920 #ifdef __aarch64__
7921       max_from_vec = vmaxvq_u8(max_value_u8x16);
7922 #else   // 32 bit
7923       uint8x8_t max_val_u8x8 =
7924           vpmax_u8(vget_low_u8(max_value_u8x16), vget_high_u8(max_value_u8x16));
7925       max_val_u8x8 = vpmax_u8(max_val_u8x8, max_val_u8x8);
7926       max_val_u8x8 = vpmax_u8(max_val_u8x8, max_val_u8x8);
7927       max_val_u8x8 = vpmax_u8(max_val_u8x8, max_val_u8x8);
7928       max_from_vec = vget_lane_u8(max_val_u8x8, 0);
7929 #endif  // __aarch64__
7930       if (max_from_vec > max_value) {
7931         max_value = max_from_vec;
7932         max_index = i;
7933       }
7934     }
7935   }
7936   for (int start_idx = max_index; start_idx < max_index + VECTOR_SIZE;
7937        start_idx++) {
7938     if (input_data[start_idx] == max_value) {
7939       max_index = start_idx;
7940       break;
7941     }
7942   }
7943 
7944 #endif  // USE_NEON
7945   // Leftover loop.
7946   for (; i < size; ++i) {
7947     const uint8_t curr_value = input_data[i];
7948     if (curr_value > max_value) {
7949       max_value = curr_value;
7950       max_index = i;
7951     }
7952   }
7953 
7954   return max_index;
7955 }
7956 
7957 // Specializes ArgMinMax function with axis=dims-1.
7958 // In this case, ArgMinMax reduction is applied on contiguous memory.
7959 template <typename T1, typename T2, bool is_arg_max>
ArgMinMaxLastAxis(const RuntimeShape & input_shape,const T1 * input_data,const RuntimeShape & output_shape,T2 * output_data)7960 inline void ArgMinMaxLastAxis(const RuntimeShape& input_shape,
7961                               const T1* input_data,
7962                               const RuntimeShape& output_shape,
7963                               T2* output_data) {
7964   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7965   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 1);
7966   TFLITE_DCHECK_EQ(input_shape.Dims(0), output_shape.Dims(0));
7967 
7968   int outer_size = input_shape.Dims(0);
7969   int axis_size = input_shape.Dims(1);
7970   for (int outer = 0; outer < outer_size; ++outer) {
7971     if (is_arg_max) {
7972       output_data[outer] = static_cast<T2>(
7973           ArgMaxVector<T1>(input_data + outer * axis_size, axis_size));
7974     } else {
7975       output_data[outer] = static_cast<T2>(
7976           ArgMinVector<T1>(input_data + outer * axis_size, axis_size));
7977     }
7978   }
7979 }
7980 
7981 template <typename T1, typename T2, typename T3>
ArgMinMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data,const bool is_arg_max)7982 inline void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
7983                       const T3* input2_data, const RuntimeShape& output_shape,
7984                       T2* output_data, const bool is_arg_max) {
7985   ruy::profiler::ScopeLabel label("ArgMinMax");
7986 
7987   TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0);
7988   TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1,
7989                    output_shape.DimensionsCount());
7990   int axis = input2_data[0];
7991   if (axis < 0) {
7992     axis += input1_shape.DimensionsCount();
7993   }
7994   const int axis_size = input1_shape.Dims(axis);
7995 
7996   int outer_size = 1;
7997   for (int i = 0; i < axis; ++i) {
7998     TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i));
7999     outer_size *= input1_shape.Dims(i);
8000   }
8001 
8002   int inner_size = 1;
8003   const int dims_count = input1_shape.DimensionsCount();
8004   for (int i = axis + 1; i < dims_count; ++i) {
8005     TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1));
8006     inner_size *= input1_shape.Dims(i);
8007   }
8008 
8009   // Call specialized function when axis=dims-1. So far, only float32 is
8010   // optimized so reroute to specialized function only when T1 is float32.
8011   if (inner_size == 1 &&
8012       (std::is_same<T1, float>::value || std::is_same<T1, int8_t>::value ||
8013        std::is_same<T1, uint8_t>::value)) {
8014     if (is_arg_max) {
8015       ArgMinMaxLastAxis<T1, T2, /*is_arg_max=*/true>(
8016           {outer_size, axis_size}, input1_data, {outer_size}, output_data);
8017     } else {
8018       ArgMinMaxLastAxis<T1, T2, /*is_arg_max=*/false>(
8019           {outer_size, axis_size}, input1_data, {outer_size}, output_data);
8020     }
8021     return;
8022   }
8023 
8024   reference_ops::ArgMinMax(input1_shape, input1_data, input2_data, output_shape,
8025                            output_data, is_arg_max);
8026 }
8027 
8028 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)8029 void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
8030             const T3* input2_data, const RuntimeShape& output_shape,
8031             T2* output_data) {
8032   ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
8033             /*is_arg_max=*/true);
8034 }
8035 
8036 // Convenience version that allows, for example, generated-code calls to be
8037 // the same as other binary ops.
8038 // For backward compatibility, reference_ops has ArgMax function.
8039 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const RuntimeShape & input2_shape,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)8040 inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
8041                    const RuntimeShape& input2_shape, const T3* input2_data,
8042                    const RuntimeShape& output_shape, T2* output_data) {
8043   // Drop shape of second input: not needed.
8044   ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
8045 }
8046 
Conv3D(const Conv3DParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data,const RuntimeShape & transposed_filter_shape,float * transposed_filter_data,CpuBackendContext * cpu_backend_context)8047 inline void Conv3D(const Conv3DParams& params, const RuntimeShape& input_shape,
8048                    const float* input_data, const RuntimeShape& filter_shape,
8049                    const float* filter_data, const RuntimeShape& bias_shape,
8050                    const float* bias_data, const RuntimeShape& output_shape,
8051                    float* output_data, const RuntimeShape& im2col_shape,
8052                    float* im2col_data,
8053                    const RuntimeShape& transposed_filter_shape,
8054                    float* transposed_filter_data,
8055                    CpuBackendContext* cpu_backend_context) {
8056   const int stride_depth = params.stride_depth;
8057   const int stride_height = params.stride_height;
8058   const int stride_width = params.stride_width;
8059   const int dilation_depth_factor = params.dilation_depth;
8060   const int dilation_height_factor = params.dilation_height;
8061   const int dilation_width_factor = params.dilation_width;
8062   const float output_activation_min = params.float_activation_min;
8063   const float output_activation_max = params.float_activation_max;
8064   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
8065   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 5);
8066   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 5);
8067 
8068   ruy::profiler::ScopeLabel label("Conv3D");
8069 
8070   // NB: the float 0.0f value is represented by all zero bytes.
8071   const uint8 float_zero_byte = 0x00;
8072   const float* gemm_input_data = nullptr;
8073   const RuntimeShape* gemm_input_shape = nullptr;
8074   const int filter_width = filter_shape.Dims(2);
8075   const int filter_height = filter_shape.Dims(1);
8076   const int filter_depth = filter_shape.Dims(0);
8077   const bool need_dilated_im2col = dilation_width_factor != 1 ||
8078                                    dilation_height_factor != 1 ||
8079                                    dilation_depth_factor != 1;
8080   const bool need_im2col = stride_depth != 1 || stride_height != 1 ||
8081                            stride_width != 1 || filter_depth != 1 ||
8082                            filter_height != 1 || filter_width != 1;
8083 
8084   if (need_dilated_im2col) {
8085     DilatedIm2col3D(params, filter_depth, filter_height, filter_width,
8086                     float_zero_byte, input_shape, input_data, im2col_shape,
8087                     im2col_data);
8088     gemm_input_data = im2col_data;
8089     gemm_input_shape = &im2col_shape;
8090   } else if (need_im2col) {
8091     TFLITE_DCHECK(im2col_data);
8092     Im2col3D(params, filter_depth, filter_height, filter_width, float_zero_byte,
8093              input_shape, input_data, im2col_shape, im2col_data);
8094     gemm_input_data = im2col_data;
8095     gemm_input_shape = &im2col_shape;
8096   } else {
8097     TFLITE_DCHECK(!im2col_data);
8098     gemm_input_data = input_data;
8099     gemm_input_shape = &input_shape;
8100   }
8101 
8102   // Transpose the filter tensor.
8103   TransposeParams transpose_params;
8104   transpose_params.perm_count = 5;
8105   transpose_params.perm[0] = 4;
8106   transpose_params.perm[1] = 0;
8107   transpose_params.perm[2] = 1;
8108   transpose_params.perm[3] = 2;
8109   transpose_params.perm[4] = 3;
8110   Transpose<float, 5>(transpose_params, filter_shape, filter_data,
8111                       transposed_filter_shape, transposed_filter_data);
8112 
8113   const int gemm_input_dims = gemm_input_shape->DimensionsCount();
8114   int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
8115   int n = output_shape.Dims(4);
8116   int k = gemm_input_shape->Dims(gemm_input_dims - 1);
8117 
8118   cpu_backend_gemm::MatrixParams<float> lhs_params;
8119   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
8120   lhs_params.rows = n;
8121   lhs_params.cols = k;
8122   cpu_backend_gemm::MatrixParams<float> rhs_params;
8123   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
8124   rhs_params.rows = k;
8125   rhs_params.cols = m;
8126   cpu_backend_gemm::MatrixParams<float> dst_params;
8127   dst_params.order = cpu_backend_gemm::Order::kColMajor;
8128   dst_params.rows = n;
8129   dst_params.cols = m;
8130   cpu_backend_gemm::GemmParams<float, float> gemm_params;
8131   gemm_params.bias = bias_data;
8132   gemm_params.clamp_min = output_activation_min;
8133   gemm_params.clamp_max = output_activation_max;
8134   cpu_backend_gemm::Gemm(lhs_params, transposed_filter_data, rhs_params,
8135                          gemm_input_data, dst_params, output_data, gemm_params,
8136                          cpu_backend_context);
8137 }
8138 
8139 // Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
8140 // order (planes, height, width, channel), constructed from patches in
8141 // 'col_data', which is required to be in storage order (out_planes * out_height
8142 // * out_width, filter_planes, filter_height, filter_width, in_channel).
8143 //
8144 // This function is copied from tensorflow/core/kernels/conv_grad_ops_3d.cc
8145 // authored by Eugene Zhulenev(ezhulenev).
8146 template <typename T>
Col2im(const T * col_data,const int channel,const int planes,const int height,const int width,const int filter_p,const int filter_h,const int filter_w,const int pad_pt,const int pad_t,const int pad_l,const int pad_pb,const int pad_b,const int pad_r,const int stride_p,const int stride_h,const int stride_w,T * im_data)8147 void Col2im(const T* col_data, const int channel, const int planes,
8148             const int height, const int width, const int filter_p,
8149             const int filter_h, const int filter_w, const int pad_pt,
8150             const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
8151             const int pad_r, const int stride_p, const int stride_h,
8152             const int stride_w, T* im_data) {
8153   const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
8154   const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
8155   const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
8156   int p_pad = -pad_pt;
8157   for (int p = 0; p < planes_col; ++p) {
8158     int h_pad = -pad_t;
8159     for (int h = 0; h < height_col; ++h) {
8160       int w_pad = -pad_l;
8161       for (int w = 0; w < width_col; ++w) {
8162         T* im_patch_data =
8163             im_data +
8164             (p_pad * height * width + h_pad * width + w_pad) * channel;
8165         for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
8166           for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
8167             for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
8168               if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
8169                   iw < width) {
8170                 for (int i = 0; i < channel; ++i) {
8171                   im_patch_data[i] += col_data[i];
8172                 }
8173               }
8174               im_patch_data += channel;
8175               col_data += channel;
8176             }
8177             // Jump over remaining number of channel.
8178             im_patch_data += channel * (width - filter_w);
8179           }
8180           // Jump over remaining number of (channel * width).
8181           im_patch_data += (channel * width) * (height - filter_h);
8182         }
8183         w_pad += stride_w;
8184       }
8185       h_pad += stride_h;
8186     }
8187     p_pad += stride_p;
8188   }
8189 }
8190 
8191 template <typename T>
BiasAdd3D(T * im_data,const T * bias_data,const RuntimeShape & input_shape,float float_activation_min,float float_activation_max)8192 void BiasAdd3D(T* im_data, const T* bias_data, const RuntimeShape& input_shape,
8193                float float_activation_min, float float_activation_max) {
8194   if (bias_data) {
8195     const int outer_size = input_shape.Dims(0) * input_shape.Dims(1) *
8196                            input_shape.Dims(2) * input_shape.Dims(3);
8197     const int num_channels = input_shape.Dims(4);
8198     for (int n = 0; n < outer_size; ++n) {
8199       for (int c = 0; c < num_channels; ++c) {
8200         im_data[c] = ActivationFunctionWithMinMax(im_data[c] + bias_data[c],
8201                                                   float_activation_min,
8202                                                   float_activation_max);
8203       }
8204       im_data += num_channels;
8205     }
8206   } else {
8207     const int flat_size = input_shape.FlatSize();
8208     for (int i = 0; i < flat_size; ++i) {
8209       im_data[i] = ActivationFunctionWithMinMax(
8210           im_data[i], float_activation_min, float_activation_max);
8211     }
8212   }
8213 }
8214 
Conv3DTranspose(const Conv3DTransposeParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * const output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)8215 inline void Conv3DTranspose(
8216     const Conv3DTransposeParams& params, const RuntimeShape& input_shape,
8217     const float* input_data, const RuntimeShape& filter_shape,
8218     const float* filter_data, const RuntimeShape& bias_shape,
8219     const float* bias_data, const RuntimeShape& output_shape,
8220     float* const output_data, const RuntimeShape& col2im_shape,
8221     float* col2im_data, CpuBackendContext* cpu_backend_context) {
8222   ruy::profiler::ScopeLabel label("Conv3DTranspose/float");
8223   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
8224   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 5);
8225   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 5);
8226   TFLITE_DCHECK(col2im_data);
8227 
8228   const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
8229   const int input_channel = MatchingDim(input_shape, 4, filter_shape, 4);
8230   const int output_channel = MatchingDim(output_shape, 4, filter_shape, 3);
8231   const int input_spatial_size =
8232       input_shape.Dims(1) * input_shape.Dims(2) * input_shape.Dims(3);
8233   const int output_spatial_size =
8234       output_shape.Dims(1) * output_shape.Dims(2) * output_shape.Dims(3);
8235 
8236   const int output_spatial_dim_1 = output_shape.Dims(1);
8237   const int output_spatial_dim_2 = output_shape.Dims(2);
8238   const int output_spatial_dim_3 = output_shape.Dims(3);
8239   const int input_offset = input_spatial_size * input_channel;
8240   const int output_offset = output_spatial_size * output_channel;
8241 
8242   const int filter_spatial_dim_1 = filter_shape.Dims(0);
8243   const int filter_spatial_dim_2 = filter_shape.Dims(1);
8244   const int filter_spatial_dim_3 = filter_shape.Dims(2);
8245 
8246   const int spatial_dim_1_padding_before = params.padding_values.depth;
8247   const int spatial_dim_1_padding_after =
8248       params.padding_values.height + params.padding_values.depth_offset;
8249   const int spatial_dim_2_padding_before = params.padding_values.height;
8250   const int spatial_dim_2_padding_after =
8251       params.padding_values.height + params.padding_values.height_offset;
8252   const int spatial_dim_3_padding_before = params.padding_values.width;
8253   const int spatial_dim_3_padding_after =
8254       params.padding_values.width + params.padding_values.width_offset;
8255   const int spatial_dim_1_stride = params.stride_depth;
8256   const int spatial_dim_2_stride = params.stride_height;
8257   const int spatial_dim_3_stride = params.stride_width;
8258   const int filter_total_size = filter_spatial_dim_1 * filter_spatial_dim_2 *
8259                                 filter_spatial_dim_3 * output_channel;
8260 
8261   cpu_backend_gemm::MatrixParams<float> lhs_params;
8262   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
8263   lhs_params.rows = filter_total_size;
8264   lhs_params.cols = input_channel;
8265   float* output_data_p = output_data;
8266   std::fill_n(output_data, output_offset * batch_size, 0.0f);
8267   for (int i = 0; i < batch_size; ++i) {
8268     cpu_backend_gemm::MatrixParams<float> rhs_params;
8269     rhs_params.order = cpu_backend_gemm::Order::kColMajor;
8270     rhs_params.rows = input_channel;
8271     rhs_params.cols = input_spatial_size;
8272     cpu_backend_gemm::MatrixParams<float> dst_params;
8273     dst_params.order = cpu_backend_gemm::Order::kColMajor;
8274     dst_params.rows = filter_total_size;
8275     dst_params.cols = input_spatial_size;
8276     cpu_backend_gemm::GemmParams<float, float> gemm_params;
8277     cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params,
8278                            input_data + input_offset * i, dst_params,
8279                            col2im_data, gemm_params, cpu_backend_context);
8280 
8281     Col2im(col2im_data, output_channel, output_spatial_dim_1,
8282            output_spatial_dim_2, output_spatial_dim_3, filter_spatial_dim_1,
8283            filter_spatial_dim_2, filter_spatial_dim_3,
8284            spatial_dim_1_padding_before, spatial_dim_2_padding_before,
8285            spatial_dim_3_padding_before, spatial_dim_1_padding_after,
8286            spatial_dim_2_padding_after, spatial_dim_3_padding_after,
8287            spatial_dim_1_stride, spatial_dim_2_stride, spatial_dim_3_stride,
8288            output_data_p);
8289     output_data_p += output_offset;
8290   }
8291   output_data_p = output_data;
8292   BiasAdd3D(output_data_p, bias_data, output_shape, params.float_activation_min,
8293             params.float_activation_max);
8294 }
8295 
8296 // Worker for summing up within a single interval. Interval is identified by
8297 // index from [start, end).
8298 template <typename T>
8299 struct AddNWorkerTask : cpu_backend_threadpool::Task {
AddNWorkerTaskAddNWorkerTask8300   AddNWorkerTask(const T* const* input_data, T* scratch_buffer, int start,
8301                  int end, int num_elems, int split)
8302       : input_data(input_data),
8303         scratch_buffer(scratch_buffer),
8304         start(start),
8305         end(end),
8306         num_elems(num_elems),
8307         split(split) {}
RunAddNWorkerTask8308   void Run() override {
8309     RuntimeShape shape(1);
8310     shape.SetDim(0, num_elems);
8311     ArithmeticParams params;
8312     T output_activation_min = std::numeric_limits<T>::lowest(),
8313       output_activation_max = std::numeric_limits<T>::max();
8314     SetActivationParams(output_activation_min, output_activation_max, &params);
8315     T* start_p = scratch_buffer + split * num_elems;
8316     memcpy(start_p, input_data[start], sizeof(T) * num_elems);
8317     for (int i = start + 1; i < end; i++) {
8318       Add(params, shape, start_p, shape, input_data[i], shape, start_p);
8319     }
8320   }
8321 
8322   const T* const* input_data;
8323   T* scratch_buffer;
8324   int start;
8325   int end;
8326   int num_elems;
8327   int split;
8328 };
8329 
8330 // T is expected to be either float or int.
8331 template <typename T>
AddN(const RuntimeShape & input_shape,const size_t num_inputs,const T * const * input_data,T * output_data,T * scratch_buffer,CpuBackendContext * cpu_backend_context)8332 inline void AddN(const RuntimeShape& input_shape, const size_t num_inputs,
8333                  const T* const* input_data, T* output_data, T* scratch_buffer,
8334                  CpuBackendContext* cpu_backend_context) {
8335   // All inputs and output should have the same shape, this is checked during
8336   // Prepare stage.
8337   const size_t num_elems = input_shape.FlatSize();
8338   const int thread_count =
8339       std::min(std::max(1, static_cast<int>(num_inputs) / 2),
8340                cpu_backend_context->max_num_threads());
8341   memset(scratch_buffer, 0, sizeof(T) * num_elems * thread_count);
8342 
8343   std::vector<AddNWorkerTask<T>> tasks;
8344   tasks.reserve(thread_count);
8345   int start = 0;
8346   for (int i = 0; i < thread_count; ++i) {
8347     int end = start + (num_inputs - start) / (thread_count - i);
8348     tasks.emplace_back(AddNWorkerTask<T>(input_data, scratch_buffer, start, end,
8349                                          num_elems, i));
8350     start = end;
8351   }
8352   // Run all tasks on the thread pool.
8353   cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
8354                                   cpu_backend_context);
8355   RuntimeShape shape(1);
8356   shape.SetDim(0, num_elems);
8357   ArithmeticParams params;
8358   T output_activation_min = std::numeric_limits<T>::lowest(),
8359     output_activation_max = std::numeric_limits<T>::max();
8360   SetActivationParams(output_activation_min, output_activation_max, &params);
8361   memcpy(output_data, scratch_buffer, sizeof(T) * num_elems);
8362   for (int i = 1; i < tasks.size(); i++) {
8363     Add(params, shape, output_data, shape, scratch_buffer + i * num_elems,
8364         shape, output_data);
8365   }
8366 }
8367 
8368 }  // namespace optimized_ops
8369 }  // namespace tflite
8370 
8371 #if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
8372 #undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
8373 #pragma GCC diagnostic pop
8374 #endif
8375 
8376 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
8377