xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
17 
18 #include <stdint.h>
19 #include <sys/types.h>
20 
21 #include <algorithm>
22 
23 #include "public/gemmlowp.h"
24 #include "tensorflow/lite/kernels/cpu_backend_context.h"
25 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
26 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h"
27 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h"
28 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
29 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
30 #include "tensorflow/lite/kernels/internal/optimized/resize_bilinear.h"
31 #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
32 #include "tensorflow/lite/kernels/internal/types.h"
33 
34 namespace tflite {
35 namespace optimized_ops {
36 
37 // Unoptimized reference ops:
38 using reference_ops::Broadcast4DSlowGreater;
39 using reference_ops::Broadcast4DSlowGreaterEqual;
40 using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
41 using reference_ops::Broadcast4DSlowGreaterWithScaling;
42 using reference_ops::Broadcast4DSlowLess;
43 using reference_ops::Broadcast4DSlowLessEqual;
44 using reference_ops::Broadcast4DSlowLessEqualWithScaling;
45 using reference_ops::Broadcast4DSlowLessWithScaling;
46 using reference_ops::BroadcastAdd4DSlow;
47 using reference_ops::BroadcastGreater;
48 using reference_ops::BroadcastGreaterEqual;
49 using reference_ops::BroadcastLess;
50 using reference_ops::BroadcastLessEqual;
51 using reference_ops::BroadcastMul4DSlow;
52 using reference_ops::BroadcastSubSlow;
53 using reference_ops::Concatenation;
54 using reference_ops::ConcatenationWithScaling;
55 using reference_ops::DepthConcatenation;
56 using reference_ops::Div;
57 using reference_ops::FakeQuant;
58 using reference_ops::Gather;
59 using reference_ops::Greater;
60 using reference_ops::GreaterEqual;
61 using reference_ops::GreaterEqualWithScaling;
62 using reference_ops::GreaterWithScaling;
63 using reference_ops::Less;
64 using reference_ops::LessEqual;
65 using reference_ops::LessEqualWithScaling;
66 using reference_ops::LessWithScaling;
67 using reference_ops::Mean;
68 using reference_ops::RankOneSelect;
69 using reference_ops::Relu1;
70 using reference_ops::Relu6;
71 using reference_ops::ReluX;
72 using reference_ops::Select;
73 using reference_ops::SpaceToBatchND;
74 using reference_ops::Split;
75 using reference_ops::TensorFlowSplit;
76 
77 static constexpr int kDepthwiseReverseShift = -1;
78 
79 template <typename Scalar, int N>
MapAsVector(Scalar * data,const Dims<N> & dims)80 VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
81   const int size = FlatSize(dims);
82   return VectorMap<Scalar>(data, size, 1);
83 }
84 
85 template <typename Scalar, int N>
MapAsMatrixWithFirstDimAsRows(Scalar * data,const Dims<N> & dims)86 MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
87                                                 const Dims<N>& dims) {
88   const int rows = dims.sizes[0];
89   int cols = 1;
90   for (int d = 1; d < N; d++) {
91     cols *= dims.sizes[d];
92   }
93   return MatrixMap<Scalar>(data, rows, cols);
94 }
95 
96 template <typename Scalar, int N>
MapAsMatrixWithLastDimAsCols(Scalar * data,const Dims<N> & dims)97 MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
98                                                const Dims<N>& dims) {
99   const int cols = dims.sizes[N - 1];
100   int rows = 1;
101   for (int d = 0; d < N - 1; d++) {
102     rows *= dims.sizes[d];
103   }
104   return MatrixMap<Scalar>(data, rows, cols);
105 }
106 
107 template <typename Scalar, int N>
MapAsArrayWithFirstDimAsRows(Scalar * data,const Dims<N> & dims)108 ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
109                                               const Dims<N>& dims) {
110   const int rows = dims.sizes[0];
111   int cols = 1;
112   for (int d = 1; d < N; d++) {
113     cols *= dims.sizes[d];
114   }
115   return ArrayMap<Scalar>(data, rows, cols);
116 }
117 
118 // TODO(b/62193649): this function is only needed as long
119 // as we have the --variable_batch hack.
120 template <typename Scalar, int N>
MapAsMatrixWithGivenNumberOfRows(Scalar * data,const Dims<N> & dims,int rows)121 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
122                                                    const Dims<N>& dims,
123                                                    int rows) {
124   const int flatsize = FlatSize(dims);
125   TFLITE_DCHECK((flatsize % rows) == 0);
126   const int cols = flatsize / rows;
127   return MatrixMap<Scalar>(data, rows, cols);
128 }
129 
AreSameDims(const Dims<4> & dims1,const Dims<4> & dims2)130 inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
131   for (int i = 0; i < 4; i++) {
132     if (dims1.sizes[i] != dims2.sizes[i]) {
133       return false;
134     }
135   }
136   return true;
137 }
138 
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)139 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
140                           const float* filter_data, const Dims<4>& filter_dims,
141                           const float* bias_data, const Dims<4>& bias_dims,
142                           int stride_width, int stride_height,
143                           int dilation_width_factor, int dilation_height_factor,
144                           int pad_width, int pad_height, int depth_multiplier,
145                           float output_activation_min,
146                           float output_activation_max, float* output_data,
147                           const Dims<4>& output_dims) {
148   tflite::DepthwiseParams op_params;
149   // Padding type is ignored, but still set.
150   op_params.padding_type = PaddingType::kSame;
151   op_params.padding_values.width = pad_width;
152   op_params.padding_values.height = pad_height;
153   op_params.stride_width = stride_width;
154   op_params.stride_height = stride_height;
155   op_params.dilation_width_factor = dilation_width_factor;
156   op_params.dilation_height_factor = dilation_height_factor;
157   op_params.depth_multiplier = depth_multiplier;
158   op_params.float_activation_min = output_activation_min;
159   op_params.float_activation_max = output_activation_max;
160 
161   const RuntimeShape output_shape = DimsToShape(output_dims);
162   const int output_height = output_shape.Dims(1);
163 
164   DepthwiseConvImpl(op_params, DimsToShape(input_dims), input_data,
165                     DimsToShape(filter_dims), filter_data,
166                     DimsToShape(bias_dims), bias_data, output_shape,
167                     output_data, CpuFlags(), /*thread_start=*/0,
168                     /*thread_end=*/output_height, /*thread_dim=*/1);
169 }
170 
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)171 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
172                           const float* filter_data, const Dims<4>& filter_dims,
173                           const float* bias_data, const Dims<4>& bias_dims,
174                           int stride_width, int stride_height, int pad_width,
175                           int pad_height, int depth_multiplier,
176                           float output_activation_min,
177                           float output_activation_max, float* output_data,
178                           const Dims<4>& output_dims) {
179   DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
180                 bias_dims, stride_width, stride_height, 1, 1, pad_width,
181                 pad_height, depth_multiplier, output_activation_min,
182                 output_activation_max, output_data, output_dims);
183 }
184 
185 // legacy, for compatibility with old checked-in code
186 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)187 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
188                    const float* filter_data, const Dims<4>& filter_dims,
189                    const float* bias_data, const Dims<4>& bias_dims,
190                    int stride_width, int stride_height, int pad_width,
191                    int pad_height, int depth_multiplier, float* output_data,
192                    const Dims<4>& output_dims) {
193   float output_activation_min, output_activation_max;
194   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
195   DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
196                 bias_dims, stride_width, stride_height, pad_width, pad_height,
197                 depth_multiplier, output_activation_min, output_activation_max,
198                 output_data, output_dims);
199 }
200 
201 // legacy, for compatibility with old checked-in code
202 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)203 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
204                    const float* filter_data, const Dims<4>& filter_dims,
205                    const float* bias_data, const Dims<4>& bias_dims, int stride,
206                    int pad_width, int pad_height, int depth_multiplier,
207                    float* output_data, const Dims<4>& output_dims) {
208   DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
209                     bias_dims, stride, stride, pad_width, pad_height,
210                     depth_multiplier, output_data, output_dims);
211 }
212 
213 template <DepthwiseConvOutputRounding kOutputRounding>
LegacyDepthwiseConvWithRounding(const DepthwiseParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,int thread_start,int thread_end,int thread_dim)214 inline void LegacyDepthwiseConvWithRounding(
215     const DepthwiseParams& params, const RuntimeShape& input_shape,
216     const uint8* input_data, const RuntimeShape& filter_shape,
217     const uint8* filter_data, const RuntimeShape& bias_shape,
218     const int32* bias_data, const RuntimeShape& output_shape,
219     uint8* output_data, int thread_start, int thread_end, int thread_dim) {
220   ruy::profiler::ScopeLabel label("DepthwiseConv/8bit");
221   const int depth_multiplier = params.depth_multiplier;
222   const int32 output_activation_min = params.quantized_activation_min;
223   const int32 output_activation_max = params.quantized_activation_max;
224   const int dilation_width_factor = params.dilation_width_factor;
225   const int dilation_height_factor = params.dilation_height_factor;
226   TFLITE_DCHECK_GE(dilation_width_factor, 1);
227   TFLITE_DCHECK_GE(dilation_height_factor, 1);
228   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
229   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
230   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
231   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
232   const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
233   const int input_depth = input_shape.Dims(3);
234   TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
235   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
236 
237 // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
238 // Jetson TX-2. This compiler does not support the offsetof() macro.
239 #if defined(__aarch64__) && !defined(GOOGLE_L4T)
240   const int stride_width = params.stride_width;
241   const int stride_height = params.stride_height;
242   const int pad_width = params.padding_values.width;
243   const int pad_height = params.padding_values.height;
244   const int output_shift = params.output_shift;
245 
246   // Call kernel optimized for depthwise convolutions using 3x3 filters if
247   // parameters are supported.
248   if (depthwise_conv::Fast3x3FilterKernelSupported(
249           input_shape, filter_shape, stride_width, stride_height,
250           dilation_width_factor, dilation_height_factor, pad_width, pad_height,
251           depth_multiplier, output_shape, output_shift)) {
252     ruy::profiler::ScopeLabel specialized_label("DepthwiseConv/8bit/3x3");
253     depthwise_conv::DepthwiseConv3x3Filter<kOutputRounding>(
254         params, input_shape, input_data, filter_shape, filter_data, bias_shape,
255         bias_data, output_shape, output_data, thread_start, thread_end,
256         thread_dim);
257     return;
258   }
259 #endif
260 
261   ruy::profiler::ScopeLabel specialized_label("DepthwiseConv/8bit/General");
262   depthwise_conv::DepthwiseConvGeneral(params, input_shape, input_data,
263                                        filter_shape, filter_data, bias_shape,
264                                        bias_data, output_shape, output_data,
265                                        thread_start, thread_end, thread_dim);
266 }
267 
LegacyDepthwiseConvImpl(const DepthwiseParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,int thread_start,int thread_end,int thread_dim)268 inline void LegacyDepthwiseConvImpl(
269     const DepthwiseParams& params, const RuntimeShape& input_shape,
270     const uint8* input_data, const RuntimeShape& filter_shape,
271     const uint8* filter_data, const RuntimeShape& bias_shape,
272     const int32* bias_data, const RuntimeShape& output_shape,
273     uint8* output_data, int thread_start, int thread_end, int thread_dim) {
274   return LegacyDepthwiseConvWithRounding<
275       DepthwiseConvOutputRounding::kAwayFromZero>(
276       params, input_shape, input_data, filter_shape, filter_data, bias_shape,
277       bias_data, output_shape, output_data, thread_start, thread_end,
278       thread_dim);
279 }
280 
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)281 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
282                           int32 input_offset, const uint8* filter_data,
283                           const Dims<4>& filter_dims, int32 filter_offset,
284                           const int32* bias_data, const Dims<4>& bias_dims,
285                           int stride_width, int stride_height,
286                           int dilation_width_factor, int dilation_height_factor,
287                           int pad_width, int pad_height, int depth_multiplier,
288                           int32 output_offset, int32 output_multiplier,
289                           int output_shift, int32 output_activation_min,
290                           int32 output_activation_max, uint8* output_data,
291                           const Dims<4>& output_dims) {
292   tflite::DepthwiseParams op_params;
293   // Padding type is ignored, but still set.
294   op_params.padding_type = PaddingType::kSame;
295   op_params.padding_values.width = pad_width;
296   op_params.padding_values.height = pad_height;
297   op_params.stride_width = stride_width;
298   op_params.stride_height = stride_height;
299   op_params.dilation_width_factor = dilation_width_factor;
300   op_params.dilation_height_factor = dilation_height_factor;
301   op_params.depth_multiplier = depth_multiplier;
302   op_params.quantized_activation_min = output_activation_min;
303   op_params.quantized_activation_max = output_activation_max;
304   op_params.input_offset = input_offset;
305   op_params.weights_offset = filter_offset;
306   op_params.output_offset = output_offset;
307   op_params.output_multiplier = output_multiplier;
308   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
309   op_params.output_shift = kDepthwiseReverseShift * output_shift;
310 
311   const RuntimeShape output_shape = DimsToShape(output_dims);
312   const int output_height = output_shape.Dims(1);
313 
314   LegacyDepthwiseConvImpl(
315       op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
316       filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
317       output_data, /*thread_start=*/0,
318       /*thread_end=*/output_height, /*thread_dim=*/1);
319 }
320 
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)321 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
322                           int32 input_offset, const uint8* filter_data,
323                           const Dims<4>& filter_dims, int32 filter_offset,
324                           const int32* bias_data, const Dims<4>& bias_dims,
325                           int stride_width, int stride_height, int pad_width,
326                           int pad_height, int depth_multiplier,
327                           int32 output_offset, int32 output_multiplier,
328                           int output_shift, int32 output_activation_min,
329                           int32 output_activation_max, uint8* output_data,
330                           const Dims<4>& output_dims) {
331   DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
332                 filter_offset, bias_data, bias_dims, stride_width,
333                 stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
334                 output_offset, output_multiplier, output_shift,
335                 output_activation_min, output_activation_max, output_data,
336                 output_dims);
337 }
338 
339 // Legacy, for compatibility with old checked-in code.
340 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)341 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
342                    int32 input_offset, const uint8* filter_data,
343                    const Dims<4>& filter_dims, int32 filter_offset,
344                    const int32* bias_data, const Dims<4>& bias_dims,
345                    int stride_width, int stride_height, int pad_width,
346                    int pad_height, int depth_multiplier, int32 output_offset,
347                    int32 output_multiplier, int output_shift,
348                    int32 output_activation_min, int32 output_activation_max,
349                    uint8* output_data, const Dims<4>& output_dims) {
350   if (Ac == FusedActivationFunctionType::kNone) {
351     TFLITE_DCHECK_EQ(output_activation_min, 0);
352     TFLITE_DCHECK_EQ(output_activation_max, 255);
353   }
354   DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
355                 filter_offset, bias_data, bias_dims, stride_width,
356                 stride_height, pad_width, pad_height, depth_multiplier,
357                 output_offset, output_multiplier, output_shift,
358                 output_activation_min, output_activation_max, output_data,
359                 output_dims);
360 }
361 
362 // Legacy, for compatibility with old checked-in code.
363 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)364 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
365                    int32 input_offset, const uint8* filter_data,
366                    const Dims<4>& filter_dims, int32 filter_offset,
367                    const int32* bias_data, const Dims<4>& bias_dims, int stride,
368                    int pad_width, int pad_height, int depth_multiplier,
369                    int32 output_offset, int32 output_multiplier,
370                    int output_shift, int32 output_activation_min,
371                    int32 output_activation_max, uint8* output_data,
372                    const Dims<4>& output_dims) {
373   DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
374                     filter_dims, filter_offset, bias_data, bias_dims, stride,
375                     stride, pad_width, pad_height, depth_multiplier,
376                     output_offset, output_multiplier, output_shift,
377                     output_activation_min, output_activation_max, output_data,
378                     output_dims);
379 }
380 
381 template <typename T, typename TS>
382 struct LegacyDepthwiseConvWorkerTask : public gemmlowp::Task {
LegacyDepthwiseConvWorkerTaskLegacyDepthwiseConvWorkerTask383   LegacyDepthwiseConvWorkerTask(
384       const DepthwiseParams& params, const RuntimeShape& input_shape,
385       const T* input_data, const RuntimeShape& filter_shape,
386       const T* filter_data, const RuntimeShape& bias_shape, const TS* bias_data,
387       const RuntimeShape& output_shape, T* output_data, int thread_start,
388       int thread_end, int thread_dim)
389       : params_(params),
390         input_shape_(input_shape),
391         input_data_(input_data),
392         filter_shape_(filter_shape),
393         filter_data_(filter_data),
394         bias_shape_(bias_shape),
395         bias_data_(bias_data),
396         output_shape_(output_shape),
397         output_data_(output_data),
398         thread_start_(thread_start),
399         thread_end_(thread_end),
400         thread_dim_(thread_dim) {}
401 
RunLegacyDepthwiseConvWorkerTask402   void Run() override {
403     LegacyDepthwiseConvImpl(params_, input_shape_, input_data_, filter_shape_,
404                             filter_data_, bias_shape_, bias_data_,
405                             output_shape_, output_data_, thread_start_,
406                             thread_end_, thread_dim_);
407   }
408 
409  private:
410   const DepthwiseParams& params_;
411   const RuntimeShape& input_shape_;
412   const T* input_data_;
413   const RuntimeShape& filter_shape_;
414   const T* filter_data_;
415   const RuntimeShape& bias_shape_;
416   const TS* bias_data_;
417   const RuntimeShape& output_shape_;
418   T* output_data_;
419   int thread_start_;
420   int thread_end_;
421   int thread_dim_;
422 };
423 
HowManyConvThreads(const RuntimeShape & output_shape,const RuntimeShape & filter_shape,int thread_dim)424 inline int HowManyConvThreads(const RuntimeShape& output_shape,
425                               const RuntimeShape& filter_shape,
426                               int thread_dim) {
427   constexpr int kMinMulPerThread = 8;
428   const int output_units = output_shape.Dims(thread_dim);
429   const int filter_height = filter_shape.Dims(1);
430   const int filter_width = filter_shape.Dims(2);
431   const int num_mul_per_unit =
432       FlatSizeSkipDim(output_shape, thread_dim) * filter_height * filter_width;
433   const int min_units_per_thread = kMinMulPerThread / num_mul_per_unit + 1;
434   int thread_count = output_units / min_units_per_thread;
435   return thread_count;
436 }
437 
438 inline void DepthwiseConv(
439     const DepthwiseParams& params, const RuntimeShape& input_shape,
440     const uint8* input_data, const RuntimeShape& filter_shape,
441     const uint8* filter_data, const RuntimeShape& bias_shape,
442     const int32* bias_data, const RuntimeShape& output_shape,
443     uint8* output_data, gemmlowp::GemmContext* gemmlowp_context = nullptr) {
444   ruy::profiler::ScopeLabel label("DepthwiseConv");
445 
446   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
447   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
448   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
449 
450   const int output_batches = output_shape.Dims(0);
451   const int output_rows = output_shape.Dims(1);
452   int thread_count_batch = HowManyConvThreads(output_shape, filter_shape, 0);
453   int thread_count_row = HowManyConvThreads(output_shape, filter_shape, 1);
454   int thread_dim, thread_count, thread_dim_size;
455   if (thread_count_batch > thread_count_row) {
456     thread_dim = 0;
457     thread_dim_size = output_batches;
458     thread_count = thread_count_batch;
459   } else {
460     thread_dim = 1;
461     thread_dim_size = output_rows;
462     thread_count = thread_count_row;
463   }
464 
465   const int max_threads =
466       gemmlowp_context ? gemmlowp_context->max_num_threads() : 1;
467   thread_count = std::max(1, std::min(thread_count, max_threads));
468 
469   if (thread_count == 1) {
470     LegacyDepthwiseConvImpl(params, input_shape, input_data, filter_shape,
471                             filter_data, bias_shape, bias_data, output_shape,
472                             output_data, /*thread_start=*/0,
473                             /*thread_end=*/output_rows, /*thread_dim=*/1);
474   } else {
475     std::vector<gemmlowp::Task*> tasks(thread_count);
476     int thread_start = 0;
477     for (int i = 0; i < thread_count; ++i) {
478       int thread_end =
479           thread_start + (thread_dim_size - thread_start) / (thread_count - i);
480       tasks[i] = new LegacyDepthwiseConvWorkerTask<uint8, int32>(
481           params, input_shape, input_data, filter_shape, filter_data,
482           bias_shape, bias_data, output_shape, output_data, thread_start,
483           thread_end, thread_dim);
484       thread_start = thread_end;
485     }
486     gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
487   }
488 }
489 
490 template <typename T, typename TS>
491 struct LegacyPerChannelDepthwiseConvWorkerTask : public gemmlowp::Task {
LegacyPerChannelDepthwiseConvWorkerTaskLegacyPerChannelDepthwiseConvWorkerTask492   LegacyPerChannelDepthwiseConvWorkerTask(
493       const DepthwiseParams& params, const int32* output_multiplier,
494       const int32* output_shift, const RuntimeShape& input_shape,
495       const T* input_data, const RuntimeShape& filter_shape,
496       const T* filter_data, const RuntimeShape& bias_shape, const TS* bias_data,
497       const RuntimeShape& output_shape, T* output_data, int thread_start,
498       int thread_end, int thread_dim)
499       : params_(params),
500         output_multiplier_(output_multiplier),
501         output_shift_(output_shift),
502         input_shape_(input_shape),
503         input_data_(input_data),
504         filter_shape_(filter_shape),
505         filter_data_(filter_data),
506         bias_shape_(bias_shape),
507         bias_data_(bias_data),
508         output_shape_(output_shape),
509         output_data_(output_data),
510         thread_start_(thread_start),
511         thread_end_(thread_end),
512         thread_dim_(thread_dim) {}
513 
RunLegacyPerChannelDepthwiseConvWorkerTask514   void Run() override {
515     CpuBackendContext backend_context;
516     optimized_integer_ops::DepthwiseConvImpl(
517         params_, output_multiplier_, output_shift_, input_shape_, input_data_,
518         filter_shape_, filter_data_, bias_shape_, bias_data_, output_shape_,
519         output_data_, thread_start_, thread_end_, thread_dim_, backend_context);
520   }
521 
522  private:
523   const DepthwiseParams& params_;
524   const int32* output_multiplier_;
525   const int32* output_shift_;
526   const RuntimeShape& input_shape_;
527   const T* input_data_;
528   const RuntimeShape& filter_shape_;
529   const T* filter_data_;
530   const RuntimeShape& bias_shape_;
531   const TS* bias_data_;
532   const RuntimeShape& output_shape_;
533   T* output_data_;
534   int thread_start_;
535   int thread_end_;
536   int thread_dim_;
537 };
538 
539 inline void DepthwiseConvPerChannel(
540     const DepthwiseParams& params, const int32* output_multiplier,
541     const int32* output_shift, const RuntimeShape& input_shape,
542     const int8* input_data, const RuntimeShape& filter_shape,
543     const int8* filter_data, const RuntimeShape& bias_shape,
544     const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
545     gemmlowp::GemmContext* gemmlowp_context = nullptr) {
546   ruy::profiler::ScopeLabel label("DepthwiseConvInt8");
547 
548   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
549   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
550   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
551 
552   const int output_batches = output_shape.Dims(0);
553   const int output_rows = output_shape.Dims(1);
554   int thread_count_batch = HowManyConvThreads(output_shape, filter_shape, 0);
555   int thread_count_row = HowManyConvThreads(output_shape, filter_shape, 1);
556   int thread_dim, thread_count, thread_dim_size;
557   if (thread_count_batch > thread_count_row) {
558     thread_dim = 0;
559     thread_dim_size = output_batches;
560     thread_count = thread_count_batch;
561   } else {
562     thread_dim = 1;
563     thread_dim_size = output_rows;
564     thread_count = thread_count_row;
565   }
566 
567   const int max_threads =
568       gemmlowp_context ? gemmlowp_context->max_num_threads() : 1;
569   thread_count = std::max(1, std::min(thread_count, max_threads));
570 
571   if (thread_count == 1) {
572     CpuBackendContext backend_context;
573     optimized_integer_ops::DepthwiseConvImpl(
574         params, output_multiplier, output_shift, input_shape, input_data,
575         filter_shape, filter_data, bias_shape, bias_data, output_shape,
576         output_data, /*thread_start=*/0,
577         /*thread_end=*/output_rows, /*thread_dim=*/1, backend_context);
578   } else {
579     std::vector<gemmlowp::Task*> tasks(thread_count);
580     int thread_start = 0;
581     for (int i = 0; i < thread_count; ++i) {
582       int thread_end =
583           thread_start + (thread_dim_size - thread_start) / (thread_count - i);
584       tasks[i] = new LegacyPerChannelDepthwiseConvWorkerTask<int8, int32>(
585           params, output_multiplier, output_shift, input_shape, input_data,
586           filter_shape, filter_data, bias_shape, bias_data, output_shape,
587           output_data, thread_start, thread_end, thread_dim);
588       thread_start = thread_end;
589     }
590     gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
591   }
592 }
593 
DepthwiseConv(const DepthwiseParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data)594 inline void DepthwiseConv(
595     const DepthwiseParams& params, const RuntimeShape& input_shape,
596     const float* input_data, const RuntimeShape& filter_shape,
597     const float* filter_data, const RuntimeShape& bias_shape,
598     const float* bias_data, const RuntimeShape& output_shape,
599     float* output_data) {
600   DepthwiseConvImpl(params, input_shape, input_data, filter_shape, filter_data,
601                     bias_shape, bias_data, output_shape, output_data,
602                     CpuFlags(),
603                     /*thread_start=*/0,
604                     /*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1);
605 }
606 
AddBiasAndEvalActivationFunction(const float * bias_data,const Dims<4> & bias_dims,float * array_data,const Dims<4> & array_dims,float output_activation_min,float output_activation_max)607 inline void AddBiasAndEvalActivationFunction(const float* bias_data,
608                                              const Dims<4>& bias_dims,
609                                              float* array_data,
610                                              const Dims<4>& array_dims,
611                                              float output_activation_min,
612                                              float output_activation_max) {
613   AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
614                                    DimsToShape(bias_dims), bias_data,
615                                    DimsToShape(array_dims), array_data);
616 }
617 
618 // legacy, for compatibility with old checked-in code
619 template <FusedActivationFunctionType Ac>
AddBiasAndEvalActivationFunction(const float * bias_data,const Dims<4> & bias_dims,float * array_data,const Dims<4> & array_dims)620 void AddBiasAndEvalActivationFunction(const float* bias_data,
621                                       const Dims<4>& bias_dims,
622                                       float* array_data,
623                                       const Dims<4>& array_dims) {
624   float output_activation_min, output_activation_max;
625   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
626   AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims,
627                                    output_activation_min,
628                                    output_activation_max);
629 }
630 
631 template <typename Lhs, typename Rhs, typename Result>
Gemm(const Eigen::MatrixBase<Lhs> & lhs,const Eigen::MatrixBase<Rhs> & rhs,Eigen::MatrixBase<Result> * result)632 void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
633           Eigen::MatrixBase<Result>* result) {
634   if (rhs.cols() == 1) {
635     ruy::profiler::ScopeLabel label("GEMV");
636     result->col(0).noalias() = lhs * rhs.col(0);
637   } else {
638     ruy::profiler::ScopeLabel label("GEMM");
639     result->noalias() = lhs * rhs;
640   }
641 }
642 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * optional_bias_data,const RuntimeShape & output_shape,float * output_data)643 inline void FullyConnected(
644     const FullyConnectedParams& params, const RuntimeShape& input_shape,
645     const float* input_data, const RuntimeShape& weights_shape,
646     const float* weights_data, const RuntimeShape& bias_shape,
647     const float* optional_bias_data, const RuntimeShape& output_shape,
648     float* output_data) {
649   ruy::profiler::ScopeLabel label("FullyConnected");
650   const float output_activation_min = params.float_activation_min;
651   const float output_activation_max = params.float_activation_max;
652 
653   // TODO(b/62193649): this convoluted shape computation (determining
654   // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows)
655   // is because the current --variable_batch hack consists in overwriting the
656   // 3rd dimension with the runtime batch size, as we don't keep track for each
657   // array of which dimension is the batch dimension in it.
658   // When that is fixed, this should become:
659   // const auto input_matrix_map =
660   //     MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
661   const int dims_count = weights_shape.DimensionsCount();
662   const int input_rows = weights_shape.Dims(dims_count - 1);
663   const auto input_matrix_map =
664       MapAsMatrixWithGivenNumberOfRows(input_data, input_shape, input_rows);
665   const auto filter_matrix_map =
666       MapAsMatrixWithLastDimAsRows(weights_data, weights_shape);
667   auto output_matrix_map =
668       MapAsMatrixWithLastDimAsRows(output_data, output_shape);
669 
670   Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
671 
672   if (optional_bias_data != nullptr) {
673     AddBiasAndEvalActivationFunction(
674         output_activation_min, output_activation_max, bias_shape,
675         optional_bias_data, output_shape, output_data);
676   } else {
677     const int flat_size = output_shape.FlatSize();
678     for (int i = 0; i < flat_size; ++i) {
679       output_data[i] = ActivationFunctionWithMinMax(
680           output_data[i], output_activation_min, output_activation_max);
681     }
682   }
683 }
684 
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)685 inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
686                            const float* weights_data,
687                            const Dims<4>& weights_dims, const float* bias_data,
688                            const Dims<4>& bias_dims,
689                            float output_activation_min,
690                            float output_activation_max, float* output_data,
691                            const Dims<4>& output_dims) {
692   tflite::FullyConnectedParams op_params;
693   op_params.float_activation_min = output_activation_min;
694   op_params.float_activation_max = output_activation_max;
695 
696   FullyConnected(op_params, DimsToShape(input_dims), input_data,
697                  DimsToShape(weights_dims), weights_data,
698                  DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
699                  output_data);
700 }
701 
702 // legacy, for compatibility with old checked-in code
703 template <FusedActivationFunctionType Ac>
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)704 void FullyConnected(const float* input_data, const Dims<4>& input_dims,
705                     const float* weights_data, const Dims<4>& weights_dims,
706                     const float* bias_data, const Dims<4>& bias_dims,
707                     float* output_data, const Dims<4>& output_dims) {
708   float output_activation_min, output_activation_max;
709   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
710   FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
711                  bias_dims, output_activation_min, output_activation_max,
712                  output_data, output_dims);
713 }
714 
715 struct GemmlowpOutputPipeline {
716   typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
717       ColVectorMap;
718   typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
719                      gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
720                      gemmlowp::OutputStageClamp,
721                      gemmlowp::OutputStageSaturatingCastToUint8>
722       Pipeline;
MakeExpGemmlowpOutputPipeline723   static Pipeline MakeExp(const int32* bias_data, int output_rows,
724                           int32 output_offset, int32 output_multiplier,
725                           int output_left_shift, int32 output_activation_min,
726                           int32 output_activation_max) {
727     ColVectorMap bias_vector(bias_data, output_rows);
728     gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
729     bias_addition_stage.bias_vector = bias_vector;
730     gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
731     quantize_down_stage.result_offset_after_shift = output_offset;
732     quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
733     quantize_down_stage.result_exponent = output_left_shift;
734     gemmlowp::OutputStageClamp clamp_stage;
735     clamp_stage.min = output_activation_min;
736     clamp_stage.max = output_activation_max;
737     gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage;
738     return std::make_tuple(bias_addition_stage, quantize_down_stage,
739                            clamp_stage, saturating_cast_stage);
740   }
741 };
742 
743 struct GemmlowpOutputPipelineInt8 {
744   typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
745       ColVectorMap;
746   typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
747                      gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
748                      gemmlowp::OutputStageClamp,
749                      gemmlowp::OutputStageSaturatingCastToInt8>
750       Pipeline;
MakeExpGemmlowpOutputPipelineInt8751   static Pipeline MakeExp(const int32* bias_data, int output_rows,
752                           int32 output_offset, int32 output_multiplier,
753                           int output_left_shift, int32 output_activation_min,
754                           int32 output_activation_max) {
755     ColVectorMap bias_vector(bias_data, output_rows);
756     gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
757     bias_addition_stage.bias_vector = bias_vector;
758     gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
759     quantize_down_stage.result_offset_after_shift = output_offset;
760     quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
761     quantize_down_stage.result_exponent = output_left_shift;
762     gemmlowp::OutputStageClamp clamp_stage;
763     clamp_stage.min = output_activation_min;
764     clamp_stage.max = output_activation_max;
765     gemmlowp::OutputStageSaturatingCastToInt8 saturating_cast_stage;
766     return std::make_tuple(bias_addition_stage, quantize_down_stage,
767                            clamp_stage, saturating_cast_stage);
768   }
769 };
770 
771 #ifdef USE_NEON
LegacyFullyConnectedAsGEMVWorkerImpl(const RuntimeShape & input_shape,const uint8 * input_data,int32 input_offset,const RuntimeShape & filter_shape,const uint8 * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,uint8 * output_data,int row_start,int row_end)772 inline void LegacyFullyConnectedAsGEMVWorkerImpl(
773     const RuntimeShape& input_shape, const uint8* input_data,
774     int32 input_offset, const RuntimeShape& filter_shape,
775     const uint8* filter_data, int32 filter_offset,
776     const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
777     int32 output_multiplier, int output_shift, int32 output_activation_min,
778     int32 output_activation_max, const RuntimeShape& output_shape,
779     uint8* output_data, int row_start, int row_end) {
780   ruy::profiler::ScopeLabel label("FullyConnectedAsGEMV/8bit");
781   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
782   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
783   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
784   const int output_dim_count = output_shape.DimensionsCount();
785   TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
786   const int input_size = FlatSizeSkipDim(input_shape, 0);
787   static constexpr int kPeel = 4;
788   const bool shift_left = (output_shift > 0);
789   for (int k = 0; k < input_size; k += 64) {
790     optimized_ops_preload_l1_stream(input_data + k);
791   }
792   for (int k = 0; k < kPeel * input_size; k += 64) {
793     optimized_ops_preload_l1_stream(filter_data + k);
794   }
795 
796   TFLITE_DCHECK_GE(row_end - row_start, kPeel);
797 
798   for (int out = row_start; out < row_end; out += kPeel) {
799     out = std::min(out, row_end - kPeel);
800     int32x4_t acc0 = vdupq_n_s32(0);
801     int32x4_t acc1 = acc0;
802     int32x4_t acc2 = acc0;
803     int32x4_t acc3 = acc0;
804     const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
805     const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
806     int in = 0;
807     for (; in <= input_size - 16; in += 16) {
808       const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
809       const uint8* filter_ptr = filter_data + in + out * input_size;
810       uint8x16_t filter_val_u8_0 = vld1q_u8(filter_ptr);
811       optimized_ops_preload_l1_stream(filter_ptr + 64);
812       filter_ptr += input_size;
813       uint8x16_t filter_val_u8_1 = vld1q_u8(filter_ptr);
814       optimized_ops_preload_l1_stream(filter_ptr + 64);
815       filter_ptr += input_size;
816       uint8x16_t filter_val_u8_2 = vld1q_u8(filter_ptr);
817       optimized_ops_preload_l1_stream(filter_ptr + 64);
818       filter_ptr += input_size;
819       uint8x16_t filter_val_u8_3 = vld1q_u8(filter_ptr);
820       optimized_ops_preload_l1_stream(filter_ptr + 64);
821       int16x8_t input_val_0, input_val_1;
822       uint8x8_t low = vget_low_u8(input_val_u8);
823       uint8x8_t high = vget_high_u8(input_val_u8);
824       input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low));
825       input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high));
826       input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
827       input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
828       low = vget_low_u8(filter_val_u8_0);
829       high = vget_high_u8(filter_val_u8_0);
830       int16x8_t filter_val_0_0 = vreinterpretq_s16_u16(vmovl_u8(low));
831       int16x8_t filter_val_0_1 = vreinterpretq_s16_u16(vmovl_u8(high));
832       filter_val_0_0 = vaddq_s16(filter_val_0_0, filter_offset_vec);
833       filter_val_0_1 = vaddq_s16(filter_val_0_1, filter_offset_vec);
834       low = vget_low_u8(filter_val_u8_1);
835       high = vget_high_u8(filter_val_u8_1);
836       int16x8_t filter_val_1_0 = vreinterpretq_s16_u16(vmovl_u8(low));
837       int16x8_t filter_val_1_1 = vreinterpretq_s16_u16(vmovl_u8(high));
838       filter_val_1_0 = vaddq_s16(filter_val_1_0, filter_offset_vec);
839       filter_val_1_1 = vaddq_s16(filter_val_1_1, filter_offset_vec);
840       low = vget_low_u8(filter_val_u8_2);
841       high = vget_high_u8(filter_val_u8_2);
842       int16x8_t filter_val_2_0 = vreinterpretq_s16_u16(vmovl_u8(low));
843       int16x8_t filter_val_2_1 = vreinterpretq_s16_u16(vmovl_u8(high));
844       filter_val_2_0 = vaddq_s16(filter_val_2_0, filter_offset_vec);
845       filter_val_2_1 = vaddq_s16(filter_val_2_1, filter_offset_vec);
846       low = vget_low_u8(filter_val_u8_3);
847       high = vget_high_u8(filter_val_u8_3);
848       int16x8_t filter_val_3_0 = vreinterpretq_s16_u16(vmovl_u8(low));
849       int16x8_t filter_val_3_1 = vreinterpretq_s16_u16(vmovl_u8(high));
850       filter_val_3_0 = vaddq_s16(filter_val_3_0, filter_offset_vec);
851       filter_val_3_1 = vaddq_s16(filter_val_3_1, filter_offset_vec);
852       acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_0),
853                        vget_low_s16(input_val_0));
854       acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_0),
855                        vget_low_s16(input_val_0));
856       acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_0),
857                        vget_low_s16(input_val_0));
858       acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_0),
859                        vget_low_s16(input_val_0));
860       acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_1),
861                        vget_low_s16(input_val_1));
862       acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_1),
863                        vget_low_s16(input_val_1));
864       acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_1),
865                        vget_low_s16(input_val_1));
866       acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_1),
867                        vget_low_s16(input_val_1));
868       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_0),
869                        vget_high_s16(input_val_0));
870       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_0),
871                        vget_high_s16(input_val_0));
872       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_0),
873                        vget_high_s16(input_val_0));
874       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_0),
875                        vget_high_s16(input_val_0));
876       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_1),
877                        vget_high_s16(input_val_1));
878       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_1),
879                        vget_high_s16(input_val_1));
880       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_1),
881                        vget_high_s16(input_val_1));
882       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_1),
883                        vget_high_s16(input_val_1));
884     }
885     for (; in <= input_size - 8; in += 8) {
886       const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
887       const uint8* filter_ptr = filter_data + in + out * input_size;
888       uint8x8_t filter_val_u8_0 = vld1_u8(filter_ptr);
889       filter_ptr += input_size;
890       uint8x8_t filter_val_u8_1 = vld1_u8(filter_ptr);
891       filter_ptr += input_size;
892       uint8x8_t filter_val_u8_2 = vld1_u8(filter_ptr);
893       filter_ptr += input_size;
894       uint8x8_t filter_val_u8_3 = vld1_u8(filter_ptr);
895       int16x8_t input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
896       input_val = vaddq_s16(input_val, input_offset_vec);
897       int16x8_t filter_val_0 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_0));
898       filter_val_0 = vaddq_s16(filter_val_0, filter_offset_vec);
899       int16x8_t filter_val_1 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_1));
900       filter_val_1 = vaddq_s16(filter_val_1, filter_offset_vec);
901       int16x8_t filter_val_2 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_2));
902       filter_val_2 = vaddq_s16(filter_val_2, filter_offset_vec);
903       int16x8_t filter_val_3 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_3));
904       filter_val_3 = vaddq_s16(filter_val_3, filter_offset_vec);
905       acc0 =
906           vmlal_s16(acc0, vget_low_s16(filter_val_0), vget_low_s16(input_val));
907       acc1 =
908           vmlal_s16(acc1, vget_low_s16(filter_val_1), vget_low_s16(input_val));
909       acc2 =
910           vmlal_s16(acc2, vget_low_s16(filter_val_2), vget_low_s16(input_val));
911       acc3 =
912           vmlal_s16(acc3, vget_low_s16(filter_val_3), vget_low_s16(input_val));
913       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
914                        vget_high_s16(input_val));
915       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
916                        vget_high_s16(input_val));
917       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
918                        vget_high_s16(input_val));
919       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
920                        vget_high_s16(input_val));
921     }
922     if (in < input_size) {
923       int32 buf[16];
924       vst1q_s32(buf + 0, acc0);
925       vst1q_s32(buf + 4, acc1);
926       vst1q_s32(buf + 8, acc2);
927       vst1q_s32(buf + 12, acc3);
928       for (; in < input_size; in++) {
929         int lane = (in + 8 - input_size) % 4;
930         const int32 input_val = input_data[in] + input_offset;
931         for (int k = 0; k < kPeel; k++) {
932           int32 filter_val =
933               filter_data[in + (out + k) * input_size] + filter_offset;
934           buf[lane + 4 * k] += filter_val * input_val;
935         }
936       }
937       acc0 = vld1q_s32(buf + 0);
938       acc1 = vld1q_s32(buf + 4);
939       acc2 = vld1q_s32(buf + 8);
940       acc3 = vld1q_s32(buf + 12);
941     }
942 
943     // Horizontally reduce accumulators
944     int32x2_t pairwise_reduced_acc_0 =
945         vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
946     int32x2_t pairwise_reduced_acc_1 =
947         vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1));
948     int32x2_t pairwise_reduced_acc_2 =
949         vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2));
950     int32x2_t pairwise_reduced_acc_3 =
951         vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3));
952     const int32x2_t reduced_lo =
953         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
954     const int32x2_t reduced_hi =
955         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
956     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
957     // Add bias values.
958     int32x4_t bias_vec = vld1q_s32(bias_data + out);
959     reduced = vaddq_s32(reduced, bias_vec);
960     if (shift_left) {
961       const int32 multiplier_power_of_two = 1 << output_shift;
962       reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
963       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
964     } else {
965       // Multiply by the fixed-point multiplier.
966       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
967       // Rounding-shift-right.
968       using gemmlowp::RoundingDivideByPOT;
969       reduced = RoundingDivideByPOT(reduced, -output_shift);
970     }
971     // Add the output offset.
972     const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
973     reduced = vaddq_s32(reduced, output_offset_vec);
974     // Narrow values down to 16 bit signed.
975     const int16x4_t res16 = vqmovn_s32(reduced);
976     // Narrow values down to 8 bit unsigned, saturating.
977     uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16));
978     // Apply the clamping from the activation function
979     res8 = vmax_u8(res8, vdup_n_u8(output_activation_min));
980     res8 = vmin_u8(res8, vdup_n_u8(output_activation_max));
981     // Store results to destination.
982     vst1_lane_u8(output_data + out + 0, res8, 0);
983     vst1_lane_u8(output_data + out + 1, res8, 1);
984     vst1_lane_u8(output_data + out + 2, res8, 2);
985     vst1_lane_u8(output_data + out + 3, res8, 3);
986   }
987 }
988 
989 struct LegacyFullyConnectedAsGEMVWorkerTask : public gemmlowp::Task {
LegacyFullyConnectedAsGEMVWorkerTaskLegacyFullyConnectedAsGEMVWorkerTask990   LegacyFullyConnectedAsGEMVWorkerTask(
991       const RuntimeShape& input_shape, const uint8* input_data,
992       int32 input_offset, const RuntimeShape& filter_shape,
993       const uint8* filter_data, int32 filter_offset,
994       const RuntimeShape& bias_shape, const int32* bias_data,
995       int32 output_offset, int32 output_multiplier, int output_shift,
996       int32 output_activation_min, int32 output_activation_max,
997       const RuntimeShape& output_shape, uint8* output_data, int row_start,
998       int row_end)
999       : input_shape_(input_shape),
1000         input_data_(input_data),
1001         input_offset_(input_offset),
1002         filter_shape_(filter_shape),
1003         filter_data_(filter_data),
1004         filter_offset_(filter_offset),
1005         bias_shape_(bias_shape),
1006         bias_data_(bias_data),
1007         output_offset_(output_offset),
1008         output_multiplier_(output_multiplier),
1009         output_shift_(output_shift),
1010         output_activation_min_(output_activation_min),
1011         output_activation_max_(output_activation_max),
1012         output_shape_(output_shape),
1013         output_data_(output_data),
1014         row_start_(row_start),
1015         row_end_(row_end) {}
1016 
RunLegacyFullyConnectedAsGEMVWorkerTask1017   void Run() override {
1018     LegacyFullyConnectedAsGEMVWorkerImpl(
1019         input_shape_, input_data_, input_offset_, filter_shape_, filter_data_,
1020         filter_offset_, bias_shape_, bias_data_, output_offset_,
1021         output_multiplier_, output_shift_, output_activation_min_,
1022         output_activation_max_, output_shape_, output_data_, row_start_,
1023         row_end_);
1024   }
1025 
1026   const RuntimeShape& input_shape_;
1027   const uint8* input_data_;
1028   int32 input_offset_;
1029   const RuntimeShape& filter_shape_;
1030   const uint8* filter_data_;
1031   int32 filter_offset_;
1032   const RuntimeShape& bias_shape_;
1033   const int32* bias_data_;
1034   int32 output_offset_;
1035   int32 output_multiplier_;
1036   int output_shift_;
1037   int32 output_activation_min_;
1038   int32 output_activation_max_;
1039   const RuntimeShape& output_shape_;
1040   uint8* output_data_;
1041   int row_start_;
1042   int row_end_;
1043 };
1044 
FullyConnectedAsGEMV(const RuntimeShape & input_shape,const uint8 * input_data,int32 input_offset,const RuntimeShape & filter_shape,const uint8 * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext * gemmlowp_context)1045 inline void FullyConnectedAsGEMV(
1046     const RuntimeShape& input_shape, const uint8* input_data,
1047     int32 input_offset, const RuntimeShape& filter_shape,
1048     const uint8* filter_data, int32 filter_offset,
1049     const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
1050     int32 output_multiplier, int output_shift, int32 output_activation_min,
1051     int32 output_activation_max, const RuntimeShape& output_shape,
1052     uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
1053   const int output_dim_count = output_shape.DimensionsCount();
1054   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1055   const int output_rows = output_shape.Dims(output_dim_count - 1);
1056   const int input_size = FlatSizeSkipDim(input_shape, 0);
1057   static constexpr int kKernelRows = 4;
1058   const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
1059       gemmlowp_context->max_num_threads(), output_rows, batches, input_size);
1060   if (thread_count == 1) {
1061     // Single-thread case: do the computation on the current thread, don't
1062     // use a threadpool
1063     LegacyFullyConnectedAsGEMVWorkerImpl(
1064         input_shape, input_data, input_offset, filter_shape, filter_data,
1065         filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
1066         output_shift, output_activation_min, output_activation_max,
1067         output_shape, output_data, 0, output_rows);
1068     return;
1069   }
1070 
1071   // Multi-threaded case: use the gemmlowp context's threadpool.
1072   TFLITE_DCHECK_GT(thread_count, 1);
1073   std::vector<gemmlowp::Task*> tasks(thread_count);
1074   const int kRowsPerWorker = gemmlowp::RoundUp<kKernelRows>(
1075       gemmlowp::CeilQuotient(output_rows, thread_count));
1076   int row_start = 0;
1077   for (int i = 0; i < thread_count; ++i) {
1078     int row_end = std::min(output_rows, row_start + kRowsPerWorker);
1079     tasks[i] = new LegacyFullyConnectedAsGEMVWorkerTask(
1080         input_shape, input_data, input_offset, filter_shape, filter_data,
1081         filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
1082         output_shift, output_activation_min, output_activation_max,
1083         output_shape, output_data, row_start, row_end);
1084     row_start = row_end;
1085   }
1086   TFLITE_DCHECK_EQ(row_start, output_rows);
1087   gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
1088 }
1089 #endif  // USE_NEON
1090 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext * gemmlowp_context)1091 inline void FullyConnected(
1092     const FullyConnectedParams& params, const RuntimeShape& input_shape,
1093     const uint8* input_data, const RuntimeShape& filter_shape,
1094     const uint8* filter_data, const RuntimeShape& bias_shape,
1095     const int32* bias_data, const RuntimeShape& output_shape,
1096     uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
1097   ruy::profiler::ScopeLabel label("FullyConnected/8bit");
1098   const int32 input_offset = params.input_offset;
1099   const int32 filter_offset = params.weights_offset;
1100   const int32 output_offset = params.output_offset;
1101   const int32 output_multiplier = params.output_multiplier;
1102   const int output_shift = params.output_shift;
1103   const int32 output_activation_min = params.quantized_activation_min;
1104   const int32 output_activation_max = params.quantized_activation_max;
1105   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1106   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1107   // TODO(b/62193649): This really should be:
1108   //     const int batches = ArraySize(output_dims, 1);
1109   // but the current --variable_batch hack consists in overwriting the 3rd
1110   // dimension with the runtime batch size, as we don't keep track for each
1111   // array of which dimension is the batch dimension in it.
1112   const int output_dim_count = output_shape.DimensionsCount();
1113   const int filter_dim_count = filter_shape.DimensionsCount();
1114   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1115 #ifdef USE_NEON
1116   if (batches == 1) {
1117     const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
1118                                         output_shape, output_dim_count - 1);
1119     if (output_size >= 4) {
1120       return FullyConnectedAsGEMV(
1121           input_shape, input_data, input_offset, filter_shape, filter_data,
1122           filter_offset, bias_shape, bias_data, output_offset,
1123           output_multiplier, output_shift, output_activation_min,
1124           output_activation_max, output_shape, output_data, gemmlowp_context);
1125     }
1126   }
1127 #endif  // USE_NEON
1128   const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
1129   const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
1130   TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
1131   const int output_rows = output_shape.Dims(output_dim_count - 1);
1132   TFLITE_DCHECK_EQ(output_rows, filter_rows);
1133   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1134 
1135   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
1136       filter_data, output_rows, filter_cols, filter_cols);
1137   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
1138       input_data, filter_cols, batches, filter_cols);
1139   gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
1140       output_data, output_rows, batches, output_rows);
1141   const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
1142       bias_data, output_rows, output_offset, output_multiplier, output_shift,
1143       output_activation_min, output_activation_max);
1144   gemmlowp::GemmWithOutputPipeline<uint8, uint8,
1145                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
1146       gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
1147       filter_offset, input_offset, output_pipeline);
1148 }
1149 
1150 #ifdef GEMMLOWP_NEON
1151 // In the common case of batch size 1, a fully-connected node degenerates
1152 // to a matrix*vector product. LSTM cells contain a fully-connected node;
1153 // when quantized, this becomes a special type of GEMV operation where
1154 // the output is 16bit-quantized, thus needs its own special path.
GEMVForLstmCell(const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * weights_data,uint8 weights_zero_point,const RuntimeShape & bias_shape,const int32 * bias_data,int32 accum_multiplier,int accum_shift,const RuntimeShape & output_shape,int16 * output_data)1155 inline void GEMVForLstmCell(const RuntimeShape& input_shape,
1156                             const uint8* input_data,
1157                             const RuntimeShape& weights_shape,
1158                             const uint8* weights_data, uint8 weights_zero_point,
1159                             const RuntimeShape& bias_shape,
1160                             const int32* bias_data, int32 accum_multiplier,
1161                             int accum_shift, const RuntimeShape& output_shape,
1162                             int16* output_data) {
1163   ruy::profiler::ScopeLabel label("GEMVForLstmCell");
1164   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1165   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
1166   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1167   const int output_dim_count = output_shape.DimensionsCount();
1168   const int weights_dim_count = weights_shape.DimensionsCount();
1169   TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
1170   const int input_size = FlatSizeSkipDim(input_shape, 0);
1171   const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
1172                                       output_shape, output_dim_count - 1);
1173   // This special fast path for quantized LSTM cells does not try to support
1174   // odd sizes that we haven't encountered in any LSTM cell, that would
1175   // require special code (that would go untested until any LSTM cell
1176   // exercises it). We just guard our assumptions about size evenness with
1177   // the following assertions.
1178   TFLITE_DCHECK(!(output_size % 4));
1179   TFLITE_DCHECK(!(input_size % 8));
1180   const int32* bias_ptr = bias_data;
1181   int16* output_ptr = output_data;
1182   for (int out = 0; out < output_size; out += 4) {
1183     int32x4_t acc_0 = vdupq_n_s32(0);
1184     int32x4_t acc_1 = vdupq_n_s32(0);
1185     int32x4_t acc_2 = vdupq_n_s32(0);
1186     int32x4_t acc_3 = vdupq_n_s32(0);
1187     const int16x8_t input_offset_vec = vdupq_n_s16(-128);
1188     const int16x8_t weights_offset_vec = vdupq_n_s16(-weights_zero_point);
1189     int in = 0;
1190     // Handle 16 levels of depth at a time.
1191     for (; in <= input_size - 16; in += 16) {
1192       const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
1193       const uint8* weights_ptr = weights_data + in + out * input_size;
1194       uint8x16_t weights_val_u8_0 = vld1q_u8(weights_ptr + 0 * input_size);
1195       uint8x16_t weights_val_u8_1 = vld1q_u8(weights_ptr + 1 * input_size);
1196       uint8x16_t weights_val_u8_2 = vld1q_u8(weights_ptr + 2 * input_size);
1197       uint8x16_t weights_val_u8_3 = vld1q_u8(weights_ptr + 3 * input_size);
1198       int16x8_t input_val_0, input_val_1;
1199       const uint8x8_t low = vget_low_u8(input_val_u8);
1200       const uint8x8_t high = vget_high_u8(input_val_u8);
1201       input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low));
1202       input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high));
1203       input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
1204       input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
1205       int16x8_t weights_val_0_0, weights_val_1_0, weights_val_2_0,
1206           weights_val_3_0;
1207       int16x8_t weights_val_0_1, weights_val_1_1, weights_val_2_1,
1208           weights_val_3_1;
1209       weights_val_0_0 = vaddq_s16(
1210           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_0))),
1211           weights_offset_vec);
1212       weights_val_0_1 = vaddq_s16(
1213           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_0))),
1214           weights_offset_vec);
1215       weights_val_1_0 = vaddq_s16(
1216           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_1))),
1217           weights_offset_vec);
1218       weights_val_1_1 = vaddq_s16(
1219           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_1))),
1220           weights_offset_vec);
1221       weights_val_2_0 = vaddq_s16(
1222           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_2))),
1223           weights_offset_vec);
1224       weights_val_2_1 = vaddq_s16(
1225           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_2))),
1226           weights_offset_vec);
1227       weights_val_3_0 = vaddq_s16(
1228           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_3))),
1229           weights_offset_vec);
1230       weights_val_3_1 = vaddq_s16(
1231           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_3))),
1232           weights_offset_vec);
1233       acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_0),
1234                         vget_low_s16(input_val_0));
1235       acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_0),
1236                         vget_low_s16(input_val_0));
1237       acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_0),
1238                         vget_low_s16(input_val_0));
1239       acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_0),
1240                         vget_low_s16(input_val_0));
1241       acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_0),
1242                         vget_high_s16(input_val_0));
1243       acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_0),
1244                         vget_high_s16(input_val_0));
1245       acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_0),
1246                         vget_high_s16(input_val_0));
1247       acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_0),
1248                         vget_high_s16(input_val_0));
1249       acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_1),
1250                         vget_low_s16(input_val_1));
1251       acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_1),
1252                         vget_low_s16(input_val_1));
1253       acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_1),
1254                         vget_low_s16(input_val_1));
1255       acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_1),
1256                         vget_low_s16(input_val_1));
1257       acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_1),
1258                         vget_high_s16(input_val_1));
1259       acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_1),
1260                         vget_high_s16(input_val_1));
1261       acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_1),
1262                         vget_high_s16(input_val_1));
1263       acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_1),
1264                         vget_high_s16(input_val_1));
1265     }
1266     // Handle 8 levels of depth at a time.
1267     for (; in < input_size; in += 8) {
1268       const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
1269       const uint8* weights_ptr = weights_data + in + out * input_size;
1270       uint8x8_t weights_val_u8_0 = vld1_u8(weights_ptr + 0 * input_size);
1271       uint8x8_t weights_val_u8_1 = vld1_u8(weights_ptr + 1 * input_size);
1272       uint8x8_t weights_val_u8_2 = vld1_u8(weights_ptr + 2 * input_size);
1273       uint8x8_t weights_val_u8_3 = vld1_u8(weights_ptr + 3 * input_size);
1274       int16x8_t input_val;
1275       input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
1276       input_val = vaddq_s16(input_val, input_offset_vec);
1277       int16x8_t weights_val_0, weights_val_1, weights_val_2, weights_val_3;
1278       weights_val_0 =
1279           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_0)),
1280                     weights_offset_vec);
1281       weights_val_1 =
1282           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_1)),
1283                     weights_offset_vec);
1284       weights_val_2 =
1285           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_2)),
1286                     weights_offset_vec);
1287       weights_val_3 =
1288           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_3)),
1289                     weights_offset_vec);
1290       acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0),
1291                         vget_low_s16(input_val));
1292       acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1),
1293                         vget_low_s16(input_val));
1294       acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2),
1295                         vget_low_s16(input_val));
1296       acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3),
1297                         vget_low_s16(input_val));
1298       acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0),
1299                         vget_high_s16(input_val));
1300       acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1),
1301                         vget_high_s16(input_val));
1302       acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2),
1303                         vget_high_s16(input_val));
1304       acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3),
1305                         vget_high_s16(input_val));
1306     }
1307     // Horizontally reduce accumulators
1308     int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
1309         pairwise_reduced_acc_2, pairwise_reduced_acc_3;
1310     pairwise_reduced_acc_0 =
1311         vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0));
1312     pairwise_reduced_acc_1 =
1313         vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1));
1314     pairwise_reduced_acc_2 =
1315         vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2));
1316     pairwise_reduced_acc_3 =
1317         vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3));
1318     const int32x2_t reduced_lo =
1319         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1320     const int32x2_t reduced_hi =
1321         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1322     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1323     // Add bias values.
1324     int32x4_t bias_vec = vld1q_s32(bias_ptr);
1325     bias_ptr += 4;
1326     reduced = vaddq_s32(reduced, bias_vec);
1327     int left_shift = accum_shift > 0 ? accum_shift : 0;
1328     int right_shift = accum_shift > 0 ? 0 : -accum_shift;
1329     reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
1330     // Multiply by the fixed-point multiplier.
1331     reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
1332     // Rounding-shift-right.
1333     using gemmlowp::RoundingDivideByPOT;
1334     reduced = RoundingDivideByPOT(reduced, right_shift);
1335     // Narrow values down to 16 bit signed.
1336     const int16x4_t res16 = vqmovn_s32(reduced);
1337     vst1_s16(output_ptr, res16);
1338     output_ptr += 4;
1339   }
1340 }
1341 #endif
1342 
1343 #ifdef GEMMLOWP_NEON
GEMVForLstmCellWithSymmetricRange(const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,int32 accum_multiplier,int accum_shift,const RuntimeShape & output_shape,int16 * output_data)1344 inline void GEMVForLstmCellWithSymmetricRange(
1345     const RuntimeShape& input_shape, const uint8* input_data,
1346     const RuntimeShape& weights_shape, const uint8* weights_data,
1347     const RuntimeShape& bias_shape, const int32* bias_data,
1348     int32 accum_multiplier, int accum_shift, const RuntimeShape& output_shape,
1349     int16* output_data) {
1350   ruy::profiler::ScopeLabel label("GEMVForLstmCellWithSymmetricRange");
1351   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1352   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
1353   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1354   const int output_dim_count = output_shape.DimensionsCount();
1355   const int weights_dim_count = weights_shape.DimensionsCount();
1356   TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
1357   const int input_size = FlatSizeSkipDim(input_shape, 0);
1358   const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
1359                                       output_shape, output_dim_count - 1);
1360   // This special fast path for quantized LSTM cells does not try to support
1361   // odd sizes that we haven't encountered in any LSTM cell, that would
1362   // require special code (that would go untested until any LSTM cell
1363   // exercises it). We just guard our assumptions about size evenness with
1364   // the following assertions.
1365   TFLITE_DCHECK(!(output_size % 4));
1366   TFLITE_DCHECK(!(input_size % 64));
1367   const int32* bias_ptr = bias_data;
1368   int16* output_ptr = output_data;
1369   const uint8x16_t signbit = vdupq_n_u8(0x80);
1370   for (int in = 0; in < input_size; in += 32) {
1371     optimized_ops_preload_l1_keep(input_data + in);
1372   }
1373   const int left_shift = accum_shift > 0 ? accum_shift : 0;
1374   const int right_shift = accum_shift > 0 ? 0 : -accum_shift;
1375   for (int out = 0; out < output_size; out += 4) {
1376     // Load the bias values
1377     int32x4_t bias_vec = vld1q_s32(bias_ptr);
1378     bias_ptr += 4;
1379 
1380     // Clear accumulators. We use 2 accumulator registers per row,
1381     // for 4 rows. row_accumRN is the N-th accumulator for row R.
1382     int32x4_t row_accum00 = vdupq_n_s32(0);
1383     int32x4_t row_accum01 = vdupq_n_s32(0);
1384     int32x4_t row_accum10 = vdupq_n_s32(0);
1385     int32x4_t row_accum11 = vdupq_n_s32(0);
1386     int32x4_t row_accum20 = vdupq_n_s32(0);
1387     int32x4_t row_accum21 = vdupq_n_s32(0);
1388     int32x4_t row_accum30 = vdupq_n_s32(0);
1389     int32x4_t row_accum31 = vdupq_n_s32(0);
1390 
1391     // kReadAhead parametrizes how far ahead we prefetch weights into L1 cache.
1392     const int kReadAhead = 512;
1393     // Prefetch the first weights values.
1394     for (int k = 0; k < kReadAhead; k += 64) {
1395       optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
1396                                       k);
1397       optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
1398                                       k);
1399       optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
1400                                       k);
1401       optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
1402                                       k);
1403     }
1404     // Loop along the rows, handling 64 bytes per iteration because that's
1405     // cache line size on most current ARM-architecture CPUs.
1406     for (int in = 0; in < input_size; in += 64) {
1407       // Prefetch some future weights values.
1408       optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
1409                                       in + kReadAhead);
1410       optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
1411                                       in + kReadAhead);
1412       optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
1413                                       in + kReadAhead);
1414       optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
1415                                       in + kReadAhead);
1416 
1417       // We will use 2 local 16-bit accumulators per row, for 2 rows.
1418       // See below (*) for the rationale of processing only 2 rows at a time.
1419       // local_accumRN is the N-th local accumulator for row R.
1420       int16x8_t local_accum00;
1421       int16x8_t local_accum01;
1422       int16x8_t local_accum10;
1423       int16x8_t local_accum11;
1424 
1425       // Load 64 bytes of input activations values. Convert to signed int8
1426       // by flipping the sign bit (i.e. subtracting 128, the required
1427       // zero_point value).
1428       int8x16_t input0 = vreinterpretq_s8_u8(
1429           veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 0)));
1430       int8x16_t input1 = vreinterpretq_s8_u8(
1431           veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 1)));
1432       int8x16_t input2 = vreinterpretq_s8_u8(
1433           veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 2)));
1434       int8x16_t input3 = vreinterpretq_s8_u8(
1435           veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 3)));
1436 
1437       // Beginning of the core accumulation. Notice how while we have 4
1438       // rows to process, this code is taking care of only 2 rows at a time,
1439       // thus being divided into two parts looking similar ("Rows 0 and 1" and
1440       // "Rows 2 and 3").
1441       //
1442       // (*) The rationale for handling only 2 rows at a time is to avoid
1443       // cache aliasing issues on 4-way set-associative L1-cache CPUs, such
1444       // as Cortex-A53. With sufficiently large, power-of-two matrix dimensions,
1445       // we may find ourselves in a situation where rows alias each other in
1446       // the L1 cache, and moreover may also mutually alias with the input
1447       // activations. If we try to load 4 rows at a time, together with the
1448       // input activations, that may be 5 mutually-aliasing vectors, resulting
1449       // in constant mutual eviction from L1 cache. Handling 2 rows at a time
1450       // here largely mitigates these issues, and seems at least to be very
1451       // effective on Cortex-A53:
1452       //                          Before       After
1453       // big (Cortex-A73)         2.85 ms      2.85 ms
1454       // little (Cortex-A53)      11.0 ms      5.16 ms
1455 
1456       // Rows 0 and 1:
1457       // Load 64 bytes of weights values from each row. Convert to signed int8
1458       // by flipping the sign bit (i.e. subtracting 128, the required
1459       // zero_point value).
1460       int8x16_t weights00 = vreinterpretq_s8_u8(veorq_u8(
1461           signbit,
1462           vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 0)));
1463       int8x16_t weights01 = vreinterpretq_s8_u8(veorq_u8(
1464           signbit,
1465           vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 1)));
1466       int8x16_t weights02 = vreinterpretq_s8_u8(veorq_u8(
1467           signbit,
1468           vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 2)));
1469       int8x16_t weights03 = vreinterpretq_s8_u8(veorq_u8(
1470           signbit,
1471           vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 3)));
1472       int8x16_t weights10 = vreinterpretq_s8_u8(veorq_u8(
1473           signbit,
1474           vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 0)));
1475       int8x16_t weights11 = vreinterpretq_s8_u8(veorq_u8(
1476           signbit,
1477           vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 1)));
1478       int8x16_t weights12 = vreinterpretq_s8_u8(veorq_u8(
1479           signbit,
1480           vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 2)));
1481       int8x16_t weights13 = vreinterpretq_s8_u8(veorq_u8(
1482           signbit,
1483           vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 3)));
1484       // Multiply-accumulate into local 16-bit accumulators.
1485       // We can accumulate two products without overflow because weights are
1486       // required to never be -128, so each product is at most 127^2 in absolute
1487       // value.
1488       local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
1489       local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
1490       local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
1491       local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
1492       local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
1493                                vget_high_s8(input0));
1494       local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
1495                                vget_high_s8(input1));
1496       local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
1497                                vget_high_s8(input0));
1498       local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
1499                                vget_high_s8(input1));
1500       // Pairwise add and accumulate into 32-bit accumulators
1501       row_accum00 = vpadalq_s16(row_accum00, local_accum00);
1502       row_accum01 = vpadalq_s16(row_accum01, local_accum01);
1503       row_accum10 = vpadalq_s16(row_accum10, local_accum10);
1504       row_accum11 = vpadalq_s16(row_accum11, local_accum11);
1505       // Multiply-accumulate into local 16-bit accumulators.
1506       // We can accumulate two products without overflow because weights are
1507       // required to never be -128, so each product is at most 127^2 in absolute
1508       // value.
1509       local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
1510       local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
1511       local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
1512       local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
1513       local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
1514                                vget_high_s8(input2));
1515       local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
1516                                vget_high_s8(input3));
1517       local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
1518                                vget_high_s8(input2));
1519       local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
1520                                vget_high_s8(input3));
1521       // Pairwise add and accumulate into 32-bit accumulators
1522       row_accum00 = vpadalq_s16(row_accum00, local_accum00);
1523       row_accum01 = vpadalq_s16(row_accum01, local_accum01);
1524       row_accum10 = vpadalq_s16(row_accum10, local_accum10);
1525       row_accum11 = vpadalq_s16(row_accum11, local_accum11);
1526 
1527       // Rows 2 and 3:
1528       // Load 64 bytes of weights values from each row. Convert to signed int8
1529       // by flipping the sign bit (i.e. subtracting 128, the required
1530       // zero_point value).
1531       weights00 = vreinterpretq_s8_u8(veorq_u8(
1532           signbit,
1533           vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 0)));
1534       weights01 = vreinterpretq_s8_u8(veorq_u8(
1535           signbit,
1536           vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 1)));
1537       weights02 = vreinterpretq_s8_u8(veorq_u8(
1538           signbit,
1539           vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 2)));
1540       weights03 = vreinterpretq_s8_u8(veorq_u8(
1541           signbit,
1542           vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 3)));
1543       weights10 = vreinterpretq_s8_u8(veorq_u8(
1544           signbit,
1545           vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 0)));
1546       weights11 = vreinterpretq_s8_u8(veorq_u8(
1547           signbit,
1548           vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 1)));
1549       weights12 = vreinterpretq_s8_u8(veorq_u8(
1550           signbit,
1551           vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 2)));
1552       weights13 = vreinterpretq_s8_u8(veorq_u8(
1553           signbit,
1554           vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 3)));
1555       // Multiply-accumulate into local 16-bit accumulators.
1556       // We can accumulate two products without overflow because weights are
1557       // required to never be -128, so each product is at most 127^2 in absolute
1558       // value.
1559       local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
1560       local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
1561       local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
1562       local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
1563       local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
1564                                vget_high_s8(input0));
1565       local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
1566                                vget_high_s8(input1));
1567       local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
1568                                vget_high_s8(input0));
1569       local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
1570                                vget_high_s8(input1));
1571       // Pairwise add and accumulate into 32-bit accumulators
1572       row_accum20 = vpadalq_s16(row_accum20, local_accum00);
1573       row_accum21 = vpadalq_s16(row_accum21, local_accum01);
1574       row_accum30 = vpadalq_s16(row_accum30, local_accum10);
1575       row_accum31 = vpadalq_s16(row_accum31, local_accum11);
1576       // Multiply-accumulate into local 16-bit accumulators.
1577       // We can accumulate two products without overflow because weights are
1578       // required to never be -128, so each product is at most 127^2 in absolute
1579       // value.
1580       local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
1581       local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
1582       local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
1583       local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
1584       local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
1585                                vget_high_s8(input2));
1586       local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
1587                                vget_high_s8(input3));
1588       local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
1589                                vget_high_s8(input2));
1590       local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
1591                                vget_high_s8(input3));
1592       // Pairwise add and accumulate into 32-bit accumulators
1593       row_accum20 = vpadalq_s16(row_accum20, local_accum00);
1594       row_accum21 = vpadalq_s16(row_accum21, local_accum01);
1595       row_accum30 = vpadalq_s16(row_accum30, local_accum10);
1596       row_accum31 = vpadalq_s16(row_accum31, local_accum11);
1597     }
1598 
1599     row_accum00 = vaddq_s32(row_accum00, row_accum01);
1600     row_accum10 = vaddq_s32(row_accum10, row_accum11);
1601     row_accum20 = vaddq_s32(row_accum20, row_accum21);
1602     row_accum30 = vaddq_s32(row_accum30, row_accum31);
1603     // Horizontally reduce accumulators
1604     int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
1605         pairwise_reduced_acc_2, pairwise_reduced_acc_3;
1606     pairwise_reduced_acc_0 =
1607         vpadd_s32(vget_low_s32(row_accum00), vget_high_s32(row_accum00));
1608     pairwise_reduced_acc_1 =
1609         vpadd_s32(vget_low_s32(row_accum10), vget_high_s32(row_accum10));
1610     pairwise_reduced_acc_2 =
1611         vpadd_s32(vget_low_s32(row_accum20), vget_high_s32(row_accum20));
1612     pairwise_reduced_acc_3 =
1613         vpadd_s32(vget_low_s32(row_accum30), vget_high_s32(row_accum30));
1614     const int32x2_t reduced_lo =
1615         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1616     const int32x2_t reduced_hi =
1617         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1618     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1619     // Add bias values.
1620     reduced = vaddq_s32(reduced, bias_vec);
1621     reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
1622     // Multiply by the fixed-point multiplier.
1623     reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
1624     // Rounding-shift-right.
1625     using gemmlowp::RoundingDivideByPOT;
1626     reduced = RoundingDivideByPOT(reduced, right_shift);
1627     // Narrow values down to 16 bit signed.
1628     const int16x4_t res16 = vqmovn_s32(reduced);
1629     vst1_s16(output_ptr, res16);
1630     output_ptr += 4;
1631   }
1632 }
1633 #endif
1634 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data_int32,const RuntimeShape & output_shape,int16 * output_data,gemmlowp::GemmContext * gemmlowp_context)1635 inline void FullyConnected(
1636     const FullyConnectedParams& params, const RuntimeShape& input_shape,
1637     const uint8* input_data, const RuntimeShape& filter_shape,
1638     const uint8* filter_data, const RuntimeShape& bias_shape,
1639     const int32* bias_data_int32, const RuntimeShape& output_shape,
1640     int16* output_data, gemmlowp::GemmContext* gemmlowp_context) {
1641   ruy::profiler::ScopeLabel label("FullyConnected/Uint8Int16");
1642   const int32 input_offset = params.input_offset;
1643   const int32 filter_offset = params.weights_offset;
1644   const int32 output_offset = params.output_offset;
1645   const int32 output_multiplier = params.output_multiplier;
1646   const int output_shift = params.output_shift;
1647   const int32 output_activation_min = params.quantized_activation_min;
1648   const int32 output_activation_max = params.quantized_activation_max;
1649   // This is a copy of the reference implementation. We do not currently have a
1650   // properly optimized version.
1651   (void)gemmlowp_context;  // only used in properly optimized code.
1652   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1653   TFLITE_DCHECK_EQ(output_offset, 0);
1654   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1655   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1656 
1657   // TODO(b/62193649): This really should be:
1658   //     const int batches = ArraySize(output_dims, 1);
1659   // but the current --variable_batch hack consists in overwriting the 3rd
1660   // dimension with the runtime batch size, as we don't keep track for each
1661   // array of which dimension is the batch dimension in it.
1662   const int output_dim_count = output_shape.DimensionsCount();
1663   const int filter_dim_count = filter_shape.DimensionsCount();
1664   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1665   const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
1666                                        output_shape, output_dim_count - 1);
1667   const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
1668 
1669   // Implementation of the fully connected node suited to the inside of an LSTM
1670   // cell. The operands are 8-bit integers, the accumulators are internally
1671   // 32bit integers, and the output is 16-bit fixed-point with 3 integer bits so
1672   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
1673   // is explained in the function comment above.
1674 #ifdef GEMMLOWP_NEON
1675   if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
1676       output_activation_max == 32767) {
1677     if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
1678       GEMVForLstmCellWithSymmetricRange(
1679           input_shape, input_data, filter_shape, filter_data, bias_shape,
1680           bias_data_int32, output_multiplier, output_shift, output_shape,
1681           output_data);
1682       return;
1683     }
1684     if (!(output_depth % 4) && !(accum_depth % 8)) {
1685       GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data,
1686                       filter_offset, bias_shape, bias_data_int32,
1687                       output_multiplier, output_shift, output_shape,
1688                       output_data);
1689       return;
1690     }
1691   }
1692 #endif
1693   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> weights_matrix(
1694       filter_data, output_depth, accum_depth);
1695   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
1696       input_data, accum_depth, batches);
1697   gemmlowp::MatrixMap<int16, gemmlowp::MapOrder::ColMajor> output_matrix(
1698       output_data, output_depth, batches);
1699   typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
1700       ColVectorMap;
1701   ColVectorMap bias_vector(bias_data_int32, output_depth);
1702   gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
1703   bias_addition_stage.bias_vector = bias_vector;
1704   gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
1705   scale_stage.result_offset_after_shift = 0;
1706   scale_stage.result_fixedpoint_multiplier = output_multiplier;
1707   // Note that this shift is negated wrt ordinary FC.
1708   scale_stage.result_exponent = output_shift;
1709   gemmlowp::OutputStageClamp clamp_stage;
1710   clamp_stage.min = output_activation_min;
1711   clamp_stage.max = output_activation_max;
1712   gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
1713   auto output_pipeline =
1714       std::make_tuple(bias_addition_stage, scale_stage, clamp_stage,
1715                       saturating_cast_int16_stage);
1716   gemmlowp::GemmWithOutputPipeline<uint8, int16,
1717                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
1718       gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
1719       filter_offset, input_offset, output_pipeline);
1720 }
1721 
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)1722 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
1723                            int32 input_offset, const uint8* filter_data,
1724                            const Dims<4>& filter_dims, int32 filter_offset,
1725                            const int32* bias_data, const Dims<4>& bias_dims,
1726                            int32 output_offset, int32 output_multiplier,
1727                            int output_shift, int32 output_activation_min,
1728                            int32 output_activation_max, uint8* output_data,
1729                            const Dims<4>& output_dims,
1730                            gemmlowp::GemmContext* gemmlowp_context) {
1731   tflite::FullyConnectedParams op_params;
1732   op_params.input_offset = input_offset;
1733   op_params.weights_offset = filter_offset;
1734   op_params.output_offset = output_offset;
1735   op_params.output_multiplier = output_multiplier;
1736   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1737   op_params.output_shift = kReverseShift * output_shift;
1738   op_params.quantized_activation_min = output_activation_min;
1739   op_params.quantized_activation_max = output_activation_max;
1740 
1741   FullyConnected(op_params, DimsToShape(input_dims), input_data,
1742                  DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
1743                  bias_data, DimsToShape(output_dims), output_data,
1744                  gemmlowp_context);
1745 }
1746 
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data_int32,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)1747 inline void FullyConnected(
1748     const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
1749     const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
1750     const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
1751     int32 output_multiplier, int output_shift, int32 output_activation_min,
1752     int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
1753     gemmlowp::GemmContext* gemmlowp_context) {
1754   tflite::FullyConnectedParams op_params;
1755   op_params.input_offset = input_offset;
1756   op_params.weights_offset = filter_offset;
1757   op_params.output_offset = output_offset;
1758   op_params.output_multiplier = output_multiplier;
1759   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1760   op_params.output_shift = kReverseShift * output_shift;
1761   op_params.quantized_activation_min = output_activation_min;
1762   op_params.quantized_activation_max = output_activation_max;
1763 
1764   FullyConnected(op_params, DimsToShape(input_dims), input_data,
1765                  DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
1766                  bias_data_int32, DimsToShape(output_dims), output_data,
1767                  gemmlowp_context);
1768 }
1769 
1770 // legacy, for compatibility with old checked-in code
1771 template <FusedActivationFunctionType Ac>
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)1772 void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
1773                     int32 input_offset, const uint8* filter_data,
1774                     const Dims<4>& filter_dims, int32 filter_offset,
1775                     const int32* bias_data, const Dims<4>& bias_dims,
1776                     int32 output_offset, int32 output_multiplier,
1777                     int output_shift, int32 output_activation_min,
1778                     int32 output_activation_max, uint8* output_data,
1779                     const Dims<4>& output_dims,
1780                     gemmlowp::GemmContext* gemmlowp_context) {
1781   static_assert(Ac == FusedActivationFunctionType::kNone ||
1782                     Ac == FusedActivationFunctionType::kRelu ||
1783                     Ac == FusedActivationFunctionType::kRelu6 ||
1784                     Ac == FusedActivationFunctionType::kRelu1,
1785                 "");
1786   FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
1787                  filter_offset, bias_data, bias_dims, output_offset,
1788                  output_multiplier, output_shift, output_activation_min,
1789                  output_activation_max, output_data, output_dims,
1790                  gemmlowp_context);
1791 }
1792 
1793 #ifdef USE_NEON
LegacyInt8FullyConnectedAsGEMVWorkerImpl(const RuntimeShape & input_shape,const int8_t * input_data,int32 input_offset,const RuntimeShape & filter_shape,const int8_t * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,int8_t * output_data,int row_start,int row_end)1794 inline void LegacyInt8FullyConnectedAsGEMVWorkerImpl(
1795     const RuntimeShape& input_shape, const int8_t* input_data,
1796     int32 input_offset, const RuntimeShape& filter_shape,
1797     const int8_t* filter_data, int32 filter_offset,
1798     const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
1799     int32 output_multiplier, int output_shift, int32 output_activation_min,
1800     int32 output_activation_max, const RuntimeShape& output_shape,
1801     int8_t* output_data, int row_start, int row_end) {
1802   ruy::profiler::ScopeLabel label("FullyConnectedAsGEMVInt8/8bit");
1803   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1804   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1805   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1806   const int output_dim_count = output_shape.DimensionsCount();
1807   TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
1808   const int input_size = FlatSizeSkipDim(input_shape, 0);
1809   static constexpr int kPeel = 4;
1810   const bool shift_left = (output_shift > 0);
1811   TFLITE_DCHECK_GE(row_end - row_start, kPeel);
1812 
1813   for (int out = row_start; out < row_end; out += kPeel) {
1814     out = std::min(out, row_end - kPeel);
1815     int32x4_t acc0 = vdupq_n_s32(0);
1816     int32x4_t acc1 = acc0;
1817     int32x4_t acc2 = acc0;
1818     int32x4_t acc3 = acc0;
1819     const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
1820     const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
1821     int in = 0;
1822     for (; in <= input_size - 16; in += 16) {
1823       const int8x16_t input_val_s8 = vld1q_s8(input_data + in);
1824       const int8_t* filter_ptr = filter_data + in + out * input_size;
1825       int8x16_t filter_val_s8_0 = vld1q_s8(filter_ptr);
1826       filter_ptr += input_size;
1827       int8x16_t filter_val_s8_1 = vld1q_s8(filter_ptr);
1828       filter_ptr += input_size;
1829       int8x16_t filter_val_s8_2 = vld1q_s8(filter_ptr);
1830       filter_ptr += input_size;
1831       int8x16_t filter_val_s8_3 = vld1q_s8(filter_ptr);
1832       int16x8_t input_val_0, input_val_1;
1833       int8x8_t low = vget_low_s8(input_val_s8);
1834       int8x8_t high = vget_high_s8(input_val_s8);
1835       input_val_0 = vmovl_s8(low);
1836       input_val_1 = vmovl_s8(high);
1837       input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
1838       input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
1839       low = vget_low_s8(filter_val_s8_0);
1840       high = vget_high_s8(filter_val_s8_0);
1841       int16x8_t filter_val_0_0 = vmovl_s8(low);
1842       int16x8_t filter_val_0_1 = vmovl_s8(high);
1843       filter_val_0_0 = vaddq_s16(filter_val_0_0, filter_offset_vec);
1844       filter_val_0_1 = vaddq_s16(filter_val_0_1, filter_offset_vec);
1845       low = vget_low_s8(filter_val_s8_1);
1846       high = vget_high_s8(filter_val_s8_1);
1847       int16x8_t filter_val_1_0 = vmovl_s8(low);
1848       int16x8_t filter_val_1_1 = vmovl_s8(high);
1849       filter_val_1_0 = vaddq_s16(filter_val_1_0, filter_offset_vec);
1850       filter_val_1_1 = vaddq_s16(filter_val_1_1, filter_offset_vec);
1851       low = vget_low_s8(filter_val_s8_2);
1852       high = vget_high_s8(filter_val_s8_2);
1853       int16x8_t filter_val_2_0 = vmovl_s8(low);
1854       int16x8_t filter_val_2_1 = vmovl_s8(high);
1855       filter_val_2_0 = vaddq_s16(filter_val_2_0, filter_offset_vec);
1856       filter_val_2_1 = vaddq_s16(filter_val_2_1, filter_offset_vec);
1857       low = vget_low_s8(filter_val_s8_3);
1858       high = vget_high_s8(filter_val_s8_3);
1859       int16x8_t filter_val_3_0 = vmovl_s8(low);
1860       int16x8_t filter_val_3_1 = vmovl_s8(high);
1861       filter_val_3_0 = vaddq_s16(filter_val_3_0, filter_offset_vec);
1862       filter_val_3_1 = vaddq_s16(filter_val_3_1, filter_offset_vec);
1863       acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_0),
1864                        vget_low_s16(input_val_0));
1865       acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_0),
1866                        vget_low_s16(input_val_0));
1867       acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_0),
1868                        vget_low_s16(input_val_0));
1869       acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_0),
1870                        vget_low_s16(input_val_0));
1871       acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_1),
1872                        vget_low_s16(input_val_1));
1873       acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_1),
1874                        vget_low_s16(input_val_1));
1875       acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_1),
1876                        vget_low_s16(input_val_1));
1877       acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_1),
1878                        vget_low_s16(input_val_1));
1879       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_0),
1880                        vget_high_s16(input_val_0));
1881       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_0),
1882                        vget_high_s16(input_val_0));
1883       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_0),
1884                        vget_high_s16(input_val_0));
1885       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_0),
1886                        vget_high_s16(input_val_0));
1887       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_1),
1888                        vget_high_s16(input_val_1));
1889       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_1),
1890                        vget_high_s16(input_val_1));
1891       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_1),
1892                        vget_high_s16(input_val_1));
1893       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_1),
1894                        vget_high_s16(input_val_1));
1895     }
1896     for (; in <= input_size - 8; in += 8) {
1897       const int8x8_t input_val_s8 = vld1_s8(input_data + in);
1898       const int8_t* filter_ptr = filter_data + in + out * input_size;
1899       int8x8_t filter_val_s8_0 = vld1_s8(filter_ptr);
1900       filter_ptr += input_size;
1901       int8x8_t filter_val_s8_1 = vld1_s8(filter_ptr);
1902       filter_ptr += input_size;
1903       int8x8_t filter_val_s8_2 = vld1_s8(filter_ptr);
1904       filter_ptr += input_size;
1905       int8x8_t filter_val_s8_3 = vld1_s8(filter_ptr);
1906       int16x8_t input_val = vmovl_s8(input_val_s8);
1907       input_val = vaddq_s16(input_val, input_offset_vec);
1908       int16x8_t filter_val_0 = vmovl_s8(filter_val_s8_0);
1909       filter_val_0 = vaddq_s16(filter_val_0, filter_offset_vec);
1910       int16x8_t filter_val_1 = vmovl_s8(filter_val_s8_1);
1911       filter_val_1 = vaddq_s16(filter_val_1, filter_offset_vec);
1912       int16x8_t filter_val_2 = vmovl_s8(filter_val_s8_2);
1913       filter_val_2 = vaddq_s16(filter_val_2, filter_offset_vec);
1914       int16x8_t filter_val_3 = vmovl_s8(filter_val_s8_3);
1915       filter_val_3 = vaddq_s16(filter_val_3, filter_offset_vec);
1916       acc0 =
1917           vmlal_s16(acc0, vget_low_s16(filter_val_0), vget_low_s16(input_val));
1918       acc1 =
1919           vmlal_s16(acc1, vget_low_s16(filter_val_1), vget_low_s16(input_val));
1920       acc2 =
1921           vmlal_s16(acc2, vget_low_s16(filter_val_2), vget_low_s16(input_val));
1922       acc3 =
1923           vmlal_s16(acc3, vget_low_s16(filter_val_3), vget_low_s16(input_val));
1924       acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
1925                        vget_high_s16(input_val));
1926       acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
1927                        vget_high_s16(input_val));
1928       acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
1929                        vget_high_s16(input_val));
1930       acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
1931                        vget_high_s16(input_val));
1932     }
1933     if (in < input_size) {
1934       int32 buf[16];
1935       vst1q_s32(buf + 0, acc0);
1936       vst1q_s32(buf + 4, acc1);
1937       vst1q_s32(buf + 8, acc2);
1938       vst1q_s32(buf + 12, acc3);
1939       for (; in < input_size; in++) {
1940         int lane = (in + 8 - input_size) % 4;
1941         const int32 input_val = input_data[in] + input_offset;
1942         for (int k = 0; k < kPeel; k++) {
1943           int32 filter_val =
1944               filter_data[in + (out + k) * input_size] + filter_offset;
1945           buf[lane + 4 * k] += filter_val * input_val;
1946         }
1947       }
1948       acc0 = vld1q_s32(buf + 0);
1949       acc1 = vld1q_s32(buf + 4);
1950       acc2 = vld1q_s32(buf + 8);
1951       acc3 = vld1q_s32(buf + 12);
1952     }
1953 
1954     // Horizontally reduce accumulators
1955     int32x2_t pairwise_reduced_acc_0 =
1956         vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
1957     int32x2_t pairwise_reduced_acc_1 =
1958         vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1));
1959     int32x2_t pairwise_reduced_acc_2 =
1960         vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2));
1961     int32x2_t pairwise_reduced_acc_3 =
1962         vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3));
1963     const int32x2_t reduced_lo =
1964         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1965     const int32x2_t reduced_hi =
1966         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1967     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1968     // Add bias values.
1969     int32x4_t bias_vec = vld1q_s32(bias_data + out);
1970     reduced = vaddq_s32(reduced, bias_vec);
1971     if (shift_left) {
1972       const int32 multiplier_power_of_two = 1 << output_shift;
1973       reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
1974       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1975     } else {
1976       // Multiply by the fixed-point multiplier.
1977       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1978       // Rounding-shift-right.
1979       using gemmlowp::RoundingDivideByPOT;
1980       reduced = RoundingDivideByPOT(reduced, -output_shift);
1981     }
1982     // Add the output offset.
1983     const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
1984     reduced = vaddq_s32(reduced, output_offset_vec);
1985     // Narrow values down to 16 bit signed.
1986     const int16x4_t res16 = vqmovn_s32(reduced);
1987     // Narrow values down to 8 bit signed, saturating.
1988     int8x8_t res8 = vqmovn_s16(vcombine_s16(res16, res16));
1989     // Apply the clamping from the activation function
1990     res8 = vmax_s8(res8, vdup_n_s8(output_activation_min));
1991     res8 = vmin_s8(res8, vdup_n_s8(output_activation_max));
1992     // Store results to destination.
1993     vst1_lane_s8(output_data + out + 0, res8, 0);
1994     vst1_lane_s8(output_data + out + 1, res8, 1);
1995     vst1_lane_s8(output_data + out + 2, res8, 2);
1996     vst1_lane_s8(output_data + out + 3, res8, 3);
1997   }
1998 }
1999 
2000 struct LegacyInt8FullyConnectedAsGEMVWorkerTask : public gemmlowp::Task {
LegacyInt8FullyConnectedAsGEMVWorkerTaskLegacyInt8FullyConnectedAsGEMVWorkerTask2001   LegacyInt8FullyConnectedAsGEMVWorkerTask(
2002       const RuntimeShape& input_shape, const int8_t* input_data,
2003       int32 input_offset, const RuntimeShape& filter_shape,
2004       const int8_t* filter_data, int32 filter_offset,
2005       const RuntimeShape& bias_shape, const int32* bias_data,
2006       int32 output_offset, int32 output_multiplier, int output_shift,
2007       int32 output_activation_min, int32 output_activation_max,
2008       const RuntimeShape& output_shape, int8_t* output_data, int row_start,
2009       int row_end)
2010       : input_shape_(input_shape),
2011         input_data_(input_data),
2012         input_offset_(input_offset),
2013         filter_shape_(filter_shape),
2014         filter_data_(filter_data),
2015         filter_offset_(filter_offset),
2016         bias_shape_(bias_shape),
2017         bias_data_(bias_data),
2018         output_offset_(output_offset),
2019         output_multiplier_(output_multiplier),
2020         output_shift_(output_shift),
2021         output_activation_min_(output_activation_min),
2022         output_activation_max_(output_activation_max),
2023         output_shape_(output_shape),
2024         output_data_(output_data),
2025         row_start_(row_start),
2026         row_end_(row_end) {}
2027 
RunLegacyInt8FullyConnectedAsGEMVWorkerTask2028   void Run() override {
2029     LegacyInt8FullyConnectedAsGEMVWorkerImpl(
2030         input_shape_, input_data_, input_offset_, filter_shape_, filter_data_,
2031         filter_offset_, bias_shape_, bias_data_, output_offset_,
2032         output_multiplier_, output_shift_, output_activation_min_,
2033         output_activation_max_, output_shape_, output_data_, row_start_,
2034         row_end_);
2035   }
2036 
2037   const RuntimeShape& input_shape_;
2038   const int8_t* input_data_;
2039   int32 input_offset_;
2040   const RuntimeShape& filter_shape_;
2041   const int8_t* filter_data_;
2042   int32 filter_offset_;
2043   const RuntimeShape& bias_shape_;
2044   const int32* bias_data_;
2045   int32 output_offset_;
2046   int32 output_multiplier_;
2047   int output_shift_;
2048   int32 output_activation_min_;
2049   int32 output_activation_max_;
2050   const RuntimeShape& output_shape_;
2051   int8_t* output_data_;
2052   int row_start_;
2053   int row_end_;
2054 };
2055 
LegacyInt8FullyConnectedAsGEMV(const RuntimeShape & input_shape,const int8_t * input_data,int32 input_offset,const RuntimeShape & filter_shape,const int8_t * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,int8_t * output_data,gemmlowp::GemmContext * gemmlowp_context)2056 inline void LegacyInt8FullyConnectedAsGEMV(
2057     const RuntimeShape& input_shape, const int8_t* input_data,
2058     int32 input_offset, const RuntimeShape& filter_shape,
2059     const int8_t* filter_data, int32 filter_offset,
2060     const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
2061     int32 output_multiplier, int output_shift, int32 output_activation_min,
2062     int32 output_activation_max, const RuntimeShape& output_shape,
2063     int8_t* output_data, gemmlowp::GemmContext* gemmlowp_context) {
2064   const int output_dim_count = output_shape.DimensionsCount();
2065   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
2066   const int output_rows = output_shape.Dims(output_dim_count - 1);
2067   const int input_size = FlatSizeSkipDim(input_shape, 0);
2068   static constexpr int kKernelRows = 4;
2069   const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
2070       gemmlowp_context->max_num_threads(), output_rows, batches, input_size);
2071   if (thread_count == 1) {
2072     // Single-thread case: do the computation on the current thread, don't
2073     // use a threadpool
2074     LegacyInt8FullyConnectedAsGEMVWorkerImpl(
2075         input_shape, input_data, input_offset, filter_shape, filter_data,
2076         filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
2077         output_shift, output_activation_min, output_activation_max,
2078         output_shape, output_data, 0, output_rows);
2079     return;
2080   }
2081 
2082   // Multi-threaded case: use the gemmlowp context's threadpool.
2083   TFLITE_DCHECK_GT(thread_count, 1);
2084   std::vector<LegacyInt8FullyConnectedAsGEMVWorkerTask> tasks;
2085   // TODO(b/131746020) don't create new heap allocations every time.
2086   // At least we make it a single heap allocation by using reserve().
2087   tasks.reserve(thread_count);
2088   const int kRowsPerWorker = gemmlowp::RoundUp<kKernelRows>(
2089       gemmlowp::CeilQuotient(output_rows, thread_count));
2090   int row_start = 0;
2091   for (int i = 0; i < thread_count; ++i) {
2092     int row_end = std::min(output_rows, row_start + kRowsPerWorker);
2093     tasks.emplace_back(input_shape, input_data, input_offset, filter_shape,
2094                        filter_data, filter_offset, bias_shape, bias_data,
2095                        output_offset, output_multiplier, output_shift,
2096                        output_activation_min, output_activation_max,
2097                        output_shape, output_data, row_start, row_end);
2098     row_start = row_end;
2099   }
2100   TFLITE_DCHECK_EQ(row_start, output_rows);
2101   gemmlowp_context->workers_pool()->Execute(tasks.size(), tasks.data());
2102 }
2103 #endif  // USE_NEON
2104 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & filter_shape,const int8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int8 * output_data,gemmlowp::GemmContext * gemmlowp_context)2105 inline void FullyConnected(
2106     const FullyConnectedParams& params, const RuntimeShape& input_shape,
2107     const int8* input_data, const RuntimeShape& filter_shape,
2108     const int8* filter_data, const RuntimeShape& bias_shape,
2109     const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
2110     gemmlowp::GemmContext* gemmlowp_context) {
2111   ruy::profiler::ScopeLabel label("FullyConnectedInt8/8bit");
2112 
2113 #ifdef USE_NEON
2114   const int32 input_offset = params.input_offset;
2115   const int32 filter_offset = params.weights_offset;
2116   const int32 output_offset = params.output_offset;
2117   const int32 output_multiplier = params.output_multiplier;
2118   const int output_shift = params.output_shift;
2119   const int32 output_activation_min = params.quantized_activation_min;
2120   const int32 output_activation_max = params.quantized_activation_max;
2121   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
2122   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
2123   // TODO(b/62193649): This really should be:
2124   //     const int batches = ArraySize(output_dims, 1);
2125   // but the current --variable_batch hack consists in overwriting the 3rd
2126   // dimension with the runtime batch size, as we don't keep track for each
2127   // array of which dimension is the batch dimension in it.
2128   const int output_dim_count = output_shape.DimensionsCount();
2129   const int filter_dim_count = filter_shape.DimensionsCount();
2130   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
2131   if (batches == 1) {
2132     const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
2133                                         output_shape, output_dim_count - 1);
2134     if (output_size >= 4) {
2135       return LegacyInt8FullyConnectedAsGEMV(
2136           input_shape, input_data, input_offset, filter_shape, filter_data,
2137           filter_offset, bias_shape, bias_data, output_offset,
2138           output_multiplier, output_shift, output_activation_min,
2139           output_activation_max, output_shape, output_data, gemmlowp_context);
2140     }
2141   }
2142 #endif  // USE_NEON
2143 
2144 #ifdef GEMMLOWP_NEON
2145   const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
2146   const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
2147   TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
2148   const int output_rows = output_shape.Dims(output_dim_count - 1);
2149   TFLITE_DCHECK_EQ(output_rows, filter_rows);
2150   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
2151 
2152   gemmlowp::MatrixMap<const int8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2153       filter_data, output_rows, filter_cols, filter_cols);
2154   gemmlowp::MatrixMap<const int8, gemmlowp::MapOrder::ColMajor> input_matrix(
2155       input_data, filter_cols, batches, filter_cols);
2156   gemmlowp::MatrixMap<int8, gemmlowp::MapOrder::ColMajor> output_matrix(
2157       output_data, output_rows, batches, output_rows);
2158   const auto& output_pipeline = GemmlowpOutputPipelineInt8::MakeExp(
2159       bias_data, output_rows, output_offset, output_multiplier, output_shift,
2160       output_activation_min, output_activation_max);
2161 
2162   gemmlowp::GemmWithOutputPipeline<
2163       int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>(
2164       gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
2165       filter_offset, input_offset, output_pipeline);
2166   return;
2167 #endif  // GEMMLOWP_NEON
2168 
2169   // If both GEMMLOWP_NEON && NEON paths are skipped, fallback to reference
2170   // implementation.
2171   reference_integer_ops::FullyConnected(params, input_shape, input_data,
2172                                         filter_shape, filter_data, bias_shape,
2173                                         bias_data, output_shape, output_data);
2174 }
2175 
2176 struct LegacyShuffledFullyConnectedWorkerTask : gemmlowp::Task {
LegacyShuffledFullyConnectedWorkerTaskLegacyShuffledFullyConnectedWorkerTask2177   LegacyShuffledFullyConnectedWorkerTask(const uint8* input_data,
2178                                          const int8* shuffled_weights_data,
2179                                          int batches, int output_depth,
2180                                          int output_stride, int accum_depth,
2181                                          const int32* bias_data,
2182                                          int32 output_multiplier,
2183                                          int output_shift, int16* output_data)
2184       : input_data_(input_data),
2185         shuffled_weights_data_(shuffled_weights_data),
2186         batches_(batches),
2187         output_depth_(output_depth),
2188         output_stride_(output_stride),
2189         accum_depth_(accum_depth),
2190         bias_data_(bias_data),
2191         output_multiplier_(output_multiplier),
2192         output_shift_(output_shift),
2193         output_data_(output_data) {}
2194 
RunLegacyShuffledFullyConnectedWorkerTask2195   void Run() override {
2196     ShuffledFullyConnectedWorkerImpl(
2197         input_data_, shuffled_weights_data_, batches_, output_depth_,
2198         output_stride_, accum_depth_, bias_data_, output_multiplier_,
2199         output_shift_, output_data_);
2200   }
2201 
2202   const uint8* input_data_;
2203   const int8* shuffled_weights_data_;
2204   int batches_;
2205   int output_depth_;
2206   int output_stride_;
2207   int accum_depth_;
2208   const int32* bias_data_;
2209   int32 output_multiplier_;
2210   int output_shift_;
2211   int16* output_data_;
2212 };
2213 
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemmlowp_context)2214 inline void ShuffledFullyConnected(
2215     const FullyConnectedParams& params, const RuntimeShape& input_shape,
2216     const uint8* input_data, const RuntimeShape& weights_shape,
2217     const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
2218     const int32* bias_data, const RuntimeShape& output_shape,
2219     int16* output_data, uint8* shuffled_input_workspace_data,
2220     gemmlowp::GemmContext* gemmlowp_context) {
2221   ruy::profiler::ScopeLabel label("ShuffledFullyConnected/8bit");
2222   const int32 output_multiplier = params.output_multiplier;
2223   const int output_shift = params.output_shift;
2224   const int32 output_activation_min = params.quantized_activation_min;
2225   const int32 output_activation_max = params.quantized_activation_max;
2226   (void)gemmlowp_context;  // only used in optimized code.
2227   TFLITE_DCHECK_EQ(output_activation_min, -32768);
2228   TFLITE_DCHECK_EQ(output_activation_max, 32767);
2229   TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
2230   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2231   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
2232   // TODO(b/62193649): This really should be:
2233   //     const int batches = ArraySize(output_dims, 1);
2234   // but the current --variable_batch hack consists in overwriting the 3rd
2235   // dimension with the runtime batch size, as we don't keep track for each
2236   // array of which dimension is the batch dimension in it.
2237   const int output_dim_count = output_shape.DimensionsCount();
2238   const int weights_dim_count = weights_shape.DimensionsCount();
2239   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
2240   const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
2241                                        output_shape, output_dim_count - 1);
2242   const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
2243   TFLITE_DCHECK((accum_depth % 16) == 0);
2244   TFLITE_DCHECK((output_depth % 4) == 0);
2245   // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
2246   // so that just reinterpreting them as int8 values is equivalent to
2247   // subtracting 128 from them, thus implementing for free the subtraction of
2248   // the zero_point value 128.
2249   const int8* int8_shuffled_weights_data =
2250       reinterpret_cast<const int8*>(shuffled_weights_data);
2251 
2252   // Shuffling and xoring of input activations into the workspace buffer
2253   if (batches == 1) {
2254 #ifdef USE_NEON
2255     const uint8x16_t signbit = vdupq_n_u8(0x80);
2256     for (int i = 0; i < accum_depth; i += 16) {
2257       uint8x16_t val = vld1q_u8(input_data + i);
2258       val = veorq_u8(val, signbit);
2259       vst1q_u8(shuffled_input_workspace_data + i, val);
2260     }
2261 #else
2262     for (int i = 0; i < accum_depth; i++) {
2263       shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
2264     }
2265 #endif
2266   } else if (batches == 4) {
2267     uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
2268     int c = 0;
2269 #ifdef USE_NEON
2270     const uint8x16_t signbit = vdupq_n_u8(0x80);
2271     for (c = 0; c < accum_depth; c += 16) {
2272       const uint8* src_data_ptr = input_data + c;
2273       uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth);
2274       uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth);
2275       uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth);
2276       uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth);
2277       val0 = veorq_u8(val0, signbit);
2278       val1 = veorq_u8(val1, signbit);
2279       val2 = veorq_u8(val2, signbit);
2280       val3 = veorq_u8(val3, signbit);
2281       vst1q_u8(shuffled_input_workspace_ptr + 0, val0);
2282       vst1q_u8(shuffled_input_workspace_ptr + 16, val1);
2283       vst1q_u8(shuffled_input_workspace_ptr + 32, val2);
2284       vst1q_u8(shuffled_input_workspace_ptr + 48, val3);
2285       shuffled_input_workspace_ptr += 64;
2286     }
2287 #else
2288     for (c = 0; c < accum_depth; c += 16) {
2289       for (int b = 0; b < 4; b++) {
2290         const uint8* src_data_ptr = input_data + b * accum_depth + c;
2291         for (int j = 0; j < 16; j++) {
2292           uint8 src_val = *src_data_ptr++;
2293           // Flip the sign bit, so that the kernel will only need to
2294           // reinterpret these uint8 values as int8, getting for free the
2295           // subtraction of the zero_point value 128.
2296           uint8 dst_val = src_val ^ 0x80;
2297           *shuffled_input_workspace_ptr++ = dst_val;
2298         }
2299       }
2300     }
2301 #endif
2302   } else {
2303     TFLITE_DCHECK(false);
2304     return;
2305   }
2306 
2307   static constexpr int kKernelRows = 4;
2308   const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
2309       gemmlowp_context->max_num_threads(), output_depth, batches, accum_depth);
2310   if (thread_count == 1) {
2311     // Single-thread case: do the computation on the current thread, don't
2312     // use a threadpool
2313     ShuffledFullyConnectedWorkerImpl(
2314         shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
2315         output_depth, output_depth, accum_depth, bias_data, output_multiplier,
2316         output_shift, output_data);
2317     return;
2318   }
2319 
2320   // Multi-threaded case: use the gemmlowp context's threadpool.
2321   TFLITE_DCHECK_GT(thread_count, 1);
2322   std::vector<gemmlowp::Task*> tasks(thread_count);
2323   const int kRowsPerWorker = gemmlowp::RoundUp<kKernelRows>(
2324       gemmlowp::CeilQuotient(output_depth, thread_count));
2325   int row_start = 0;
2326   for (int i = 0; i < thread_count; i++) {
2327     int row_end = std::min(output_depth, row_start + kRowsPerWorker);
2328     tasks[i] = new LegacyShuffledFullyConnectedWorkerTask(
2329         shuffled_input_workspace_data,
2330         int8_shuffled_weights_data + row_start * accum_depth, batches,
2331         row_end - row_start, output_depth, accum_depth, bias_data + row_start,
2332         output_multiplier, output_shift, output_data + row_start);
2333     row_start = row_end;
2334   }
2335   TFLITE_DCHECK_EQ(row_start, output_depth);
2336   gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
2337 }
2338 
ShuffledFullyConnected(const uint8 * input_data,const Dims<4> & input_dims,const uint8 * shuffled_weights_data,const Dims<4> & weights_dims,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemmlowp_context)2339 inline void ShuffledFullyConnected(
2340     const uint8* input_data, const Dims<4>& input_dims,
2341     const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
2342     const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
2343     int output_shift, int32 output_activation_min, int32 output_activation_max,
2344     int16* output_data, const Dims<4>& output_dims,
2345     uint8* shuffled_input_workspace_data,
2346     gemmlowp::GemmContext* gemmlowp_context) {
2347   tflite::FullyConnectedParams op_params;
2348   op_params.output_multiplier = output_multiplier;
2349   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
2350   op_params.output_shift = kReverseShift * output_shift;
2351   op_params.quantized_activation_min = output_activation_min;
2352   op_params.quantized_activation_max = output_activation_max;
2353 
2354   ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
2355                          DimsToShape(weights_dims), shuffled_weights_data,
2356                          DimsToShape(bias_dims), bias_data,
2357                          DimsToShape(output_dims), output_data,
2358                          shuffled_input_workspace_data, gemmlowp_context);
2359 }
2360 
2361 template <typename T>
ExtractPatchIntoBufferColumn(const Dims<4> & input_dims,int w,int h,int b,int kheight,int kwidth,int stride_width,int stride_height,int pad_width,int pad_height,int in_width,int in_height,int in_depth,int single_buffer_length,int buffer_id,const T * in_data,T * conv_buffer_data,uint8 zero_byte)2362 inline void ExtractPatchIntoBufferColumn(
2363     const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
2364     int stride_width, int stride_height, int pad_width, int pad_height,
2365     int in_width, int in_height, int in_depth, int single_buffer_length,
2366     int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) {
2367   ExtractPatchIntoBufferColumn(
2368       DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width,
2369       stride_height, pad_width, pad_height, in_width, in_height, in_depth,
2370       single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte);
2371 }
2372 
2373 template <typename T>
DilatedIm2col(const T * input_data,const Dims<4> & input_dims,const Dims<4> & filter_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,const Dims<4> & output_dims,uint8 zero_byte,T * im2col_data)2374 void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
2375                    const Dims<4>& filter_dims, int stride_width,
2376                    int stride_height, int dilation_width_factor,
2377                    int dilation_height_factor, int pad_width, int pad_height,
2378                    const Dims<4>& output_dims, uint8 zero_byte,
2379                    T* im2col_data) {
2380   tflite::ConvParams op_params;
2381   // Padding type is ignored, but still set.
2382   op_params.padding_type = PaddingType::kSame;
2383   op_params.padding_values.width = pad_width;
2384   op_params.padding_values.height = pad_height;
2385   op_params.stride_width = stride_width;
2386   op_params.stride_height = stride_height;
2387   op_params.dilation_width_factor = dilation_width_factor;
2388   op_params.dilation_height_factor = dilation_height_factor;
2389 
2390   DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
2391                 DimsToShape(filter_dims), DimsToShape(output_dims),
2392                 im2col_data);
2393 }
2394 
2395 template <typename T>
Im2col(const T * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kheight,int kwidth,uint8 zero_byte,T * output_data,const Dims<4> & output_dims)2396 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
2397             int stride_height, int pad_width, int pad_height, int kheight,
2398             int kwidth, uint8 zero_byte, T* output_data,
2399             const Dims<4>& output_dims) {
2400   tflite::ConvParams op_params;
2401   // Padding type is ignored, but still set.
2402   op_params.padding_type = PaddingType::kSame;
2403   op_params.padding_values.width = pad_width;
2404   op_params.padding_values.height = pad_height;
2405   op_params.stride_width = stride_width;
2406   op_params.stride_height = stride_height;
2407   op_params.dilation_width_factor = 1;
2408   op_params.dilation_height_factor = 1;
2409 
2410   Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims),
2411          input_data, DimsToShape(output_dims), output_data);
2412 }
2413 
2414 // legacy, for compatibility with old checked-in code
2415 template <typename T>
Im2col(const T * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int kheight,int kwidth,uint8 zero_byte,T * output_data,const Dims<4> & output_dims)2416 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
2417             int pad_width, int pad_height, int kheight, int kwidth,
2418             uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
2419   Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
2420          kwidth, zero_byte, output_data, output_dims);
2421 }
2422 
Conv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)2423 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
2424                  const float* input_data, const RuntimeShape& filter_shape,
2425                  const float* filter_data, const RuntimeShape& bias_shape,
2426                  const float* bias_data, const RuntimeShape& output_shape,
2427                  float* output_data, const RuntimeShape& im2col_shape,
2428                  float* im2col_data) {
2429   const int stride_width = params.stride_width;
2430   const int stride_height = params.stride_height;
2431   const int dilation_width_factor = params.dilation_width_factor;
2432   const int dilation_height_factor = params.dilation_height_factor;
2433   const float output_activation_min = params.float_activation_min;
2434   const float output_activation_max = params.float_activation_max;
2435   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2436   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2437   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2438 
2439   (void)im2col_data;
2440   (void)im2col_shape;
2441   ruy::profiler::ScopeLabel label("Conv");
2442 
2443   // NB: the float 0.0f value is represented by all zero bytes.
2444   const uint8 float_zero_byte = 0x00;
2445   const float* gemm_input_data = nullptr;
2446   const RuntimeShape* gemm_input_shape = nullptr;
2447   const int filter_width = filter_shape.Dims(2);
2448   const int filter_height = filter_shape.Dims(1);
2449   const bool need_dilated_im2col =
2450       dilation_width_factor != 1 || dilation_height_factor != 1;
2451   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2452                            filter_width != 1 || filter_height != 1;
2453   if (need_dilated_im2col) {
2454     DilatedIm2col(params, float_zero_byte, input_shape, input_data,
2455                   filter_shape, output_shape, im2col_data);
2456     gemm_input_data = im2col_data;
2457     gemm_input_shape = &im2col_shape;
2458   } else if (need_im2col) {
2459     TFLITE_DCHECK(im2col_data);
2460     Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
2461            input_data, im2col_shape, im2col_data);
2462     gemm_input_data = im2col_data;
2463     gemm_input_shape = &im2col_shape;
2464   } else {
2465     // TODO(aselle): We need to make sure to not send im2col if it is not
2466     // needed.
2467     TFLITE_DCHECK(!im2col_data);
2468     gemm_input_data = input_data;
2469     gemm_input_shape = &input_shape;
2470   }
2471 
2472   // The following code computes matrix multiplication c = a * transponse(b)
2473   // with CBLAS, where:
2474   // * `a` is a matrix with dimensions (m, k).
2475   // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
2476   // * `c` is a matrix with dimensions (m, n).
2477   // The naming of variables are aligned with CBLAS specification here.
2478   const float* a = gemm_input_data;
2479   const float* b = filter_data;
2480   float* c = output_data;
2481   const int gemm_input_dims = gemm_input_shape->DimensionsCount();
2482   int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
2483   int n = output_shape.Dims(3);
2484   int k = gemm_input_shape->Dims(gemm_input_dims - 1);
2485 
2486 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
2487   // The stride of matrix a, b and c respectively.
2488   int stride_a = k;
2489   int stride_b = k;
2490   int stride_c = n;
2491 
2492   cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
2493               stride_a, b, stride_b, 0.0f, c, stride_c);
2494 #else
2495   // When an optimized CBLAS implementation is not available, fall back
2496   // to using Eigen.
2497   typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
2498       Matrix;
2499   typedef Eigen::Map<Matrix> MatrixRef;
2500   typedef Eigen::Map<const Matrix> ConstMatrixRef;
2501 
2502   MatrixRef matrix_c(c, m, n);
2503   ConstMatrixRef matrix_a(a, m, k);
2504   ConstMatrixRef matrix_b(b, n, k);
2505 
2506   // The following special casing for when a or b is a vector is required
2507   // as Eigen seem to fail to make this optimization on its own.
2508   if (n == 1) {
2509     ruy::profiler::ScopeLabel label("GEMV");
2510     matrix_c.col(0).noalias() = matrix_a * matrix_b.row(0).transpose();
2511   } else if (m == 1) {
2512     ruy::profiler::ScopeLabel label("GEMV");
2513     matrix_c.row(0).noalias() = matrix_a.row(0) * matrix_b.transpose();
2514   } else {
2515     ruy::profiler::ScopeLabel label("GEMM");
2516     matrix_c.noalias() = matrix_a * matrix_b.transpose();
2517   }
2518 
2519 #endif  //  defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
2520 
2521   optimized_ops::AddBiasAndEvalActivationFunction(
2522       output_activation_min, output_activation_max, bias_shape, bias_data,
2523       output_shape, output_data);
2524 }
2525 
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2526 inline void Conv(const float* input_data, const Dims<4>& input_dims,
2527                  const float* filter_data, const Dims<4>& filter_dims,
2528                  const float* bias_data, const Dims<4>& bias_dims,
2529                  int stride_width, int stride_height, int dilation_width_factor,
2530                  int dilation_height_factor, int pad_width, int pad_height,
2531                  float output_activation_min, float output_activation_max,
2532                  float* output_data, const Dims<4>& output_dims,
2533                  float* im2col_data, const Dims<4>& im2col_dims) {
2534   tflite::ConvParams op_params;
2535   // Padding type is ignored, but still set.
2536   op_params.padding_type = PaddingType::kSame;
2537   op_params.padding_values.width = pad_width;
2538   op_params.padding_values.height = pad_height;
2539   op_params.stride_width = stride_width;
2540   op_params.stride_height = stride_height;
2541   op_params.dilation_width_factor = dilation_width_factor;
2542   op_params.dilation_height_factor = dilation_height_factor;
2543   op_params.float_activation_min = output_activation_min;
2544   op_params.float_activation_max = output_activation_max;
2545 
2546   Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
2547        filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
2548        output_data, DimsToShape(im2col_dims), im2col_data);
2549 }
2550 
HybridConv(const int8_t * input_data,const Dims<4> & input_dims,const int8_t * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * scaling_factors_ptr,float output_activation_min,float output_activation_max,int32_t * scratch_data,const Dims<4> & scratch_dims,float * output_data,const Dims<4> & output_dims,int8_t * im2col_data,const Dims<4> & im2col_dims,CpuBackendContext * context)2551 inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
2552                        const int8_t* filter_data, const Dims<4>& filter_dims,
2553                        const float* bias_data, const Dims<4>& bias_dims,
2554                        int stride_width, int stride_height, int pad_width,
2555                        int pad_height, float* scaling_factors_ptr,
2556                        float output_activation_min, float output_activation_max,
2557                        int32_t* scratch_data, const Dims<4>& scratch_dims,
2558                        float* output_data, const Dims<4>& output_dims,
2559                        int8_t* im2col_data, const Dims<4>& im2col_dims,
2560                        CpuBackendContext* context) {
2561   tflite::ConvParams op_params;
2562   // Padding type is ignored, but still set.
2563   op_params.padding_type = PaddingType::kSame;
2564   op_params.padding_values.width = pad_width;
2565   op_params.padding_values.height = pad_height;
2566   op_params.stride_width = stride_width;
2567   op_params.stride_height = stride_height;
2568   op_params.float_activation_min = output_activation_min;
2569   op_params.float_activation_max = output_activation_max;
2570 
2571   HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
2572              input_data, DimsToShape(filter_dims), filter_data,
2573              DimsToShape(bias_dims), bias_data, DimsToShape(scratch_dims),
2574              scratch_data, DimsToShape(output_dims), output_data,
2575              DimsToShape(im2col_dims), im2col_data, context);
2576 }
2577 
2578 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2579 void Conv(const float* input_data, const Dims<4>& input_dims,
2580           const float* filter_data, const Dims<4>& filter_dims,
2581           const float* bias_data, const Dims<4>& bias_dims, int stride_width,
2582           int stride_height, int dilation_width_factor,
2583           int dilation_height_factor, int pad_width, int pad_height,
2584           float* output_data, const Dims<4>& output_dims, float* im2col_data,
2585           const Dims<4>& im2col_dims) {
2586   float output_activation_min, output_activation_max;
2587   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
2588   Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
2589        stride_width, stride_height, dilation_width_factor,
2590        dilation_height_factor, pad_width, pad_height, output_activation_min,
2591        output_activation_max, output_data, output_dims, im2col_data,
2592        im2col_dims);
2593 }
2594 
2595 // legacy, for compatibility with old checked-in code
2596 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2597 void Conv(const float* input_data, const Dims<4>& input_dims,
2598           const float* filter_data, const Dims<4>& filter_dims,
2599           const float* bias_data, const Dims<4>& bias_dims, int stride_width,
2600           int stride_height, int pad_width, int pad_height, float* output_data,
2601           const Dims<4>& output_dims, float* im2col_data,
2602           const Dims<4>& im2col_dims) {
2603   float output_activation_min, output_activation_max;
2604   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
2605   Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
2606        stride_width, stride_height, 1, 1, pad_width, pad_height,
2607        output_activation_min, output_activation_max, output_data, output_dims,
2608        im2col_data, im2col_dims);
2609 }
2610 
2611 // legacy, for compatibility with old checked-in code
2612 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2613 void Conv(const float* input_data, const Dims<4>& input_dims,
2614           const float* filter_data, const Dims<4>& filter_dims,
2615           const float* bias_data, const Dims<4>& bias_dims, int stride,
2616           int pad_width, int pad_height, float* output_data,
2617           const Dims<4>& output_dims, float* im2col_data,
2618           const Dims<4>& im2col_dims) {
2619   Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
2620            bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
2621            output_dims, im2col_data, im2col_dims);
2622 }
2623 
Conv(const ConvParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,const RuntimeShape & im2col_shape,uint8 * im2col_data,gemmlowp::GemmContext * gemmlowp_context)2624 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
2625                  const uint8* input_data, const RuntimeShape& filter_shape,
2626                  const uint8* filter_data, const RuntimeShape& bias_shape,
2627                  const int32* bias_data, const RuntimeShape& output_shape,
2628                  uint8* output_data, const RuntimeShape& im2col_shape,
2629                  uint8* im2col_data, gemmlowp::GemmContext* gemmlowp_context) {
2630   ruy::profiler::ScopeLabel label("Conv/8bit");
2631   const int stride_width = params.stride_width;
2632   const int stride_height = params.stride_height;
2633   const int dilation_width_factor = params.dilation_width_factor;
2634   const int dilation_height_factor = params.dilation_height_factor;
2635   const int32 input_offset = params.input_offset;
2636   const int32 filter_offset = params.weights_offset;
2637   const int32 output_offset = params.output_offset;
2638   const int32 output_multiplier = params.output_multiplier;
2639   const int output_shift = params.output_shift;
2640   const int32 output_activation_min = params.quantized_activation_min;
2641   const int32 output_activation_max = params.quantized_activation_max;
2642   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2643   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2644   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2645 
2646   const uint8* gemm_input_data = nullptr;
2647   const RuntimeShape* gemm_input_shape = nullptr;
2648   const int filter_width = filter_shape.Dims(2);
2649   const int filter_height = filter_shape.Dims(1);
2650   const bool need_dilated_im2col =
2651       dilation_width_factor != 1 || dilation_height_factor != 1;
2652   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2653                            filter_width != 1 || filter_height != 1;
2654   if (need_dilated_im2col) {
2655     TFLITE_DCHECK(im2col_data);
2656     const int input_zero_point = -input_offset;
2657     TFLITE_DCHECK_GE(input_zero_point, 0);
2658     TFLITE_DCHECK_LE(input_zero_point, 255);
2659     DilatedIm2col(params, input_zero_point, input_shape, input_data,
2660                   filter_shape, output_shape, im2col_data);
2661     gemm_input_data = im2col_data;
2662     gemm_input_shape = &im2col_shape;
2663   } else if (need_im2col) {
2664     TFLITE_DCHECK(im2col_data);
2665     const int input_zero_point = -input_offset;
2666     TFLITE_DCHECK_GE(input_zero_point, 0);
2667     TFLITE_DCHECK_LE(input_zero_point, 255);
2668     Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
2669            input_data, im2col_shape, im2col_data);
2670     gemm_input_data = im2col_data;
2671     gemm_input_shape = &im2col_shape;
2672   } else {
2673     TFLITE_DCHECK(!im2col_data);
2674     gemm_input_data = input_data;
2675     gemm_input_shape = &input_shape;
2676   }
2677 
2678   const int gemm_input_rows = gemm_input_shape->Dims(3);
2679   // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
2680   // The root cause has not yet been identified though. Same applies below for
2681   // the other calls commented out. This is a partial rollback of cl/196819423.
2682   // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
2683   const int gemm_input_cols = gemm_input_shape->Dims(0) *
2684                               gemm_input_shape->Dims(1) *
2685                               gemm_input_shape->Dims(2);
2686   const int filter_rows = filter_shape.Dims(0);
2687   // See b/79927784.
2688   // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
2689   const int filter_cols =
2690       filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
2691   const int output_rows = output_shape.Dims(3);
2692   // See b/79927784.
2693   // const int output_cols = FlatSizeSkipDim(output_shape, 3);
2694   const int output_cols =
2695       output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
2696   TFLITE_DCHECK_EQ(output_rows, filter_rows);
2697   TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
2698   TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
2699   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
2700 
2701 #ifdef USE_NEON
2702   if (gemm_input_cols == 1 && output_rows >= 4) {
2703     RuntimeShape fc_filter_shape{
2704         filter_shape.Dims(0),
2705         filter_shape.Dims(filter_shape.DimensionsCount() - 1)};
2706 
2707     return FullyConnectedAsGEMV(
2708         *gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape,
2709         filter_data, filter_offset, bias_shape, bias_data, output_offset,
2710         output_multiplier, output_shift, output_activation_min,
2711         output_activation_max, output_shape, output_data, gemmlowp_context);
2712   }
2713 #endif
2714 
2715   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2716       filter_data, filter_rows, filter_cols);
2717   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
2718       gemm_input_data, gemm_input_rows, gemm_input_cols);
2719   gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
2720       output_data, output_rows, output_cols);
2721   const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
2722       bias_data, output_rows, output_offset, output_multiplier, output_shift,
2723       output_activation_min, output_activation_max);
2724   gemmlowp::GemmWithOutputPipeline<uint8, uint8,
2725                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
2726       gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
2727       filter_offset, input_offset, output_pipeline);
2728 }
2729 
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2730 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
2731                  int32 input_offset, const uint8* filter_data,
2732                  const Dims<4>& filter_dims, int32 filter_offset,
2733                  const int32* bias_data, const Dims<4>& bias_dims,
2734                  int stride_width, int stride_height, int dilation_width_factor,
2735                  int dilation_height_factor, int pad_width, int pad_height,
2736                  int32 output_offset, int32 output_multiplier, int output_shift,
2737                  int32 output_activation_min, int32 output_activation_max,
2738                  uint8* output_data, const Dims<4>& output_dims,
2739                  uint8* im2col_data, const Dims<4>& im2col_dims,
2740                  gemmlowp::GemmContext* gemmlowp_context) {
2741   tflite::ConvParams op_params;
2742   // Padding type is ignored, but still set.
2743   op_params.padding_type = PaddingType::kSame;
2744   op_params.padding_values.width = pad_width;
2745   op_params.padding_values.height = pad_height;
2746   op_params.stride_width = stride_width;
2747   op_params.stride_height = stride_height;
2748   op_params.dilation_width_factor = dilation_width_factor;
2749   op_params.dilation_height_factor = dilation_height_factor;
2750   op_params.input_offset = input_offset;
2751   op_params.weights_offset = filter_offset;
2752   op_params.output_offset = output_offset;
2753   op_params.output_multiplier = output_multiplier;
2754   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
2755   op_params.output_shift = kReverseShift * output_shift;
2756   op_params.quantized_activation_min = output_activation_min;
2757   op_params.quantized_activation_max = output_activation_max;
2758 
2759   Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
2760        filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
2761        output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context);
2762 }
2763 
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2764 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
2765                  int32 input_offset, const uint8* filter_data,
2766                  const Dims<4>& filter_dims, int32 filter_offset,
2767                  const int32* bias_data, const Dims<4>& bias_dims,
2768                  int stride_width, int stride_height, int pad_width,
2769                  int pad_height, int32 output_offset, int32 output_multiplier,
2770                  int output_shift, int32 output_activation_min,
2771                  int32 output_activation_max, uint8* output_data,
2772                  const Dims<4>& output_dims, uint8* im2col_data,
2773                  const Dims<4>& im2col_dims,
2774                  gemmlowp::GemmContext* gemmlowp_context) {
2775   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
2776        filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
2777        pad_width, pad_height, output_offset, output_multiplier, output_shift,
2778        output_activation_min, output_activation_max, output_data, output_dims,
2779        im2col_data, im2col_dims, gemmlowp_context);
2780 }
2781 
2782 // legacy, for compatibility with old checked-in code
2783 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2784 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
2785                  int32 input_offset, const uint8* filter_data,
2786                  const Dims<4>& filter_dims, int32 filter_offset,
2787                  const int32* bias_data, const Dims<4>& bias_dims,
2788                  int stride_width, int stride_height, int pad_width,
2789                  int pad_height, int32 output_offset, int32 output_multiplier,
2790                  int output_shift, int32 output_activation_min,
2791                  int32 output_activation_max, uint8* output_data,
2792                  const Dims<4>& output_dims, uint8* im2col_data,
2793                  const Dims<4>& im2col_dims,
2794                  gemmlowp::GemmContext* gemmlowp_context) {
2795   static_assert(Ac == FusedActivationFunctionType::kNone ||
2796                     Ac == FusedActivationFunctionType::kRelu ||
2797                     Ac == FusedActivationFunctionType::kRelu6 ||
2798                     Ac == FusedActivationFunctionType::kRelu1,
2799                 "");
2800   if (Ac == FusedActivationFunctionType::kNone) {
2801     TFLITE_DCHECK_EQ(output_activation_min, 0);
2802     TFLITE_DCHECK_EQ(output_activation_max, 255);
2803   }
2804   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
2805        filter_offset, bias_data, bias_dims, stride_width, stride_height,
2806        pad_width, pad_height, output_offset, output_multiplier, output_shift,
2807        output_activation_min, output_activation_max, output_data, output_dims,
2808        im2col_data, im2col_dims, gemmlowp_context);
2809 }
2810 
2811 // legacy, for compatibility with old checked-in code
2812 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2813 void Conv(const uint8* input_data, const Dims<4>& input_dims,
2814           int32 input_offset, const uint8* filter_data,
2815           const Dims<4>& filter_dims, int32 filter_offset,
2816           const int32* bias_data, const Dims<4>& bias_dims, int stride,
2817           int pad_width, int pad_height, int32 output_offset,
2818           int32 output_multiplier, int output_shift,
2819           int32 output_activation_min, int32 output_activation_max,
2820           uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
2821           const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) {
2822   static_assert(Ac == FusedActivationFunctionType::kNone ||
2823                     Ac == FusedActivationFunctionType::kRelu ||
2824                     Ac == FusedActivationFunctionType::kRelu6 ||
2825                     Ac == FusedActivationFunctionType::kRelu1,
2826                 "");
2827   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
2828        filter_offset, bias_data, bias_dims, stride, stride, pad_width,
2829        pad_height, output_offset, output_multiplier, output_shift,
2830        output_activation_min, output_activation_max, output_data, output_dims,
2831        im2col_data, im2col_dims, gemmlowp_context);
2832 }
2833 
2834 // legacy, for compatibility with old checked-in code
2835 template <FusedActivationFunctionType Ac, typename T>
Im2col(const T * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int kheight,int kwidth,uint8 zero_byte,T * output_data,const Dims<4> & output_dims)2836 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
2837             int pad_width, int pad_height, int kheight, int kwidth,
2838             uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
2839   Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
2840          kwidth, zero_byte, output_data, output_dims);
2841 }
2842 
2843 // legacy, for compatibility with old checked-in code
2844 template <FusedActivationFunctionType Ac>
ConvAsGemm(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)2845 void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
2846                 const float* filter_data, const Dims<4>& filter_dims,
2847                 const float* bias_data, const Dims<4>& bias_dims,
2848                 float* output_data, const Dims<4>& output_dims) {
2849   ruy::profiler::ScopeLabel label("ConvAsGemm");
2850 
2851   const auto input_matrix_map =
2852       MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
2853   const auto filter_matrix_map =
2854       MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
2855   auto output_matrix_map =
2856       MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
2857 
2858   Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
2859 
2860   AddBiasAndEvalActivationFunction<Ac>(bias_data, bias_dims, output_data,
2861                                        output_dims);
2862 }
2863 
2864 // legacy, for compatibility with old checked-in code
2865 template <FusedActivationFunctionType Ac>
ConvAsGemm(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)2866 void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
2867                 int32 input_offset, const uint8* filter_data,
2868                 const Dims<4>& filter_dims, int32 filter_offset,
2869                 const int32* bias_data, const Dims<4>& bias_dims,
2870                 int32 output_offset, int32 output_multiplier, int output_shift,
2871                 int32 output_activation_min, int32 output_activation_max,
2872                 uint8* output_data, const Dims<4>& output_dims,
2873                 gemmlowp::GemmContext* gemmlowp_context) {
2874   ruy::profiler::ScopeLabel label("ConvAsGemm/8bit");
2875   static_assert(Ac == FusedActivationFunctionType::kNone ||
2876                     Ac == FusedActivationFunctionType::kRelu ||
2877                     Ac == FusedActivationFunctionType::kRelu6 ||
2878                     Ac == FusedActivationFunctionType::kRelu1,
2879                 "");
2880   const int input_rows = input_dims.sizes[0];
2881   const int input_cols = FlatSizeSkipDim(input_dims, 0);
2882   const int filter_rows = filter_dims.sizes[3];
2883   const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
2884   const int output_rows = output_dims.sizes[0];
2885   const int output_cols = FlatSizeSkipDim(output_dims, 0);
2886   TFLITE_DCHECK_EQ(output_rows, filter_rows);
2887   TFLITE_DCHECK_EQ(output_cols, input_cols);
2888   TFLITE_DCHECK_EQ(filter_cols, input_rows);
2889   TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
2890   TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
2891   TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
2892   TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
2893   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2894       filter_data, output_rows, filter_cols, filter_cols);
2895   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
2896       input_data, filter_cols, output_cols, filter_cols);
2897   gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
2898       output_data, output_rows, output_cols, output_rows);
2899   const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
2900       bias_data, output_rows, output_offset, output_multiplier, -output_shift,
2901       output_activation_min, output_activation_max);
2902   gemmlowp::GemmWithOutputPipeline<uint8, uint8,
2903                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
2904       gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
2905       filter_offset, input_offset, output_pipeline);
2906 }
2907 
TransposeConv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)2908 inline void TransposeConv(
2909     const ConvParams& params, const RuntimeShape& input_shape,
2910     const float* input_data, const RuntimeShape& filter_shape,
2911     const float* filter_data, const RuntimeShape& output_shape,
2912     float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
2913   ruy::profiler::ScopeLabel label("TransposeConv");
2914   // Note we could use transposed weights with forward conv for unstrided
2915   // cases. But we are already getting good performance with this code as-is.
2916   TFLITE_DCHECK(im2col_data);
2917   TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
2918                   output_shape, im2col_data);
2919 
2920   const auto im2col_matrix_map =
2921       MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
2922   const auto filter_matrix_map =
2923       MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
2924   auto output_matrix_map =
2925       MapAsMatrixWithLastDimAsRows(output_data, output_shape);
2926 
2927   Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
2928 }
2929 
TransposeConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2930 inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
2931                           const float* filter_data, const Dims<4>& filter_dims,
2932                           int stride_width, int stride_height, int pad_width,
2933                           int pad_height, float* output_data,
2934                           const Dims<4>& output_dims, float* im2col_data,
2935                           const Dims<4>& im2col_dims) {
2936   tflite::ConvParams op_params;
2937   // Padding type is ignored, but still set.
2938   op_params.padding_type = PaddingType::kSame;
2939   op_params.padding_values.width = pad_width;
2940   op_params.padding_values.height = pad_height;
2941   op_params.stride_width = stride_width;
2942   op_params.stride_height = stride_height;
2943 
2944   TransposeConv(op_params, DimsToShape(input_dims), input_data,
2945                 DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
2946                 output_data, DimsToShape(im2col_dims), im2col_data);
2947 }
2948 
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const float * hwoi_ordered_filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)2949 inline void TransposeConvV2(
2950     const ConvParams& params, const RuntimeShape& input_shape,
2951     const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
2952     const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
2953     float* output_data, const RuntimeShape& col2im_shape, float* col2im_data,
2954     CpuBackendContext* cpu_backend_context) {
2955   TransposeConvV2(params, input_shape, input_data, hwoi_ordered_filter_shape,
2956                   hwoi_ordered_filter_data, /*bias_shape*/ RuntimeShape(),
2957                   /*bias_data*/ nullptr, output_shape, output_data,
2958                   col2im_shape, col2im_data, cpu_backend_context);
2959 }
2960 
2961 template <typename T>
TransposeIm2col(const T * input_data,const Dims<4> & input_dims,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,const Dims<4> & output_dims,uint8 zero_byte,T * im2col_data)2962 void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
2963                      const Dims<4>& filter_dims, int stride_width,
2964                      int stride_height, int pad_width, int pad_height,
2965                      const Dims<4>& output_dims, uint8 zero_byte,
2966                      T* im2col_data) {
2967   tflite::ConvParams op_params;
2968   // Padding type is ignored, but still set.
2969   op_params.padding_type = PaddingType::kSame;
2970   op_params.padding_values.width = pad_width;
2971   op_params.padding_values.height = pad_height;
2972   op_params.stride_width = stride_width;
2973   op_params.stride_height = stride_height;
2974 
2975   TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
2976                   DimsToShape(filter_dims), DimsToShape(output_dims),
2977                   im2col_data);
2978 }
2979 
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data)2980 inline void LstmCell(
2981     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2982     const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
2983     const float* prev_activ_data, const RuntimeShape& weights_shape,
2984     const float* weights_data, const RuntimeShape& unextended_bias_shape,
2985     const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
2986     const float* prev_state_data,
2987     const RuntimeShape& unextended_output_state_shape, float* output_state_data,
2988     const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
2989     const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
2990     const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
2991   ruy::profiler::ScopeLabel label("LstmCell");
2992   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2993   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2994   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2995   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2996   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2997   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2998   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2999   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
3000   const RuntimeShape input_shape =
3001       RuntimeShape::ExtendedShape(4, unextended_input_shape);
3002   const RuntimeShape prev_activ_shape =
3003       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
3004   const RuntimeShape bias_shape =
3005       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
3006   const RuntimeShape prev_state_shape =
3007       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
3008   const RuntimeShape output_state_shape =
3009       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
3010   const RuntimeShape output_activ_shape =
3011       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
3012   const RuntimeShape concat_temp_shape =
3013       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
3014   const RuntimeShape activ_temp_shape =
3015       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
3016   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
3017 
3018   const int weights_dim_count = weights_shape.DimensionsCount();
3019   MatchingDim(  // batches
3020       input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
3021       output_state_shape, 0, output_activ_shape, 0);
3022   MatchingDim(  // height
3023       input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
3024       output_state_shape, 1, output_activ_shape, 1);
3025   MatchingDim(  // width
3026       input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
3027       output_state_shape, 2, output_activ_shape, 2);
3028   const int input_depth = input_shape.Dims(3);
3029   const int prev_activ_depth = prev_activ_shape.Dims(3);
3030   const int total_input_depth = prev_activ_depth + input_depth;
3031   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
3032                    total_input_depth);
3033   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
3034   const int intern_activ_depth =
3035       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
3036   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
3037                    intern_activ_depth * total_input_depth);
3038   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
3039   const int output_depth =
3040       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
3041                   3, output_activ_shape, 3);
3042   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
3043 
3044   // Concatenate prev_activ and input data together
3045   std::vector<float const*> concat_input_arrays_data;
3046   std::vector<RuntimeShape const*> concat_input_arrays_shapes;
3047   concat_input_arrays_data.push_back(input_data);
3048   concat_input_arrays_data.push_back(prev_activ_data);
3049   concat_input_arrays_shapes.push_back(&input_shape);
3050   concat_input_arrays_shapes.push_back(&prev_activ_shape);
3051   tflite::ConcatenationParams concat_params;
3052   concat_params.axis = 3;
3053   concat_params.inputs_count = concat_input_arrays_data.size();
3054   Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
3055                 &(concat_input_arrays_data[0]), concat_temp_shape,
3056                 concat_temp_data);
3057 
3058   // Fully connected
3059   tflite::FullyConnectedParams fc_params;
3060   fc_params.float_activation_min = std::numeric_limits<float>::lowest();
3061   fc_params.float_activation_max = std::numeric_limits<float>::max();
3062   FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
3063                  weights_data, bias_shape, bias_data, activ_temp_shape,
3064                  activ_temp_data);
3065 
3066   // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
3067   // operations.
3068   ArrayMap<float> activ_temp_map =
3069       MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
3070   auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
3071                                             activ_temp_map.cols());
3072   auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
3073                                            activ_temp_map.cols());
3074   auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
3075                                              activ_temp_map.cols());
3076   auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
3077                                              activ_temp_map.cols());
3078   ArrayMap<const float> prev_state_map =
3079       MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
3080   ArrayMap<float> output_state_map =
3081       MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
3082   ArrayMap<float> output_activ_map =
3083       MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
3084 
3085   // Combined memory state and final output calculation
3086   ruy::profiler::ScopeLabel label2("MemoryStateAndFinalOutput");
3087   output_state_map =
3088       input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3089           new_input_sm.tanh() +
3090       forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3091           prev_state_map;
3092   output_activ_map =
3093       output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3094       output_state_map.tanh();
3095 }
3096 
LstmCell(const float * input_data,const Dims<4> & input_dims,const float * prev_activ_data,const Dims<4> & prev_activ_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,const float * prev_state_data,const Dims<4> & prev_state_dims,float * output_state_data,const Dims<4> & output_state_dims,float * output_activ_data,const Dims<4> & output_activ_dims,float * concat_temp_data,const Dims<4> & concat_temp_dims,float * activ_temp_data,const Dims<4> & activ_temp_dims)3097 inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
3098                      const float* prev_activ_data,
3099                      const Dims<4>& prev_activ_dims, const float* weights_data,
3100                      const Dims<4>& weights_dims, const float* bias_data,
3101                      const Dims<4>& bias_dims, const float* prev_state_data,
3102                      const Dims<4>& prev_state_dims, float* output_state_data,
3103                      const Dims<4>& output_state_dims, float* output_activ_data,
3104                      const Dims<4>& output_activ_dims, float* concat_temp_data,
3105                      const Dims<4>& concat_temp_dims, float* activ_temp_data,
3106                      const Dims<4>& activ_temp_dims) {
3107   tflite::LstmCellParams op_params;
3108   // Float LSTM cell does not need parameters to be set: leave untouched.
3109 
3110   LstmCell(op_params, DimsToShape(input_dims), input_data,
3111            DimsToShape(prev_activ_dims), prev_activ_data,
3112            DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
3113            bias_data, DimsToShape(prev_state_dims), prev_state_data,
3114            DimsToShape(output_state_dims), output_state_data,
3115            DimsToShape(output_activ_dims), output_activ_data,
3116            DimsToShape(concat_temp_dims), concat_temp_data,
3117            DimsToShape(activ_temp_dims), activ_temp_data);
3118 }
3119 
3120 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,gemmlowp::GemmContext * gemmlowp_context)3121 inline void LstmCell(
3122     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
3123     const uint8* input_data_uint8,
3124     const RuntimeShape& unextended_prev_activ_shape,
3125     const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
3126     const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
3127     const int32* bias_data_int32,
3128     const RuntimeShape& unextended_prev_state_shape,
3129     const int16* prev_state_data_int16,
3130     const RuntimeShape& unextended_output_state_shape,
3131     int16* output_state_data_int16,
3132     const RuntimeShape& unextended_output_activ_shape,
3133     uint8* output_activ_data_uint8,
3134     const RuntimeShape& unextended_concat_temp_shape,
3135     uint8* concat_temp_data_uint8,
3136     const RuntimeShape& unextended_activ_temp_shape,
3137     int16* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) {
3138   ruy::profiler::ScopeLabel label(
3139       "LstmCell/quantized (8bit external, 16bit internal)");
3140   int32 weights_zero_point = params.weights_zero_point;
3141   int32 accum_multiplier = params.accum_multiplier;
3142   int accum_shift = params.accum_shift;
3143   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
3144   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
3145   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
3146   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
3147   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
3148   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
3149   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
3150   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
3151   const RuntimeShape input_shape =
3152       RuntimeShape::ExtendedShape(4, unextended_input_shape);
3153   const RuntimeShape prev_activ_shape =
3154       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
3155   const RuntimeShape bias_shape =
3156       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
3157   const RuntimeShape prev_state_shape =
3158       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
3159   const RuntimeShape output_state_shape =
3160       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
3161   const RuntimeShape output_activ_shape =
3162       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
3163   const RuntimeShape concat_temp_shape =
3164       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
3165   const RuntimeShape activ_temp_shape =
3166       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
3167   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
3168 
3169   // Gather dimensions information, and perform consistency checks.
3170   const int weights_dim_count = weights_shape.DimensionsCount();
3171   const int outer_size = MatchingFlatSizeSkipDim(
3172       input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
3173       output_activ_shape);
3174   const int input_depth = input_shape.Dims(3);
3175   const int prev_activ_depth = prev_activ_shape.Dims(3);
3176   const int total_input_depth = prev_activ_depth + input_depth;
3177   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
3178                    total_input_depth);
3179   const int intern_activ_depth =
3180       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
3181   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
3182                    intern_activ_depth * total_input_depth);
3183   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
3184   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
3185   const int output_depth =
3186       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
3187                   3, output_activ_shape, 3);
3188   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
3189   const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
3190   const int fc_output_depth =
3191       MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
3192   const int fc_accum_depth = total_input_depth;
3193   TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
3194 
3195   // Depth-concatenate prev_activ and input data together.
3196   uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
3197                                               prev_activ_data_uint8};
3198   const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
3199                                                        &prev_activ_shape};
3200   tflite::ConcatenationParams concat_params;
3201   concat_params.axis = 3;
3202   concat_params.inputs_count = 2;
3203   Concatenation(concat_params, concat_input_arrays_shapes,
3204                 concat_input_arrays_data, concat_temp_shape,
3205                 concat_temp_data_uint8);
3206 
3207   // Implementation of the fully connected node inside the LSTM cell.
3208   // The operands are 8-bit integers, the accumulators are internally 32bit
3209   // integers, and the output is 16-bit fixed-point with 3 integer bits so
3210   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
3211   // is explained in the function comment above.
3212   bool gemm_already_performed = false;
3213 #ifdef GEMMLOWP_NEON
3214   if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) {
3215     GEMVForLstmCell(concat_temp_shape, concat_temp_data_uint8, weights_shape,
3216                     weights_data_uint8, weights_zero_point, bias_shape,
3217                     bias_data_int32, accum_multiplier, accum_shift,
3218                     activ_temp_shape, activ_temp_data_int16);
3219     gemm_already_performed = true;
3220   }
3221 #endif
3222   if (!gemm_already_performed) {
3223     gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor>
3224         weights_matrix(weights_data_uint8, fc_output_depth, fc_accum_depth);
3225     gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
3226         concat_temp_data_uint8, fc_accum_depth, fc_batches);
3227     gemmlowp::MatrixMap<int16, gemmlowp::MapOrder::ColMajor> output_matrix(
3228         activ_temp_data_int16, fc_output_depth, fc_batches);
3229     typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
3230         ColVectorMap;
3231     ColVectorMap bias_vector(bias_data_int32, fc_output_depth);
3232     gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
3233     bias_addition_stage.bias_vector = bias_vector;
3234     gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
3235     scale_stage.result_offset_after_shift = 0;
3236     scale_stage.result_fixedpoint_multiplier = accum_multiplier;
3237     scale_stage.result_exponent = accum_shift;
3238     gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
3239     auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
3240                                            saturating_cast_int16_stage);
3241     gemmlowp::GemmWithOutputPipeline<
3242         uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
3243         gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
3244         -weights_zero_point, -128, output_pipeline);
3245   }
3246 
3247   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
3248   // and muls, all done in 16-bit fixed-point.
3249   const int16* input_gate_input_ptr = activ_temp_data_int16;
3250   const int16* input_modulation_gate_input_ptr =
3251       activ_temp_data_int16 + output_depth;
3252   const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth;
3253   const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth;
3254   const int16* prev_state_ptr = prev_state_data_int16;
3255   int16* output_state_data_ptr = output_state_data_int16;
3256   uint8* output_activ_data_ptr = output_activ_data_uint8;
3257 
3258   for (int b = 0; b < outer_size; ++b) {
3259     int c = 0;
3260 #ifdef GEMMLOWP_NEON
3261     for (; c <= output_depth - 8; c += 8) {
3262       // Define the fixed-point data types that we will use here. All use
3263       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3264       // They only differ by the number of integral vs. fractional bits,
3265       // determining the range of values that they can represent.
3266       //
3267       // F0 uses 0 integer bits, range [-1, 1].
3268       // This is the return type of math functions such as tanh, logistic,
3269       // whose range is in [-1, 1].
3270       using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
3271       // F3 uses 3 integer bits, range [-8, 8].
3272       // This is the range of the previous fully-connected node's output,
3273       // which is our input here.
3274       using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
3275       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3276       // 2^StateIntegerBits]. It's used to represent the internal state, whose
3277       // number of integer bits is currently dictated by the model. See comment
3278       // on the StateIntegerBits template parameter above.
3279       using FS = gemmlowp::FixedPoint<int16x8_t, StateIntegerBits>;
3280       // Implementation of input gate, using fixed-point logistic function.
3281       F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr));
3282       input_gate_input_ptr += 8;
3283       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3284       // Implementation of input modulation gate, using fixed-point tanh
3285       // function.
3286       F3 input_modulation_gate_input =
3287           F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr));
3288       input_modulation_gate_input_ptr += 8;
3289       F0 input_modulation_gate_output =
3290           gemmlowp::tanh(input_modulation_gate_input);
3291       // Implementation of forget gate, using fixed-point logistic function.
3292       F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr));
3293       forget_gate_input_ptr += 8;
3294       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3295       // Implementation of output gate, using fixed-point logistic function.
3296       F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr));
3297       output_gate_input_ptr += 8;
3298       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3299       // Implementation of internal multiplication nodes, still in fixed-point.
3300       F0 input_times_input_modulation =
3301           input_gate_output * input_modulation_gate_output;
3302       FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr));
3303       prev_state_ptr += 8;
3304       FS prev_state_times_forget_state = forget_gate_output * prev_state;
3305       // Implementation of internal addition node, saturating.
3306       FS new_state = gemmlowp::SaturatingAdd(
3307           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3308           prev_state_times_forget_state);
3309       // Implementation of last internal Tanh node, still in fixed-point.
3310       // Since a Tanh fixed-point implementation is specialized for a given
3311       // number or integer bits, and each specialization can have a substantial
3312       // code size, and we already used above a Tanh on an input with 3 integer
3313       // bits, and per the table in the above function comment there is no
3314       // significant accuracy to be lost by clamping to [-8, +8] for a
3315       // 3-integer-bits representation, let us just do that. This helps people
3316       // porting this to targets where code footprint must be minimized.
3317       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3318       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3319       // Store the new internal state back to memory, as 16-bit integers.
3320       // Note: here we store the original value with StateIntegerBits, not
3321       // the rescaled 3-integer-bits value fed to tanh.
3322       vst1q_s16(output_state_data_ptr, new_state.raw());
3323       output_state_data_ptr += 8;
3324       // Down-scale the output activations to 8-bit integers, saturating,
3325       // and store back to memory.
3326       int16x8_t rescaled_output_activ =
3327           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3328       int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ);
3329       uint8x8_t uint8_output_activ =
3330           vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ));
3331       vst1_u8(output_activ_data_ptr, uint8_output_activ);
3332       output_activ_data_ptr += 8;
3333     }
3334 #endif
3335     for (; c < output_depth; ++c) {
3336       // Define the fixed-point data types that we will use here. All use
3337       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3338       // They only differ by the number of integral vs. fractional bits,
3339       // determining the range of values that they can represent.
3340       //
3341       // F0 uses 0 integer bits, range [-1, 1].
3342       // This is the return type of math functions such as tanh, logistic,
3343       // whose range is in [-1, 1].
3344       using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3345       // F3 uses 3 integer bits, range [-8, 8].
3346       // This is the range of the previous fully-connected node's output,
3347       // which is our input here.
3348       using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
3349       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3350       // 2^StateIntegerBits]. It's used to represent the internal state, whose
3351       // number of integer bits is currently dictated by the model. See comment
3352       // on the StateIntegerBits template parameter above.
3353       using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
3354       // Implementation of input gate, using fixed-point logistic function.
3355       F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++);
3356       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3357       // Implementation of input modulation gate, using fixed-point tanh
3358       // function.
3359       F3 input_modulation_gate_input =
3360           F3::FromRaw(*input_modulation_gate_input_ptr++);
3361       F0 input_modulation_gate_output =
3362           gemmlowp::tanh(input_modulation_gate_input);
3363       // Implementation of forget gate, using fixed-point logistic function.
3364       F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++);
3365       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3366       // Implementation of output gate, using fixed-point logistic function.
3367       F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++);
3368       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3369       // Implementation of internal multiplication nodes, still in fixed-point.
3370       F0 input_times_input_modulation =
3371           input_gate_output * input_modulation_gate_output;
3372       FS prev_state = FS::FromRaw(*prev_state_ptr++);
3373       FS prev_state_times_forget_state = forget_gate_output * prev_state;
3374       // Implementation of internal addition node, saturating.
3375       FS new_state = gemmlowp::SaturatingAdd(
3376           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3377           prev_state_times_forget_state);
3378       // Implementation of last internal Tanh node, still in fixed-point.
3379       // Since a Tanh fixed-point implementation is specialized for a given
3380       // number or integer bits, and each specialization can have a substantial
3381       // code size, and we already used above a Tanh on an input with 3 integer
3382       // bits, and per the table in the above function comment there is no
3383       // significant accuracy to be lost by clamping to [-8, +8] for a
3384       // 3-integer-bits representation, let us just do that. This helps people
3385       // porting this to targets where code footprint must be minimized.
3386       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3387       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3388       // Store the new internal state back to memory, as 16-bit integers.
3389       // Note: here we store the original value with StateIntegerBits, not
3390       // the rescaled 3-integer-bits value fed to tanh.
3391       *output_state_data_ptr++ = new_state.raw();
3392       // Down-scale the output activations to 8-bit integers, saturating,
3393       // and store back to memory.
3394       int16 rescaled_output_activ =
3395           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3396       int16 clamped_output_activ =
3397           std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
3398       *output_activ_data_ptr++ = 128 + clamped_output_activ;
3399     }
3400     input_gate_input_ptr += 3 * output_depth;
3401     input_modulation_gate_input_ptr += 3 * output_depth;
3402     forget_gate_input_ptr += 3 * output_depth;
3403     output_gate_input_ptr += 3 * output_depth;
3404   }
3405 }
3406 
3407 template <int StateIntegerBits>
LstmCell(const uint8 * input_data_uint8,const Dims<4> & input_dims,const uint8 * prev_activ_data_uint8,const Dims<4> & prev_activ_dims,const uint8 * weights_data_uint8,const Dims<4> & weights_dims,const int32 * bias_data_int32,const Dims<4> & bias_dims,const int16 * prev_state_data_int16,const Dims<4> & prev_state_dims,int16 * output_state_data_int16,const Dims<4> & output_state_dims,uint8 * output_activ_data_uint8,const Dims<4> & output_activ_dims,uint8 * concat_temp_data_uint8,const Dims<4> & concat_temp_dims,int16 * activ_temp_data_int16,const Dims<4> & activ_temp_dims,int32 weights_zero_point,int32 accum_multiplier,int accum_shift,gemmlowp::GemmContext * gemmlowp_context)3408 void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
3409               const uint8* prev_activ_data_uint8,
3410               const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
3411               const Dims<4>& weights_dims, const int32* bias_data_int32,
3412               const Dims<4>& bias_dims, const int16* prev_state_data_int16,
3413               const Dims<4>& prev_state_dims, int16* output_state_data_int16,
3414               const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
3415               const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
3416               const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
3417               const Dims<4>& activ_temp_dims, int32 weights_zero_point,
3418               int32 accum_multiplier, int accum_shift,
3419               gemmlowp::GemmContext* gemmlowp_context) {
3420   tflite::LstmCellParams op_params;
3421   op_params.weights_zero_point = weights_zero_point;
3422   op_params.accum_multiplier = accum_multiplier;
3423   op_params.accum_shift = accum_shift;
3424 
3425   LstmCell<StateIntegerBits>(
3426       op_params, DimsToShape(input_dims), input_data_uint8,
3427       DimsToShape(prev_activ_dims), prev_activ_data_uint8,
3428       DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
3429       bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
3430       DimsToShape(output_state_dims), output_state_data_int16,
3431       DimsToShape(output_activ_dims), output_activ_data_uint8,
3432       DimsToShape(concat_temp_dims), concat_temp_data_uint8,
3433       DimsToShape(activ_temp_dims), activ_temp_data_int16, gemmlowp_context);
3434 }
3435 
3436 template <typename T>
BroadcastDiv(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)3437 void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
3438                   const T* input2_data, const Dims<4>& input2_dims,
3439                   T output_activation_min, T output_activation_max,
3440                   T* output_data, const Dims<4>& output_dims) {
3441   tflite::ArithmeticParams op_params;
3442   SetActivationParams(output_activation_min, output_activation_max, &op_params);
3443 
3444   BroadcastDivSlow(op_params, DimsToShape(input1_dims), input1_data,
3445                    DimsToShape(input2_dims), input2_data,
3446                    DimsToShape(output_dims), output_data);
3447 }
3448 
3449 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)3450 void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
3451                      float* output_data, const RuntimeShape& output_shape) {
3452   static_assert(Ac == FusedActivationFunctionType::kNone, "");
3453   tflite::L2NormalizationParams op_params;
3454   // No params need to be set for float, but reserved in signature for future
3455   // activations.
3456 
3457   L2Normalization(op_params, input_shape, input_data, output_shape,
3458                   output_data);
3459 }
3460 
L2Normalization(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,uint8 * output_data,const RuntimeShape & output_shape)3461 inline void L2Normalization(const uint8* input_data,
3462                             const RuntimeShape& input_shape,
3463                             int32 input_zero_point, uint8* output_data,
3464                             const RuntimeShape& output_shape) {
3465   tflite::L2NormalizationParams op_params;
3466   op_params.input_zero_point = input_zero_point;
3467 
3468   L2Normalization(op_params, input_shape, input_data, output_shape,
3469                   output_data);
3470 }
3471 
3472 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)3473 void L2Normalization(const float* input_data, const Dims<4>& input_dims,
3474                      float* output_data, const Dims<4>& output_dims) {
3475   L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
3476                       DimsToShape(output_dims));
3477 }
3478 
L2Normalization(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,uint8 * output_data,const Dims<4> & output_dims)3479 inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
3480                             int32 input_zero_point, uint8* output_data,
3481                             const Dims<4>& output_dims) {
3482   L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
3483                   output_data, DimsToShape(output_dims));
3484 }
3485 
Relu(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)3486 inline void Relu(const float* input_data, const Dims<4>& input_dims,
3487                  float* output_data, const Dims<4>& output_dims) {
3488   Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3489        output_data);
3490 }
3491 
3492 // legacy, for compatibility with old checked-in code
3493 template <FusedActivationFunctionType Ac>
Add(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)3494 void Add(const float* input1_data, const Dims<4>& input1_dims,
3495          const float* input2_data, const Dims<4>& input2_dims,
3496          float* output_data, const Dims<4>& output_dims) {
3497   float output_activation_min, output_activation_max;
3498   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3499 
3500   tflite::ArithmeticParams op_params;
3501   op_params.float_activation_min = output_activation_min;
3502   op_params.float_activation_max = output_activation_max;
3503   Add(op_params, DimsToShape(input1_dims), input1_data,
3504       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3505       output_data);
3506 }
3507 
3508 template <FusedActivationFunctionType Ac>
Add(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3509 inline void Add(int left_shift, const uint8* input1_data,
3510                 const Dims<4>& input1_dims, int32 input1_offset,
3511                 int32 input1_multiplier, int input1_shift,
3512                 const uint8* input2_data, const Dims<4>& input2_dims,
3513                 int32 input2_offset, int32 input2_multiplier, int input2_shift,
3514                 int32 output_offset, int32 output_multiplier, int output_shift,
3515                 int32 output_activation_min, int32 output_activation_max,
3516                 uint8* output_data, const Dims<4>& output_dims) {
3517   constexpr int kReverseShift = -1;
3518   static_assert(Ac == FusedActivationFunctionType::kNone ||
3519                     Ac == FusedActivationFunctionType::kRelu ||
3520                     Ac == FusedActivationFunctionType::kRelu6 ||
3521                     Ac == FusedActivationFunctionType::kRelu1,
3522                 "");
3523   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3524   if (Ac == FusedActivationFunctionType::kNone) {
3525     TFLITE_DCHECK_EQ(output_activation_min, 0);
3526     TFLITE_DCHECK_EQ(output_activation_max, 255);
3527   }
3528 
3529   tflite::ArithmeticParams op_params;
3530   op_params.left_shift = left_shift;
3531   op_params.input1_offset = input1_offset;
3532   op_params.input1_multiplier = input1_multiplier;
3533   op_params.input1_shift = kReverseShift * input1_shift;
3534   op_params.input2_offset = input2_offset;
3535   op_params.input2_multiplier = input2_multiplier;
3536   op_params.input2_shift = kReverseShift * input2_shift;
3537   op_params.output_offset = output_offset;
3538   op_params.output_multiplier = output_multiplier;
3539   op_params.output_shift = kReverseShift * output_shift;
3540   op_params.quantized_activation_min = output_activation_min;
3541   op_params.quantized_activation_max = output_activation_max;
3542   Add(op_params, DimsToShape(input1_dims), input1_data,
3543       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3544       output_data);
3545 }
3546 
3547 template <FusedActivationFunctionType Ac>
Add(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)3548 void Add(const int32* input1_data, const Dims<4>& input1_dims,
3549          const int32* input2_data, const Dims<4>& input2_dims,
3550          int32* output_data, const Dims<4>& output_dims) {
3551   ruy::profiler::ScopeLabel label("Add/int32");
3552   TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
3553 
3554   tflite::ArithmeticParams op_params;
3555   op_params.quantized_activation_min = std::numeric_limits<int32>::min();
3556   op_params.quantized_activation_max = std::numeric_limits<int32>::max();
3557   Add(op_params, DimsToShape(input1_dims), input1_data,
3558       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3559       output_data);
3560 }
3561 
3562 template <typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)3563 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
3564                   const T* input2_data, const Dims<4>& input2_dims,
3565                   T output_activation_min, T output_activation_max,
3566                   T* output_data, const Dims<4>& output_dims) {
3567   tflite::ArithmeticParams op_params;
3568   op_params.float_activation_min = output_activation_min;
3569   op_params.float_activation_max = output_activation_max;
3570   BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
3571                      DimsToShape(input2_dims), input2_data,
3572                      DimsToShape(output_dims), output_data);
3573 }
3574 
3575 template <FusedActivationFunctionType Ac>
BroadcastAdd(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3576 inline void BroadcastAdd(int left_shift, const uint8* input1_data,
3577                          const Dims<4>& input1_dims, int32 input1_offset,
3578                          int32 input1_multiplier, int input1_shift,
3579                          const uint8* input2_data, const Dims<4>& input2_dims,
3580                          int32 input2_offset, int32 input2_multiplier,
3581                          int input2_shift, int32 output_offset,
3582                          int32 output_multiplier, int output_shift,
3583                          int32 output_activation_min,
3584                          int32 output_activation_max, uint8* output_data,
3585                          const Dims<4>& output_dims) {
3586   constexpr int kReverseShift = -1;
3587   static_assert(Ac == FusedActivationFunctionType::kNone ||
3588                     Ac == FusedActivationFunctionType::kRelu ||
3589                     Ac == FusedActivationFunctionType::kRelu6 ||
3590                     Ac == FusedActivationFunctionType::kRelu1,
3591                 "");
3592   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3593   if (Ac == FusedActivationFunctionType::kNone) {
3594     TFLITE_DCHECK_EQ(output_activation_min, 0);
3595     TFLITE_DCHECK_EQ(output_activation_max, 255);
3596   }
3597 
3598   tflite::ArithmeticParams op_params;
3599   op_params.left_shift = left_shift;
3600   op_params.input1_offset = input1_offset;
3601   op_params.input1_multiplier = input1_multiplier;
3602   op_params.input1_shift = kReverseShift * input1_shift;
3603   op_params.input2_offset = input2_offset;
3604   op_params.input2_multiplier = input2_multiplier;
3605   op_params.input2_shift = kReverseShift * input2_shift;
3606   op_params.output_offset = output_offset;
3607   op_params.output_multiplier = output_multiplier;
3608   op_params.output_shift = kReverseShift * output_shift;
3609   op_params.quantized_activation_min = output_activation_min;
3610   op_params.quantized_activation_max = output_activation_max;
3611   BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
3612                      DimsToShape(input2_dims), input2_data,
3613                      DimsToShape(output_dims), output_data);
3614 }
3615 
3616 template <FusedActivationFunctionType Ac>
BroadcastAddFivefold(int y0,int y1,int y2,int y3,int y4,int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3617 inline void BroadcastAddFivefold(
3618     int y0, int y1, int y2, int y3, int y4, int left_shift,
3619     const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
3620     int32 input1_multiplier, int input1_shift, const uint8* input2_data,
3621     const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
3622     int input2_shift, int32 output_offset, int32 output_multiplier,
3623     int output_shift, int32 output_activation_min, int32 output_activation_max,
3624     uint8* output_data, const Dims<4>& output_dims) {
3625   constexpr int kReverseShift = -1;
3626   static_assert(Ac == FusedActivationFunctionType::kNone ||
3627                     Ac == FusedActivationFunctionType::kRelu ||
3628                     Ac == FusedActivationFunctionType::kRelu6 ||
3629                     Ac == FusedActivationFunctionType::kRelu1,
3630                 "");
3631   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3632   if (Ac == FusedActivationFunctionType::kNone) {
3633     TFLITE_DCHECK_EQ(output_activation_min, 0);
3634     TFLITE_DCHECK_EQ(output_activation_max, 255);
3635   }
3636   tflite::ArithmeticParams op_params;
3637   op_params.broadcast_category =
3638       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
3639   op_params.left_shift = left_shift;
3640   op_params.input1_offset = input1_offset;
3641   op_params.input1_multiplier = input1_multiplier;
3642   op_params.input1_shift = kReverseShift * input1_shift;
3643   op_params.input2_offset = input2_offset;
3644   op_params.input2_multiplier = input2_multiplier;
3645   op_params.input2_shift = kReverseShift * input2_shift;
3646   op_params.output_offset = output_offset;
3647   op_params.output_multiplier = output_multiplier;
3648   op_params.output_shift = kReverseShift * output_shift;
3649   op_params.quantized_activation_min = output_activation_min;
3650   op_params.quantized_activation_max = output_activation_max;
3651   op_params.broadcast_shape[4] = y0;
3652   op_params.broadcast_shape[3] = y1;
3653   op_params.broadcast_shape[2] = y2;
3654   op_params.broadcast_shape[1] = y3;
3655   op_params.broadcast_shape[0] = y4;
3656   BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
3657                        DimsToShape(input2_dims), input2_data,
3658                        DimsToShape(output_dims), output_data);
3659 }
3660 
3661 // legacy, for compatibility with old checked-in code
3662 template <FusedActivationFunctionType Ac, typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)3663 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
3664                   const T* input2_data, const Dims<4>& input2_dims,
3665                   T* output_data, const Dims<4>& output_dims) {
3666   T output_activation_min, output_activation_max;
3667   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3668 
3669   BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
3670                output_activation_min, output_activation_max, output_data,
3671                output_dims);
3672 }
3673 
3674 template <FusedActivationFunctionType Ac>
Add(const int16 * input1_data,const Dims<4> & input1_dims,int input1_shift,const int16 * input2_data,const Dims<4> & input2_dims,int input2_shift,int16 output_activation_min,int16 output_activation_max,int16 * output_data,const Dims<4> & output_dims)3675 inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
3676                 int input1_shift, const int16* input2_data,
3677                 const Dims<4>& input2_dims, int input2_shift,
3678                 int16 output_activation_min, int16 output_activation_max,
3679                 int16* output_data, const Dims<4>& output_dims) {
3680   constexpr int kReverseShift = -1;
3681   static_assert(Ac == FusedActivationFunctionType::kNone ||
3682                     Ac == FusedActivationFunctionType::kRelu ||
3683                     Ac == FusedActivationFunctionType::kRelu6 ||
3684                     Ac == FusedActivationFunctionType::kRelu1,
3685                 "");
3686   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3687   if (Ac == FusedActivationFunctionType::kNone) {
3688     TFLITE_DCHECK_EQ(output_activation_min, -32768);
3689     TFLITE_DCHECK_EQ(output_activation_max, 32767);
3690   }
3691 
3692   tflite::ArithmeticParams op_params;
3693   op_params.input1_shift = kReverseShift * input1_shift;
3694   op_params.input2_shift = kReverseShift * input2_shift;
3695   op_params.quantized_activation_min = output_activation_min;
3696   op_params.quantized_activation_max = output_activation_max;
3697   Add(op_params, DimsToShape(input1_dims), input1_data,
3698       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3699       output_data);
3700 }
3701 
Sub(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)3702 inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
3703                 const float* input2_data, const Dims<4>& input2_dims,
3704                 float* output_data, const Dims<4>& output_dims) {
3705   float output_activation_min, output_activation_max;
3706   GetActivationMinMax(FusedActivationFunctionType::kNone,
3707                       &output_activation_min, &output_activation_max);
3708   tflite::ArithmeticParams op_params;
3709   op_params.float_activation_min = output_activation_min;
3710   op_params.float_activation_max = output_activation_max;
3711   Sub(op_params, DimsToShape(input1_dims), input1_data,
3712       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3713       output_data);
3714 }
3715 
3716 template <typename T>
Sub(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)3717 void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
3718          const Dims<4>& input2_dims, T* output_data,
3719          const Dims<4>& output_dims) {
3720   T output_activation_min, output_activation_max;
3721   GetActivationMinMax(FusedActivationFunctionType::kNone,
3722                       &output_activation_min, &output_activation_max);
3723   tflite::ArithmeticParams op_params;
3724   op_params.quantized_activation_min = output_activation_min;
3725   op_params.quantized_activation_max = output_activation_max;
3726   Sub(op_params, DimsToShape(input1_dims), input1_data,
3727       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3728       output_data);
3729 }
3730 
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3731 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
3732                          int32 input1_offset, const uint8* input2_data,
3733                          const Dims<4>& input2_dims, int32 input2_offset,
3734                          int32 output_offset, int32 output_multiplier,
3735                          int output_shift, int32 output_activation_min,
3736                          int32 output_activation_max, uint8* output_data,
3737                          const Dims<4>& output_dims) {
3738   tflite::ArithmeticParams op_params;
3739   SetActivationParams(output_activation_min, output_activation_max, &op_params);
3740   op_params.input1_offset = input1_offset;
3741   op_params.input2_offset = input2_offset;
3742   op_params.output_offset = output_offset;
3743   op_params.output_multiplier = output_multiplier;
3744   op_params.output_shift = kReverseShift * output_shift;
3745 
3746   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
3747                      DimsToShape(input2_dims), input2_data,
3748                      DimsToShape(output_dims), output_data);
3749 }
3750 
3751 // legacy, for compatibility with old checked-in code
3752 template <FusedActivationFunctionType Ac>
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3753 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
3754                          int32 input1_offset, const uint8* input2_data,
3755                          const Dims<4>& input2_dims, int32 input2_offset,
3756                          int32 output_offset, int32 output_multiplier,
3757                          int output_shift, int32 output_activation_min,
3758                          int32 output_activation_max, uint8* output_data,
3759                          const Dims<4>& output_dims) {
3760   BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
3761                input2_dims, input2_offset, output_offset, output_multiplier,
3762                output_shift, output_activation_min, output_activation_max,
3763                output_data, output_dims);
3764 }
3765 
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)3766 inline bool AveragePool(const float* input_data, const Dims<4>& input_dims,
3767                         int stride_width, int stride_height, int pad_width,
3768                         int pad_height, int kwidth, int kheight,
3769                         float output_activation_min,
3770                         float output_activation_max, float* output_data,
3771                         const Dims<4>& output_dims) {
3772   tflite::PoolParams params;
3773   params.stride_height = stride_height;
3774   params.stride_width = stride_width;
3775   params.filter_height = kheight;
3776   params.filter_width = kwidth;
3777   params.padding_values.height = pad_height;
3778   params.padding_values.width = pad_width;
3779   params.float_activation_min = output_activation_min;
3780   params.float_activation_max = output_activation_max;
3781   return AveragePool(params, DimsToShape(input_dims), input_data,
3782                      DimsToShape(output_dims), output_data);
3783 }
3784 
3785 // legacy, for compatibility with old checked-in code
3786 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)3787 bool AveragePool(const float* input_data, const Dims<4>& input_dims,
3788                  int stride_width, int stride_height, int pad_width,
3789                  int pad_height, int kwidth, int kheight, float* output_data,
3790                  const Dims<4>& output_dims) {
3791   float output_activation_min, output_activation_max;
3792   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3793 
3794   return AveragePool(input_data, input_dims, stride_width, stride_height,
3795                      pad_width, pad_height, kwidth, kheight,
3796                      output_activation_min, output_activation_max, output_data,
3797                      output_dims);
3798 }
3799 
3800 // legacy, for compatibility with old checked-in code
3801 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3802 bool AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
3803                  int pad_width, int pad_height, int filter_width,
3804                  int filter_height, float* output_data,
3805                  const Dims<4>& output_dims) {
3806   return AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width,
3807                          pad_height, filter_width, filter_height, output_data,
3808                          output_dims);
3809 }
3810 
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3811 inline bool AveragePool(const uint8* input_data, const Dims<4>& input_dims,
3812                         int stride_width, int stride_height, int pad_width,
3813                         int pad_height, int filter_width, int filter_height,
3814                         int32 output_activation_min,
3815                         int32 output_activation_max, uint8* output_data,
3816                         const Dims<4>& output_dims) {
3817   tflite::PoolParams params;
3818   params.stride_height = stride_height;
3819   params.stride_width = stride_width;
3820   params.filter_height = filter_height;
3821   params.filter_width = filter_width;
3822   params.padding_values.height = pad_height;
3823   params.padding_values.width = pad_width;
3824   params.quantized_activation_min = output_activation_min;
3825   params.quantized_activation_max = output_activation_max;
3826   return AveragePool(params, DimsToShape(input_dims), input_data,
3827                      DimsToShape(output_dims), output_data);
3828 }
3829 
3830 // legacy, for compatibility with old checked-in code
3831 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3832 bool AveragePool(const uint8* input_data, const Dims<4>& input_dims,
3833                  int stride_width, int stride_height, int pad_width,
3834                  int pad_height, int filter_width, int filter_height,
3835                  int32 output_activation_min, int32 output_activation_max,
3836                  uint8* output_data, const Dims<4>& output_dims) {
3837   static_assert(Ac == FusedActivationFunctionType::kNone ||
3838                     Ac == FusedActivationFunctionType::kRelu ||
3839                     Ac == FusedActivationFunctionType::kRelu6 ||
3840                     Ac == FusedActivationFunctionType::kRelu1,
3841                 "");
3842   if (Ac == FusedActivationFunctionType::kNone) {
3843     TFLITE_DCHECK_EQ(output_activation_min, 0);
3844     TFLITE_DCHECK_EQ(output_activation_max, 255);
3845   }
3846   return AveragePool(input_data, input_dims, stride_width, stride_height,
3847                      pad_width, pad_height, filter_width, filter_height,
3848                      output_activation_min, output_activation_max, output_data,
3849                      output_dims);
3850 }
3851 
3852 // legacy, for compatibility with old checked-in code
3853 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3854 bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
3855                  int pad_width, int pad_height, int filter_width,
3856                  int filter_height, int32 output_activation_min,
3857                  int32 output_activation_max, uint8* output_data,
3858                  const Dims<4>& output_dims) {
3859   return AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width,
3860                          pad_height, filter_width, filter_height,
3861                          output_activation_min, output_activation_max,
3862                          output_data, output_dims);
3863 }
3864 
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)3865 inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
3866                     int stride_width, int stride_height, int pad_width,
3867                     int pad_height, int kwidth, int kheight,
3868                     float output_activation_min, float output_activation_max,
3869                     float* output_data, const Dims<4>& output_dims) {
3870   tflite::PoolParams params;
3871   params.stride_height = stride_height;
3872   params.stride_width = stride_width;
3873   params.filter_height = kheight;
3874   params.filter_width = kwidth;
3875   params.padding_values.height = pad_height;
3876   params.padding_values.width = pad_width;
3877   params.float_activation_min = output_activation_min;
3878   params.float_activation_max = output_activation_max;
3879   MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3880           output_data);
3881 }
3882 
3883 // legacy, for compatibility with old checked-in code
3884 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)3885 void MaxPool(const float* input_data, const Dims<4>& input_dims,
3886              int stride_width, int stride_height, int pad_width, int pad_height,
3887              int kwidth, int kheight, float* output_data,
3888              const Dims<4>& output_dims) {
3889   float output_activation_min, output_activation_max;
3890   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3891   MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
3892           pad_height, kwidth, kheight, output_activation_min,
3893           output_activation_max, output_data, output_dims);
3894 }
3895 
3896 // legacy, for compatibility with old checked-in code
3897 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3898 void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
3899              int pad_width, int pad_height, int filter_width, int filter_height,
3900              float* output_data, const Dims<4>& output_dims) {
3901   MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3902               filter_width, filter_height, output_data, output_dims);
3903 }
3904 
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3905 inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
3906                     int stride_width, int stride_height, int pad_width,
3907                     int pad_height, int filter_width, int filter_height,
3908                     int32 output_activation_min, int32 output_activation_max,
3909                     uint8* output_data, const Dims<4>& output_dims) {
3910   PoolParams params;
3911   params.stride_height = stride_height;
3912   params.stride_width = stride_width;
3913   params.filter_height = filter_height;
3914   params.filter_width = filter_width;
3915   params.padding_values.height = pad_height;
3916   params.padding_values.width = pad_width;
3917   params.quantized_activation_min = output_activation_min;
3918   params.quantized_activation_max = output_activation_max;
3919   MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3920           output_data);
3921 }
3922 
3923 // legacy, for compatibility with old checked-in code
3924 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3925 void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
3926              int stride_width, int stride_height, int pad_width, int pad_height,
3927              int filter_width, int filter_height, int32 output_activation_min,
3928              int32 output_activation_max, uint8* output_data,
3929              const Dims<4>& output_dims) {
3930   static_assert(Ac == FusedActivationFunctionType::kNone ||
3931                     Ac == FusedActivationFunctionType::kRelu ||
3932                     Ac == FusedActivationFunctionType::kRelu6 ||
3933                     Ac == FusedActivationFunctionType::kRelu1,
3934                 "");
3935   if (Ac == FusedActivationFunctionType::kNone) {
3936     TFLITE_DCHECK_EQ(output_activation_min, 0);
3937     TFLITE_DCHECK_EQ(output_activation_max, 255);
3938   }
3939   MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
3940           pad_height, filter_width, filter_height, output_activation_min,
3941           output_activation_max, output_data, output_dims);
3942 }
3943 
3944 // legacy, for compatibility with old checked-in code
3945 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3946 void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
3947              int pad_width, int pad_height, int filter_width, int filter_height,
3948              int32 output_activation_min, int32 output_activation_max,
3949              uint8* output_data, const Dims<4>& output_dims) {
3950   MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3951               filter_width, filter_height, output_activation_min,
3952               output_activation_max, output_data, output_dims);
3953 }
3954 
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)3955 inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
3956                    int stride_width, int stride_height, int pad_width,
3957                    int pad_height, int filter_width, int filter_height,
3958                    float output_activation_min, float output_activation_max,
3959                    float* output_data, const Dims<4>& output_dims) {
3960   PoolParams params;
3961   params.stride_height = stride_height;
3962   params.stride_width = stride_width;
3963   params.filter_height = filter_height;
3964   params.filter_width = filter_width;
3965   params.padding_values.height = pad_height;
3966   params.padding_values.width = pad_width;
3967   params.float_activation_min = output_activation_min;
3968   params.float_activation_max = output_activation_max;
3969   L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3970          output_data);
3971 }
3972 
3973 // legacy, for compatibility with old checked-in code
3974 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3975 void L2Pool(const float* input_data, const Dims<4>& input_dims,
3976             int stride_width, int stride_height, int pad_width, int pad_height,
3977             int filter_width, int filter_height, float* output_data,
3978             const Dims<4>& output_dims) {
3979   float output_activation_min, output_activation_max;
3980   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3981   L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
3982          pad_height, filter_width, filter_height, output_activation_min,
3983          output_activation_max, output_data, output_dims);
3984 }
3985 
3986 // legacy, for compatibility with old checked-in code
3987 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3988 void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
3989             int pad_width, int pad_height, int filter_width, int filter_height,
3990             float* output_data, const Dims<4>& output_dims) {
3991   L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3992              filter_width, filter_height, output_data, output_dims);
3993 }
3994 
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3995 inline void Softmax(const SoftmaxParams& params,
3996                     const RuntimeShape& input_shape, const uint8* input_data,
3997                     const RuntimeShape& output_shape, uint8* output_data) {
3998   const int32 input_beta_multiplier = params.input_multiplier;
3999   const int32 input_beta_left_shift = params.input_left_shift;
4000   const int diff_min = params.diff_min;
4001   // The representation chosen for the input to the exp() function is Q5.26.
4002   // We need to leave extra space since values that we skip might be as large as
4003   // -32 before multiplying by input_beta_multiplier, and therefore as large as
4004   // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
4005   // accumulation, but exp(-16) definitely is.
4006   static const int kScaledDiffIntegerBits = 5;
4007   static const int kAccumulationIntegerBits = 12;
4008   using FixedPointScaledDiff =
4009       gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
4010   using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
4011   using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4012 
4013   ruy::profiler::ScopeLabel label("Softmax/8bit");
4014   const int trailing_dim = input_shape.DimensionsCount() - 1;
4015   const int outer_size =
4016       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4017   const int depth =
4018       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4019 
4020   for (int b = 0; b < outer_size; ++b) {
4021     const uint8* input_data_ptr = input_data + b * depth;
4022     uint8* output_data_ptr = output_data + b * depth;
4023 
4024     // Determine the largest entry in the current row
4025     uint8 max_in_row = 0;
4026     {
4027       int c = 0;
4028 #ifdef USE_NEON
4029       uint8x16_t max16_0 = vdupq_n_u8(0);
4030       uint8x16_t max16_1 = vdupq_n_u8(0);
4031       for (; c <= depth - 32; c += 32) {
4032         max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
4033         max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
4034       }
4035       uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
4036       if (c <= depth - 16) {
4037         max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
4038         c += 16;
4039       }
4040       uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
4041       if (c <= depth - 8) {
4042         max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
4043         c += 8;
4044       }
4045       uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
4046       uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
4047       uint8x8_t max1 = vpmax_u8(max2, max2);
4048       max_in_row = vget_lane_u8(max1, 0);
4049 #endif
4050       for (; c < depth; ++c) {
4051         max_in_row = std::max(max_in_row, input_data_ptr[c]);
4052       }
4053     }
4054 
4055 #ifdef USE_NEON
4056     using FixedPointAccumInt32x4 =
4057         gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
4058     using FixedPointScaledDiffInt32x4 =
4059         gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
4060     using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
4061     FixedPoint0Int32x4 input_beta_multiplier_f0 =
4062         FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
4063     int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
4064 #endif
4065 
4066     // Compute the sum of exponentials of the differences of entries in the
4067     // current row from the largest entry in the current row.
4068     FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
4069     {
4070       int c = 0;
4071 #ifdef USE_NEON
4072       int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
4073       FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
4074       FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
4075       FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
4076       for (; c <= depth - 8; c += 8) {
4077         uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
4078         int16x8_t input_diff_s16 =
4079             vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
4080         int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
4081         int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
4082         int32x4_t mask_0 =
4083             gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
4084         int32x4_t mask_1 =
4085             gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
4086         FixedPointScaledDiffInt32x4 scaled_diff_0 =
4087             input_beta_multiplier_f0 *
4088             FixedPointScaledDiffInt32x4::FromRaw(
4089                 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
4090         FixedPointScaledDiffInt32x4 scaled_diff_1 =
4091             input_beta_multiplier_f0 *
4092             FixedPointScaledDiffInt32x4::FromRaw(
4093                 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
4094         FixedPointAccumInt32x4 exps_0 =
4095             gemmlowp::Rescale<kAccumulationIntegerBits>(
4096                 exp_on_negative_values(scaled_diff_0));
4097         FixedPointAccumInt32x4 exps_1 =
4098             gemmlowp::Rescale<kAccumulationIntegerBits>(
4099                 exp_on_negative_values(scaled_diff_1));
4100         FixedPointAccumInt32x4 masked_exps_0 =
4101             SelectUsingMask(mask_0, exps_0, zeros);
4102         FixedPointAccumInt32x4 masked_exps_1 =
4103             SelectUsingMask(mask_1, exps_1, zeros);
4104         sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
4105         sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
4106       }
4107       int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
4108       int32x2_t sum_of_exps_reduced_2 =
4109           vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
4110                    vget_high_s32(sum_of_exps_reduced_4));
4111       int32x2_t sum_of_exps_reduced_1 =
4112           vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
4113       sum_of_exps =
4114           FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
4115 #endif
4116       for (; c < depth; ++c) {
4117         int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
4118         if (input_diff >= diff_min) {
4119           const int32 input_diff_rescaled =
4120               MultiplyByQuantizedMultiplierGreaterThanOne(
4121                   input_diff, input_beta_multiplier, input_beta_left_shift);
4122           const FixedPointScaledDiff scaled_diff_f8 =
4123               FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4124           sum_of_exps =
4125               sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
4126                                 exp_on_negative_values(scaled_diff_f8));
4127         }
4128       }
4129     }
4130 
4131     // Compute the fixed-point multiplier and shift that we need to apply to
4132     // perform a division by the above-computed sum-of-exponentials.
4133     int num_bits_over_unit = 0;
4134     FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
4135         sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
4136 
4137     // Compute the quotients of exponentials of differences of entries in the
4138     // current row from the largest entry, over the previously-computed sum of
4139     // exponentials.
4140     {
4141       int c = 0;
4142 #ifdef USE_NEON
4143       int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
4144       for (; c <= depth - 8; c += 8) {
4145         uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
4146         int16x8_t input_diff_s16 =
4147             vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
4148         int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
4149         int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
4150         uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
4151         FixedPointScaledDiffInt32x4 scaled_diff_0 =
4152             input_beta_multiplier_f0 *
4153             FixedPointScaledDiffInt32x4::FromRaw(
4154                 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
4155         FixedPointScaledDiffInt32x4 scaled_diff_1 =
4156             input_beta_multiplier_f0 *
4157             FixedPointScaledDiffInt32x4::FromRaw(
4158                 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
4159         FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
4160         FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
4161         int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
4162             vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
4163             num_bits_over_unit + 31 - 8);
4164         int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
4165             vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
4166             num_bits_over_unit + 31 - 8);
4167         int16x8_t output_s16 =
4168             vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
4169         uint8x8_t output_u8 = vqmovun_s16(output_s16);
4170         uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
4171         vst1_u8(output_data_ptr + c, masked_output);
4172       }
4173 #endif
4174       for (; c < depth; ++c) {
4175         int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
4176         if (input_diff >= diff_min) {
4177           const int32 input_diff_rescaled =
4178               MultiplyByQuantizedMultiplierGreaterThanOne(
4179                   input_diff, input_beta_multiplier, input_beta_left_shift);
4180           const FixedPointScaledDiff scaled_diff_f8 =
4181               FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4182 
4183           FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
4184           int32 unsat_output = gemmlowp::RoundingDivideByPOT(
4185               (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
4186 
4187           output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
4188 
4189         } else {
4190           output_data_ptr[c] = 0;
4191         }
4192       }
4193     }
4194   }
4195 }
4196 
Softmax(const float * input_data,const RuntimeShape & input_shape,float beta,float * output_data,const RuntimeShape & output_shape)4197 inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
4198                     float beta, float* output_data,
4199                     const RuntimeShape& output_shape) {
4200   SoftmaxParams params;
4201   params.beta = beta;
4202   Softmax(params, input_shape, input_data, output_shape, output_data);
4203 }
4204 
Softmax(const float * input_data,const Dims<4> & input_dims,float beta,float * output_data,const Dims<4> & output_dims)4205 inline void Softmax(const float* input_data, const Dims<4>& input_dims,
4206                     float beta, float* output_data,
4207                     const Dims<4>& output_dims) {
4208   Softmax(input_data, DimsToShape(input_dims), beta, output_data,
4209           DimsToShape(output_dims));
4210 }
4211 
Softmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)4212 inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
4213                     int32 input_beta_multiplier, int32 input_beta_left_shift,
4214                     int diff_min, uint8* output_data,
4215                     const RuntimeShape& output_shape) {
4216   SoftmaxParams params;
4217   params.input_multiplier = input_beta_multiplier;
4218   params.input_left_shift = input_beta_left_shift;
4219   params.diff_min = diff_min;
4220   Softmax(params, input_shape, input_data, output_shape, output_data);
4221 }
Softmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)4222 inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
4223                     int32 input_beta_multiplier, int32 input_beta_left_shift,
4224                     int diff_min, uint8* output_data,
4225                     const Dims<4>& output_dims) {
4226   Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
4227           input_beta_left_shift, diff_min, output_data,
4228           DimsToShape(output_dims));
4229 }
4230 
LogSoftmax(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)4231 inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
4232                        float* output_data, const RuntimeShape& output_shape) {
4233   SoftmaxParams params;
4234   // No params currently used for float LogSoftmax.
4235   LogSoftmax(params, input_shape, input_data, output_shape, output_data);
4236 }
4237 
LogSoftmax(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4238 inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
4239                        float* output_data, const Dims<4>& output_dims) {
4240   LogSoftmax(input_data, DimsToShape(input_dims), output_data,
4241              DimsToShape(output_dims));
4242 }
4243 
LogSoftmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)4244 inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
4245                        int32 input_multiplier, int32 input_left_shift,
4246                        int32 reverse_scaling_divisor,
4247                        int32 reverse_scaling_right_shift, int diff_min,
4248                        uint8* output_data, const RuntimeShape& output_shape) {
4249   SoftmaxParams params;
4250   params.input_multiplier = input_multiplier;
4251   params.input_left_shift = input_left_shift;
4252   params.reverse_scaling_divisor = reverse_scaling_divisor;
4253   params.reverse_scaling_right_shift = reverse_scaling_right_shift;
4254   params.diff_min = diff_min;
4255   reference_ops::LogSoftmax(params, input_shape, input_data, output_shape,
4256                             output_data);
4257 }
4258 
LogSoftmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)4259 inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
4260                        int32 input_multiplier, int32 input_left_shift,
4261                        int32 reverse_scaling_divisor,
4262                        int32 reverse_scaling_right_shift, int diff_min,
4263                        uint8* output_data, const Dims<4>& output_dims) {
4264   reference_ops::LogSoftmax(
4265       input_data, DimsToShape(input_dims), input_multiplier, input_left_shift,
4266       reverse_scaling_divisor, reverse_scaling_right_shift, diff_min,
4267       output_data, DimsToShape(output_dims));
4268 }
4269 
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4270 inline void Logistic(const LogisticParams& params,
4271                      const RuntimeShape& input_shape, const uint8* input_data,
4272                      const RuntimeShape& output_shape, uint8* output_data) {
4273   ruy::profiler::ScopeLabel label("Logistic/Uint8");
4274   const int32 input_zero_point = params.input_zero_point;
4275   const int32 input_range_radius = params.input_range_radius;
4276   const int32 input_multiplier = params.input_multiplier;
4277   const int input_left_shift = params.input_left_shift;
4278   const int size = MatchingFlatSize(input_shape, output_shape);
4279 
4280   int c = 0;
4281 #ifdef USE_NEON
4282   // Handle 16 values at a time
4283   for (; c <= size - 16; c += 16) {
4284     // Read input uint8 values, cast to int16 and subtract input_zero_point
4285     uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
4286     int16x8_t input_val_centered_0 =
4287         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
4288                   vdupq_n_s16(input_zero_point));
4289     int16x8_t input_val_centered_1 =
4290         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
4291                   vdupq_n_s16(input_zero_point));
4292 
4293     // Prepare the bit masks that we will use at the end to implement the logic
4294     // that was expressed in the scalar code with branching:
4295     //   if (input_val_centered < -input_range_radius) {
4296     //     output_val = 0;
4297     //   } else if (input_val_centered > input_range_radius) {
4298     //     output_val = 255;
4299     //   } else {
4300     //     ...
4301     uint16x8_t mask_rightclamp_0 =
4302         vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
4303     uint16x8_t mask_rightclamp_1 =
4304         vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
4305     uint16x8_t mask_leftclamp_0 =
4306         vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
4307     uint16x8_t mask_leftclamp_1 =
4308         vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
4309     uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
4310                                              vshrn_n_u16(mask_rightclamp_1, 8));
4311     uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
4312                                             vshrn_n_u16(mask_leftclamp_1, 8));
4313 
4314     // This performs what is expressed in the scalar code as
4315     // const int32 input_val_rescaled =
4316     //     MultiplyByQuantizedMultiplierGreaterThanOne(
4317     //         input_val_centered, input_multiplier, input_left_shift);
4318     int32x4_t input_val_rescaled_0 =
4319         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
4320                   vdupq_n_s32(input_left_shift));
4321     int32x4_t input_val_rescaled_1 =
4322         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
4323                   vdupq_n_s32(input_left_shift));
4324     int32x4_t input_val_rescaled_2 =
4325         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
4326                   vdupq_n_s32(input_left_shift));
4327     int32x4_t input_val_rescaled_3 =
4328         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
4329                   vdupq_n_s32(input_left_shift));
4330     input_val_rescaled_0 =
4331         vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
4332     input_val_rescaled_1 =
4333         vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
4334     input_val_rescaled_2 =
4335         vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
4336     input_val_rescaled_3 =
4337         vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
4338 
4339     // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t
4340     using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
4341     using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
4342     const FixedPoint4 input_val_f4_0 =
4343         FixedPoint4::FromRaw(input_val_rescaled_0);
4344     const FixedPoint4 input_val_f4_1 =
4345         FixedPoint4::FromRaw(input_val_rescaled_1);
4346     const FixedPoint4 input_val_f4_2 =
4347         FixedPoint4::FromRaw(input_val_rescaled_2);
4348     const FixedPoint4 input_val_f4_3 =
4349         FixedPoint4::FromRaw(input_val_rescaled_3);
4350     const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
4351     const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
4352     const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
4353     const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
4354 
4355     // Divide by 2^23 as in the scalar code
4356     using gemmlowp::RoundingDivideByPOT;
4357     int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23);
4358     int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23);
4359     int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23);
4360     int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23);
4361 
4362     // Cast output values to uint8, saturating
4363     int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
4364                                               vqmovn_s32(output_val_s32_1));
4365     int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
4366                                               vqmovn_s32(output_val_s32_3));
4367     uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
4368                                            vqmovun_s16(output_val_s16_1));
4369 
4370     // Perform the bit-masking with the bit masks computed at the beginning,
4371     // see the comment there.
4372     output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
4373     output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
4374 
4375     // Store back to memory
4376     vst1q_u8(output_data + c, output_val_u8);
4377   }
4378 #endif
4379   // Leftover loop: handle one value at a time with scalar code.
4380   for (; c < size; ++c) {
4381     const uint8 input_val_u8 = input_data[c];
4382     const int32 input_val_centered =
4383         static_cast<int32>(input_val_u8) - input_zero_point;
4384     uint8 output_val;
4385     if (input_val_centered < -input_range_radius) {
4386       output_val = 0;
4387     } else if (input_val_centered > input_range_radius) {
4388       output_val = 255;
4389     } else {
4390       const int32 input_val_rescaled =
4391           MultiplyByQuantizedMultiplierGreaterThanOne(
4392               input_val_centered, input_multiplier, input_left_shift);
4393       using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
4394       using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4395       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
4396       const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
4397       using gemmlowp::RoundingDivideByPOT;
4398       int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
4399       if (output_val_s32 == 256) {
4400         output_val_s32 = 255;
4401       }
4402       TFLITE_DCHECK_GE(output_val_s32, 0);
4403       TFLITE_DCHECK_LE(output_val_s32, 255);
4404       output_val = static_cast<uint8>(output_val_s32);
4405     }
4406     output_data[c] = output_val;
4407   }
4408 }
4409 
Logistic(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)4410 inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
4411                      int32 input_zero_point, int32 input_range_radius,
4412                      int32 input_multiplier, int input_left_shift,
4413                      uint8* output_data, const RuntimeShape& output_shape) {
4414   LogisticParams params;
4415   params.input_zero_point = input_zero_point;
4416   params.input_range_radius = input_range_radius;
4417   params.input_multiplier = input_multiplier;
4418   params.input_left_shift = input_left_shift;
4419   Logistic(params, input_shape, input_data, output_shape, output_data);
4420 }
4421 
Logistic(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4422 inline void Logistic(const float* input_data, const Dims<4>& input_dims,
4423                      float* output_data, const Dims<4>& output_dims) {
4424   Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4425            output_data);
4426 }
4427 
Logistic(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)4428 inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
4429                      int32 input_zero_point, int32 input_range_radius,
4430                      int32 input_multiplier, int input_left_shift,
4431                      uint8* output_data, const Dims<4>& output_dims) {
4432   Logistic(input_data, DimsToShape(input_dims), input_zero_point,
4433            input_range_radius, input_multiplier, input_left_shift, output_data,
4434            DimsToShape(output_dims));
4435 }
4436 
Logistic(const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4437 inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
4438                      const RuntimeShape& output_shape, int16* output_data) {
4439   LogisticParams params;
4440   // No params currently needed by int16 Logistic.
4441   Logistic(params, input_shape, input_data, output_shape, output_data);
4442 }
4443 
Logistic(const int16 * input_data,const RuntimeShape & input_shape,int16 * output_data,const RuntimeShape & output_shape)4444 inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
4445                      int16* output_data, const RuntimeShape& output_shape) {
4446   LogisticParams params;
4447   // No params currently needed by int16 Logistic.
4448   Logistic(params, input_shape, input_data, output_shape, output_data);
4449 }
4450 
Logistic(const int16 * input_data,const Dims<4> & input_dims,int16 * output_data,const Dims<4> & output_dims)4451 inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
4452                      int16* output_data, const Dims<4>& output_dims) {
4453   Logistic(input_data, DimsToShape(input_dims), output_data,
4454            DimsToShape(output_dims));
4455 }
4456 
Tanh(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4457 inline void Tanh(const float* input_data, const Dims<4>& input_dims,
4458                  float* output_data, const Dims<4>& output_dims) {
4459   Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4460        output_data);
4461 }
4462 
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4463 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
4464                  const uint8* input_data, const RuntimeShape& output_shape,
4465                  uint8* output_data) {
4466   // Note that this is almost the exact same code as in Logistic().
4467   ruy::profiler::ScopeLabel label("Tanh");
4468   const int32 input_zero_point = params.input_zero_point;
4469   const int32 input_range_radius = params.input_range_radius;
4470   const int32 input_multiplier = params.input_multiplier;
4471   const int input_left_shift = params.input_left_shift;
4472   const int size = MatchingFlatSize(input_shape, output_shape);
4473 
4474   int c = 0;
4475   int32_t output_zero_point = 128;
4476 #ifdef USE_NEON
4477   // Handle 16 values at a time
4478   for (; c <= size - 16; c += 16) {
4479     // Read input uint8 values, cast to int16 and subtract input_zero_point
4480     uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
4481     int16x8_t input_val_centered_0 =
4482         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
4483                   vdupq_n_s16(input_zero_point));
4484     int16x8_t input_val_centered_1 =
4485         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
4486                   vdupq_n_s16(input_zero_point));
4487 
4488     // Prepare the bit masks that we will use at the end to implement the logic
4489     // that was expressed in the scalar code with branching:
4490     //   if (input_val_centered < -input_range_radius) {
4491     //     output_val = 0;
4492     //   } else if (input_val_centered > input_range_radius) {
4493     //     output_val = 255;
4494     //   } else {
4495     //     ...
4496     uint16x8_t mask_rightclamp_0 =
4497         vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
4498     uint16x8_t mask_rightclamp_1 =
4499         vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
4500     uint16x8_t mask_leftclamp_0 =
4501         vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
4502     uint16x8_t mask_leftclamp_1 =
4503         vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
4504     uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
4505                                              vshrn_n_u16(mask_rightclamp_1, 8));
4506     uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
4507                                             vshrn_n_u16(mask_leftclamp_1, 8));
4508 
4509     // This performs what is expressed in the scalar code as
4510     // const int32 input_val_rescaled =
4511     //     MultiplyByQuantizedMultiplierGreaterThanOne(
4512     //         input_val_centered, input_multiplier, input_left_shift);
4513     int32x4_t input_val_rescaled_0 =
4514         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
4515                   vdupq_n_s32(input_left_shift));
4516     int32x4_t input_val_rescaled_1 =
4517         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
4518                   vdupq_n_s32(input_left_shift));
4519     int32x4_t input_val_rescaled_2 =
4520         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
4521                   vdupq_n_s32(input_left_shift));
4522     int32x4_t input_val_rescaled_3 =
4523         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
4524                   vdupq_n_s32(input_left_shift));
4525     input_val_rescaled_0 =
4526         vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
4527     input_val_rescaled_1 =
4528         vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
4529     input_val_rescaled_2 =
4530         vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
4531     input_val_rescaled_3 =
4532         vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
4533 
4534     // Invoke gemmlowp::tanh on FixedPoint wrapping int32x4_t
4535     using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
4536     using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
4537     const FixedPoint4 input_val_f4_0 =
4538         FixedPoint4::FromRaw(input_val_rescaled_0);
4539     const FixedPoint4 input_val_f4_1 =
4540         FixedPoint4::FromRaw(input_val_rescaled_1);
4541     const FixedPoint4 input_val_f4_2 =
4542         FixedPoint4::FromRaw(input_val_rescaled_2);
4543     const FixedPoint4 input_val_f4_3 =
4544         FixedPoint4::FromRaw(input_val_rescaled_3);
4545     const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
4546     const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
4547     const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
4548     const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
4549 
4550     // Divide by 2^24 as in the scalar code
4551     using gemmlowp::RoundingDivideByPOT;
4552     int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 24);
4553     int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 24);
4554     int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 24);
4555     int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 24);
4556 
4557     // Add the output zero point
4558     int32x4_t output_zero_point_s32 = vdupq_n_s32(output_zero_point);
4559     output_val_s32_0 = vaddq_s32(output_val_s32_0, output_zero_point_s32);
4560     output_val_s32_1 = vaddq_s32(output_val_s32_1, output_zero_point_s32);
4561     output_val_s32_2 = vaddq_s32(output_val_s32_2, output_zero_point_s32);
4562     output_val_s32_3 = vaddq_s32(output_val_s32_3, output_zero_point_s32);
4563 
4564     // Cast output values to uint8, saturating
4565     int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
4566                                               vqmovn_s32(output_val_s32_1));
4567     int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
4568                                               vqmovn_s32(output_val_s32_3));
4569     uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
4570                                            vqmovun_s16(output_val_s16_1));
4571 
4572     // Perform the bit-masking with the bit masks computed at the beginning,
4573     // see the comment there.
4574     output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
4575     output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
4576 
4577     // Store back to memory
4578     vst1q_u8(output_data + c, output_val_u8);
4579   }
4580 #endif
4581   // Leftover loop: handle one value at a time with scalar code.
4582   for (; c < size; ++c) {
4583     const uint8 input_val_u8 = input_data[c];
4584     const int32 input_val_centered =
4585         static_cast<int32>(input_val_u8) - input_zero_point;
4586     uint8 output_val;
4587     if (input_val_centered < -input_range_radius) {
4588       output_val = 0;
4589     } else if (input_val_centered > input_range_radius) {
4590       output_val = 255;
4591     } else {
4592       const int32 input_val_rescaled =
4593           MultiplyByQuantizedMultiplierGreaterThanOne(
4594               input_val_centered, input_multiplier, input_left_shift);
4595       using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
4596       using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4597       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
4598       const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
4599       using gemmlowp::RoundingDivideByPOT;
4600       int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24);
4601       output_val_s32 += output_zero_point;
4602       if (output_val_s32 == 256) {
4603         output_val_s32 = 255;
4604       }
4605       TFLITE_DCHECK_GE(output_val_s32, 0);
4606       TFLITE_DCHECK_LE(output_val_s32, 255);
4607       output_val = static_cast<uint8>(output_val_s32);
4608     }
4609     output_data[c] = output_val;
4610   }
4611 }
4612 
Tanh(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)4613 inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
4614                  int32 input_zero_point, int32 input_range_radius,
4615                  int32 input_multiplier, int input_left_shift,
4616                  uint8* output_data, const RuntimeShape& output_shape) {
4617   TanhParams params;
4618   params.input_zero_point = input_zero_point;
4619   params.input_range_radius = input_range_radius;
4620   params.input_multiplier = input_multiplier;
4621   params.input_left_shift = input_left_shift;
4622   Tanh(params, input_shape, input_data, output_shape, output_data);
4623 }
4624 
Tanh(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)4625 inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
4626                  int32 input_zero_point, int32 input_range_radius,
4627                  int32 input_multiplier, int input_left_shift,
4628                  uint8* output_data, const Dims<4>& output_dims) {
4629   Tanh(input_data, DimsToShape(input_dims), input_zero_point,
4630        input_range_radius, input_multiplier, input_left_shift, output_data,
4631        DimsToShape(output_dims));
4632 }
4633 
Tanh(const int16 * input_data,const RuntimeShape & input_shape,int input_left_shift,int16 * output_data,const RuntimeShape & output_shape)4634 inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
4635                  int input_left_shift, int16* output_data,
4636                  const RuntimeShape& output_shape) {
4637   TanhParams params;
4638   params.input_left_shift = input_left_shift;
4639   Tanh(params, input_shape, input_data, output_shape, output_data);
4640 }
4641 
Tanh(const int16 * input_data,const Dims<4> & input_dims,int input_left_shift,int16 * output_data,const Dims<4> & output_dims)4642 inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
4643                  int input_left_shift, int16* output_data,
4644                  const Dims<4>& output_dims) {
4645   Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
4646        DimsToShape(output_dims));
4647 }
4648 
4649 template <typename T>
DepthToSpace(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)4650 inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
4651                          int block_size, T* output_data,
4652                          const Dims<4>& output_dims) {
4653   tflite::DepthToSpaceParams op_params;
4654   op_params.block_size = block_size;
4655 
4656   DepthToSpace(op_params, DimsToShape(input_dims), input_data,
4657                DimsToShape(output_dims), output_data);
4658 }
4659 
4660 template <typename T>
SpaceToDepth(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)4661 inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
4662                          int block_size, T* output_data,
4663                          const Dims<4>& output_dims) {
4664   tflite::SpaceToDepthParams op_params;
4665   op_params.block_size = block_size;
4666 
4667   SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
4668                DimsToShape(output_dims), output_data);
4669 }
4670 
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)4671 inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
4672                 const float* input2_data, const Dims<4>& input2_dims,
4673                 float output_activation_min, float output_activation_max,
4674                 float* output_data, const Dims<4>& output_dims) {
4675   tflite::ArithmeticParams op_params;
4676   op_params.float_activation_min = output_activation_min;
4677   op_params.float_activation_max = output_activation_max;
4678 
4679   Mul(op_params, DimsToShape(input1_dims), input1_data,
4680       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4681       output_data);
4682 }
4683 
4684 template <FusedActivationFunctionType Ac>
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)4685 void Mul(const float* input1_data, const Dims<4>& input1_dims,
4686          const float* input2_data, const Dims<4>& input2_dims,
4687          float* output_data, const Dims<4>& output_dims) {
4688   float output_activation_min, output_activation_max;
4689   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
4690 
4691   Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
4692       output_activation_max, output_data, output_dims);
4693 }
4694 
Mul(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 output_activation_min,int32 output_activation_max,int32 * output_data,const Dims<4> & output_dims)4695 inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
4696                 const int32* input2_data, const Dims<4>& input2_dims,
4697                 int32 output_activation_min, int32 output_activation_max,
4698                 int32* output_data, const Dims<4>& output_dims) {
4699   tflite::ArithmeticParams op_params;
4700   op_params.quantized_activation_min = output_activation_min;
4701   op_params.quantized_activation_max = output_activation_max;
4702 
4703   Mul(op_params, DimsToShape(input1_dims), input1_data,
4704       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4705       output_data);
4706 }
4707 
4708 template <FusedActivationFunctionType Ac>
Mul(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)4709 void Mul(const int32* input1_data, const Dims<4>& input1_dims,
4710          const int32* input2_data, const Dims<4>& input2_dims,
4711          int32* output_data, const Dims<4>& output_dims) {
4712   TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
4713   tflite::ArithmeticParams op_params;
4714   // No parameters needed.
4715 
4716   MulNoActivation(op_params, DimsToShape(input1_dims), input1_data,
4717                   DimsToShape(input2_dims), input2_data,
4718                   DimsToShape(output_dims), output_data);
4719 }
4720 
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int16 * output_data,const Dims<4> & output_dims)4721 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
4722                 const int16* input2_data, const Dims<4>& input2_dims,
4723                 int16* output_data, const Dims<4>& output_dims) {
4724   tflite::ArithmeticParams op_params;
4725   // No parameters needed.
4726 
4727   Mul(op_params, DimsToShape(input1_dims), input1_data,
4728       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4729       output_data);
4730 }
4731 
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int32 output_offset,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)4732 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
4733                 const int16* input2_data, const Dims<4>& input2_dims,
4734                 int32 output_offset, int32 output_activation_min,
4735                 int32 output_activation_max, uint8* output_data,
4736                 const Dims<4>& output_dims) {
4737   tflite::ArithmeticParams op_params;
4738   op_params.output_offset = output_offset;
4739   op_params.quantized_activation_min = output_activation_min;
4740   op_params.quantized_activation_max = output_activation_max;
4741 
4742   Mul(op_params, DimsToShape(input1_dims), input1_data,
4743       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4744       output_data);
4745 }
4746 
4747 template <typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)4748 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
4749                   const T* input2_data, const Dims<4>& input2_dims,
4750                   T output_activation_min, T output_activation_max,
4751                   T* output_data, const Dims<4>& output_dims) {
4752   tflite::ArithmeticParams op_params;
4753   SetActivationParams(output_activation_min, output_activation_max, &op_params);
4754 
4755   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
4756                      DimsToShape(input2_dims), input2_data,
4757                      DimsToShape(output_dims), output_data);
4758 }
4759 
4760 // For compatibility with old checked-in code
4761 template <FusedActivationFunctionType Ac>
BroadcastMul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)4762 inline void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
4763                          const float* input2_data, const Dims<4>& input2_dims,
4764                          float* output_data, const Dims<4>& output_dims) {
4765   tflite::ArithmeticParams op_params;
4766   float float_activation_min;
4767   float float_activation_max;
4768   GetActivationMinMax(Ac, &float_activation_min, &float_activation_max);
4769   SetActivationParams(float_activation_min, float_activation_max, &op_params);
4770 
4771   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
4772                      DimsToShape(input2_dims), input2_data,
4773                      DimsToShape(output_dims), output_data);
4774 }
4775 
LocalResponseNormalization(const float * input_data,const Dims<4> & input_dims,int range,float bias,float alpha,float beta,float * output_data,const Dims<4> & output_dims)4776 inline void LocalResponseNormalization(const float* input_data,
4777                                        const Dims<4>& input_dims, int range,
4778                                        float bias, float alpha, float beta,
4779                                        float* output_data,
4780                                        const Dims<4>& output_dims) {
4781   tflite::LocalResponseNormalizationParams op_params;
4782   op_params.range = range;
4783   op_params.bias = bias;
4784   op_params.alpha = alpha;
4785   op_params.beta = beta;
4786 
4787   LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
4788                              DimsToShape(output_dims), output_data);
4789 }
4790 
4791 template <typename SrcT, typename DstT>
Cast(const SrcT * input_data,const Dims<4> & input_dims,DstT * output_data,const Dims<4> & output_dims)4792 void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
4793           const Dims<4>& output_dims) {
4794   Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4795        output_data);
4796 }
4797 
Floor(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4798 inline void Floor(const float* input_data, const Dims<4>& input_dims,
4799                   float* output_data, const Dims<4>& output_dims) {
4800   Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4801         output_data);
4802 }
4803 
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims,bool align_corners)4804 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
4805                            const int32* output_size_data,
4806                            const Dims<4>& output_size_dims, float* output_data,
4807                            const Dims<4>& output_dims, bool align_corners) {
4808   tflite::ResizeBilinearParams op_params;
4809   op_params.align_corners = align_corners;
4810   op_params.half_pixel_centers = false;
4811   ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
4812                  DimsToShape(output_size_dims), output_size_data,
4813                  DimsToShape(output_dims), output_data);
4814 }
4815 
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims,bool align_corners)4816 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
4817                            const int32* output_size_data,
4818                            const Dims<4>& output_size_dims, uint8* output_data,
4819                            const Dims<4>& output_dims, bool align_corners) {
4820   tflite::ResizeBilinearParams op_params;
4821   op_params.align_corners = align_corners;
4822   op_params.half_pixel_centers = false;
4823   ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
4824                  DimsToShape(output_size_dims), output_size_data,
4825                  DimsToShape(output_dims), output_data);
4826 }
4827 
4828 // legacy, for compatibility with old checked-in code
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims)4829 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
4830                            const int32* output_size_data,
4831                            const Dims<4>& output_size_dims, float* output_data,
4832                            const Dims<4>& output_dims) {
4833   ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
4834                  output_data, output_dims, /*align_corners=*/false);
4835 }
4836 
4837 // legacy, for compatibility with old checked-in code
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims)4838 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
4839                            const int32* output_size_data,
4840                            const Dims<4>& output_size_dims, uint8* output_data,
4841                            const Dims<4>& output_dims) {
4842   ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
4843                  output_data, output_dims, /*align_corners=*/false);
4844 }
4845 
4846 template <typename T>
BatchToSpaceND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * crops_data,const Dims<4> & crops_dims,T * output_data,const Dims<4> & output_dims)4847 inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
4848                            const int32* block_shape_data,
4849                            const Dims<4>& block_shape_dims,
4850                            const int32* crops_data, const Dims<4>& crops_dims,
4851                            T* output_data, const Dims<4>& output_dims) {
4852   BatchToSpaceND(DimsToShape(input_dims), input_data,
4853                  DimsToShape(block_shape_dims), block_shape_data,
4854                  DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
4855                  output_data);
4856 }
4857 
4858 // Legacy signature, function covered both Pad and PadV2.
4859 template <typename T>
PadV2(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const T pad_value)4860 inline void PadV2(const T* input_data, const Dims<4>& input_dims,
4861                   const std::vector<int>& left_paddings,
4862                   const std::vector<int>& right_paddings, T* output_data,
4863                   const Dims<4>& output_dims, const T pad_value) {
4864   TFLITE_DCHECK_EQ(left_paddings.size(), 4);
4865   TFLITE_DCHECK_EQ(right_paddings.size(), 4);
4866   tflite::PadParams op_params;
4867   op_params.left_padding_count = 4;
4868   op_params.right_padding_count = 4;
4869   for (int i = 0; i < 4; ++i) {
4870     op_params.left_padding[i] = left_paddings[3 - i];
4871     op_params.right_padding[i] = right_paddings[3 - i];
4872   }
4873   const T pad_value_copy = pad_value;
4874 
4875   Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
4876       DimsToShape(output_dims), output_data);
4877 }
4878 
4879 // Old Pad that calls legacy PadV2.
4880 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)4881 inline void Pad(const T* input_data, const Dims<4>& input_dims,
4882                 const std::vector<int>& left_paddings,
4883                 const std::vector<int>& right_paddings, T* output_data,
4884                 const Dims<4>& output_dims, const int32_t pad_value) {
4885   const T converted_pad_value = static_cast<T>(pad_value);
4886   PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
4887            output_dims, converted_pad_value);
4888 }
4889 
4890 // Old Pad that only padded with 0.
4891 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims)4892 inline void Pad(const T* input_data, const Dims<4>& input_dims,
4893                 const std::vector<int>& left_paddings,
4894                 const std::vector<int>& right_paddings, T* output_data,
4895                 const Dims<4>& output_dims) {
4896   const T pad_value = static_cast<T>(0);
4897   PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
4898            output_dims, pad_value);
4899 }
4900 
4901 template <typename T>
Slice(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & begin,const std::vector<int> & size,T * output_data,const Dims<4> & output_dims)4902 inline void Slice(const T* input_data, const Dims<4>& input_dims,
4903                   const std::vector<int>& begin, const std::vector<int>& size,
4904                   T* output_data, const Dims<4>& output_dims) {
4905   tflite::SliceParams op_params;
4906   op_params.begin_count = 4;
4907   op_params.size_count = 4;
4908   for (int i = 0; i < 4; ++i) {
4909     op_params.begin[i] = begin[3 - i];
4910     op_params.size[i] = size[3 - i];
4911   }
4912 
4913   Slice(op_params, DimsToShape(input_dims), input_data,
4914         DimsToShape(output_dims), output_data);
4915 }
4916 
4917 template <typename T>
TensorFlowMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)4918 void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
4919                        const T* input2_data, T* output_data,
4920                        const Dims<4>& output_dims) {
4921   Minimum(DimsToShape(input1_dims), input1_data, input2_data,
4922           DimsToShape(output_dims), output_data);
4923 }
4924 
4925 template <typename T>
TensorFlowMaximum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)4926 void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
4927                        const T* input2_data, T* output_data,
4928                        const Dims<4>& output_dims) {
4929   Maximum(DimsToShape(input1_dims), input1_data, input2_data,
4930           DimsToShape(output_dims), output_data);
4931 }
4932 
Dequantize(const uint8 * input_data,const Dims<4> & input_dims,int32 zero_point,double scale,float * output_data,const Dims<4> & output_dims)4933 inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
4934                        int32 zero_point, double scale, float* output_data,
4935                        const Dims<4>& output_dims) {
4936   tflite::DequantizationParams op_params;
4937   op_params.zero_point = zero_point;
4938   op_params.scale = scale;
4939 
4940   Dequantize(op_params, DimsToShape(input_dims), input_data,
4941              DimsToShape(output_dims), output_data);
4942 }
4943 
4944 template <typename T>
Transpose(const T * input,const Dims<4> & input_dims,T * output,const Dims<4> & output_dims,const int * permuted_axes)4945 void Transpose(const T* input, const Dims<4>& input_dims, T* output,
4946                const Dims<4>& output_dims, const int* permuted_axes) {
4947   TransposeParams params;
4948   params.perm_count = 4;
4949   for (int i = 0; i < 4; ++i) {
4950     params.perm[i] = 3 - permuted_axes[3 - i];
4951   }
4952   Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
4953             output);
4954 }
4955 
4956 template <typename T>
StridedSlice(const T * input_data,const Dims<4> & input_dims,int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides,T * output_data,const Dims<4> & output_dims)4957 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
4958                          int begin_mask, int end_mask, int shrink_axis_mask,
4959                          const std::vector<int>& start_indices,
4960                          const std::vector<int>& stop_indices,
4961                          const std::vector<int>& strides, T* output_data,
4962                          const Dims<4>& output_dims) {
4963   TFLITE_DCHECK_EQ(start_indices.size(), 4);
4964   auto op_params = strided_slice::BuildStridedSliceParams(
4965       begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
4966       strides);
4967   reference_ops::StridedSliceReverseIndices(&op_params);
4968 
4969   StridedSlice(op_params, DimsToShape(input_dims), input_data,
4970                DimsToShape(output_dims), output_data);
4971 }
4972 
4973 template <typename T1, typename T2, typename T3>
ArgMax(const T3 * axis,const T1 * input_data,const tflite::Dims<4> & input_dims,T2 * output_data,const tflite::Dims<4> & output_dims)4974 void ArgMax(const T3* axis, const T1* input_data,
4975             const tflite::Dims<4>& input_dims, T2* output_data,
4976             const tflite::Dims<4>& output_dims) {
4977   // Assumes the input always has 4 dimensions, and therefore,
4978   // output always has three dimensions.
4979   auto output_shape = RuntimeShape(
4980       {output_dims.sizes[2], output_dims.sizes[1], output_dims.sizes[0]});
4981   // Another way to interpret this is that output_dims.sizes[4] is always 1.
4982   TFLITE_DCHECK_EQ(output_shape.FlatSize(),
4983                    DimsToShape(output_dims).FlatSize());
4984   // Legacy path only supported this.
4985   TFLITE_DCHECK_EQ(axis[0], 3);
4986   ArgMinMax(DimsToShape(input_dims), input_data, axis, output_shape,
4987             output_data, /*is_arg_max=*/true);
4988 }
4989 
4990 template <typename T1, typename T2, typename T3>
ArgMinMax(const T3 * axis,const T1 * input_data,const Dims<4> & input_dims,T2 * output_data,const Dims<4> & output_dims,const bool is_arg_max)4991 void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
4992                T2* output_data, const Dims<4>& output_dims,
4993                const bool is_arg_max) {
4994   ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4995             output_data, is_arg_max);
4996 }
4997 
4998 }  // namespace optimized_ops
4999 }  // namespace tflite
5000 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
5001