1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
17
18 #include <stdint.h>
19 #include <sys/types.h>
20
21 #include <algorithm>
22
23 #include "public/gemmlowp.h"
24 #include "tensorflow/lite/kernels/cpu_backend_context.h"
25 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
26 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h"
27 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h"
28 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
29 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
30 #include "tensorflow/lite/kernels/internal/optimized/resize_bilinear.h"
31 #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
32 #include "tensorflow/lite/kernels/internal/types.h"
33
34 namespace tflite {
35 namespace optimized_ops {
36
37 // Unoptimized reference ops:
38 using reference_ops::Broadcast4DSlowGreater;
39 using reference_ops::Broadcast4DSlowGreaterEqual;
40 using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
41 using reference_ops::Broadcast4DSlowGreaterWithScaling;
42 using reference_ops::Broadcast4DSlowLess;
43 using reference_ops::Broadcast4DSlowLessEqual;
44 using reference_ops::Broadcast4DSlowLessEqualWithScaling;
45 using reference_ops::Broadcast4DSlowLessWithScaling;
46 using reference_ops::BroadcastAdd4DSlow;
47 using reference_ops::BroadcastGreater;
48 using reference_ops::BroadcastGreaterEqual;
49 using reference_ops::BroadcastLess;
50 using reference_ops::BroadcastLessEqual;
51 using reference_ops::BroadcastMul4DSlow;
52 using reference_ops::BroadcastSubSlow;
53 using reference_ops::Concatenation;
54 using reference_ops::ConcatenationWithScaling;
55 using reference_ops::DepthConcatenation;
56 using reference_ops::Div;
57 using reference_ops::FakeQuant;
58 using reference_ops::Gather;
59 using reference_ops::Greater;
60 using reference_ops::GreaterEqual;
61 using reference_ops::GreaterEqualWithScaling;
62 using reference_ops::GreaterWithScaling;
63 using reference_ops::Less;
64 using reference_ops::LessEqual;
65 using reference_ops::LessEqualWithScaling;
66 using reference_ops::LessWithScaling;
67 using reference_ops::Mean;
68 using reference_ops::RankOneSelect;
69 using reference_ops::Relu1;
70 using reference_ops::Relu6;
71 using reference_ops::ReluX;
72 using reference_ops::Select;
73 using reference_ops::SpaceToBatchND;
74 using reference_ops::Split;
75 using reference_ops::TensorFlowSplit;
76
77 static constexpr int kDepthwiseReverseShift = -1;
78
79 template <typename Scalar, int N>
MapAsVector(Scalar * data,const Dims<N> & dims)80 VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
81 const int size = FlatSize(dims);
82 return VectorMap<Scalar>(data, size, 1);
83 }
84
85 template <typename Scalar, int N>
MapAsMatrixWithFirstDimAsRows(Scalar * data,const Dims<N> & dims)86 MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
87 const Dims<N>& dims) {
88 const int rows = dims.sizes[0];
89 int cols = 1;
90 for (int d = 1; d < N; d++) {
91 cols *= dims.sizes[d];
92 }
93 return MatrixMap<Scalar>(data, rows, cols);
94 }
95
96 template <typename Scalar, int N>
MapAsMatrixWithLastDimAsCols(Scalar * data,const Dims<N> & dims)97 MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
98 const Dims<N>& dims) {
99 const int cols = dims.sizes[N - 1];
100 int rows = 1;
101 for (int d = 0; d < N - 1; d++) {
102 rows *= dims.sizes[d];
103 }
104 return MatrixMap<Scalar>(data, rows, cols);
105 }
106
107 template <typename Scalar, int N>
MapAsArrayWithFirstDimAsRows(Scalar * data,const Dims<N> & dims)108 ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
109 const Dims<N>& dims) {
110 const int rows = dims.sizes[0];
111 int cols = 1;
112 for (int d = 1; d < N; d++) {
113 cols *= dims.sizes[d];
114 }
115 return ArrayMap<Scalar>(data, rows, cols);
116 }
117
118 // TODO(b/62193649): this function is only needed as long
119 // as we have the --variable_batch hack.
120 template <typename Scalar, int N>
MapAsMatrixWithGivenNumberOfRows(Scalar * data,const Dims<N> & dims,int rows)121 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
122 const Dims<N>& dims,
123 int rows) {
124 const int flatsize = FlatSize(dims);
125 TFLITE_DCHECK((flatsize % rows) == 0);
126 const int cols = flatsize / rows;
127 return MatrixMap<Scalar>(data, rows, cols);
128 }
129
AreSameDims(const Dims<4> & dims1,const Dims<4> & dims2)130 inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
131 for (int i = 0; i < 4; i++) {
132 if (dims1.sizes[i] != dims2.sizes[i]) {
133 return false;
134 }
135 }
136 return true;
137 }
138
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)139 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
140 const float* filter_data, const Dims<4>& filter_dims,
141 const float* bias_data, const Dims<4>& bias_dims,
142 int stride_width, int stride_height,
143 int dilation_width_factor, int dilation_height_factor,
144 int pad_width, int pad_height, int depth_multiplier,
145 float output_activation_min,
146 float output_activation_max, float* output_data,
147 const Dims<4>& output_dims) {
148 tflite::DepthwiseParams op_params;
149 // Padding type is ignored, but still set.
150 op_params.padding_type = PaddingType::kSame;
151 op_params.padding_values.width = pad_width;
152 op_params.padding_values.height = pad_height;
153 op_params.stride_width = stride_width;
154 op_params.stride_height = stride_height;
155 op_params.dilation_width_factor = dilation_width_factor;
156 op_params.dilation_height_factor = dilation_height_factor;
157 op_params.depth_multiplier = depth_multiplier;
158 op_params.float_activation_min = output_activation_min;
159 op_params.float_activation_max = output_activation_max;
160
161 const RuntimeShape output_shape = DimsToShape(output_dims);
162 const int output_height = output_shape.Dims(1);
163
164 DepthwiseConvImpl(op_params, DimsToShape(input_dims), input_data,
165 DimsToShape(filter_dims), filter_data,
166 DimsToShape(bias_dims), bias_data, output_shape,
167 output_data, CpuFlags(), /*thread_start=*/0,
168 /*thread_end=*/output_height, /*thread_dim=*/1);
169 }
170
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)171 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
172 const float* filter_data, const Dims<4>& filter_dims,
173 const float* bias_data, const Dims<4>& bias_dims,
174 int stride_width, int stride_height, int pad_width,
175 int pad_height, int depth_multiplier,
176 float output_activation_min,
177 float output_activation_max, float* output_data,
178 const Dims<4>& output_dims) {
179 DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
180 bias_dims, stride_width, stride_height, 1, 1, pad_width,
181 pad_height, depth_multiplier, output_activation_min,
182 output_activation_max, output_data, output_dims);
183 }
184
185 // legacy, for compatibility with old checked-in code
186 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)187 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
188 const float* filter_data, const Dims<4>& filter_dims,
189 const float* bias_data, const Dims<4>& bias_dims,
190 int stride_width, int stride_height, int pad_width,
191 int pad_height, int depth_multiplier, float* output_data,
192 const Dims<4>& output_dims) {
193 float output_activation_min, output_activation_max;
194 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
195 DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
196 bias_dims, stride_width, stride_height, pad_width, pad_height,
197 depth_multiplier, output_activation_min, output_activation_max,
198 output_data, output_dims);
199 }
200
201 // legacy, for compatibility with old checked-in code
202 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)203 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
204 const float* filter_data, const Dims<4>& filter_dims,
205 const float* bias_data, const Dims<4>& bias_dims, int stride,
206 int pad_width, int pad_height, int depth_multiplier,
207 float* output_data, const Dims<4>& output_dims) {
208 DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
209 bias_dims, stride, stride, pad_width, pad_height,
210 depth_multiplier, output_data, output_dims);
211 }
212
213 template <DepthwiseConvOutputRounding kOutputRounding>
LegacyDepthwiseConvWithRounding(const DepthwiseParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,int thread_start,int thread_end,int thread_dim)214 inline void LegacyDepthwiseConvWithRounding(
215 const DepthwiseParams& params, const RuntimeShape& input_shape,
216 const uint8* input_data, const RuntimeShape& filter_shape,
217 const uint8* filter_data, const RuntimeShape& bias_shape,
218 const int32* bias_data, const RuntimeShape& output_shape,
219 uint8* output_data, int thread_start, int thread_end, int thread_dim) {
220 ruy::profiler::ScopeLabel label("DepthwiseConv/8bit");
221 const int depth_multiplier = params.depth_multiplier;
222 const int32 output_activation_min = params.quantized_activation_min;
223 const int32 output_activation_max = params.quantized_activation_max;
224 const int dilation_width_factor = params.dilation_width_factor;
225 const int dilation_height_factor = params.dilation_height_factor;
226 TFLITE_DCHECK_GE(dilation_width_factor, 1);
227 TFLITE_DCHECK_GE(dilation_height_factor, 1);
228 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
229 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
230 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
231 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
232 const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
233 const int input_depth = input_shape.Dims(3);
234 TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
235 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
236
237 // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
238 // Jetson TX-2. This compiler does not support the offsetof() macro.
239 #if defined(__aarch64__) && !defined(GOOGLE_L4T)
240 const int stride_width = params.stride_width;
241 const int stride_height = params.stride_height;
242 const int pad_width = params.padding_values.width;
243 const int pad_height = params.padding_values.height;
244 const int output_shift = params.output_shift;
245
246 // Call kernel optimized for depthwise convolutions using 3x3 filters if
247 // parameters are supported.
248 if (depthwise_conv::Fast3x3FilterKernelSupported(
249 input_shape, filter_shape, stride_width, stride_height,
250 dilation_width_factor, dilation_height_factor, pad_width, pad_height,
251 depth_multiplier, output_shape, output_shift)) {
252 ruy::profiler::ScopeLabel specialized_label("DepthwiseConv/8bit/3x3");
253 depthwise_conv::DepthwiseConv3x3Filter<kOutputRounding>(
254 params, input_shape, input_data, filter_shape, filter_data, bias_shape,
255 bias_data, output_shape, output_data, thread_start, thread_end,
256 thread_dim);
257 return;
258 }
259 #endif
260
261 ruy::profiler::ScopeLabel specialized_label("DepthwiseConv/8bit/General");
262 depthwise_conv::DepthwiseConvGeneral(params, input_shape, input_data,
263 filter_shape, filter_data, bias_shape,
264 bias_data, output_shape, output_data,
265 thread_start, thread_end, thread_dim);
266 }
267
LegacyDepthwiseConvImpl(const DepthwiseParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,int thread_start,int thread_end,int thread_dim)268 inline void LegacyDepthwiseConvImpl(
269 const DepthwiseParams& params, const RuntimeShape& input_shape,
270 const uint8* input_data, const RuntimeShape& filter_shape,
271 const uint8* filter_data, const RuntimeShape& bias_shape,
272 const int32* bias_data, const RuntimeShape& output_shape,
273 uint8* output_data, int thread_start, int thread_end, int thread_dim) {
274 return LegacyDepthwiseConvWithRounding<
275 DepthwiseConvOutputRounding::kAwayFromZero>(
276 params, input_shape, input_data, filter_shape, filter_data, bias_shape,
277 bias_data, output_shape, output_data, thread_start, thread_end,
278 thread_dim);
279 }
280
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)281 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
282 int32 input_offset, const uint8* filter_data,
283 const Dims<4>& filter_dims, int32 filter_offset,
284 const int32* bias_data, const Dims<4>& bias_dims,
285 int stride_width, int stride_height,
286 int dilation_width_factor, int dilation_height_factor,
287 int pad_width, int pad_height, int depth_multiplier,
288 int32 output_offset, int32 output_multiplier,
289 int output_shift, int32 output_activation_min,
290 int32 output_activation_max, uint8* output_data,
291 const Dims<4>& output_dims) {
292 tflite::DepthwiseParams op_params;
293 // Padding type is ignored, but still set.
294 op_params.padding_type = PaddingType::kSame;
295 op_params.padding_values.width = pad_width;
296 op_params.padding_values.height = pad_height;
297 op_params.stride_width = stride_width;
298 op_params.stride_height = stride_height;
299 op_params.dilation_width_factor = dilation_width_factor;
300 op_params.dilation_height_factor = dilation_height_factor;
301 op_params.depth_multiplier = depth_multiplier;
302 op_params.quantized_activation_min = output_activation_min;
303 op_params.quantized_activation_max = output_activation_max;
304 op_params.input_offset = input_offset;
305 op_params.weights_offset = filter_offset;
306 op_params.output_offset = output_offset;
307 op_params.output_multiplier = output_multiplier;
308 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
309 op_params.output_shift = kDepthwiseReverseShift * output_shift;
310
311 const RuntimeShape output_shape = DimsToShape(output_dims);
312 const int output_height = output_shape.Dims(1);
313
314 LegacyDepthwiseConvImpl(
315 op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
316 filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
317 output_data, /*thread_start=*/0,
318 /*thread_end=*/output_height, /*thread_dim=*/1);
319 }
320
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)321 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
322 int32 input_offset, const uint8* filter_data,
323 const Dims<4>& filter_dims, int32 filter_offset,
324 const int32* bias_data, const Dims<4>& bias_dims,
325 int stride_width, int stride_height, int pad_width,
326 int pad_height, int depth_multiplier,
327 int32 output_offset, int32 output_multiplier,
328 int output_shift, int32 output_activation_min,
329 int32 output_activation_max, uint8* output_data,
330 const Dims<4>& output_dims) {
331 DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
332 filter_offset, bias_data, bias_dims, stride_width,
333 stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
334 output_offset, output_multiplier, output_shift,
335 output_activation_min, output_activation_max, output_data,
336 output_dims);
337 }
338
339 // Legacy, for compatibility with old checked-in code.
340 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)341 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
342 int32 input_offset, const uint8* filter_data,
343 const Dims<4>& filter_dims, int32 filter_offset,
344 const int32* bias_data, const Dims<4>& bias_dims,
345 int stride_width, int stride_height, int pad_width,
346 int pad_height, int depth_multiplier, int32 output_offset,
347 int32 output_multiplier, int output_shift,
348 int32 output_activation_min, int32 output_activation_max,
349 uint8* output_data, const Dims<4>& output_dims) {
350 if (Ac == FusedActivationFunctionType::kNone) {
351 TFLITE_DCHECK_EQ(output_activation_min, 0);
352 TFLITE_DCHECK_EQ(output_activation_max, 255);
353 }
354 DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
355 filter_offset, bias_data, bias_dims, stride_width,
356 stride_height, pad_width, pad_height, depth_multiplier,
357 output_offset, output_multiplier, output_shift,
358 output_activation_min, output_activation_max, output_data,
359 output_dims);
360 }
361
362 // Legacy, for compatibility with old checked-in code.
363 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)364 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
365 int32 input_offset, const uint8* filter_data,
366 const Dims<4>& filter_dims, int32 filter_offset,
367 const int32* bias_data, const Dims<4>& bias_dims, int stride,
368 int pad_width, int pad_height, int depth_multiplier,
369 int32 output_offset, int32 output_multiplier,
370 int output_shift, int32 output_activation_min,
371 int32 output_activation_max, uint8* output_data,
372 const Dims<4>& output_dims) {
373 DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
374 filter_dims, filter_offset, bias_data, bias_dims, stride,
375 stride, pad_width, pad_height, depth_multiplier,
376 output_offset, output_multiplier, output_shift,
377 output_activation_min, output_activation_max, output_data,
378 output_dims);
379 }
380
381 template <typename T, typename TS>
382 struct LegacyDepthwiseConvWorkerTask : public gemmlowp::Task {
LegacyDepthwiseConvWorkerTaskLegacyDepthwiseConvWorkerTask383 LegacyDepthwiseConvWorkerTask(
384 const DepthwiseParams& params, const RuntimeShape& input_shape,
385 const T* input_data, const RuntimeShape& filter_shape,
386 const T* filter_data, const RuntimeShape& bias_shape, const TS* bias_data,
387 const RuntimeShape& output_shape, T* output_data, int thread_start,
388 int thread_end, int thread_dim)
389 : params_(params),
390 input_shape_(input_shape),
391 input_data_(input_data),
392 filter_shape_(filter_shape),
393 filter_data_(filter_data),
394 bias_shape_(bias_shape),
395 bias_data_(bias_data),
396 output_shape_(output_shape),
397 output_data_(output_data),
398 thread_start_(thread_start),
399 thread_end_(thread_end),
400 thread_dim_(thread_dim) {}
401
RunLegacyDepthwiseConvWorkerTask402 void Run() override {
403 LegacyDepthwiseConvImpl(params_, input_shape_, input_data_, filter_shape_,
404 filter_data_, bias_shape_, bias_data_,
405 output_shape_, output_data_, thread_start_,
406 thread_end_, thread_dim_);
407 }
408
409 private:
410 const DepthwiseParams& params_;
411 const RuntimeShape& input_shape_;
412 const T* input_data_;
413 const RuntimeShape& filter_shape_;
414 const T* filter_data_;
415 const RuntimeShape& bias_shape_;
416 const TS* bias_data_;
417 const RuntimeShape& output_shape_;
418 T* output_data_;
419 int thread_start_;
420 int thread_end_;
421 int thread_dim_;
422 };
423
HowManyConvThreads(const RuntimeShape & output_shape,const RuntimeShape & filter_shape,int thread_dim)424 inline int HowManyConvThreads(const RuntimeShape& output_shape,
425 const RuntimeShape& filter_shape,
426 int thread_dim) {
427 constexpr int kMinMulPerThread = 8;
428 const int output_units = output_shape.Dims(thread_dim);
429 const int filter_height = filter_shape.Dims(1);
430 const int filter_width = filter_shape.Dims(2);
431 const int num_mul_per_unit =
432 FlatSizeSkipDim(output_shape, thread_dim) * filter_height * filter_width;
433 const int min_units_per_thread = kMinMulPerThread / num_mul_per_unit + 1;
434 int thread_count = output_units / min_units_per_thread;
435 return thread_count;
436 }
437
438 inline void DepthwiseConv(
439 const DepthwiseParams& params, const RuntimeShape& input_shape,
440 const uint8* input_data, const RuntimeShape& filter_shape,
441 const uint8* filter_data, const RuntimeShape& bias_shape,
442 const int32* bias_data, const RuntimeShape& output_shape,
443 uint8* output_data, gemmlowp::GemmContext* gemmlowp_context = nullptr) {
444 ruy::profiler::ScopeLabel label("DepthwiseConv");
445
446 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
447 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
448 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
449
450 const int output_batches = output_shape.Dims(0);
451 const int output_rows = output_shape.Dims(1);
452 int thread_count_batch = HowManyConvThreads(output_shape, filter_shape, 0);
453 int thread_count_row = HowManyConvThreads(output_shape, filter_shape, 1);
454 int thread_dim, thread_count, thread_dim_size;
455 if (thread_count_batch > thread_count_row) {
456 thread_dim = 0;
457 thread_dim_size = output_batches;
458 thread_count = thread_count_batch;
459 } else {
460 thread_dim = 1;
461 thread_dim_size = output_rows;
462 thread_count = thread_count_row;
463 }
464
465 const int max_threads =
466 gemmlowp_context ? gemmlowp_context->max_num_threads() : 1;
467 thread_count = std::max(1, std::min(thread_count, max_threads));
468
469 if (thread_count == 1) {
470 LegacyDepthwiseConvImpl(params, input_shape, input_data, filter_shape,
471 filter_data, bias_shape, bias_data, output_shape,
472 output_data, /*thread_start=*/0,
473 /*thread_end=*/output_rows, /*thread_dim=*/1);
474 } else {
475 std::vector<gemmlowp::Task*> tasks(thread_count);
476 int thread_start = 0;
477 for (int i = 0; i < thread_count; ++i) {
478 int thread_end =
479 thread_start + (thread_dim_size - thread_start) / (thread_count - i);
480 tasks[i] = new LegacyDepthwiseConvWorkerTask<uint8, int32>(
481 params, input_shape, input_data, filter_shape, filter_data,
482 bias_shape, bias_data, output_shape, output_data, thread_start,
483 thread_end, thread_dim);
484 thread_start = thread_end;
485 }
486 gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
487 }
488 }
489
490 template <typename T, typename TS>
491 struct LegacyPerChannelDepthwiseConvWorkerTask : public gemmlowp::Task {
LegacyPerChannelDepthwiseConvWorkerTaskLegacyPerChannelDepthwiseConvWorkerTask492 LegacyPerChannelDepthwiseConvWorkerTask(
493 const DepthwiseParams& params, const int32* output_multiplier,
494 const int32* output_shift, const RuntimeShape& input_shape,
495 const T* input_data, const RuntimeShape& filter_shape,
496 const T* filter_data, const RuntimeShape& bias_shape, const TS* bias_data,
497 const RuntimeShape& output_shape, T* output_data, int thread_start,
498 int thread_end, int thread_dim)
499 : params_(params),
500 output_multiplier_(output_multiplier),
501 output_shift_(output_shift),
502 input_shape_(input_shape),
503 input_data_(input_data),
504 filter_shape_(filter_shape),
505 filter_data_(filter_data),
506 bias_shape_(bias_shape),
507 bias_data_(bias_data),
508 output_shape_(output_shape),
509 output_data_(output_data),
510 thread_start_(thread_start),
511 thread_end_(thread_end),
512 thread_dim_(thread_dim) {}
513
RunLegacyPerChannelDepthwiseConvWorkerTask514 void Run() override {
515 CpuBackendContext backend_context;
516 optimized_integer_ops::DepthwiseConvImpl(
517 params_, output_multiplier_, output_shift_, input_shape_, input_data_,
518 filter_shape_, filter_data_, bias_shape_, bias_data_, output_shape_,
519 output_data_, thread_start_, thread_end_, thread_dim_, backend_context);
520 }
521
522 private:
523 const DepthwiseParams& params_;
524 const int32* output_multiplier_;
525 const int32* output_shift_;
526 const RuntimeShape& input_shape_;
527 const T* input_data_;
528 const RuntimeShape& filter_shape_;
529 const T* filter_data_;
530 const RuntimeShape& bias_shape_;
531 const TS* bias_data_;
532 const RuntimeShape& output_shape_;
533 T* output_data_;
534 int thread_start_;
535 int thread_end_;
536 int thread_dim_;
537 };
538
539 inline void DepthwiseConvPerChannel(
540 const DepthwiseParams& params, const int32* output_multiplier,
541 const int32* output_shift, const RuntimeShape& input_shape,
542 const int8* input_data, const RuntimeShape& filter_shape,
543 const int8* filter_data, const RuntimeShape& bias_shape,
544 const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
545 gemmlowp::GemmContext* gemmlowp_context = nullptr) {
546 ruy::profiler::ScopeLabel label("DepthwiseConvInt8");
547
548 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
549 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
550 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
551
552 const int output_batches = output_shape.Dims(0);
553 const int output_rows = output_shape.Dims(1);
554 int thread_count_batch = HowManyConvThreads(output_shape, filter_shape, 0);
555 int thread_count_row = HowManyConvThreads(output_shape, filter_shape, 1);
556 int thread_dim, thread_count, thread_dim_size;
557 if (thread_count_batch > thread_count_row) {
558 thread_dim = 0;
559 thread_dim_size = output_batches;
560 thread_count = thread_count_batch;
561 } else {
562 thread_dim = 1;
563 thread_dim_size = output_rows;
564 thread_count = thread_count_row;
565 }
566
567 const int max_threads =
568 gemmlowp_context ? gemmlowp_context->max_num_threads() : 1;
569 thread_count = std::max(1, std::min(thread_count, max_threads));
570
571 if (thread_count == 1) {
572 CpuBackendContext backend_context;
573 optimized_integer_ops::DepthwiseConvImpl(
574 params, output_multiplier, output_shift, input_shape, input_data,
575 filter_shape, filter_data, bias_shape, bias_data, output_shape,
576 output_data, /*thread_start=*/0,
577 /*thread_end=*/output_rows, /*thread_dim=*/1, backend_context);
578 } else {
579 std::vector<gemmlowp::Task*> tasks(thread_count);
580 int thread_start = 0;
581 for (int i = 0; i < thread_count; ++i) {
582 int thread_end =
583 thread_start + (thread_dim_size - thread_start) / (thread_count - i);
584 tasks[i] = new LegacyPerChannelDepthwiseConvWorkerTask<int8, int32>(
585 params, output_multiplier, output_shift, input_shape, input_data,
586 filter_shape, filter_data, bias_shape, bias_data, output_shape,
587 output_data, thread_start, thread_end, thread_dim);
588 thread_start = thread_end;
589 }
590 gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
591 }
592 }
593
DepthwiseConv(const DepthwiseParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data)594 inline void DepthwiseConv(
595 const DepthwiseParams& params, const RuntimeShape& input_shape,
596 const float* input_data, const RuntimeShape& filter_shape,
597 const float* filter_data, const RuntimeShape& bias_shape,
598 const float* bias_data, const RuntimeShape& output_shape,
599 float* output_data) {
600 DepthwiseConvImpl(params, input_shape, input_data, filter_shape, filter_data,
601 bias_shape, bias_data, output_shape, output_data,
602 CpuFlags(),
603 /*thread_start=*/0,
604 /*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1);
605 }
606
AddBiasAndEvalActivationFunction(const float * bias_data,const Dims<4> & bias_dims,float * array_data,const Dims<4> & array_dims,float output_activation_min,float output_activation_max)607 inline void AddBiasAndEvalActivationFunction(const float* bias_data,
608 const Dims<4>& bias_dims,
609 float* array_data,
610 const Dims<4>& array_dims,
611 float output_activation_min,
612 float output_activation_max) {
613 AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
614 DimsToShape(bias_dims), bias_data,
615 DimsToShape(array_dims), array_data);
616 }
617
618 // legacy, for compatibility with old checked-in code
619 template <FusedActivationFunctionType Ac>
AddBiasAndEvalActivationFunction(const float * bias_data,const Dims<4> & bias_dims,float * array_data,const Dims<4> & array_dims)620 void AddBiasAndEvalActivationFunction(const float* bias_data,
621 const Dims<4>& bias_dims,
622 float* array_data,
623 const Dims<4>& array_dims) {
624 float output_activation_min, output_activation_max;
625 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
626 AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims,
627 output_activation_min,
628 output_activation_max);
629 }
630
631 template <typename Lhs, typename Rhs, typename Result>
Gemm(const Eigen::MatrixBase<Lhs> & lhs,const Eigen::MatrixBase<Rhs> & rhs,Eigen::MatrixBase<Result> * result)632 void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
633 Eigen::MatrixBase<Result>* result) {
634 if (rhs.cols() == 1) {
635 ruy::profiler::ScopeLabel label("GEMV");
636 result->col(0).noalias() = lhs * rhs.col(0);
637 } else {
638 ruy::profiler::ScopeLabel label("GEMM");
639 result->noalias() = lhs * rhs;
640 }
641 }
642
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * optional_bias_data,const RuntimeShape & output_shape,float * output_data)643 inline void FullyConnected(
644 const FullyConnectedParams& params, const RuntimeShape& input_shape,
645 const float* input_data, const RuntimeShape& weights_shape,
646 const float* weights_data, const RuntimeShape& bias_shape,
647 const float* optional_bias_data, const RuntimeShape& output_shape,
648 float* output_data) {
649 ruy::profiler::ScopeLabel label("FullyConnected");
650 const float output_activation_min = params.float_activation_min;
651 const float output_activation_max = params.float_activation_max;
652
653 // TODO(b/62193649): this convoluted shape computation (determining
654 // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows)
655 // is because the current --variable_batch hack consists in overwriting the
656 // 3rd dimension with the runtime batch size, as we don't keep track for each
657 // array of which dimension is the batch dimension in it.
658 // When that is fixed, this should become:
659 // const auto input_matrix_map =
660 // MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
661 const int dims_count = weights_shape.DimensionsCount();
662 const int input_rows = weights_shape.Dims(dims_count - 1);
663 const auto input_matrix_map =
664 MapAsMatrixWithGivenNumberOfRows(input_data, input_shape, input_rows);
665 const auto filter_matrix_map =
666 MapAsMatrixWithLastDimAsRows(weights_data, weights_shape);
667 auto output_matrix_map =
668 MapAsMatrixWithLastDimAsRows(output_data, output_shape);
669
670 Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
671
672 if (optional_bias_data != nullptr) {
673 AddBiasAndEvalActivationFunction(
674 output_activation_min, output_activation_max, bias_shape,
675 optional_bias_data, output_shape, output_data);
676 } else {
677 const int flat_size = output_shape.FlatSize();
678 for (int i = 0; i < flat_size; ++i) {
679 output_data[i] = ActivationFunctionWithMinMax(
680 output_data[i], output_activation_min, output_activation_max);
681 }
682 }
683 }
684
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)685 inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
686 const float* weights_data,
687 const Dims<4>& weights_dims, const float* bias_data,
688 const Dims<4>& bias_dims,
689 float output_activation_min,
690 float output_activation_max, float* output_data,
691 const Dims<4>& output_dims) {
692 tflite::FullyConnectedParams op_params;
693 op_params.float_activation_min = output_activation_min;
694 op_params.float_activation_max = output_activation_max;
695
696 FullyConnected(op_params, DimsToShape(input_dims), input_data,
697 DimsToShape(weights_dims), weights_data,
698 DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
699 output_data);
700 }
701
702 // legacy, for compatibility with old checked-in code
703 template <FusedActivationFunctionType Ac>
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)704 void FullyConnected(const float* input_data, const Dims<4>& input_dims,
705 const float* weights_data, const Dims<4>& weights_dims,
706 const float* bias_data, const Dims<4>& bias_dims,
707 float* output_data, const Dims<4>& output_dims) {
708 float output_activation_min, output_activation_max;
709 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
710 FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
711 bias_dims, output_activation_min, output_activation_max,
712 output_data, output_dims);
713 }
714
715 struct GemmlowpOutputPipeline {
716 typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
717 ColVectorMap;
718 typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
719 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
720 gemmlowp::OutputStageClamp,
721 gemmlowp::OutputStageSaturatingCastToUint8>
722 Pipeline;
MakeExpGemmlowpOutputPipeline723 static Pipeline MakeExp(const int32* bias_data, int output_rows,
724 int32 output_offset, int32 output_multiplier,
725 int output_left_shift, int32 output_activation_min,
726 int32 output_activation_max) {
727 ColVectorMap bias_vector(bias_data, output_rows);
728 gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
729 bias_addition_stage.bias_vector = bias_vector;
730 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
731 quantize_down_stage.result_offset_after_shift = output_offset;
732 quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
733 quantize_down_stage.result_exponent = output_left_shift;
734 gemmlowp::OutputStageClamp clamp_stage;
735 clamp_stage.min = output_activation_min;
736 clamp_stage.max = output_activation_max;
737 gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage;
738 return std::make_tuple(bias_addition_stage, quantize_down_stage,
739 clamp_stage, saturating_cast_stage);
740 }
741 };
742
743 struct GemmlowpOutputPipelineInt8 {
744 typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
745 ColVectorMap;
746 typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
747 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
748 gemmlowp::OutputStageClamp,
749 gemmlowp::OutputStageSaturatingCastToInt8>
750 Pipeline;
MakeExpGemmlowpOutputPipelineInt8751 static Pipeline MakeExp(const int32* bias_data, int output_rows,
752 int32 output_offset, int32 output_multiplier,
753 int output_left_shift, int32 output_activation_min,
754 int32 output_activation_max) {
755 ColVectorMap bias_vector(bias_data, output_rows);
756 gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
757 bias_addition_stage.bias_vector = bias_vector;
758 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
759 quantize_down_stage.result_offset_after_shift = output_offset;
760 quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
761 quantize_down_stage.result_exponent = output_left_shift;
762 gemmlowp::OutputStageClamp clamp_stage;
763 clamp_stage.min = output_activation_min;
764 clamp_stage.max = output_activation_max;
765 gemmlowp::OutputStageSaturatingCastToInt8 saturating_cast_stage;
766 return std::make_tuple(bias_addition_stage, quantize_down_stage,
767 clamp_stage, saturating_cast_stage);
768 }
769 };
770
771 #ifdef USE_NEON
LegacyFullyConnectedAsGEMVWorkerImpl(const RuntimeShape & input_shape,const uint8 * input_data,int32 input_offset,const RuntimeShape & filter_shape,const uint8 * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,uint8 * output_data,int row_start,int row_end)772 inline void LegacyFullyConnectedAsGEMVWorkerImpl(
773 const RuntimeShape& input_shape, const uint8* input_data,
774 int32 input_offset, const RuntimeShape& filter_shape,
775 const uint8* filter_data, int32 filter_offset,
776 const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
777 int32 output_multiplier, int output_shift, int32 output_activation_min,
778 int32 output_activation_max, const RuntimeShape& output_shape,
779 uint8* output_data, int row_start, int row_end) {
780 ruy::profiler::ScopeLabel label("FullyConnectedAsGEMV/8bit");
781 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
782 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
783 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
784 const int output_dim_count = output_shape.DimensionsCount();
785 TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
786 const int input_size = FlatSizeSkipDim(input_shape, 0);
787 static constexpr int kPeel = 4;
788 const bool shift_left = (output_shift > 0);
789 for (int k = 0; k < input_size; k += 64) {
790 optimized_ops_preload_l1_stream(input_data + k);
791 }
792 for (int k = 0; k < kPeel * input_size; k += 64) {
793 optimized_ops_preload_l1_stream(filter_data + k);
794 }
795
796 TFLITE_DCHECK_GE(row_end - row_start, kPeel);
797
798 for (int out = row_start; out < row_end; out += kPeel) {
799 out = std::min(out, row_end - kPeel);
800 int32x4_t acc0 = vdupq_n_s32(0);
801 int32x4_t acc1 = acc0;
802 int32x4_t acc2 = acc0;
803 int32x4_t acc3 = acc0;
804 const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
805 const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
806 int in = 0;
807 for (; in <= input_size - 16; in += 16) {
808 const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
809 const uint8* filter_ptr = filter_data + in + out * input_size;
810 uint8x16_t filter_val_u8_0 = vld1q_u8(filter_ptr);
811 optimized_ops_preload_l1_stream(filter_ptr + 64);
812 filter_ptr += input_size;
813 uint8x16_t filter_val_u8_1 = vld1q_u8(filter_ptr);
814 optimized_ops_preload_l1_stream(filter_ptr + 64);
815 filter_ptr += input_size;
816 uint8x16_t filter_val_u8_2 = vld1q_u8(filter_ptr);
817 optimized_ops_preload_l1_stream(filter_ptr + 64);
818 filter_ptr += input_size;
819 uint8x16_t filter_val_u8_3 = vld1q_u8(filter_ptr);
820 optimized_ops_preload_l1_stream(filter_ptr + 64);
821 int16x8_t input_val_0, input_val_1;
822 uint8x8_t low = vget_low_u8(input_val_u8);
823 uint8x8_t high = vget_high_u8(input_val_u8);
824 input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low));
825 input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high));
826 input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
827 input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
828 low = vget_low_u8(filter_val_u8_0);
829 high = vget_high_u8(filter_val_u8_0);
830 int16x8_t filter_val_0_0 = vreinterpretq_s16_u16(vmovl_u8(low));
831 int16x8_t filter_val_0_1 = vreinterpretq_s16_u16(vmovl_u8(high));
832 filter_val_0_0 = vaddq_s16(filter_val_0_0, filter_offset_vec);
833 filter_val_0_1 = vaddq_s16(filter_val_0_1, filter_offset_vec);
834 low = vget_low_u8(filter_val_u8_1);
835 high = vget_high_u8(filter_val_u8_1);
836 int16x8_t filter_val_1_0 = vreinterpretq_s16_u16(vmovl_u8(low));
837 int16x8_t filter_val_1_1 = vreinterpretq_s16_u16(vmovl_u8(high));
838 filter_val_1_0 = vaddq_s16(filter_val_1_0, filter_offset_vec);
839 filter_val_1_1 = vaddq_s16(filter_val_1_1, filter_offset_vec);
840 low = vget_low_u8(filter_val_u8_2);
841 high = vget_high_u8(filter_val_u8_2);
842 int16x8_t filter_val_2_0 = vreinterpretq_s16_u16(vmovl_u8(low));
843 int16x8_t filter_val_2_1 = vreinterpretq_s16_u16(vmovl_u8(high));
844 filter_val_2_0 = vaddq_s16(filter_val_2_0, filter_offset_vec);
845 filter_val_2_1 = vaddq_s16(filter_val_2_1, filter_offset_vec);
846 low = vget_low_u8(filter_val_u8_3);
847 high = vget_high_u8(filter_val_u8_3);
848 int16x8_t filter_val_3_0 = vreinterpretq_s16_u16(vmovl_u8(low));
849 int16x8_t filter_val_3_1 = vreinterpretq_s16_u16(vmovl_u8(high));
850 filter_val_3_0 = vaddq_s16(filter_val_3_0, filter_offset_vec);
851 filter_val_3_1 = vaddq_s16(filter_val_3_1, filter_offset_vec);
852 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_0),
853 vget_low_s16(input_val_0));
854 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_0),
855 vget_low_s16(input_val_0));
856 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_0),
857 vget_low_s16(input_val_0));
858 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_0),
859 vget_low_s16(input_val_0));
860 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_1),
861 vget_low_s16(input_val_1));
862 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_1),
863 vget_low_s16(input_val_1));
864 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_1),
865 vget_low_s16(input_val_1));
866 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_1),
867 vget_low_s16(input_val_1));
868 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_0),
869 vget_high_s16(input_val_0));
870 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_0),
871 vget_high_s16(input_val_0));
872 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_0),
873 vget_high_s16(input_val_0));
874 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_0),
875 vget_high_s16(input_val_0));
876 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_1),
877 vget_high_s16(input_val_1));
878 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_1),
879 vget_high_s16(input_val_1));
880 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_1),
881 vget_high_s16(input_val_1));
882 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_1),
883 vget_high_s16(input_val_1));
884 }
885 for (; in <= input_size - 8; in += 8) {
886 const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
887 const uint8* filter_ptr = filter_data + in + out * input_size;
888 uint8x8_t filter_val_u8_0 = vld1_u8(filter_ptr);
889 filter_ptr += input_size;
890 uint8x8_t filter_val_u8_1 = vld1_u8(filter_ptr);
891 filter_ptr += input_size;
892 uint8x8_t filter_val_u8_2 = vld1_u8(filter_ptr);
893 filter_ptr += input_size;
894 uint8x8_t filter_val_u8_3 = vld1_u8(filter_ptr);
895 int16x8_t input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
896 input_val = vaddq_s16(input_val, input_offset_vec);
897 int16x8_t filter_val_0 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_0));
898 filter_val_0 = vaddq_s16(filter_val_0, filter_offset_vec);
899 int16x8_t filter_val_1 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_1));
900 filter_val_1 = vaddq_s16(filter_val_1, filter_offset_vec);
901 int16x8_t filter_val_2 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_2));
902 filter_val_2 = vaddq_s16(filter_val_2, filter_offset_vec);
903 int16x8_t filter_val_3 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_3));
904 filter_val_3 = vaddq_s16(filter_val_3, filter_offset_vec);
905 acc0 =
906 vmlal_s16(acc0, vget_low_s16(filter_val_0), vget_low_s16(input_val));
907 acc1 =
908 vmlal_s16(acc1, vget_low_s16(filter_val_1), vget_low_s16(input_val));
909 acc2 =
910 vmlal_s16(acc2, vget_low_s16(filter_val_2), vget_low_s16(input_val));
911 acc3 =
912 vmlal_s16(acc3, vget_low_s16(filter_val_3), vget_low_s16(input_val));
913 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
914 vget_high_s16(input_val));
915 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
916 vget_high_s16(input_val));
917 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
918 vget_high_s16(input_val));
919 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
920 vget_high_s16(input_val));
921 }
922 if (in < input_size) {
923 int32 buf[16];
924 vst1q_s32(buf + 0, acc0);
925 vst1q_s32(buf + 4, acc1);
926 vst1q_s32(buf + 8, acc2);
927 vst1q_s32(buf + 12, acc3);
928 for (; in < input_size; in++) {
929 int lane = (in + 8 - input_size) % 4;
930 const int32 input_val = input_data[in] + input_offset;
931 for (int k = 0; k < kPeel; k++) {
932 int32 filter_val =
933 filter_data[in + (out + k) * input_size] + filter_offset;
934 buf[lane + 4 * k] += filter_val * input_val;
935 }
936 }
937 acc0 = vld1q_s32(buf + 0);
938 acc1 = vld1q_s32(buf + 4);
939 acc2 = vld1q_s32(buf + 8);
940 acc3 = vld1q_s32(buf + 12);
941 }
942
943 // Horizontally reduce accumulators
944 int32x2_t pairwise_reduced_acc_0 =
945 vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
946 int32x2_t pairwise_reduced_acc_1 =
947 vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1));
948 int32x2_t pairwise_reduced_acc_2 =
949 vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2));
950 int32x2_t pairwise_reduced_acc_3 =
951 vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3));
952 const int32x2_t reduced_lo =
953 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
954 const int32x2_t reduced_hi =
955 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
956 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
957 // Add bias values.
958 int32x4_t bias_vec = vld1q_s32(bias_data + out);
959 reduced = vaddq_s32(reduced, bias_vec);
960 if (shift_left) {
961 const int32 multiplier_power_of_two = 1 << output_shift;
962 reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
963 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
964 } else {
965 // Multiply by the fixed-point multiplier.
966 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
967 // Rounding-shift-right.
968 using gemmlowp::RoundingDivideByPOT;
969 reduced = RoundingDivideByPOT(reduced, -output_shift);
970 }
971 // Add the output offset.
972 const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
973 reduced = vaddq_s32(reduced, output_offset_vec);
974 // Narrow values down to 16 bit signed.
975 const int16x4_t res16 = vqmovn_s32(reduced);
976 // Narrow values down to 8 bit unsigned, saturating.
977 uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16));
978 // Apply the clamping from the activation function
979 res8 = vmax_u8(res8, vdup_n_u8(output_activation_min));
980 res8 = vmin_u8(res8, vdup_n_u8(output_activation_max));
981 // Store results to destination.
982 vst1_lane_u8(output_data + out + 0, res8, 0);
983 vst1_lane_u8(output_data + out + 1, res8, 1);
984 vst1_lane_u8(output_data + out + 2, res8, 2);
985 vst1_lane_u8(output_data + out + 3, res8, 3);
986 }
987 }
988
989 struct LegacyFullyConnectedAsGEMVWorkerTask : public gemmlowp::Task {
LegacyFullyConnectedAsGEMVWorkerTaskLegacyFullyConnectedAsGEMVWorkerTask990 LegacyFullyConnectedAsGEMVWorkerTask(
991 const RuntimeShape& input_shape, const uint8* input_data,
992 int32 input_offset, const RuntimeShape& filter_shape,
993 const uint8* filter_data, int32 filter_offset,
994 const RuntimeShape& bias_shape, const int32* bias_data,
995 int32 output_offset, int32 output_multiplier, int output_shift,
996 int32 output_activation_min, int32 output_activation_max,
997 const RuntimeShape& output_shape, uint8* output_data, int row_start,
998 int row_end)
999 : input_shape_(input_shape),
1000 input_data_(input_data),
1001 input_offset_(input_offset),
1002 filter_shape_(filter_shape),
1003 filter_data_(filter_data),
1004 filter_offset_(filter_offset),
1005 bias_shape_(bias_shape),
1006 bias_data_(bias_data),
1007 output_offset_(output_offset),
1008 output_multiplier_(output_multiplier),
1009 output_shift_(output_shift),
1010 output_activation_min_(output_activation_min),
1011 output_activation_max_(output_activation_max),
1012 output_shape_(output_shape),
1013 output_data_(output_data),
1014 row_start_(row_start),
1015 row_end_(row_end) {}
1016
RunLegacyFullyConnectedAsGEMVWorkerTask1017 void Run() override {
1018 LegacyFullyConnectedAsGEMVWorkerImpl(
1019 input_shape_, input_data_, input_offset_, filter_shape_, filter_data_,
1020 filter_offset_, bias_shape_, bias_data_, output_offset_,
1021 output_multiplier_, output_shift_, output_activation_min_,
1022 output_activation_max_, output_shape_, output_data_, row_start_,
1023 row_end_);
1024 }
1025
1026 const RuntimeShape& input_shape_;
1027 const uint8* input_data_;
1028 int32 input_offset_;
1029 const RuntimeShape& filter_shape_;
1030 const uint8* filter_data_;
1031 int32 filter_offset_;
1032 const RuntimeShape& bias_shape_;
1033 const int32* bias_data_;
1034 int32 output_offset_;
1035 int32 output_multiplier_;
1036 int output_shift_;
1037 int32 output_activation_min_;
1038 int32 output_activation_max_;
1039 const RuntimeShape& output_shape_;
1040 uint8* output_data_;
1041 int row_start_;
1042 int row_end_;
1043 };
1044
FullyConnectedAsGEMV(const RuntimeShape & input_shape,const uint8 * input_data,int32 input_offset,const RuntimeShape & filter_shape,const uint8 * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext * gemmlowp_context)1045 inline void FullyConnectedAsGEMV(
1046 const RuntimeShape& input_shape, const uint8* input_data,
1047 int32 input_offset, const RuntimeShape& filter_shape,
1048 const uint8* filter_data, int32 filter_offset,
1049 const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
1050 int32 output_multiplier, int output_shift, int32 output_activation_min,
1051 int32 output_activation_max, const RuntimeShape& output_shape,
1052 uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
1053 const int output_dim_count = output_shape.DimensionsCount();
1054 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1055 const int output_rows = output_shape.Dims(output_dim_count - 1);
1056 const int input_size = FlatSizeSkipDim(input_shape, 0);
1057 static constexpr int kKernelRows = 4;
1058 const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
1059 gemmlowp_context->max_num_threads(), output_rows, batches, input_size);
1060 if (thread_count == 1) {
1061 // Single-thread case: do the computation on the current thread, don't
1062 // use a threadpool
1063 LegacyFullyConnectedAsGEMVWorkerImpl(
1064 input_shape, input_data, input_offset, filter_shape, filter_data,
1065 filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
1066 output_shift, output_activation_min, output_activation_max,
1067 output_shape, output_data, 0, output_rows);
1068 return;
1069 }
1070
1071 // Multi-threaded case: use the gemmlowp context's threadpool.
1072 TFLITE_DCHECK_GT(thread_count, 1);
1073 std::vector<gemmlowp::Task*> tasks(thread_count);
1074 const int kRowsPerWorker = gemmlowp::RoundUp<kKernelRows>(
1075 gemmlowp::CeilQuotient(output_rows, thread_count));
1076 int row_start = 0;
1077 for (int i = 0; i < thread_count; ++i) {
1078 int row_end = std::min(output_rows, row_start + kRowsPerWorker);
1079 tasks[i] = new LegacyFullyConnectedAsGEMVWorkerTask(
1080 input_shape, input_data, input_offset, filter_shape, filter_data,
1081 filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
1082 output_shift, output_activation_min, output_activation_max,
1083 output_shape, output_data, row_start, row_end);
1084 row_start = row_end;
1085 }
1086 TFLITE_DCHECK_EQ(row_start, output_rows);
1087 gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
1088 }
1089 #endif // USE_NEON
1090
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext * gemmlowp_context)1091 inline void FullyConnected(
1092 const FullyConnectedParams& params, const RuntimeShape& input_shape,
1093 const uint8* input_data, const RuntimeShape& filter_shape,
1094 const uint8* filter_data, const RuntimeShape& bias_shape,
1095 const int32* bias_data, const RuntimeShape& output_shape,
1096 uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
1097 ruy::profiler::ScopeLabel label("FullyConnected/8bit");
1098 const int32 input_offset = params.input_offset;
1099 const int32 filter_offset = params.weights_offset;
1100 const int32 output_offset = params.output_offset;
1101 const int32 output_multiplier = params.output_multiplier;
1102 const int output_shift = params.output_shift;
1103 const int32 output_activation_min = params.quantized_activation_min;
1104 const int32 output_activation_max = params.quantized_activation_max;
1105 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1106 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1107 // TODO(b/62193649): This really should be:
1108 // const int batches = ArraySize(output_dims, 1);
1109 // but the current --variable_batch hack consists in overwriting the 3rd
1110 // dimension with the runtime batch size, as we don't keep track for each
1111 // array of which dimension is the batch dimension in it.
1112 const int output_dim_count = output_shape.DimensionsCount();
1113 const int filter_dim_count = filter_shape.DimensionsCount();
1114 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1115 #ifdef USE_NEON
1116 if (batches == 1) {
1117 const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
1118 output_shape, output_dim_count - 1);
1119 if (output_size >= 4) {
1120 return FullyConnectedAsGEMV(
1121 input_shape, input_data, input_offset, filter_shape, filter_data,
1122 filter_offset, bias_shape, bias_data, output_offset,
1123 output_multiplier, output_shift, output_activation_min,
1124 output_activation_max, output_shape, output_data, gemmlowp_context);
1125 }
1126 }
1127 #endif // USE_NEON
1128 const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
1129 const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
1130 TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
1131 const int output_rows = output_shape.Dims(output_dim_count - 1);
1132 TFLITE_DCHECK_EQ(output_rows, filter_rows);
1133 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1134
1135 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
1136 filter_data, output_rows, filter_cols, filter_cols);
1137 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
1138 input_data, filter_cols, batches, filter_cols);
1139 gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
1140 output_data, output_rows, batches, output_rows);
1141 const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
1142 bias_data, output_rows, output_offset, output_multiplier, output_shift,
1143 output_activation_min, output_activation_max);
1144 gemmlowp::GemmWithOutputPipeline<uint8, uint8,
1145 gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
1146 gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
1147 filter_offset, input_offset, output_pipeline);
1148 }
1149
1150 #ifdef GEMMLOWP_NEON
1151 // In the common case of batch size 1, a fully-connected node degenerates
1152 // to a matrix*vector product. LSTM cells contain a fully-connected node;
1153 // when quantized, this becomes a special type of GEMV operation where
1154 // the output is 16bit-quantized, thus needs its own special path.
GEMVForLstmCell(const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * weights_data,uint8 weights_zero_point,const RuntimeShape & bias_shape,const int32 * bias_data,int32 accum_multiplier,int accum_shift,const RuntimeShape & output_shape,int16 * output_data)1155 inline void GEMVForLstmCell(const RuntimeShape& input_shape,
1156 const uint8* input_data,
1157 const RuntimeShape& weights_shape,
1158 const uint8* weights_data, uint8 weights_zero_point,
1159 const RuntimeShape& bias_shape,
1160 const int32* bias_data, int32 accum_multiplier,
1161 int accum_shift, const RuntimeShape& output_shape,
1162 int16* output_data) {
1163 ruy::profiler::ScopeLabel label("GEMVForLstmCell");
1164 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1165 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
1166 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1167 const int output_dim_count = output_shape.DimensionsCount();
1168 const int weights_dim_count = weights_shape.DimensionsCount();
1169 TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
1170 const int input_size = FlatSizeSkipDim(input_shape, 0);
1171 const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
1172 output_shape, output_dim_count - 1);
1173 // This special fast path for quantized LSTM cells does not try to support
1174 // odd sizes that we haven't encountered in any LSTM cell, that would
1175 // require special code (that would go untested until any LSTM cell
1176 // exercises it). We just guard our assumptions about size evenness with
1177 // the following assertions.
1178 TFLITE_DCHECK(!(output_size % 4));
1179 TFLITE_DCHECK(!(input_size % 8));
1180 const int32* bias_ptr = bias_data;
1181 int16* output_ptr = output_data;
1182 for (int out = 0; out < output_size; out += 4) {
1183 int32x4_t acc_0 = vdupq_n_s32(0);
1184 int32x4_t acc_1 = vdupq_n_s32(0);
1185 int32x4_t acc_2 = vdupq_n_s32(0);
1186 int32x4_t acc_3 = vdupq_n_s32(0);
1187 const int16x8_t input_offset_vec = vdupq_n_s16(-128);
1188 const int16x8_t weights_offset_vec = vdupq_n_s16(-weights_zero_point);
1189 int in = 0;
1190 // Handle 16 levels of depth at a time.
1191 for (; in <= input_size - 16; in += 16) {
1192 const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
1193 const uint8* weights_ptr = weights_data + in + out * input_size;
1194 uint8x16_t weights_val_u8_0 = vld1q_u8(weights_ptr + 0 * input_size);
1195 uint8x16_t weights_val_u8_1 = vld1q_u8(weights_ptr + 1 * input_size);
1196 uint8x16_t weights_val_u8_2 = vld1q_u8(weights_ptr + 2 * input_size);
1197 uint8x16_t weights_val_u8_3 = vld1q_u8(weights_ptr + 3 * input_size);
1198 int16x8_t input_val_0, input_val_1;
1199 const uint8x8_t low = vget_low_u8(input_val_u8);
1200 const uint8x8_t high = vget_high_u8(input_val_u8);
1201 input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low));
1202 input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high));
1203 input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
1204 input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
1205 int16x8_t weights_val_0_0, weights_val_1_0, weights_val_2_0,
1206 weights_val_3_0;
1207 int16x8_t weights_val_0_1, weights_val_1_1, weights_val_2_1,
1208 weights_val_3_1;
1209 weights_val_0_0 = vaddq_s16(
1210 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_0))),
1211 weights_offset_vec);
1212 weights_val_0_1 = vaddq_s16(
1213 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_0))),
1214 weights_offset_vec);
1215 weights_val_1_0 = vaddq_s16(
1216 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_1))),
1217 weights_offset_vec);
1218 weights_val_1_1 = vaddq_s16(
1219 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_1))),
1220 weights_offset_vec);
1221 weights_val_2_0 = vaddq_s16(
1222 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_2))),
1223 weights_offset_vec);
1224 weights_val_2_1 = vaddq_s16(
1225 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_2))),
1226 weights_offset_vec);
1227 weights_val_3_0 = vaddq_s16(
1228 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_3))),
1229 weights_offset_vec);
1230 weights_val_3_1 = vaddq_s16(
1231 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_3))),
1232 weights_offset_vec);
1233 acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_0),
1234 vget_low_s16(input_val_0));
1235 acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_0),
1236 vget_low_s16(input_val_0));
1237 acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_0),
1238 vget_low_s16(input_val_0));
1239 acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_0),
1240 vget_low_s16(input_val_0));
1241 acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_0),
1242 vget_high_s16(input_val_0));
1243 acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_0),
1244 vget_high_s16(input_val_0));
1245 acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_0),
1246 vget_high_s16(input_val_0));
1247 acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_0),
1248 vget_high_s16(input_val_0));
1249 acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_1),
1250 vget_low_s16(input_val_1));
1251 acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_1),
1252 vget_low_s16(input_val_1));
1253 acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_1),
1254 vget_low_s16(input_val_1));
1255 acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_1),
1256 vget_low_s16(input_val_1));
1257 acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_1),
1258 vget_high_s16(input_val_1));
1259 acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_1),
1260 vget_high_s16(input_val_1));
1261 acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_1),
1262 vget_high_s16(input_val_1));
1263 acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_1),
1264 vget_high_s16(input_val_1));
1265 }
1266 // Handle 8 levels of depth at a time.
1267 for (; in < input_size; in += 8) {
1268 const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
1269 const uint8* weights_ptr = weights_data + in + out * input_size;
1270 uint8x8_t weights_val_u8_0 = vld1_u8(weights_ptr + 0 * input_size);
1271 uint8x8_t weights_val_u8_1 = vld1_u8(weights_ptr + 1 * input_size);
1272 uint8x8_t weights_val_u8_2 = vld1_u8(weights_ptr + 2 * input_size);
1273 uint8x8_t weights_val_u8_3 = vld1_u8(weights_ptr + 3 * input_size);
1274 int16x8_t input_val;
1275 input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
1276 input_val = vaddq_s16(input_val, input_offset_vec);
1277 int16x8_t weights_val_0, weights_val_1, weights_val_2, weights_val_3;
1278 weights_val_0 =
1279 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_0)),
1280 weights_offset_vec);
1281 weights_val_1 =
1282 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_1)),
1283 weights_offset_vec);
1284 weights_val_2 =
1285 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_2)),
1286 weights_offset_vec);
1287 weights_val_3 =
1288 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_3)),
1289 weights_offset_vec);
1290 acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0),
1291 vget_low_s16(input_val));
1292 acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1),
1293 vget_low_s16(input_val));
1294 acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2),
1295 vget_low_s16(input_val));
1296 acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3),
1297 vget_low_s16(input_val));
1298 acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0),
1299 vget_high_s16(input_val));
1300 acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1),
1301 vget_high_s16(input_val));
1302 acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2),
1303 vget_high_s16(input_val));
1304 acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3),
1305 vget_high_s16(input_val));
1306 }
1307 // Horizontally reduce accumulators
1308 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
1309 pairwise_reduced_acc_2, pairwise_reduced_acc_3;
1310 pairwise_reduced_acc_0 =
1311 vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0));
1312 pairwise_reduced_acc_1 =
1313 vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1));
1314 pairwise_reduced_acc_2 =
1315 vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2));
1316 pairwise_reduced_acc_3 =
1317 vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3));
1318 const int32x2_t reduced_lo =
1319 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1320 const int32x2_t reduced_hi =
1321 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1322 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1323 // Add bias values.
1324 int32x4_t bias_vec = vld1q_s32(bias_ptr);
1325 bias_ptr += 4;
1326 reduced = vaddq_s32(reduced, bias_vec);
1327 int left_shift = accum_shift > 0 ? accum_shift : 0;
1328 int right_shift = accum_shift > 0 ? 0 : -accum_shift;
1329 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
1330 // Multiply by the fixed-point multiplier.
1331 reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
1332 // Rounding-shift-right.
1333 using gemmlowp::RoundingDivideByPOT;
1334 reduced = RoundingDivideByPOT(reduced, right_shift);
1335 // Narrow values down to 16 bit signed.
1336 const int16x4_t res16 = vqmovn_s32(reduced);
1337 vst1_s16(output_ptr, res16);
1338 output_ptr += 4;
1339 }
1340 }
1341 #endif
1342
1343 #ifdef GEMMLOWP_NEON
GEMVForLstmCellWithSymmetricRange(const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,int32 accum_multiplier,int accum_shift,const RuntimeShape & output_shape,int16 * output_data)1344 inline void GEMVForLstmCellWithSymmetricRange(
1345 const RuntimeShape& input_shape, const uint8* input_data,
1346 const RuntimeShape& weights_shape, const uint8* weights_data,
1347 const RuntimeShape& bias_shape, const int32* bias_data,
1348 int32 accum_multiplier, int accum_shift, const RuntimeShape& output_shape,
1349 int16* output_data) {
1350 ruy::profiler::ScopeLabel label("GEMVForLstmCellWithSymmetricRange");
1351 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1352 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
1353 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1354 const int output_dim_count = output_shape.DimensionsCount();
1355 const int weights_dim_count = weights_shape.DimensionsCount();
1356 TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
1357 const int input_size = FlatSizeSkipDim(input_shape, 0);
1358 const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
1359 output_shape, output_dim_count - 1);
1360 // This special fast path for quantized LSTM cells does not try to support
1361 // odd sizes that we haven't encountered in any LSTM cell, that would
1362 // require special code (that would go untested until any LSTM cell
1363 // exercises it). We just guard our assumptions about size evenness with
1364 // the following assertions.
1365 TFLITE_DCHECK(!(output_size % 4));
1366 TFLITE_DCHECK(!(input_size % 64));
1367 const int32* bias_ptr = bias_data;
1368 int16* output_ptr = output_data;
1369 const uint8x16_t signbit = vdupq_n_u8(0x80);
1370 for (int in = 0; in < input_size; in += 32) {
1371 optimized_ops_preload_l1_keep(input_data + in);
1372 }
1373 const int left_shift = accum_shift > 0 ? accum_shift : 0;
1374 const int right_shift = accum_shift > 0 ? 0 : -accum_shift;
1375 for (int out = 0; out < output_size; out += 4) {
1376 // Load the bias values
1377 int32x4_t bias_vec = vld1q_s32(bias_ptr);
1378 bias_ptr += 4;
1379
1380 // Clear accumulators. We use 2 accumulator registers per row,
1381 // for 4 rows. row_accumRN is the N-th accumulator for row R.
1382 int32x4_t row_accum00 = vdupq_n_s32(0);
1383 int32x4_t row_accum01 = vdupq_n_s32(0);
1384 int32x4_t row_accum10 = vdupq_n_s32(0);
1385 int32x4_t row_accum11 = vdupq_n_s32(0);
1386 int32x4_t row_accum20 = vdupq_n_s32(0);
1387 int32x4_t row_accum21 = vdupq_n_s32(0);
1388 int32x4_t row_accum30 = vdupq_n_s32(0);
1389 int32x4_t row_accum31 = vdupq_n_s32(0);
1390
1391 // kReadAhead parametrizes how far ahead we prefetch weights into L1 cache.
1392 const int kReadAhead = 512;
1393 // Prefetch the first weights values.
1394 for (int k = 0; k < kReadAhead; k += 64) {
1395 optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
1396 k);
1397 optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
1398 k);
1399 optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
1400 k);
1401 optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
1402 k);
1403 }
1404 // Loop along the rows, handling 64 bytes per iteration because that's
1405 // cache line size on most current ARM-architecture CPUs.
1406 for (int in = 0; in < input_size; in += 64) {
1407 // Prefetch some future weights values.
1408 optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
1409 in + kReadAhead);
1410 optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
1411 in + kReadAhead);
1412 optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
1413 in + kReadAhead);
1414 optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
1415 in + kReadAhead);
1416
1417 // We will use 2 local 16-bit accumulators per row, for 2 rows.
1418 // See below (*) for the rationale of processing only 2 rows at a time.
1419 // local_accumRN is the N-th local accumulator for row R.
1420 int16x8_t local_accum00;
1421 int16x8_t local_accum01;
1422 int16x8_t local_accum10;
1423 int16x8_t local_accum11;
1424
1425 // Load 64 bytes of input activations values. Convert to signed int8
1426 // by flipping the sign bit (i.e. subtracting 128, the required
1427 // zero_point value).
1428 int8x16_t input0 = vreinterpretq_s8_u8(
1429 veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 0)));
1430 int8x16_t input1 = vreinterpretq_s8_u8(
1431 veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 1)));
1432 int8x16_t input2 = vreinterpretq_s8_u8(
1433 veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 2)));
1434 int8x16_t input3 = vreinterpretq_s8_u8(
1435 veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 3)));
1436
1437 // Beginning of the core accumulation. Notice how while we have 4
1438 // rows to process, this code is taking care of only 2 rows at a time,
1439 // thus being divided into two parts looking similar ("Rows 0 and 1" and
1440 // "Rows 2 and 3").
1441 //
1442 // (*) The rationale for handling only 2 rows at a time is to avoid
1443 // cache aliasing issues on 4-way set-associative L1-cache CPUs, such
1444 // as Cortex-A53. With sufficiently large, power-of-two matrix dimensions,
1445 // we may find ourselves in a situation where rows alias each other in
1446 // the L1 cache, and moreover may also mutually alias with the input
1447 // activations. If we try to load 4 rows at a time, together with the
1448 // input activations, that may be 5 mutually-aliasing vectors, resulting
1449 // in constant mutual eviction from L1 cache. Handling 2 rows at a time
1450 // here largely mitigates these issues, and seems at least to be very
1451 // effective on Cortex-A53:
1452 // Before After
1453 // big (Cortex-A73) 2.85 ms 2.85 ms
1454 // little (Cortex-A53) 11.0 ms 5.16 ms
1455
1456 // Rows 0 and 1:
1457 // Load 64 bytes of weights values from each row. Convert to signed int8
1458 // by flipping the sign bit (i.e. subtracting 128, the required
1459 // zero_point value).
1460 int8x16_t weights00 = vreinterpretq_s8_u8(veorq_u8(
1461 signbit,
1462 vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 0)));
1463 int8x16_t weights01 = vreinterpretq_s8_u8(veorq_u8(
1464 signbit,
1465 vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 1)));
1466 int8x16_t weights02 = vreinterpretq_s8_u8(veorq_u8(
1467 signbit,
1468 vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 2)));
1469 int8x16_t weights03 = vreinterpretq_s8_u8(veorq_u8(
1470 signbit,
1471 vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 3)));
1472 int8x16_t weights10 = vreinterpretq_s8_u8(veorq_u8(
1473 signbit,
1474 vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 0)));
1475 int8x16_t weights11 = vreinterpretq_s8_u8(veorq_u8(
1476 signbit,
1477 vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 1)));
1478 int8x16_t weights12 = vreinterpretq_s8_u8(veorq_u8(
1479 signbit,
1480 vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 2)));
1481 int8x16_t weights13 = vreinterpretq_s8_u8(veorq_u8(
1482 signbit,
1483 vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 3)));
1484 // Multiply-accumulate into local 16-bit accumulators.
1485 // We can accumulate two products without overflow because weights are
1486 // required to never be -128, so each product is at most 127^2 in absolute
1487 // value.
1488 local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
1489 local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
1490 local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
1491 local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
1492 local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
1493 vget_high_s8(input0));
1494 local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
1495 vget_high_s8(input1));
1496 local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
1497 vget_high_s8(input0));
1498 local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
1499 vget_high_s8(input1));
1500 // Pairwise add and accumulate into 32-bit accumulators
1501 row_accum00 = vpadalq_s16(row_accum00, local_accum00);
1502 row_accum01 = vpadalq_s16(row_accum01, local_accum01);
1503 row_accum10 = vpadalq_s16(row_accum10, local_accum10);
1504 row_accum11 = vpadalq_s16(row_accum11, local_accum11);
1505 // Multiply-accumulate into local 16-bit accumulators.
1506 // We can accumulate two products without overflow because weights are
1507 // required to never be -128, so each product is at most 127^2 in absolute
1508 // value.
1509 local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
1510 local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
1511 local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
1512 local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
1513 local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
1514 vget_high_s8(input2));
1515 local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
1516 vget_high_s8(input3));
1517 local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
1518 vget_high_s8(input2));
1519 local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
1520 vget_high_s8(input3));
1521 // Pairwise add and accumulate into 32-bit accumulators
1522 row_accum00 = vpadalq_s16(row_accum00, local_accum00);
1523 row_accum01 = vpadalq_s16(row_accum01, local_accum01);
1524 row_accum10 = vpadalq_s16(row_accum10, local_accum10);
1525 row_accum11 = vpadalq_s16(row_accum11, local_accum11);
1526
1527 // Rows 2 and 3:
1528 // Load 64 bytes of weights values from each row. Convert to signed int8
1529 // by flipping the sign bit (i.e. subtracting 128, the required
1530 // zero_point value).
1531 weights00 = vreinterpretq_s8_u8(veorq_u8(
1532 signbit,
1533 vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 0)));
1534 weights01 = vreinterpretq_s8_u8(veorq_u8(
1535 signbit,
1536 vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 1)));
1537 weights02 = vreinterpretq_s8_u8(veorq_u8(
1538 signbit,
1539 vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 2)));
1540 weights03 = vreinterpretq_s8_u8(veorq_u8(
1541 signbit,
1542 vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 3)));
1543 weights10 = vreinterpretq_s8_u8(veorq_u8(
1544 signbit,
1545 vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 0)));
1546 weights11 = vreinterpretq_s8_u8(veorq_u8(
1547 signbit,
1548 vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 1)));
1549 weights12 = vreinterpretq_s8_u8(veorq_u8(
1550 signbit,
1551 vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 2)));
1552 weights13 = vreinterpretq_s8_u8(veorq_u8(
1553 signbit,
1554 vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 3)));
1555 // Multiply-accumulate into local 16-bit accumulators.
1556 // We can accumulate two products without overflow because weights are
1557 // required to never be -128, so each product is at most 127^2 in absolute
1558 // value.
1559 local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
1560 local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
1561 local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
1562 local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
1563 local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
1564 vget_high_s8(input0));
1565 local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
1566 vget_high_s8(input1));
1567 local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
1568 vget_high_s8(input0));
1569 local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
1570 vget_high_s8(input1));
1571 // Pairwise add and accumulate into 32-bit accumulators
1572 row_accum20 = vpadalq_s16(row_accum20, local_accum00);
1573 row_accum21 = vpadalq_s16(row_accum21, local_accum01);
1574 row_accum30 = vpadalq_s16(row_accum30, local_accum10);
1575 row_accum31 = vpadalq_s16(row_accum31, local_accum11);
1576 // Multiply-accumulate into local 16-bit accumulators.
1577 // We can accumulate two products without overflow because weights are
1578 // required to never be -128, so each product is at most 127^2 in absolute
1579 // value.
1580 local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
1581 local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
1582 local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
1583 local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
1584 local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
1585 vget_high_s8(input2));
1586 local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
1587 vget_high_s8(input3));
1588 local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
1589 vget_high_s8(input2));
1590 local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
1591 vget_high_s8(input3));
1592 // Pairwise add and accumulate into 32-bit accumulators
1593 row_accum20 = vpadalq_s16(row_accum20, local_accum00);
1594 row_accum21 = vpadalq_s16(row_accum21, local_accum01);
1595 row_accum30 = vpadalq_s16(row_accum30, local_accum10);
1596 row_accum31 = vpadalq_s16(row_accum31, local_accum11);
1597 }
1598
1599 row_accum00 = vaddq_s32(row_accum00, row_accum01);
1600 row_accum10 = vaddq_s32(row_accum10, row_accum11);
1601 row_accum20 = vaddq_s32(row_accum20, row_accum21);
1602 row_accum30 = vaddq_s32(row_accum30, row_accum31);
1603 // Horizontally reduce accumulators
1604 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
1605 pairwise_reduced_acc_2, pairwise_reduced_acc_3;
1606 pairwise_reduced_acc_0 =
1607 vpadd_s32(vget_low_s32(row_accum00), vget_high_s32(row_accum00));
1608 pairwise_reduced_acc_1 =
1609 vpadd_s32(vget_low_s32(row_accum10), vget_high_s32(row_accum10));
1610 pairwise_reduced_acc_2 =
1611 vpadd_s32(vget_low_s32(row_accum20), vget_high_s32(row_accum20));
1612 pairwise_reduced_acc_3 =
1613 vpadd_s32(vget_low_s32(row_accum30), vget_high_s32(row_accum30));
1614 const int32x2_t reduced_lo =
1615 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1616 const int32x2_t reduced_hi =
1617 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1618 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1619 // Add bias values.
1620 reduced = vaddq_s32(reduced, bias_vec);
1621 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
1622 // Multiply by the fixed-point multiplier.
1623 reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
1624 // Rounding-shift-right.
1625 using gemmlowp::RoundingDivideByPOT;
1626 reduced = RoundingDivideByPOT(reduced, right_shift);
1627 // Narrow values down to 16 bit signed.
1628 const int16x4_t res16 = vqmovn_s32(reduced);
1629 vst1_s16(output_ptr, res16);
1630 output_ptr += 4;
1631 }
1632 }
1633 #endif
1634
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data_int32,const RuntimeShape & output_shape,int16 * output_data,gemmlowp::GemmContext * gemmlowp_context)1635 inline void FullyConnected(
1636 const FullyConnectedParams& params, const RuntimeShape& input_shape,
1637 const uint8* input_data, const RuntimeShape& filter_shape,
1638 const uint8* filter_data, const RuntimeShape& bias_shape,
1639 const int32* bias_data_int32, const RuntimeShape& output_shape,
1640 int16* output_data, gemmlowp::GemmContext* gemmlowp_context) {
1641 ruy::profiler::ScopeLabel label("FullyConnected/Uint8Int16");
1642 const int32 input_offset = params.input_offset;
1643 const int32 filter_offset = params.weights_offset;
1644 const int32 output_offset = params.output_offset;
1645 const int32 output_multiplier = params.output_multiplier;
1646 const int output_shift = params.output_shift;
1647 const int32 output_activation_min = params.quantized_activation_min;
1648 const int32 output_activation_max = params.quantized_activation_max;
1649 // This is a copy of the reference implementation. We do not currently have a
1650 // properly optimized version.
1651 (void)gemmlowp_context; // only used in properly optimized code.
1652 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1653 TFLITE_DCHECK_EQ(output_offset, 0);
1654 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1655 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1656
1657 // TODO(b/62193649): This really should be:
1658 // const int batches = ArraySize(output_dims, 1);
1659 // but the current --variable_batch hack consists in overwriting the 3rd
1660 // dimension with the runtime batch size, as we don't keep track for each
1661 // array of which dimension is the batch dimension in it.
1662 const int output_dim_count = output_shape.DimensionsCount();
1663 const int filter_dim_count = filter_shape.DimensionsCount();
1664 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1665 const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
1666 output_shape, output_dim_count - 1);
1667 const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
1668
1669 // Implementation of the fully connected node suited to the inside of an LSTM
1670 // cell. The operands are 8-bit integers, the accumulators are internally
1671 // 32bit integers, and the output is 16-bit fixed-point with 3 integer bits so
1672 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
1673 // is explained in the function comment above.
1674 #ifdef GEMMLOWP_NEON
1675 if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
1676 output_activation_max == 32767) {
1677 if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
1678 GEMVForLstmCellWithSymmetricRange(
1679 input_shape, input_data, filter_shape, filter_data, bias_shape,
1680 bias_data_int32, output_multiplier, output_shift, output_shape,
1681 output_data);
1682 return;
1683 }
1684 if (!(output_depth % 4) && !(accum_depth % 8)) {
1685 GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data,
1686 filter_offset, bias_shape, bias_data_int32,
1687 output_multiplier, output_shift, output_shape,
1688 output_data);
1689 return;
1690 }
1691 }
1692 #endif
1693 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> weights_matrix(
1694 filter_data, output_depth, accum_depth);
1695 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
1696 input_data, accum_depth, batches);
1697 gemmlowp::MatrixMap<int16, gemmlowp::MapOrder::ColMajor> output_matrix(
1698 output_data, output_depth, batches);
1699 typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
1700 ColVectorMap;
1701 ColVectorMap bias_vector(bias_data_int32, output_depth);
1702 gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
1703 bias_addition_stage.bias_vector = bias_vector;
1704 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
1705 scale_stage.result_offset_after_shift = 0;
1706 scale_stage.result_fixedpoint_multiplier = output_multiplier;
1707 // Note that this shift is negated wrt ordinary FC.
1708 scale_stage.result_exponent = output_shift;
1709 gemmlowp::OutputStageClamp clamp_stage;
1710 clamp_stage.min = output_activation_min;
1711 clamp_stage.max = output_activation_max;
1712 gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
1713 auto output_pipeline =
1714 std::make_tuple(bias_addition_stage, scale_stage, clamp_stage,
1715 saturating_cast_int16_stage);
1716 gemmlowp::GemmWithOutputPipeline<uint8, int16,
1717 gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
1718 gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
1719 filter_offset, input_offset, output_pipeline);
1720 }
1721
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)1722 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
1723 int32 input_offset, const uint8* filter_data,
1724 const Dims<4>& filter_dims, int32 filter_offset,
1725 const int32* bias_data, const Dims<4>& bias_dims,
1726 int32 output_offset, int32 output_multiplier,
1727 int output_shift, int32 output_activation_min,
1728 int32 output_activation_max, uint8* output_data,
1729 const Dims<4>& output_dims,
1730 gemmlowp::GemmContext* gemmlowp_context) {
1731 tflite::FullyConnectedParams op_params;
1732 op_params.input_offset = input_offset;
1733 op_params.weights_offset = filter_offset;
1734 op_params.output_offset = output_offset;
1735 op_params.output_multiplier = output_multiplier;
1736 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1737 op_params.output_shift = kReverseShift * output_shift;
1738 op_params.quantized_activation_min = output_activation_min;
1739 op_params.quantized_activation_max = output_activation_max;
1740
1741 FullyConnected(op_params, DimsToShape(input_dims), input_data,
1742 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
1743 bias_data, DimsToShape(output_dims), output_data,
1744 gemmlowp_context);
1745 }
1746
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data_int32,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)1747 inline void FullyConnected(
1748 const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
1749 const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
1750 const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
1751 int32 output_multiplier, int output_shift, int32 output_activation_min,
1752 int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
1753 gemmlowp::GemmContext* gemmlowp_context) {
1754 tflite::FullyConnectedParams op_params;
1755 op_params.input_offset = input_offset;
1756 op_params.weights_offset = filter_offset;
1757 op_params.output_offset = output_offset;
1758 op_params.output_multiplier = output_multiplier;
1759 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1760 op_params.output_shift = kReverseShift * output_shift;
1761 op_params.quantized_activation_min = output_activation_min;
1762 op_params.quantized_activation_max = output_activation_max;
1763
1764 FullyConnected(op_params, DimsToShape(input_dims), input_data,
1765 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
1766 bias_data_int32, DimsToShape(output_dims), output_data,
1767 gemmlowp_context);
1768 }
1769
1770 // legacy, for compatibility with old checked-in code
1771 template <FusedActivationFunctionType Ac>
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)1772 void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
1773 int32 input_offset, const uint8* filter_data,
1774 const Dims<4>& filter_dims, int32 filter_offset,
1775 const int32* bias_data, const Dims<4>& bias_dims,
1776 int32 output_offset, int32 output_multiplier,
1777 int output_shift, int32 output_activation_min,
1778 int32 output_activation_max, uint8* output_data,
1779 const Dims<4>& output_dims,
1780 gemmlowp::GemmContext* gemmlowp_context) {
1781 static_assert(Ac == FusedActivationFunctionType::kNone ||
1782 Ac == FusedActivationFunctionType::kRelu ||
1783 Ac == FusedActivationFunctionType::kRelu6 ||
1784 Ac == FusedActivationFunctionType::kRelu1,
1785 "");
1786 FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
1787 filter_offset, bias_data, bias_dims, output_offset,
1788 output_multiplier, output_shift, output_activation_min,
1789 output_activation_max, output_data, output_dims,
1790 gemmlowp_context);
1791 }
1792
1793 #ifdef USE_NEON
LegacyInt8FullyConnectedAsGEMVWorkerImpl(const RuntimeShape & input_shape,const int8_t * input_data,int32 input_offset,const RuntimeShape & filter_shape,const int8_t * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,int8_t * output_data,int row_start,int row_end)1794 inline void LegacyInt8FullyConnectedAsGEMVWorkerImpl(
1795 const RuntimeShape& input_shape, const int8_t* input_data,
1796 int32 input_offset, const RuntimeShape& filter_shape,
1797 const int8_t* filter_data, int32 filter_offset,
1798 const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
1799 int32 output_multiplier, int output_shift, int32 output_activation_min,
1800 int32 output_activation_max, const RuntimeShape& output_shape,
1801 int8_t* output_data, int row_start, int row_end) {
1802 ruy::profiler::ScopeLabel label("FullyConnectedAsGEMVInt8/8bit");
1803 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1804 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1805 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1806 const int output_dim_count = output_shape.DimensionsCount();
1807 TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
1808 const int input_size = FlatSizeSkipDim(input_shape, 0);
1809 static constexpr int kPeel = 4;
1810 const bool shift_left = (output_shift > 0);
1811 TFLITE_DCHECK_GE(row_end - row_start, kPeel);
1812
1813 for (int out = row_start; out < row_end; out += kPeel) {
1814 out = std::min(out, row_end - kPeel);
1815 int32x4_t acc0 = vdupq_n_s32(0);
1816 int32x4_t acc1 = acc0;
1817 int32x4_t acc2 = acc0;
1818 int32x4_t acc3 = acc0;
1819 const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
1820 const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
1821 int in = 0;
1822 for (; in <= input_size - 16; in += 16) {
1823 const int8x16_t input_val_s8 = vld1q_s8(input_data + in);
1824 const int8_t* filter_ptr = filter_data + in + out * input_size;
1825 int8x16_t filter_val_s8_0 = vld1q_s8(filter_ptr);
1826 filter_ptr += input_size;
1827 int8x16_t filter_val_s8_1 = vld1q_s8(filter_ptr);
1828 filter_ptr += input_size;
1829 int8x16_t filter_val_s8_2 = vld1q_s8(filter_ptr);
1830 filter_ptr += input_size;
1831 int8x16_t filter_val_s8_3 = vld1q_s8(filter_ptr);
1832 int16x8_t input_val_0, input_val_1;
1833 int8x8_t low = vget_low_s8(input_val_s8);
1834 int8x8_t high = vget_high_s8(input_val_s8);
1835 input_val_0 = vmovl_s8(low);
1836 input_val_1 = vmovl_s8(high);
1837 input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
1838 input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
1839 low = vget_low_s8(filter_val_s8_0);
1840 high = vget_high_s8(filter_val_s8_0);
1841 int16x8_t filter_val_0_0 = vmovl_s8(low);
1842 int16x8_t filter_val_0_1 = vmovl_s8(high);
1843 filter_val_0_0 = vaddq_s16(filter_val_0_0, filter_offset_vec);
1844 filter_val_0_1 = vaddq_s16(filter_val_0_1, filter_offset_vec);
1845 low = vget_low_s8(filter_val_s8_1);
1846 high = vget_high_s8(filter_val_s8_1);
1847 int16x8_t filter_val_1_0 = vmovl_s8(low);
1848 int16x8_t filter_val_1_1 = vmovl_s8(high);
1849 filter_val_1_0 = vaddq_s16(filter_val_1_0, filter_offset_vec);
1850 filter_val_1_1 = vaddq_s16(filter_val_1_1, filter_offset_vec);
1851 low = vget_low_s8(filter_val_s8_2);
1852 high = vget_high_s8(filter_val_s8_2);
1853 int16x8_t filter_val_2_0 = vmovl_s8(low);
1854 int16x8_t filter_val_2_1 = vmovl_s8(high);
1855 filter_val_2_0 = vaddq_s16(filter_val_2_0, filter_offset_vec);
1856 filter_val_2_1 = vaddq_s16(filter_val_2_1, filter_offset_vec);
1857 low = vget_low_s8(filter_val_s8_3);
1858 high = vget_high_s8(filter_val_s8_3);
1859 int16x8_t filter_val_3_0 = vmovl_s8(low);
1860 int16x8_t filter_val_3_1 = vmovl_s8(high);
1861 filter_val_3_0 = vaddq_s16(filter_val_3_0, filter_offset_vec);
1862 filter_val_3_1 = vaddq_s16(filter_val_3_1, filter_offset_vec);
1863 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_0),
1864 vget_low_s16(input_val_0));
1865 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_0),
1866 vget_low_s16(input_val_0));
1867 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_0),
1868 vget_low_s16(input_val_0));
1869 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_0),
1870 vget_low_s16(input_val_0));
1871 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_1),
1872 vget_low_s16(input_val_1));
1873 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_1),
1874 vget_low_s16(input_val_1));
1875 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_1),
1876 vget_low_s16(input_val_1));
1877 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_1),
1878 vget_low_s16(input_val_1));
1879 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_0),
1880 vget_high_s16(input_val_0));
1881 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_0),
1882 vget_high_s16(input_val_0));
1883 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_0),
1884 vget_high_s16(input_val_0));
1885 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_0),
1886 vget_high_s16(input_val_0));
1887 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_1),
1888 vget_high_s16(input_val_1));
1889 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_1),
1890 vget_high_s16(input_val_1));
1891 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_1),
1892 vget_high_s16(input_val_1));
1893 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_1),
1894 vget_high_s16(input_val_1));
1895 }
1896 for (; in <= input_size - 8; in += 8) {
1897 const int8x8_t input_val_s8 = vld1_s8(input_data + in);
1898 const int8_t* filter_ptr = filter_data + in + out * input_size;
1899 int8x8_t filter_val_s8_0 = vld1_s8(filter_ptr);
1900 filter_ptr += input_size;
1901 int8x8_t filter_val_s8_1 = vld1_s8(filter_ptr);
1902 filter_ptr += input_size;
1903 int8x8_t filter_val_s8_2 = vld1_s8(filter_ptr);
1904 filter_ptr += input_size;
1905 int8x8_t filter_val_s8_3 = vld1_s8(filter_ptr);
1906 int16x8_t input_val = vmovl_s8(input_val_s8);
1907 input_val = vaddq_s16(input_val, input_offset_vec);
1908 int16x8_t filter_val_0 = vmovl_s8(filter_val_s8_0);
1909 filter_val_0 = vaddq_s16(filter_val_0, filter_offset_vec);
1910 int16x8_t filter_val_1 = vmovl_s8(filter_val_s8_1);
1911 filter_val_1 = vaddq_s16(filter_val_1, filter_offset_vec);
1912 int16x8_t filter_val_2 = vmovl_s8(filter_val_s8_2);
1913 filter_val_2 = vaddq_s16(filter_val_2, filter_offset_vec);
1914 int16x8_t filter_val_3 = vmovl_s8(filter_val_s8_3);
1915 filter_val_3 = vaddq_s16(filter_val_3, filter_offset_vec);
1916 acc0 =
1917 vmlal_s16(acc0, vget_low_s16(filter_val_0), vget_low_s16(input_val));
1918 acc1 =
1919 vmlal_s16(acc1, vget_low_s16(filter_val_1), vget_low_s16(input_val));
1920 acc2 =
1921 vmlal_s16(acc2, vget_low_s16(filter_val_2), vget_low_s16(input_val));
1922 acc3 =
1923 vmlal_s16(acc3, vget_low_s16(filter_val_3), vget_low_s16(input_val));
1924 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
1925 vget_high_s16(input_val));
1926 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
1927 vget_high_s16(input_val));
1928 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
1929 vget_high_s16(input_val));
1930 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
1931 vget_high_s16(input_val));
1932 }
1933 if (in < input_size) {
1934 int32 buf[16];
1935 vst1q_s32(buf + 0, acc0);
1936 vst1q_s32(buf + 4, acc1);
1937 vst1q_s32(buf + 8, acc2);
1938 vst1q_s32(buf + 12, acc3);
1939 for (; in < input_size; in++) {
1940 int lane = (in + 8 - input_size) % 4;
1941 const int32 input_val = input_data[in] + input_offset;
1942 for (int k = 0; k < kPeel; k++) {
1943 int32 filter_val =
1944 filter_data[in + (out + k) * input_size] + filter_offset;
1945 buf[lane + 4 * k] += filter_val * input_val;
1946 }
1947 }
1948 acc0 = vld1q_s32(buf + 0);
1949 acc1 = vld1q_s32(buf + 4);
1950 acc2 = vld1q_s32(buf + 8);
1951 acc3 = vld1q_s32(buf + 12);
1952 }
1953
1954 // Horizontally reduce accumulators
1955 int32x2_t pairwise_reduced_acc_0 =
1956 vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
1957 int32x2_t pairwise_reduced_acc_1 =
1958 vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1));
1959 int32x2_t pairwise_reduced_acc_2 =
1960 vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2));
1961 int32x2_t pairwise_reduced_acc_3 =
1962 vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3));
1963 const int32x2_t reduced_lo =
1964 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1965 const int32x2_t reduced_hi =
1966 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1967 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1968 // Add bias values.
1969 int32x4_t bias_vec = vld1q_s32(bias_data + out);
1970 reduced = vaddq_s32(reduced, bias_vec);
1971 if (shift_left) {
1972 const int32 multiplier_power_of_two = 1 << output_shift;
1973 reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
1974 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1975 } else {
1976 // Multiply by the fixed-point multiplier.
1977 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1978 // Rounding-shift-right.
1979 using gemmlowp::RoundingDivideByPOT;
1980 reduced = RoundingDivideByPOT(reduced, -output_shift);
1981 }
1982 // Add the output offset.
1983 const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
1984 reduced = vaddq_s32(reduced, output_offset_vec);
1985 // Narrow values down to 16 bit signed.
1986 const int16x4_t res16 = vqmovn_s32(reduced);
1987 // Narrow values down to 8 bit signed, saturating.
1988 int8x8_t res8 = vqmovn_s16(vcombine_s16(res16, res16));
1989 // Apply the clamping from the activation function
1990 res8 = vmax_s8(res8, vdup_n_s8(output_activation_min));
1991 res8 = vmin_s8(res8, vdup_n_s8(output_activation_max));
1992 // Store results to destination.
1993 vst1_lane_s8(output_data + out + 0, res8, 0);
1994 vst1_lane_s8(output_data + out + 1, res8, 1);
1995 vst1_lane_s8(output_data + out + 2, res8, 2);
1996 vst1_lane_s8(output_data + out + 3, res8, 3);
1997 }
1998 }
1999
2000 struct LegacyInt8FullyConnectedAsGEMVWorkerTask : public gemmlowp::Task {
LegacyInt8FullyConnectedAsGEMVWorkerTaskLegacyInt8FullyConnectedAsGEMVWorkerTask2001 LegacyInt8FullyConnectedAsGEMVWorkerTask(
2002 const RuntimeShape& input_shape, const int8_t* input_data,
2003 int32 input_offset, const RuntimeShape& filter_shape,
2004 const int8_t* filter_data, int32 filter_offset,
2005 const RuntimeShape& bias_shape, const int32* bias_data,
2006 int32 output_offset, int32 output_multiplier, int output_shift,
2007 int32 output_activation_min, int32 output_activation_max,
2008 const RuntimeShape& output_shape, int8_t* output_data, int row_start,
2009 int row_end)
2010 : input_shape_(input_shape),
2011 input_data_(input_data),
2012 input_offset_(input_offset),
2013 filter_shape_(filter_shape),
2014 filter_data_(filter_data),
2015 filter_offset_(filter_offset),
2016 bias_shape_(bias_shape),
2017 bias_data_(bias_data),
2018 output_offset_(output_offset),
2019 output_multiplier_(output_multiplier),
2020 output_shift_(output_shift),
2021 output_activation_min_(output_activation_min),
2022 output_activation_max_(output_activation_max),
2023 output_shape_(output_shape),
2024 output_data_(output_data),
2025 row_start_(row_start),
2026 row_end_(row_end) {}
2027
RunLegacyInt8FullyConnectedAsGEMVWorkerTask2028 void Run() override {
2029 LegacyInt8FullyConnectedAsGEMVWorkerImpl(
2030 input_shape_, input_data_, input_offset_, filter_shape_, filter_data_,
2031 filter_offset_, bias_shape_, bias_data_, output_offset_,
2032 output_multiplier_, output_shift_, output_activation_min_,
2033 output_activation_max_, output_shape_, output_data_, row_start_,
2034 row_end_);
2035 }
2036
2037 const RuntimeShape& input_shape_;
2038 const int8_t* input_data_;
2039 int32 input_offset_;
2040 const RuntimeShape& filter_shape_;
2041 const int8_t* filter_data_;
2042 int32 filter_offset_;
2043 const RuntimeShape& bias_shape_;
2044 const int32* bias_data_;
2045 int32 output_offset_;
2046 int32 output_multiplier_;
2047 int output_shift_;
2048 int32 output_activation_min_;
2049 int32 output_activation_max_;
2050 const RuntimeShape& output_shape_;
2051 int8_t* output_data_;
2052 int row_start_;
2053 int row_end_;
2054 };
2055
LegacyInt8FullyConnectedAsGEMV(const RuntimeShape & input_shape,const int8_t * input_data,int32 input_offset,const RuntimeShape & filter_shape,const int8_t * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,int8_t * output_data,gemmlowp::GemmContext * gemmlowp_context)2056 inline void LegacyInt8FullyConnectedAsGEMV(
2057 const RuntimeShape& input_shape, const int8_t* input_data,
2058 int32 input_offset, const RuntimeShape& filter_shape,
2059 const int8_t* filter_data, int32 filter_offset,
2060 const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
2061 int32 output_multiplier, int output_shift, int32 output_activation_min,
2062 int32 output_activation_max, const RuntimeShape& output_shape,
2063 int8_t* output_data, gemmlowp::GemmContext* gemmlowp_context) {
2064 const int output_dim_count = output_shape.DimensionsCount();
2065 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
2066 const int output_rows = output_shape.Dims(output_dim_count - 1);
2067 const int input_size = FlatSizeSkipDim(input_shape, 0);
2068 static constexpr int kKernelRows = 4;
2069 const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
2070 gemmlowp_context->max_num_threads(), output_rows, batches, input_size);
2071 if (thread_count == 1) {
2072 // Single-thread case: do the computation on the current thread, don't
2073 // use a threadpool
2074 LegacyInt8FullyConnectedAsGEMVWorkerImpl(
2075 input_shape, input_data, input_offset, filter_shape, filter_data,
2076 filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
2077 output_shift, output_activation_min, output_activation_max,
2078 output_shape, output_data, 0, output_rows);
2079 return;
2080 }
2081
2082 // Multi-threaded case: use the gemmlowp context's threadpool.
2083 TFLITE_DCHECK_GT(thread_count, 1);
2084 std::vector<LegacyInt8FullyConnectedAsGEMVWorkerTask> tasks;
2085 // TODO(b/131746020) don't create new heap allocations every time.
2086 // At least we make it a single heap allocation by using reserve().
2087 tasks.reserve(thread_count);
2088 const int kRowsPerWorker = gemmlowp::RoundUp<kKernelRows>(
2089 gemmlowp::CeilQuotient(output_rows, thread_count));
2090 int row_start = 0;
2091 for (int i = 0; i < thread_count; ++i) {
2092 int row_end = std::min(output_rows, row_start + kRowsPerWorker);
2093 tasks.emplace_back(input_shape, input_data, input_offset, filter_shape,
2094 filter_data, filter_offset, bias_shape, bias_data,
2095 output_offset, output_multiplier, output_shift,
2096 output_activation_min, output_activation_max,
2097 output_shape, output_data, row_start, row_end);
2098 row_start = row_end;
2099 }
2100 TFLITE_DCHECK_EQ(row_start, output_rows);
2101 gemmlowp_context->workers_pool()->Execute(tasks.size(), tasks.data());
2102 }
2103 #endif // USE_NEON
2104
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & filter_shape,const int8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int8 * output_data,gemmlowp::GemmContext * gemmlowp_context)2105 inline void FullyConnected(
2106 const FullyConnectedParams& params, const RuntimeShape& input_shape,
2107 const int8* input_data, const RuntimeShape& filter_shape,
2108 const int8* filter_data, const RuntimeShape& bias_shape,
2109 const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
2110 gemmlowp::GemmContext* gemmlowp_context) {
2111 ruy::profiler::ScopeLabel label("FullyConnectedInt8/8bit");
2112
2113 #ifdef USE_NEON
2114 const int32 input_offset = params.input_offset;
2115 const int32 filter_offset = params.weights_offset;
2116 const int32 output_offset = params.output_offset;
2117 const int32 output_multiplier = params.output_multiplier;
2118 const int output_shift = params.output_shift;
2119 const int32 output_activation_min = params.quantized_activation_min;
2120 const int32 output_activation_max = params.quantized_activation_max;
2121 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
2122 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
2123 // TODO(b/62193649): This really should be:
2124 // const int batches = ArraySize(output_dims, 1);
2125 // but the current --variable_batch hack consists in overwriting the 3rd
2126 // dimension with the runtime batch size, as we don't keep track for each
2127 // array of which dimension is the batch dimension in it.
2128 const int output_dim_count = output_shape.DimensionsCount();
2129 const int filter_dim_count = filter_shape.DimensionsCount();
2130 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
2131 if (batches == 1) {
2132 const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
2133 output_shape, output_dim_count - 1);
2134 if (output_size >= 4) {
2135 return LegacyInt8FullyConnectedAsGEMV(
2136 input_shape, input_data, input_offset, filter_shape, filter_data,
2137 filter_offset, bias_shape, bias_data, output_offset,
2138 output_multiplier, output_shift, output_activation_min,
2139 output_activation_max, output_shape, output_data, gemmlowp_context);
2140 }
2141 }
2142 #endif // USE_NEON
2143
2144 #ifdef GEMMLOWP_NEON
2145 const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
2146 const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
2147 TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
2148 const int output_rows = output_shape.Dims(output_dim_count - 1);
2149 TFLITE_DCHECK_EQ(output_rows, filter_rows);
2150 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
2151
2152 gemmlowp::MatrixMap<const int8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2153 filter_data, output_rows, filter_cols, filter_cols);
2154 gemmlowp::MatrixMap<const int8, gemmlowp::MapOrder::ColMajor> input_matrix(
2155 input_data, filter_cols, batches, filter_cols);
2156 gemmlowp::MatrixMap<int8, gemmlowp::MapOrder::ColMajor> output_matrix(
2157 output_data, output_rows, batches, output_rows);
2158 const auto& output_pipeline = GemmlowpOutputPipelineInt8::MakeExp(
2159 bias_data, output_rows, output_offset, output_multiplier, output_shift,
2160 output_activation_min, output_activation_max);
2161
2162 gemmlowp::GemmWithOutputPipeline<
2163 int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>(
2164 gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
2165 filter_offset, input_offset, output_pipeline);
2166 return;
2167 #endif // GEMMLOWP_NEON
2168
2169 // If both GEMMLOWP_NEON && NEON paths are skipped, fallback to reference
2170 // implementation.
2171 reference_integer_ops::FullyConnected(params, input_shape, input_data,
2172 filter_shape, filter_data, bias_shape,
2173 bias_data, output_shape, output_data);
2174 }
2175
2176 struct LegacyShuffledFullyConnectedWorkerTask : gemmlowp::Task {
LegacyShuffledFullyConnectedWorkerTaskLegacyShuffledFullyConnectedWorkerTask2177 LegacyShuffledFullyConnectedWorkerTask(const uint8* input_data,
2178 const int8* shuffled_weights_data,
2179 int batches, int output_depth,
2180 int output_stride, int accum_depth,
2181 const int32* bias_data,
2182 int32 output_multiplier,
2183 int output_shift, int16* output_data)
2184 : input_data_(input_data),
2185 shuffled_weights_data_(shuffled_weights_data),
2186 batches_(batches),
2187 output_depth_(output_depth),
2188 output_stride_(output_stride),
2189 accum_depth_(accum_depth),
2190 bias_data_(bias_data),
2191 output_multiplier_(output_multiplier),
2192 output_shift_(output_shift),
2193 output_data_(output_data) {}
2194
RunLegacyShuffledFullyConnectedWorkerTask2195 void Run() override {
2196 ShuffledFullyConnectedWorkerImpl(
2197 input_data_, shuffled_weights_data_, batches_, output_depth_,
2198 output_stride_, accum_depth_, bias_data_, output_multiplier_,
2199 output_shift_, output_data_);
2200 }
2201
2202 const uint8* input_data_;
2203 const int8* shuffled_weights_data_;
2204 int batches_;
2205 int output_depth_;
2206 int output_stride_;
2207 int accum_depth_;
2208 const int32* bias_data_;
2209 int32 output_multiplier_;
2210 int output_shift_;
2211 int16* output_data_;
2212 };
2213
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemmlowp_context)2214 inline void ShuffledFullyConnected(
2215 const FullyConnectedParams& params, const RuntimeShape& input_shape,
2216 const uint8* input_data, const RuntimeShape& weights_shape,
2217 const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
2218 const int32* bias_data, const RuntimeShape& output_shape,
2219 int16* output_data, uint8* shuffled_input_workspace_data,
2220 gemmlowp::GemmContext* gemmlowp_context) {
2221 ruy::profiler::ScopeLabel label("ShuffledFullyConnected/8bit");
2222 const int32 output_multiplier = params.output_multiplier;
2223 const int output_shift = params.output_shift;
2224 const int32 output_activation_min = params.quantized_activation_min;
2225 const int32 output_activation_max = params.quantized_activation_max;
2226 (void)gemmlowp_context; // only used in optimized code.
2227 TFLITE_DCHECK_EQ(output_activation_min, -32768);
2228 TFLITE_DCHECK_EQ(output_activation_max, 32767);
2229 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
2230 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2231 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
2232 // TODO(b/62193649): This really should be:
2233 // const int batches = ArraySize(output_dims, 1);
2234 // but the current --variable_batch hack consists in overwriting the 3rd
2235 // dimension with the runtime batch size, as we don't keep track for each
2236 // array of which dimension is the batch dimension in it.
2237 const int output_dim_count = output_shape.DimensionsCount();
2238 const int weights_dim_count = weights_shape.DimensionsCount();
2239 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
2240 const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
2241 output_shape, output_dim_count - 1);
2242 const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
2243 TFLITE_DCHECK((accum_depth % 16) == 0);
2244 TFLITE_DCHECK((output_depth % 4) == 0);
2245 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
2246 // so that just reinterpreting them as int8 values is equivalent to
2247 // subtracting 128 from them, thus implementing for free the subtraction of
2248 // the zero_point value 128.
2249 const int8* int8_shuffled_weights_data =
2250 reinterpret_cast<const int8*>(shuffled_weights_data);
2251
2252 // Shuffling and xoring of input activations into the workspace buffer
2253 if (batches == 1) {
2254 #ifdef USE_NEON
2255 const uint8x16_t signbit = vdupq_n_u8(0x80);
2256 for (int i = 0; i < accum_depth; i += 16) {
2257 uint8x16_t val = vld1q_u8(input_data + i);
2258 val = veorq_u8(val, signbit);
2259 vst1q_u8(shuffled_input_workspace_data + i, val);
2260 }
2261 #else
2262 for (int i = 0; i < accum_depth; i++) {
2263 shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
2264 }
2265 #endif
2266 } else if (batches == 4) {
2267 uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
2268 int c = 0;
2269 #ifdef USE_NEON
2270 const uint8x16_t signbit = vdupq_n_u8(0x80);
2271 for (c = 0; c < accum_depth; c += 16) {
2272 const uint8* src_data_ptr = input_data + c;
2273 uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth);
2274 uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth);
2275 uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth);
2276 uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth);
2277 val0 = veorq_u8(val0, signbit);
2278 val1 = veorq_u8(val1, signbit);
2279 val2 = veorq_u8(val2, signbit);
2280 val3 = veorq_u8(val3, signbit);
2281 vst1q_u8(shuffled_input_workspace_ptr + 0, val0);
2282 vst1q_u8(shuffled_input_workspace_ptr + 16, val1);
2283 vst1q_u8(shuffled_input_workspace_ptr + 32, val2);
2284 vst1q_u8(shuffled_input_workspace_ptr + 48, val3);
2285 shuffled_input_workspace_ptr += 64;
2286 }
2287 #else
2288 for (c = 0; c < accum_depth; c += 16) {
2289 for (int b = 0; b < 4; b++) {
2290 const uint8* src_data_ptr = input_data + b * accum_depth + c;
2291 for (int j = 0; j < 16; j++) {
2292 uint8 src_val = *src_data_ptr++;
2293 // Flip the sign bit, so that the kernel will only need to
2294 // reinterpret these uint8 values as int8, getting for free the
2295 // subtraction of the zero_point value 128.
2296 uint8 dst_val = src_val ^ 0x80;
2297 *shuffled_input_workspace_ptr++ = dst_val;
2298 }
2299 }
2300 }
2301 #endif
2302 } else {
2303 TFLITE_DCHECK(false);
2304 return;
2305 }
2306
2307 static constexpr int kKernelRows = 4;
2308 const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
2309 gemmlowp_context->max_num_threads(), output_depth, batches, accum_depth);
2310 if (thread_count == 1) {
2311 // Single-thread case: do the computation on the current thread, don't
2312 // use a threadpool
2313 ShuffledFullyConnectedWorkerImpl(
2314 shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
2315 output_depth, output_depth, accum_depth, bias_data, output_multiplier,
2316 output_shift, output_data);
2317 return;
2318 }
2319
2320 // Multi-threaded case: use the gemmlowp context's threadpool.
2321 TFLITE_DCHECK_GT(thread_count, 1);
2322 std::vector<gemmlowp::Task*> tasks(thread_count);
2323 const int kRowsPerWorker = gemmlowp::RoundUp<kKernelRows>(
2324 gemmlowp::CeilQuotient(output_depth, thread_count));
2325 int row_start = 0;
2326 for (int i = 0; i < thread_count; i++) {
2327 int row_end = std::min(output_depth, row_start + kRowsPerWorker);
2328 tasks[i] = new LegacyShuffledFullyConnectedWorkerTask(
2329 shuffled_input_workspace_data,
2330 int8_shuffled_weights_data + row_start * accum_depth, batches,
2331 row_end - row_start, output_depth, accum_depth, bias_data + row_start,
2332 output_multiplier, output_shift, output_data + row_start);
2333 row_start = row_end;
2334 }
2335 TFLITE_DCHECK_EQ(row_start, output_depth);
2336 gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
2337 }
2338
ShuffledFullyConnected(const uint8 * input_data,const Dims<4> & input_dims,const uint8 * shuffled_weights_data,const Dims<4> & weights_dims,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemmlowp_context)2339 inline void ShuffledFullyConnected(
2340 const uint8* input_data, const Dims<4>& input_dims,
2341 const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
2342 const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
2343 int output_shift, int32 output_activation_min, int32 output_activation_max,
2344 int16* output_data, const Dims<4>& output_dims,
2345 uint8* shuffled_input_workspace_data,
2346 gemmlowp::GemmContext* gemmlowp_context) {
2347 tflite::FullyConnectedParams op_params;
2348 op_params.output_multiplier = output_multiplier;
2349 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
2350 op_params.output_shift = kReverseShift * output_shift;
2351 op_params.quantized_activation_min = output_activation_min;
2352 op_params.quantized_activation_max = output_activation_max;
2353
2354 ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
2355 DimsToShape(weights_dims), shuffled_weights_data,
2356 DimsToShape(bias_dims), bias_data,
2357 DimsToShape(output_dims), output_data,
2358 shuffled_input_workspace_data, gemmlowp_context);
2359 }
2360
2361 template <typename T>
ExtractPatchIntoBufferColumn(const Dims<4> & input_dims,int w,int h,int b,int kheight,int kwidth,int stride_width,int stride_height,int pad_width,int pad_height,int in_width,int in_height,int in_depth,int single_buffer_length,int buffer_id,const T * in_data,T * conv_buffer_data,uint8 zero_byte)2362 inline void ExtractPatchIntoBufferColumn(
2363 const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
2364 int stride_width, int stride_height, int pad_width, int pad_height,
2365 int in_width, int in_height, int in_depth, int single_buffer_length,
2366 int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) {
2367 ExtractPatchIntoBufferColumn(
2368 DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width,
2369 stride_height, pad_width, pad_height, in_width, in_height, in_depth,
2370 single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte);
2371 }
2372
2373 template <typename T>
DilatedIm2col(const T * input_data,const Dims<4> & input_dims,const Dims<4> & filter_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,const Dims<4> & output_dims,uint8 zero_byte,T * im2col_data)2374 void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
2375 const Dims<4>& filter_dims, int stride_width,
2376 int stride_height, int dilation_width_factor,
2377 int dilation_height_factor, int pad_width, int pad_height,
2378 const Dims<4>& output_dims, uint8 zero_byte,
2379 T* im2col_data) {
2380 tflite::ConvParams op_params;
2381 // Padding type is ignored, but still set.
2382 op_params.padding_type = PaddingType::kSame;
2383 op_params.padding_values.width = pad_width;
2384 op_params.padding_values.height = pad_height;
2385 op_params.stride_width = stride_width;
2386 op_params.stride_height = stride_height;
2387 op_params.dilation_width_factor = dilation_width_factor;
2388 op_params.dilation_height_factor = dilation_height_factor;
2389
2390 DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
2391 DimsToShape(filter_dims), DimsToShape(output_dims),
2392 im2col_data);
2393 }
2394
2395 template <typename T>
Im2col(const T * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kheight,int kwidth,uint8 zero_byte,T * output_data,const Dims<4> & output_dims)2396 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
2397 int stride_height, int pad_width, int pad_height, int kheight,
2398 int kwidth, uint8 zero_byte, T* output_data,
2399 const Dims<4>& output_dims) {
2400 tflite::ConvParams op_params;
2401 // Padding type is ignored, but still set.
2402 op_params.padding_type = PaddingType::kSame;
2403 op_params.padding_values.width = pad_width;
2404 op_params.padding_values.height = pad_height;
2405 op_params.stride_width = stride_width;
2406 op_params.stride_height = stride_height;
2407 op_params.dilation_width_factor = 1;
2408 op_params.dilation_height_factor = 1;
2409
2410 Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims),
2411 input_data, DimsToShape(output_dims), output_data);
2412 }
2413
2414 // legacy, for compatibility with old checked-in code
2415 template <typename T>
Im2col(const T * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int kheight,int kwidth,uint8 zero_byte,T * output_data,const Dims<4> & output_dims)2416 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
2417 int pad_width, int pad_height, int kheight, int kwidth,
2418 uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
2419 Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
2420 kwidth, zero_byte, output_data, output_dims);
2421 }
2422
Conv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)2423 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
2424 const float* input_data, const RuntimeShape& filter_shape,
2425 const float* filter_data, const RuntimeShape& bias_shape,
2426 const float* bias_data, const RuntimeShape& output_shape,
2427 float* output_data, const RuntimeShape& im2col_shape,
2428 float* im2col_data) {
2429 const int stride_width = params.stride_width;
2430 const int stride_height = params.stride_height;
2431 const int dilation_width_factor = params.dilation_width_factor;
2432 const int dilation_height_factor = params.dilation_height_factor;
2433 const float output_activation_min = params.float_activation_min;
2434 const float output_activation_max = params.float_activation_max;
2435 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2436 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2437 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2438
2439 (void)im2col_data;
2440 (void)im2col_shape;
2441 ruy::profiler::ScopeLabel label("Conv");
2442
2443 // NB: the float 0.0f value is represented by all zero bytes.
2444 const uint8 float_zero_byte = 0x00;
2445 const float* gemm_input_data = nullptr;
2446 const RuntimeShape* gemm_input_shape = nullptr;
2447 const int filter_width = filter_shape.Dims(2);
2448 const int filter_height = filter_shape.Dims(1);
2449 const bool need_dilated_im2col =
2450 dilation_width_factor != 1 || dilation_height_factor != 1;
2451 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2452 filter_width != 1 || filter_height != 1;
2453 if (need_dilated_im2col) {
2454 DilatedIm2col(params, float_zero_byte, input_shape, input_data,
2455 filter_shape, output_shape, im2col_data);
2456 gemm_input_data = im2col_data;
2457 gemm_input_shape = &im2col_shape;
2458 } else if (need_im2col) {
2459 TFLITE_DCHECK(im2col_data);
2460 Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
2461 input_data, im2col_shape, im2col_data);
2462 gemm_input_data = im2col_data;
2463 gemm_input_shape = &im2col_shape;
2464 } else {
2465 // TODO(aselle): We need to make sure to not send im2col if it is not
2466 // needed.
2467 TFLITE_DCHECK(!im2col_data);
2468 gemm_input_data = input_data;
2469 gemm_input_shape = &input_shape;
2470 }
2471
2472 // The following code computes matrix multiplication c = a * transponse(b)
2473 // with CBLAS, where:
2474 // * `a` is a matrix with dimensions (m, k).
2475 // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
2476 // * `c` is a matrix with dimensions (m, n).
2477 // The naming of variables are aligned with CBLAS specification here.
2478 const float* a = gemm_input_data;
2479 const float* b = filter_data;
2480 float* c = output_data;
2481 const int gemm_input_dims = gemm_input_shape->DimensionsCount();
2482 int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
2483 int n = output_shape.Dims(3);
2484 int k = gemm_input_shape->Dims(gemm_input_dims - 1);
2485
2486 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
2487 // The stride of matrix a, b and c respectively.
2488 int stride_a = k;
2489 int stride_b = k;
2490 int stride_c = n;
2491
2492 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
2493 stride_a, b, stride_b, 0.0f, c, stride_c);
2494 #else
2495 // When an optimized CBLAS implementation is not available, fall back
2496 // to using Eigen.
2497 typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
2498 Matrix;
2499 typedef Eigen::Map<Matrix> MatrixRef;
2500 typedef Eigen::Map<const Matrix> ConstMatrixRef;
2501
2502 MatrixRef matrix_c(c, m, n);
2503 ConstMatrixRef matrix_a(a, m, k);
2504 ConstMatrixRef matrix_b(b, n, k);
2505
2506 // The following special casing for when a or b is a vector is required
2507 // as Eigen seem to fail to make this optimization on its own.
2508 if (n == 1) {
2509 ruy::profiler::ScopeLabel label("GEMV");
2510 matrix_c.col(0).noalias() = matrix_a * matrix_b.row(0).transpose();
2511 } else if (m == 1) {
2512 ruy::profiler::ScopeLabel label("GEMV");
2513 matrix_c.row(0).noalias() = matrix_a.row(0) * matrix_b.transpose();
2514 } else {
2515 ruy::profiler::ScopeLabel label("GEMM");
2516 matrix_c.noalias() = matrix_a * matrix_b.transpose();
2517 }
2518
2519 #endif // defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
2520
2521 optimized_ops::AddBiasAndEvalActivationFunction(
2522 output_activation_min, output_activation_max, bias_shape, bias_data,
2523 output_shape, output_data);
2524 }
2525
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2526 inline void Conv(const float* input_data, const Dims<4>& input_dims,
2527 const float* filter_data, const Dims<4>& filter_dims,
2528 const float* bias_data, const Dims<4>& bias_dims,
2529 int stride_width, int stride_height, int dilation_width_factor,
2530 int dilation_height_factor, int pad_width, int pad_height,
2531 float output_activation_min, float output_activation_max,
2532 float* output_data, const Dims<4>& output_dims,
2533 float* im2col_data, const Dims<4>& im2col_dims) {
2534 tflite::ConvParams op_params;
2535 // Padding type is ignored, but still set.
2536 op_params.padding_type = PaddingType::kSame;
2537 op_params.padding_values.width = pad_width;
2538 op_params.padding_values.height = pad_height;
2539 op_params.stride_width = stride_width;
2540 op_params.stride_height = stride_height;
2541 op_params.dilation_width_factor = dilation_width_factor;
2542 op_params.dilation_height_factor = dilation_height_factor;
2543 op_params.float_activation_min = output_activation_min;
2544 op_params.float_activation_max = output_activation_max;
2545
2546 Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
2547 filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
2548 output_data, DimsToShape(im2col_dims), im2col_data);
2549 }
2550
HybridConv(const int8_t * input_data,const Dims<4> & input_dims,const int8_t * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * scaling_factors_ptr,float output_activation_min,float output_activation_max,int32_t * scratch_data,const Dims<4> & scratch_dims,float * output_data,const Dims<4> & output_dims,int8_t * im2col_data,const Dims<4> & im2col_dims,CpuBackendContext * context)2551 inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
2552 const int8_t* filter_data, const Dims<4>& filter_dims,
2553 const float* bias_data, const Dims<4>& bias_dims,
2554 int stride_width, int stride_height, int pad_width,
2555 int pad_height, float* scaling_factors_ptr,
2556 float output_activation_min, float output_activation_max,
2557 int32_t* scratch_data, const Dims<4>& scratch_dims,
2558 float* output_data, const Dims<4>& output_dims,
2559 int8_t* im2col_data, const Dims<4>& im2col_dims,
2560 CpuBackendContext* context) {
2561 tflite::ConvParams op_params;
2562 // Padding type is ignored, but still set.
2563 op_params.padding_type = PaddingType::kSame;
2564 op_params.padding_values.width = pad_width;
2565 op_params.padding_values.height = pad_height;
2566 op_params.stride_width = stride_width;
2567 op_params.stride_height = stride_height;
2568 op_params.float_activation_min = output_activation_min;
2569 op_params.float_activation_max = output_activation_max;
2570
2571 HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
2572 input_data, DimsToShape(filter_dims), filter_data,
2573 DimsToShape(bias_dims), bias_data, DimsToShape(scratch_dims),
2574 scratch_data, DimsToShape(output_dims), output_data,
2575 DimsToShape(im2col_dims), im2col_data, context);
2576 }
2577
2578 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2579 void Conv(const float* input_data, const Dims<4>& input_dims,
2580 const float* filter_data, const Dims<4>& filter_dims,
2581 const float* bias_data, const Dims<4>& bias_dims, int stride_width,
2582 int stride_height, int dilation_width_factor,
2583 int dilation_height_factor, int pad_width, int pad_height,
2584 float* output_data, const Dims<4>& output_dims, float* im2col_data,
2585 const Dims<4>& im2col_dims) {
2586 float output_activation_min, output_activation_max;
2587 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
2588 Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
2589 stride_width, stride_height, dilation_width_factor,
2590 dilation_height_factor, pad_width, pad_height, output_activation_min,
2591 output_activation_max, output_data, output_dims, im2col_data,
2592 im2col_dims);
2593 }
2594
2595 // legacy, for compatibility with old checked-in code
2596 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2597 void Conv(const float* input_data, const Dims<4>& input_dims,
2598 const float* filter_data, const Dims<4>& filter_dims,
2599 const float* bias_data, const Dims<4>& bias_dims, int stride_width,
2600 int stride_height, int pad_width, int pad_height, float* output_data,
2601 const Dims<4>& output_dims, float* im2col_data,
2602 const Dims<4>& im2col_dims) {
2603 float output_activation_min, output_activation_max;
2604 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
2605 Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
2606 stride_width, stride_height, 1, 1, pad_width, pad_height,
2607 output_activation_min, output_activation_max, output_data, output_dims,
2608 im2col_data, im2col_dims);
2609 }
2610
2611 // legacy, for compatibility with old checked-in code
2612 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2613 void Conv(const float* input_data, const Dims<4>& input_dims,
2614 const float* filter_data, const Dims<4>& filter_dims,
2615 const float* bias_data, const Dims<4>& bias_dims, int stride,
2616 int pad_width, int pad_height, float* output_data,
2617 const Dims<4>& output_dims, float* im2col_data,
2618 const Dims<4>& im2col_dims) {
2619 Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
2620 bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
2621 output_dims, im2col_data, im2col_dims);
2622 }
2623
Conv(const ConvParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,const RuntimeShape & im2col_shape,uint8 * im2col_data,gemmlowp::GemmContext * gemmlowp_context)2624 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
2625 const uint8* input_data, const RuntimeShape& filter_shape,
2626 const uint8* filter_data, const RuntimeShape& bias_shape,
2627 const int32* bias_data, const RuntimeShape& output_shape,
2628 uint8* output_data, const RuntimeShape& im2col_shape,
2629 uint8* im2col_data, gemmlowp::GemmContext* gemmlowp_context) {
2630 ruy::profiler::ScopeLabel label("Conv/8bit");
2631 const int stride_width = params.stride_width;
2632 const int stride_height = params.stride_height;
2633 const int dilation_width_factor = params.dilation_width_factor;
2634 const int dilation_height_factor = params.dilation_height_factor;
2635 const int32 input_offset = params.input_offset;
2636 const int32 filter_offset = params.weights_offset;
2637 const int32 output_offset = params.output_offset;
2638 const int32 output_multiplier = params.output_multiplier;
2639 const int output_shift = params.output_shift;
2640 const int32 output_activation_min = params.quantized_activation_min;
2641 const int32 output_activation_max = params.quantized_activation_max;
2642 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2643 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2644 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2645
2646 const uint8* gemm_input_data = nullptr;
2647 const RuntimeShape* gemm_input_shape = nullptr;
2648 const int filter_width = filter_shape.Dims(2);
2649 const int filter_height = filter_shape.Dims(1);
2650 const bool need_dilated_im2col =
2651 dilation_width_factor != 1 || dilation_height_factor != 1;
2652 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2653 filter_width != 1 || filter_height != 1;
2654 if (need_dilated_im2col) {
2655 TFLITE_DCHECK(im2col_data);
2656 const int input_zero_point = -input_offset;
2657 TFLITE_DCHECK_GE(input_zero_point, 0);
2658 TFLITE_DCHECK_LE(input_zero_point, 255);
2659 DilatedIm2col(params, input_zero_point, input_shape, input_data,
2660 filter_shape, output_shape, im2col_data);
2661 gemm_input_data = im2col_data;
2662 gemm_input_shape = &im2col_shape;
2663 } else if (need_im2col) {
2664 TFLITE_DCHECK(im2col_data);
2665 const int input_zero_point = -input_offset;
2666 TFLITE_DCHECK_GE(input_zero_point, 0);
2667 TFLITE_DCHECK_LE(input_zero_point, 255);
2668 Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
2669 input_data, im2col_shape, im2col_data);
2670 gemm_input_data = im2col_data;
2671 gemm_input_shape = &im2col_shape;
2672 } else {
2673 TFLITE_DCHECK(!im2col_data);
2674 gemm_input_data = input_data;
2675 gemm_input_shape = &input_shape;
2676 }
2677
2678 const int gemm_input_rows = gemm_input_shape->Dims(3);
2679 // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
2680 // The root cause has not yet been identified though. Same applies below for
2681 // the other calls commented out. This is a partial rollback of cl/196819423.
2682 // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
2683 const int gemm_input_cols = gemm_input_shape->Dims(0) *
2684 gemm_input_shape->Dims(1) *
2685 gemm_input_shape->Dims(2);
2686 const int filter_rows = filter_shape.Dims(0);
2687 // See b/79927784.
2688 // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
2689 const int filter_cols =
2690 filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
2691 const int output_rows = output_shape.Dims(3);
2692 // See b/79927784.
2693 // const int output_cols = FlatSizeSkipDim(output_shape, 3);
2694 const int output_cols =
2695 output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
2696 TFLITE_DCHECK_EQ(output_rows, filter_rows);
2697 TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
2698 TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
2699 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
2700
2701 #ifdef USE_NEON
2702 if (gemm_input_cols == 1 && output_rows >= 4) {
2703 RuntimeShape fc_filter_shape{
2704 filter_shape.Dims(0),
2705 filter_shape.Dims(filter_shape.DimensionsCount() - 1)};
2706
2707 return FullyConnectedAsGEMV(
2708 *gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape,
2709 filter_data, filter_offset, bias_shape, bias_data, output_offset,
2710 output_multiplier, output_shift, output_activation_min,
2711 output_activation_max, output_shape, output_data, gemmlowp_context);
2712 }
2713 #endif
2714
2715 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2716 filter_data, filter_rows, filter_cols);
2717 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
2718 gemm_input_data, gemm_input_rows, gemm_input_cols);
2719 gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
2720 output_data, output_rows, output_cols);
2721 const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
2722 bias_data, output_rows, output_offset, output_multiplier, output_shift,
2723 output_activation_min, output_activation_max);
2724 gemmlowp::GemmWithOutputPipeline<uint8, uint8,
2725 gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
2726 gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
2727 filter_offset, input_offset, output_pipeline);
2728 }
2729
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2730 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
2731 int32 input_offset, const uint8* filter_data,
2732 const Dims<4>& filter_dims, int32 filter_offset,
2733 const int32* bias_data, const Dims<4>& bias_dims,
2734 int stride_width, int stride_height, int dilation_width_factor,
2735 int dilation_height_factor, int pad_width, int pad_height,
2736 int32 output_offset, int32 output_multiplier, int output_shift,
2737 int32 output_activation_min, int32 output_activation_max,
2738 uint8* output_data, const Dims<4>& output_dims,
2739 uint8* im2col_data, const Dims<4>& im2col_dims,
2740 gemmlowp::GemmContext* gemmlowp_context) {
2741 tflite::ConvParams op_params;
2742 // Padding type is ignored, but still set.
2743 op_params.padding_type = PaddingType::kSame;
2744 op_params.padding_values.width = pad_width;
2745 op_params.padding_values.height = pad_height;
2746 op_params.stride_width = stride_width;
2747 op_params.stride_height = stride_height;
2748 op_params.dilation_width_factor = dilation_width_factor;
2749 op_params.dilation_height_factor = dilation_height_factor;
2750 op_params.input_offset = input_offset;
2751 op_params.weights_offset = filter_offset;
2752 op_params.output_offset = output_offset;
2753 op_params.output_multiplier = output_multiplier;
2754 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
2755 op_params.output_shift = kReverseShift * output_shift;
2756 op_params.quantized_activation_min = output_activation_min;
2757 op_params.quantized_activation_max = output_activation_max;
2758
2759 Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
2760 filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
2761 output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context);
2762 }
2763
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2764 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
2765 int32 input_offset, const uint8* filter_data,
2766 const Dims<4>& filter_dims, int32 filter_offset,
2767 const int32* bias_data, const Dims<4>& bias_dims,
2768 int stride_width, int stride_height, int pad_width,
2769 int pad_height, int32 output_offset, int32 output_multiplier,
2770 int output_shift, int32 output_activation_min,
2771 int32 output_activation_max, uint8* output_data,
2772 const Dims<4>& output_dims, uint8* im2col_data,
2773 const Dims<4>& im2col_dims,
2774 gemmlowp::GemmContext* gemmlowp_context) {
2775 Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
2776 filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
2777 pad_width, pad_height, output_offset, output_multiplier, output_shift,
2778 output_activation_min, output_activation_max, output_data, output_dims,
2779 im2col_data, im2col_dims, gemmlowp_context);
2780 }
2781
2782 // legacy, for compatibility with old checked-in code
2783 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2784 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
2785 int32 input_offset, const uint8* filter_data,
2786 const Dims<4>& filter_dims, int32 filter_offset,
2787 const int32* bias_data, const Dims<4>& bias_dims,
2788 int stride_width, int stride_height, int pad_width,
2789 int pad_height, int32 output_offset, int32 output_multiplier,
2790 int output_shift, int32 output_activation_min,
2791 int32 output_activation_max, uint8* output_data,
2792 const Dims<4>& output_dims, uint8* im2col_data,
2793 const Dims<4>& im2col_dims,
2794 gemmlowp::GemmContext* gemmlowp_context) {
2795 static_assert(Ac == FusedActivationFunctionType::kNone ||
2796 Ac == FusedActivationFunctionType::kRelu ||
2797 Ac == FusedActivationFunctionType::kRelu6 ||
2798 Ac == FusedActivationFunctionType::kRelu1,
2799 "");
2800 if (Ac == FusedActivationFunctionType::kNone) {
2801 TFLITE_DCHECK_EQ(output_activation_min, 0);
2802 TFLITE_DCHECK_EQ(output_activation_max, 255);
2803 }
2804 Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
2805 filter_offset, bias_data, bias_dims, stride_width, stride_height,
2806 pad_width, pad_height, output_offset, output_multiplier, output_shift,
2807 output_activation_min, output_activation_max, output_data, output_dims,
2808 im2col_data, im2col_dims, gemmlowp_context);
2809 }
2810
2811 // legacy, for compatibility with old checked-in code
2812 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2813 void Conv(const uint8* input_data, const Dims<4>& input_dims,
2814 int32 input_offset, const uint8* filter_data,
2815 const Dims<4>& filter_dims, int32 filter_offset,
2816 const int32* bias_data, const Dims<4>& bias_dims, int stride,
2817 int pad_width, int pad_height, int32 output_offset,
2818 int32 output_multiplier, int output_shift,
2819 int32 output_activation_min, int32 output_activation_max,
2820 uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
2821 const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) {
2822 static_assert(Ac == FusedActivationFunctionType::kNone ||
2823 Ac == FusedActivationFunctionType::kRelu ||
2824 Ac == FusedActivationFunctionType::kRelu6 ||
2825 Ac == FusedActivationFunctionType::kRelu1,
2826 "");
2827 Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
2828 filter_offset, bias_data, bias_dims, stride, stride, pad_width,
2829 pad_height, output_offset, output_multiplier, output_shift,
2830 output_activation_min, output_activation_max, output_data, output_dims,
2831 im2col_data, im2col_dims, gemmlowp_context);
2832 }
2833
2834 // legacy, for compatibility with old checked-in code
2835 template <FusedActivationFunctionType Ac, typename T>
Im2col(const T * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int kheight,int kwidth,uint8 zero_byte,T * output_data,const Dims<4> & output_dims)2836 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
2837 int pad_width, int pad_height, int kheight, int kwidth,
2838 uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
2839 Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
2840 kwidth, zero_byte, output_data, output_dims);
2841 }
2842
2843 // legacy, for compatibility with old checked-in code
2844 template <FusedActivationFunctionType Ac>
ConvAsGemm(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)2845 void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
2846 const float* filter_data, const Dims<4>& filter_dims,
2847 const float* bias_data, const Dims<4>& bias_dims,
2848 float* output_data, const Dims<4>& output_dims) {
2849 ruy::profiler::ScopeLabel label("ConvAsGemm");
2850
2851 const auto input_matrix_map =
2852 MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
2853 const auto filter_matrix_map =
2854 MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
2855 auto output_matrix_map =
2856 MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
2857
2858 Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
2859
2860 AddBiasAndEvalActivationFunction<Ac>(bias_data, bias_dims, output_data,
2861 output_dims);
2862 }
2863
2864 // legacy, for compatibility with old checked-in code
2865 template <FusedActivationFunctionType Ac>
ConvAsGemm(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)2866 void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
2867 int32 input_offset, const uint8* filter_data,
2868 const Dims<4>& filter_dims, int32 filter_offset,
2869 const int32* bias_data, const Dims<4>& bias_dims,
2870 int32 output_offset, int32 output_multiplier, int output_shift,
2871 int32 output_activation_min, int32 output_activation_max,
2872 uint8* output_data, const Dims<4>& output_dims,
2873 gemmlowp::GemmContext* gemmlowp_context) {
2874 ruy::profiler::ScopeLabel label("ConvAsGemm/8bit");
2875 static_assert(Ac == FusedActivationFunctionType::kNone ||
2876 Ac == FusedActivationFunctionType::kRelu ||
2877 Ac == FusedActivationFunctionType::kRelu6 ||
2878 Ac == FusedActivationFunctionType::kRelu1,
2879 "");
2880 const int input_rows = input_dims.sizes[0];
2881 const int input_cols = FlatSizeSkipDim(input_dims, 0);
2882 const int filter_rows = filter_dims.sizes[3];
2883 const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
2884 const int output_rows = output_dims.sizes[0];
2885 const int output_cols = FlatSizeSkipDim(output_dims, 0);
2886 TFLITE_DCHECK_EQ(output_rows, filter_rows);
2887 TFLITE_DCHECK_EQ(output_cols, input_cols);
2888 TFLITE_DCHECK_EQ(filter_cols, input_rows);
2889 TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
2890 TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
2891 TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
2892 TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
2893 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2894 filter_data, output_rows, filter_cols, filter_cols);
2895 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
2896 input_data, filter_cols, output_cols, filter_cols);
2897 gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
2898 output_data, output_rows, output_cols, output_rows);
2899 const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
2900 bias_data, output_rows, output_offset, output_multiplier, -output_shift,
2901 output_activation_min, output_activation_max);
2902 gemmlowp::GemmWithOutputPipeline<uint8, uint8,
2903 gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
2904 gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
2905 filter_offset, input_offset, output_pipeline);
2906 }
2907
TransposeConv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)2908 inline void TransposeConv(
2909 const ConvParams& params, const RuntimeShape& input_shape,
2910 const float* input_data, const RuntimeShape& filter_shape,
2911 const float* filter_data, const RuntimeShape& output_shape,
2912 float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
2913 ruy::profiler::ScopeLabel label("TransposeConv");
2914 // Note we could use transposed weights with forward conv for unstrided
2915 // cases. But we are already getting good performance with this code as-is.
2916 TFLITE_DCHECK(im2col_data);
2917 TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
2918 output_shape, im2col_data);
2919
2920 const auto im2col_matrix_map =
2921 MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
2922 const auto filter_matrix_map =
2923 MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
2924 auto output_matrix_map =
2925 MapAsMatrixWithLastDimAsRows(output_data, output_shape);
2926
2927 Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
2928 }
2929
TransposeConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2930 inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
2931 const float* filter_data, const Dims<4>& filter_dims,
2932 int stride_width, int stride_height, int pad_width,
2933 int pad_height, float* output_data,
2934 const Dims<4>& output_dims, float* im2col_data,
2935 const Dims<4>& im2col_dims) {
2936 tflite::ConvParams op_params;
2937 // Padding type is ignored, but still set.
2938 op_params.padding_type = PaddingType::kSame;
2939 op_params.padding_values.width = pad_width;
2940 op_params.padding_values.height = pad_height;
2941 op_params.stride_width = stride_width;
2942 op_params.stride_height = stride_height;
2943
2944 TransposeConv(op_params, DimsToShape(input_dims), input_data,
2945 DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
2946 output_data, DimsToShape(im2col_dims), im2col_data);
2947 }
2948
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const float * hwoi_ordered_filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)2949 inline void TransposeConvV2(
2950 const ConvParams& params, const RuntimeShape& input_shape,
2951 const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
2952 const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
2953 float* output_data, const RuntimeShape& col2im_shape, float* col2im_data,
2954 CpuBackendContext* cpu_backend_context) {
2955 TransposeConvV2(params, input_shape, input_data, hwoi_ordered_filter_shape,
2956 hwoi_ordered_filter_data, /*bias_shape*/ RuntimeShape(),
2957 /*bias_data*/ nullptr, output_shape, output_data,
2958 col2im_shape, col2im_data, cpu_backend_context);
2959 }
2960
2961 template <typename T>
TransposeIm2col(const T * input_data,const Dims<4> & input_dims,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,const Dims<4> & output_dims,uint8 zero_byte,T * im2col_data)2962 void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
2963 const Dims<4>& filter_dims, int stride_width,
2964 int stride_height, int pad_width, int pad_height,
2965 const Dims<4>& output_dims, uint8 zero_byte,
2966 T* im2col_data) {
2967 tflite::ConvParams op_params;
2968 // Padding type is ignored, but still set.
2969 op_params.padding_type = PaddingType::kSame;
2970 op_params.padding_values.width = pad_width;
2971 op_params.padding_values.height = pad_height;
2972 op_params.stride_width = stride_width;
2973 op_params.stride_height = stride_height;
2974
2975 TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
2976 DimsToShape(filter_dims), DimsToShape(output_dims),
2977 im2col_data);
2978 }
2979
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data)2980 inline void LstmCell(
2981 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2982 const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
2983 const float* prev_activ_data, const RuntimeShape& weights_shape,
2984 const float* weights_data, const RuntimeShape& unextended_bias_shape,
2985 const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
2986 const float* prev_state_data,
2987 const RuntimeShape& unextended_output_state_shape, float* output_state_data,
2988 const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
2989 const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
2990 const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
2991 ruy::profiler::ScopeLabel label("LstmCell");
2992 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2993 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2994 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2995 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2996 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2997 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2998 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2999 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
3000 const RuntimeShape input_shape =
3001 RuntimeShape::ExtendedShape(4, unextended_input_shape);
3002 const RuntimeShape prev_activ_shape =
3003 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
3004 const RuntimeShape bias_shape =
3005 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
3006 const RuntimeShape prev_state_shape =
3007 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
3008 const RuntimeShape output_state_shape =
3009 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
3010 const RuntimeShape output_activ_shape =
3011 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
3012 const RuntimeShape concat_temp_shape =
3013 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
3014 const RuntimeShape activ_temp_shape =
3015 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
3016 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
3017
3018 const int weights_dim_count = weights_shape.DimensionsCount();
3019 MatchingDim( // batches
3020 input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
3021 output_state_shape, 0, output_activ_shape, 0);
3022 MatchingDim( // height
3023 input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
3024 output_state_shape, 1, output_activ_shape, 1);
3025 MatchingDim( // width
3026 input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
3027 output_state_shape, 2, output_activ_shape, 2);
3028 const int input_depth = input_shape.Dims(3);
3029 const int prev_activ_depth = prev_activ_shape.Dims(3);
3030 const int total_input_depth = prev_activ_depth + input_depth;
3031 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
3032 total_input_depth);
3033 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
3034 const int intern_activ_depth =
3035 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
3036 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
3037 intern_activ_depth * total_input_depth);
3038 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
3039 const int output_depth =
3040 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
3041 3, output_activ_shape, 3);
3042 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
3043
3044 // Concatenate prev_activ and input data together
3045 std::vector<float const*> concat_input_arrays_data;
3046 std::vector<RuntimeShape const*> concat_input_arrays_shapes;
3047 concat_input_arrays_data.push_back(input_data);
3048 concat_input_arrays_data.push_back(prev_activ_data);
3049 concat_input_arrays_shapes.push_back(&input_shape);
3050 concat_input_arrays_shapes.push_back(&prev_activ_shape);
3051 tflite::ConcatenationParams concat_params;
3052 concat_params.axis = 3;
3053 concat_params.inputs_count = concat_input_arrays_data.size();
3054 Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
3055 &(concat_input_arrays_data[0]), concat_temp_shape,
3056 concat_temp_data);
3057
3058 // Fully connected
3059 tflite::FullyConnectedParams fc_params;
3060 fc_params.float_activation_min = std::numeric_limits<float>::lowest();
3061 fc_params.float_activation_max = std::numeric_limits<float>::max();
3062 FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
3063 weights_data, bias_shape, bias_data, activ_temp_shape,
3064 activ_temp_data);
3065
3066 // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
3067 // operations.
3068 ArrayMap<float> activ_temp_map =
3069 MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
3070 auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
3071 activ_temp_map.cols());
3072 auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
3073 activ_temp_map.cols());
3074 auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
3075 activ_temp_map.cols());
3076 auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
3077 activ_temp_map.cols());
3078 ArrayMap<const float> prev_state_map =
3079 MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
3080 ArrayMap<float> output_state_map =
3081 MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
3082 ArrayMap<float> output_activ_map =
3083 MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
3084
3085 // Combined memory state and final output calculation
3086 ruy::profiler::ScopeLabel label2("MemoryStateAndFinalOutput");
3087 output_state_map =
3088 input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3089 new_input_sm.tanh() +
3090 forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3091 prev_state_map;
3092 output_activ_map =
3093 output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3094 output_state_map.tanh();
3095 }
3096
LstmCell(const float * input_data,const Dims<4> & input_dims,const float * prev_activ_data,const Dims<4> & prev_activ_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,const float * prev_state_data,const Dims<4> & prev_state_dims,float * output_state_data,const Dims<4> & output_state_dims,float * output_activ_data,const Dims<4> & output_activ_dims,float * concat_temp_data,const Dims<4> & concat_temp_dims,float * activ_temp_data,const Dims<4> & activ_temp_dims)3097 inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
3098 const float* prev_activ_data,
3099 const Dims<4>& prev_activ_dims, const float* weights_data,
3100 const Dims<4>& weights_dims, const float* bias_data,
3101 const Dims<4>& bias_dims, const float* prev_state_data,
3102 const Dims<4>& prev_state_dims, float* output_state_data,
3103 const Dims<4>& output_state_dims, float* output_activ_data,
3104 const Dims<4>& output_activ_dims, float* concat_temp_data,
3105 const Dims<4>& concat_temp_dims, float* activ_temp_data,
3106 const Dims<4>& activ_temp_dims) {
3107 tflite::LstmCellParams op_params;
3108 // Float LSTM cell does not need parameters to be set: leave untouched.
3109
3110 LstmCell(op_params, DimsToShape(input_dims), input_data,
3111 DimsToShape(prev_activ_dims), prev_activ_data,
3112 DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
3113 bias_data, DimsToShape(prev_state_dims), prev_state_data,
3114 DimsToShape(output_state_dims), output_state_data,
3115 DimsToShape(output_activ_dims), output_activ_data,
3116 DimsToShape(concat_temp_dims), concat_temp_data,
3117 DimsToShape(activ_temp_dims), activ_temp_data);
3118 }
3119
3120 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,gemmlowp::GemmContext * gemmlowp_context)3121 inline void LstmCell(
3122 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
3123 const uint8* input_data_uint8,
3124 const RuntimeShape& unextended_prev_activ_shape,
3125 const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
3126 const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
3127 const int32* bias_data_int32,
3128 const RuntimeShape& unextended_prev_state_shape,
3129 const int16* prev_state_data_int16,
3130 const RuntimeShape& unextended_output_state_shape,
3131 int16* output_state_data_int16,
3132 const RuntimeShape& unextended_output_activ_shape,
3133 uint8* output_activ_data_uint8,
3134 const RuntimeShape& unextended_concat_temp_shape,
3135 uint8* concat_temp_data_uint8,
3136 const RuntimeShape& unextended_activ_temp_shape,
3137 int16* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) {
3138 ruy::profiler::ScopeLabel label(
3139 "LstmCell/quantized (8bit external, 16bit internal)");
3140 int32 weights_zero_point = params.weights_zero_point;
3141 int32 accum_multiplier = params.accum_multiplier;
3142 int accum_shift = params.accum_shift;
3143 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
3144 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
3145 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
3146 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
3147 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
3148 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
3149 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
3150 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
3151 const RuntimeShape input_shape =
3152 RuntimeShape::ExtendedShape(4, unextended_input_shape);
3153 const RuntimeShape prev_activ_shape =
3154 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
3155 const RuntimeShape bias_shape =
3156 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
3157 const RuntimeShape prev_state_shape =
3158 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
3159 const RuntimeShape output_state_shape =
3160 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
3161 const RuntimeShape output_activ_shape =
3162 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
3163 const RuntimeShape concat_temp_shape =
3164 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
3165 const RuntimeShape activ_temp_shape =
3166 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
3167 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
3168
3169 // Gather dimensions information, and perform consistency checks.
3170 const int weights_dim_count = weights_shape.DimensionsCount();
3171 const int outer_size = MatchingFlatSizeSkipDim(
3172 input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
3173 output_activ_shape);
3174 const int input_depth = input_shape.Dims(3);
3175 const int prev_activ_depth = prev_activ_shape.Dims(3);
3176 const int total_input_depth = prev_activ_depth + input_depth;
3177 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
3178 total_input_depth);
3179 const int intern_activ_depth =
3180 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
3181 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
3182 intern_activ_depth * total_input_depth);
3183 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
3184 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
3185 const int output_depth =
3186 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
3187 3, output_activ_shape, 3);
3188 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
3189 const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
3190 const int fc_output_depth =
3191 MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
3192 const int fc_accum_depth = total_input_depth;
3193 TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
3194
3195 // Depth-concatenate prev_activ and input data together.
3196 uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
3197 prev_activ_data_uint8};
3198 const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
3199 &prev_activ_shape};
3200 tflite::ConcatenationParams concat_params;
3201 concat_params.axis = 3;
3202 concat_params.inputs_count = 2;
3203 Concatenation(concat_params, concat_input_arrays_shapes,
3204 concat_input_arrays_data, concat_temp_shape,
3205 concat_temp_data_uint8);
3206
3207 // Implementation of the fully connected node inside the LSTM cell.
3208 // The operands are 8-bit integers, the accumulators are internally 32bit
3209 // integers, and the output is 16-bit fixed-point with 3 integer bits so
3210 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
3211 // is explained in the function comment above.
3212 bool gemm_already_performed = false;
3213 #ifdef GEMMLOWP_NEON
3214 if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) {
3215 GEMVForLstmCell(concat_temp_shape, concat_temp_data_uint8, weights_shape,
3216 weights_data_uint8, weights_zero_point, bias_shape,
3217 bias_data_int32, accum_multiplier, accum_shift,
3218 activ_temp_shape, activ_temp_data_int16);
3219 gemm_already_performed = true;
3220 }
3221 #endif
3222 if (!gemm_already_performed) {
3223 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor>
3224 weights_matrix(weights_data_uint8, fc_output_depth, fc_accum_depth);
3225 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
3226 concat_temp_data_uint8, fc_accum_depth, fc_batches);
3227 gemmlowp::MatrixMap<int16, gemmlowp::MapOrder::ColMajor> output_matrix(
3228 activ_temp_data_int16, fc_output_depth, fc_batches);
3229 typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
3230 ColVectorMap;
3231 ColVectorMap bias_vector(bias_data_int32, fc_output_depth);
3232 gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
3233 bias_addition_stage.bias_vector = bias_vector;
3234 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
3235 scale_stage.result_offset_after_shift = 0;
3236 scale_stage.result_fixedpoint_multiplier = accum_multiplier;
3237 scale_stage.result_exponent = accum_shift;
3238 gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
3239 auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
3240 saturating_cast_int16_stage);
3241 gemmlowp::GemmWithOutputPipeline<
3242 uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
3243 gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
3244 -weights_zero_point, -128, output_pipeline);
3245 }
3246
3247 // Rest of the LSTM cell: tanh and logistic math functions, and some adds
3248 // and muls, all done in 16-bit fixed-point.
3249 const int16* input_gate_input_ptr = activ_temp_data_int16;
3250 const int16* input_modulation_gate_input_ptr =
3251 activ_temp_data_int16 + output_depth;
3252 const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth;
3253 const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth;
3254 const int16* prev_state_ptr = prev_state_data_int16;
3255 int16* output_state_data_ptr = output_state_data_int16;
3256 uint8* output_activ_data_ptr = output_activ_data_uint8;
3257
3258 for (int b = 0; b < outer_size; ++b) {
3259 int c = 0;
3260 #ifdef GEMMLOWP_NEON
3261 for (; c <= output_depth - 8; c += 8) {
3262 // Define the fixed-point data types that we will use here. All use
3263 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3264 // They only differ by the number of integral vs. fractional bits,
3265 // determining the range of values that they can represent.
3266 //
3267 // F0 uses 0 integer bits, range [-1, 1].
3268 // This is the return type of math functions such as tanh, logistic,
3269 // whose range is in [-1, 1].
3270 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
3271 // F3 uses 3 integer bits, range [-8, 8].
3272 // This is the range of the previous fully-connected node's output,
3273 // which is our input here.
3274 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
3275 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3276 // 2^StateIntegerBits]. It's used to represent the internal state, whose
3277 // number of integer bits is currently dictated by the model. See comment
3278 // on the StateIntegerBits template parameter above.
3279 using FS = gemmlowp::FixedPoint<int16x8_t, StateIntegerBits>;
3280 // Implementation of input gate, using fixed-point logistic function.
3281 F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr));
3282 input_gate_input_ptr += 8;
3283 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3284 // Implementation of input modulation gate, using fixed-point tanh
3285 // function.
3286 F3 input_modulation_gate_input =
3287 F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr));
3288 input_modulation_gate_input_ptr += 8;
3289 F0 input_modulation_gate_output =
3290 gemmlowp::tanh(input_modulation_gate_input);
3291 // Implementation of forget gate, using fixed-point logistic function.
3292 F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr));
3293 forget_gate_input_ptr += 8;
3294 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3295 // Implementation of output gate, using fixed-point logistic function.
3296 F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr));
3297 output_gate_input_ptr += 8;
3298 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3299 // Implementation of internal multiplication nodes, still in fixed-point.
3300 F0 input_times_input_modulation =
3301 input_gate_output * input_modulation_gate_output;
3302 FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr));
3303 prev_state_ptr += 8;
3304 FS prev_state_times_forget_state = forget_gate_output * prev_state;
3305 // Implementation of internal addition node, saturating.
3306 FS new_state = gemmlowp::SaturatingAdd(
3307 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3308 prev_state_times_forget_state);
3309 // Implementation of last internal Tanh node, still in fixed-point.
3310 // Since a Tanh fixed-point implementation is specialized for a given
3311 // number or integer bits, and each specialization can have a substantial
3312 // code size, and we already used above a Tanh on an input with 3 integer
3313 // bits, and per the table in the above function comment there is no
3314 // significant accuracy to be lost by clamping to [-8, +8] for a
3315 // 3-integer-bits representation, let us just do that. This helps people
3316 // porting this to targets where code footprint must be minimized.
3317 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3318 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3319 // Store the new internal state back to memory, as 16-bit integers.
3320 // Note: here we store the original value with StateIntegerBits, not
3321 // the rescaled 3-integer-bits value fed to tanh.
3322 vst1q_s16(output_state_data_ptr, new_state.raw());
3323 output_state_data_ptr += 8;
3324 // Down-scale the output activations to 8-bit integers, saturating,
3325 // and store back to memory.
3326 int16x8_t rescaled_output_activ =
3327 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3328 int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ);
3329 uint8x8_t uint8_output_activ =
3330 vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ));
3331 vst1_u8(output_activ_data_ptr, uint8_output_activ);
3332 output_activ_data_ptr += 8;
3333 }
3334 #endif
3335 for (; c < output_depth; ++c) {
3336 // Define the fixed-point data types that we will use here. All use
3337 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3338 // They only differ by the number of integral vs. fractional bits,
3339 // determining the range of values that they can represent.
3340 //
3341 // F0 uses 0 integer bits, range [-1, 1].
3342 // This is the return type of math functions such as tanh, logistic,
3343 // whose range is in [-1, 1].
3344 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3345 // F3 uses 3 integer bits, range [-8, 8].
3346 // This is the range of the previous fully-connected node's output,
3347 // which is our input here.
3348 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
3349 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3350 // 2^StateIntegerBits]. It's used to represent the internal state, whose
3351 // number of integer bits is currently dictated by the model. See comment
3352 // on the StateIntegerBits template parameter above.
3353 using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
3354 // Implementation of input gate, using fixed-point logistic function.
3355 F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++);
3356 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3357 // Implementation of input modulation gate, using fixed-point tanh
3358 // function.
3359 F3 input_modulation_gate_input =
3360 F3::FromRaw(*input_modulation_gate_input_ptr++);
3361 F0 input_modulation_gate_output =
3362 gemmlowp::tanh(input_modulation_gate_input);
3363 // Implementation of forget gate, using fixed-point logistic function.
3364 F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++);
3365 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3366 // Implementation of output gate, using fixed-point logistic function.
3367 F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++);
3368 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3369 // Implementation of internal multiplication nodes, still in fixed-point.
3370 F0 input_times_input_modulation =
3371 input_gate_output * input_modulation_gate_output;
3372 FS prev_state = FS::FromRaw(*prev_state_ptr++);
3373 FS prev_state_times_forget_state = forget_gate_output * prev_state;
3374 // Implementation of internal addition node, saturating.
3375 FS new_state = gemmlowp::SaturatingAdd(
3376 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3377 prev_state_times_forget_state);
3378 // Implementation of last internal Tanh node, still in fixed-point.
3379 // Since a Tanh fixed-point implementation is specialized for a given
3380 // number or integer bits, and each specialization can have a substantial
3381 // code size, and we already used above a Tanh on an input with 3 integer
3382 // bits, and per the table in the above function comment there is no
3383 // significant accuracy to be lost by clamping to [-8, +8] for a
3384 // 3-integer-bits representation, let us just do that. This helps people
3385 // porting this to targets where code footprint must be minimized.
3386 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3387 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3388 // Store the new internal state back to memory, as 16-bit integers.
3389 // Note: here we store the original value with StateIntegerBits, not
3390 // the rescaled 3-integer-bits value fed to tanh.
3391 *output_state_data_ptr++ = new_state.raw();
3392 // Down-scale the output activations to 8-bit integers, saturating,
3393 // and store back to memory.
3394 int16 rescaled_output_activ =
3395 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3396 int16 clamped_output_activ =
3397 std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
3398 *output_activ_data_ptr++ = 128 + clamped_output_activ;
3399 }
3400 input_gate_input_ptr += 3 * output_depth;
3401 input_modulation_gate_input_ptr += 3 * output_depth;
3402 forget_gate_input_ptr += 3 * output_depth;
3403 output_gate_input_ptr += 3 * output_depth;
3404 }
3405 }
3406
3407 template <int StateIntegerBits>
LstmCell(const uint8 * input_data_uint8,const Dims<4> & input_dims,const uint8 * prev_activ_data_uint8,const Dims<4> & prev_activ_dims,const uint8 * weights_data_uint8,const Dims<4> & weights_dims,const int32 * bias_data_int32,const Dims<4> & bias_dims,const int16 * prev_state_data_int16,const Dims<4> & prev_state_dims,int16 * output_state_data_int16,const Dims<4> & output_state_dims,uint8 * output_activ_data_uint8,const Dims<4> & output_activ_dims,uint8 * concat_temp_data_uint8,const Dims<4> & concat_temp_dims,int16 * activ_temp_data_int16,const Dims<4> & activ_temp_dims,int32 weights_zero_point,int32 accum_multiplier,int accum_shift,gemmlowp::GemmContext * gemmlowp_context)3408 void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
3409 const uint8* prev_activ_data_uint8,
3410 const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
3411 const Dims<4>& weights_dims, const int32* bias_data_int32,
3412 const Dims<4>& bias_dims, const int16* prev_state_data_int16,
3413 const Dims<4>& prev_state_dims, int16* output_state_data_int16,
3414 const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
3415 const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
3416 const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
3417 const Dims<4>& activ_temp_dims, int32 weights_zero_point,
3418 int32 accum_multiplier, int accum_shift,
3419 gemmlowp::GemmContext* gemmlowp_context) {
3420 tflite::LstmCellParams op_params;
3421 op_params.weights_zero_point = weights_zero_point;
3422 op_params.accum_multiplier = accum_multiplier;
3423 op_params.accum_shift = accum_shift;
3424
3425 LstmCell<StateIntegerBits>(
3426 op_params, DimsToShape(input_dims), input_data_uint8,
3427 DimsToShape(prev_activ_dims), prev_activ_data_uint8,
3428 DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
3429 bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
3430 DimsToShape(output_state_dims), output_state_data_int16,
3431 DimsToShape(output_activ_dims), output_activ_data_uint8,
3432 DimsToShape(concat_temp_dims), concat_temp_data_uint8,
3433 DimsToShape(activ_temp_dims), activ_temp_data_int16, gemmlowp_context);
3434 }
3435
3436 template <typename T>
BroadcastDiv(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)3437 void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
3438 const T* input2_data, const Dims<4>& input2_dims,
3439 T output_activation_min, T output_activation_max,
3440 T* output_data, const Dims<4>& output_dims) {
3441 tflite::ArithmeticParams op_params;
3442 SetActivationParams(output_activation_min, output_activation_max, &op_params);
3443
3444 BroadcastDivSlow(op_params, DimsToShape(input1_dims), input1_data,
3445 DimsToShape(input2_dims), input2_data,
3446 DimsToShape(output_dims), output_data);
3447 }
3448
3449 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)3450 void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
3451 float* output_data, const RuntimeShape& output_shape) {
3452 static_assert(Ac == FusedActivationFunctionType::kNone, "");
3453 tflite::L2NormalizationParams op_params;
3454 // No params need to be set for float, but reserved in signature for future
3455 // activations.
3456
3457 L2Normalization(op_params, input_shape, input_data, output_shape,
3458 output_data);
3459 }
3460
L2Normalization(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,uint8 * output_data,const RuntimeShape & output_shape)3461 inline void L2Normalization(const uint8* input_data,
3462 const RuntimeShape& input_shape,
3463 int32 input_zero_point, uint8* output_data,
3464 const RuntimeShape& output_shape) {
3465 tflite::L2NormalizationParams op_params;
3466 op_params.input_zero_point = input_zero_point;
3467
3468 L2Normalization(op_params, input_shape, input_data, output_shape,
3469 output_data);
3470 }
3471
3472 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)3473 void L2Normalization(const float* input_data, const Dims<4>& input_dims,
3474 float* output_data, const Dims<4>& output_dims) {
3475 L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
3476 DimsToShape(output_dims));
3477 }
3478
L2Normalization(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,uint8 * output_data,const Dims<4> & output_dims)3479 inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
3480 int32 input_zero_point, uint8* output_data,
3481 const Dims<4>& output_dims) {
3482 L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
3483 output_data, DimsToShape(output_dims));
3484 }
3485
Relu(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)3486 inline void Relu(const float* input_data, const Dims<4>& input_dims,
3487 float* output_data, const Dims<4>& output_dims) {
3488 Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3489 output_data);
3490 }
3491
3492 // legacy, for compatibility with old checked-in code
3493 template <FusedActivationFunctionType Ac>
Add(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)3494 void Add(const float* input1_data, const Dims<4>& input1_dims,
3495 const float* input2_data, const Dims<4>& input2_dims,
3496 float* output_data, const Dims<4>& output_dims) {
3497 float output_activation_min, output_activation_max;
3498 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3499
3500 tflite::ArithmeticParams op_params;
3501 op_params.float_activation_min = output_activation_min;
3502 op_params.float_activation_max = output_activation_max;
3503 Add(op_params, DimsToShape(input1_dims), input1_data,
3504 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3505 output_data);
3506 }
3507
3508 template <FusedActivationFunctionType Ac>
Add(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3509 inline void Add(int left_shift, const uint8* input1_data,
3510 const Dims<4>& input1_dims, int32 input1_offset,
3511 int32 input1_multiplier, int input1_shift,
3512 const uint8* input2_data, const Dims<4>& input2_dims,
3513 int32 input2_offset, int32 input2_multiplier, int input2_shift,
3514 int32 output_offset, int32 output_multiplier, int output_shift,
3515 int32 output_activation_min, int32 output_activation_max,
3516 uint8* output_data, const Dims<4>& output_dims) {
3517 constexpr int kReverseShift = -1;
3518 static_assert(Ac == FusedActivationFunctionType::kNone ||
3519 Ac == FusedActivationFunctionType::kRelu ||
3520 Ac == FusedActivationFunctionType::kRelu6 ||
3521 Ac == FusedActivationFunctionType::kRelu1,
3522 "");
3523 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3524 if (Ac == FusedActivationFunctionType::kNone) {
3525 TFLITE_DCHECK_EQ(output_activation_min, 0);
3526 TFLITE_DCHECK_EQ(output_activation_max, 255);
3527 }
3528
3529 tflite::ArithmeticParams op_params;
3530 op_params.left_shift = left_shift;
3531 op_params.input1_offset = input1_offset;
3532 op_params.input1_multiplier = input1_multiplier;
3533 op_params.input1_shift = kReverseShift * input1_shift;
3534 op_params.input2_offset = input2_offset;
3535 op_params.input2_multiplier = input2_multiplier;
3536 op_params.input2_shift = kReverseShift * input2_shift;
3537 op_params.output_offset = output_offset;
3538 op_params.output_multiplier = output_multiplier;
3539 op_params.output_shift = kReverseShift * output_shift;
3540 op_params.quantized_activation_min = output_activation_min;
3541 op_params.quantized_activation_max = output_activation_max;
3542 Add(op_params, DimsToShape(input1_dims), input1_data,
3543 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3544 output_data);
3545 }
3546
3547 template <FusedActivationFunctionType Ac>
Add(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)3548 void Add(const int32* input1_data, const Dims<4>& input1_dims,
3549 const int32* input2_data, const Dims<4>& input2_dims,
3550 int32* output_data, const Dims<4>& output_dims) {
3551 ruy::profiler::ScopeLabel label("Add/int32");
3552 TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
3553
3554 tflite::ArithmeticParams op_params;
3555 op_params.quantized_activation_min = std::numeric_limits<int32>::min();
3556 op_params.quantized_activation_max = std::numeric_limits<int32>::max();
3557 Add(op_params, DimsToShape(input1_dims), input1_data,
3558 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3559 output_data);
3560 }
3561
3562 template <typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)3563 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
3564 const T* input2_data, const Dims<4>& input2_dims,
3565 T output_activation_min, T output_activation_max,
3566 T* output_data, const Dims<4>& output_dims) {
3567 tflite::ArithmeticParams op_params;
3568 op_params.float_activation_min = output_activation_min;
3569 op_params.float_activation_max = output_activation_max;
3570 BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
3571 DimsToShape(input2_dims), input2_data,
3572 DimsToShape(output_dims), output_data);
3573 }
3574
3575 template <FusedActivationFunctionType Ac>
BroadcastAdd(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3576 inline void BroadcastAdd(int left_shift, const uint8* input1_data,
3577 const Dims<4>& input1_dims, int32 input1_offset,
3578 int32 input1_multiplier, int input1_shift,
3579 const uint8* input2_data, const Dims<4>& input2_dims,
3580 int32 input2_offset, int32 input2_multiplier,
3581 int input2_shift, int32 output_offset,
3582 int32 output_multiplier, int output_shift,
3583 int32 output_activation_min,
3584 int32 output_activation_max, uint8* output_data,
3585 const Dims<4>& output_dims) {
3586 constexpr int kReverseShift = -1;
3587 static_assert(Ac == FusedActivationFunctionType::kNone ||
3588 Ac == FusedActivationFunctionType::kRelu ||
3589 Ac == FusedActivationFunctionType::kRelu6 ||
3590 Ac == FusedActivationFunctionType::kRelu1,
3591 "");
3592 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3593 if (Ac == FusedActivationFunctionType::kNone) {
3594 TFLITE_DCHECK_EQ(output_activation_min, 0);
3595 TFLITE_DCHECK_EQ(output_activation_max, 255);
3596 }
3597
3598 tflite::ArithmeticParams op_params;
3599 op_params.left_shift = left_shift;
3600 op_params.input1_offset = input1_offset;
3601 op_params.input1_multiplier = input1_multiplier;
3602 op_params.input1_shift = kReverseShift * input1_shift;
3603 op_params.input2_offset = input2_offset;
3604 op_params.input2_multiplier = input2_multiplier;
3605 op_params.input2_shift = kReverseShift * input2_shift;
3606 op_params.output_offset = output_offset;
3607 op_params.output_multiplier = output_multiplier;
3608 op_params.output_shift = kReverseShift * output_shift;
3609 op_params.quantized_activation_min = output_activation_min;
3610 op_params.quantized_activation_max = output_activation_max;
3611 BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
3612 DimsToShape(input2_dims), input2_data,
3613 DimsToShape(output_dims), output_data);
3614 }
3615
3616 template <FusedActivationFunctionType Ac>
BroadcastAddFivefold(int y0,int y1,int y2,int y3,int y4,int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3617 inline void BroadcastAddFivefold(
3618 int y0, int y1, int y2, int y3, int y4, int left_shift,
3619 const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
3620 int32 input1_multiplier, int input1_shift, const uint8* input2_data,
3621 const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
3622 int input2_shift, int32 output_offset, int32 output_multiplier,
3623 int output_shift, int32 output_activation_min, int32 output_activation_max,
3624 uint8* output_data, const Dims<4>& output_dims) {
3625 constexpr int kReverseShift = -1;
3626 static_assert(Ac == FusedActivationFunctionType::kNone ||
3627 Ac == FusedActivationFunctionType::kRelu ||
3628 Ac == FusedActivationFunctionType::kRelu6 ||
3629 Ac == FusedActivationFunctionType::kRelu1,
3630 "");
3631 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3632 if (Ac == FusedActivationFunctionType::kNone) {
3633 TFLITE_DCHECK_EQ(output_activation_min, 0);
3634 TFLITE_DCHECK_EQ(output_activation_max, 255);
3635 }
3636 tflite::ArithmeticParams op_params;
3637 op_params.broadcast_category =
3638 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
3639 op_params.left_shift = left_shift;
3640 op_params.input1_offset = input1_offset;
3641 op_params.input1_multiplier = input1_multiplier;
3642 op_params.input1_shift = kReverseShift * input1_shift;
3643 op_params.input2_offset = input2_offset;
3644 op_params.input2_multiplier = input2_multiplier;
3645 op_params.input2_shift = kReverseShift * input2_shift;
3646 op_params.output_offset = output_offset;
3647 op_params.output_multiplier = output_multiplier;
3648 op_params.output_shift = kReverseShift * output_shift;
3649 op_params.quantized_activation_min = output_activation_min;
3650 op_params.quantized_activation_max = output_activation_max;
3651 op_params.broadcast_shape[4] = y0;
3652 op_params.broadcast_shape[3] = y1;
3653 op_params.broadcast_shape[2] = y2;
3654 op_params.broadcast_shape[1] = y3;
3655 op_params.broadcast_shape[0] = y4;
3656 BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
3657 DimsToShape(input2_dims), input2_data,
3658 DimsToShape(output_dims), output_data);
3659 }
3660
3661 // legacy, for compatibility with old checked-in code
3662 template <FusedActivationFunctionType Ac, typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)3663 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
3664 const T* input2_data, const Dims<4>& input2_dims,
3665 T* output_data, const Dims<4>& output_dims) {
3666 T output_activation_min, output_activation_max;
3667 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3668
3669 BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
3670 output_activation_min, output_activation_max, output_data,
3671 output_dims);
3672 }
3673
3674 template <FusedActivationFunctionType Ac>
Add(const int16 * input1_data,const Dims<4> & input1_dims,int input1_shift,const int16 * input2_data,const Dims<4> & input2_dims,int input2_shift,int16 output_activation_min,int16 output_activation_max,int16 * output_data,const Dims<4> & output_dims)3675 inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
3676 int input1_shift, const int16* input2_data,
3677 const Dims<4>& input2_dims, int input2_shift,
3678 int16 output_activation_min, int16 output_activation_max,
3679 int16* output_data, const Dims<4>& output_dims) {
3680 constexpr int kReverseShift = -1;
3681 static_assert(Ac == FusedActivationFunctionType::kNone ||
3682 Ac == FusedActivationFunctionType::kRelu ||
3683 Ac == FusedActivationFunctionType::kRelu6 ||
3684 Ac == FusedActivationFunctionType::kRelu1,
3685 "");
3686 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3687 if (Ac == FusedActivationFunctionType::kNone) {
3688 TFLITE_DCHECK_EQ(output_activation_min, -32768);
3689 TFLITE_DCHECK_EQ(output_activation_max, 32767);
3690 }
3691
3692 tflite::ArithmeticParams op_params;
3693 op_params.input1_shift = kReverseShift * input1_shift;
3694 op_params.input2_shift = kReverseShift * input2_shift;
3695 op_params.quantized_activation_min = output_activation_min;
3696 op_params.quantized_activation_max = output_activation_max;
3697 Add(op_params, DimsToShape(input1_dims), input1_data,
3698 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3699 output_data);
3700 }
3701
Sub(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)3702 inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
3703 const float* input2_data, const Dims<4>& input2_dims,
3704 float* output_data, const Dims<4>& output_dims) {
3705 float output_activation_min, output_activation_max;
3706 GetActivationMinMax(FusedActivationFunctionType::kNone,
3707 &output_activation_min, &output_activation_max);
3708 tflite::ArithmeticParams op_params;
3709 op_params.float_activation_min = output_activation_min;
3710 op_params.float_activation_max = output_activation_max;
3711 Sub(op_params, DimsToShape(input1_dims), input1_data,
3712 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3713 output_data);
3714 }
3715
3716 template <typename T>
Sub(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)3717 void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
3718 const Dims<4>& input2_dims, T* output_data,
3719 const Dims<4>& output_dims) {
3720 T output_activation_min, output_activation_max;
3721 GetActivationMinMax(FusedActivationFunctionType::kNone,
3722 &output_activation_min, &output_activation_max);
3723 tflite::ArithmeticParams op_params;
3724 op_params.quantized_activation_min = output_activation_min;
3725 op_params.quantized_activation_max = output_activation_max;
3726 Sub(op_params, DimsToShape(input1_dims), input1_data,
3727 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3728 output_data);
3729 }
3730
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3731 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
3732 int32 input1_offset, const uint8* input2_data,
3733 const Dims<4>& input2_dims, int32 input2_offset,
3734 int32 output_offset, int32 output_multiplier,
3735 int output_shift, int32 output_activation_min,
3736 int32 output_activation_max, uint8* output_data,
3737 const Dims<4>& output_dims) {
3738 tflite::ArithmeticParams op_params;
3739 SetActivationParams(output_activation_min, output_activation_max, &op_params);
3740 op_params.input1_offset = input1_offset;
3741 op_params.input2_offset = input2_offset;
3742 op_params.output_offset = output_offset;
3743 op_params.output_multiplier = output_multiplier;
3744 op_params.output_shift = kReverseShift * output_shift;
3745
3746 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
3747 DimsToShape(input2_dims), input2_data,
3748 DimsToShape(output_dims), output_data);
3749 }
3750
3751 // legacy, for compatibility with old checked-in code
3752 template <FusedActivationFunctionType Ac>
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3753 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
3754 int32 input1_offset, const uint8* input2_data,
3755 const Dims<4>& input2_dims, int32 input2_offset,
3756 int32 output_offset, int32 output_multiplier,
3757 int output_shift, int32 output_activation_min,
3758 int32 output_activation_max, uint8* output_data,
3759 const Dims<4>& output_dims) {
3760 BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
3761 input2_dims, input2_offset, output_offset, output_multiplier,
3762 output_shift, output_activation_min, output_activation_max,
3763 output_data, output_dims);
3764 }
3765
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)3766 inline bool AveragePool(const float* input_data, const Dims<4>& input_dims,
3767 int stride_width, int stride_height, int pad_width,
3768 int pad_height, int kwidth, int kheight,
3769 float output_activation_min,
3770 float output_activation_max, float* output_data,
3771 const Dims<4>& output_dims) {
3772 tflite::PoolParams params;
3773 params.stride_height = stride_height;
3774 params.stride_width = stride_width;
3775 params.filter_height = kheight;
3776 params.filter_width = kwidth;
3777 params.padding_values.height = pad_height;
3778 params.padding_values.width = pad_width;
3779 params.float_activation_min = output_activation_min;
3780 params.float_activation_max = output_activation_max;
3781 return AveragePool(params, DimsToShape(input_dims), input_data,
3782 DimsToShape(output_dims), output_data);
3783 }
3784
3785 // legacy, for compatibility with old checked-in code
3786 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)3787 bool AveragePool(const float* input_data, const Dims<4>& input_dims,
3788 int stride_width, int stride_height, int pad_width,
3789 int pad_height, int kwidth, int kheight, float* output_data,
3790 const Dims<4>& output_dims) {
3791 float output_activation_min, output_activation_max;
3792 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3793
3794 return AveragePool(input_data, input_dims, stride_width, stride_height,
3795 pad_width, pad_height, kwidth, kheight,
3796 output_activation_min, output_activation_max, output_data,
3797 output_dims);
3798 }
3799
3800 // legacy, for compatibility with old checked-in code
3801 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3802 bool AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
3803 int pad_width, int pad_height, int filter_width,
3804 int filter_height, float* output_data,
3805 const Dims<4>& output_dims) {
3806 return AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width,
3807 pad_height, filter_width, filter_height, output_data,
3808 output_dims);
3809 }
3810
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3811 inline bool AveragePool(const uint8* input_data, const Dims<4>& input_dims,
3812 int stride_width, int stride_height, int pad_width,
3813 int pad_height, int filter_width, int filter_height,
3814 int32 output_activation_min,
3815 int32 output_activation_max, uint8* output_data,
3816 const Dims<4>& output_dims) {
3817 tflite::PoolParams params;
3818 params.stride_height = stride_height;
3819 params.stride_width = stride_width;
3820 params.filter_height = filter_height;
3821 params.filter_width = filter_width;
3822 params.padding_values.height = pad_height;
3823 params.padding_values.width = pad_width;
3824 params.quantized_activation_min = output_activation_min;
3825 params.quantized_activation_max = output_activation_max;
3826 return AveragePool(params, DimsToShape(input_dims), input_data,
3827 DimsToShape(output_dims), output_data);
3828 }
3829
3830 // legacy, for compatibility with old checked-in code
3831 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3832 bool AveragePool(const uint8* input_data, const Dims<4>& input_dims,
3833 int stride_width, int stride_height, int pad_width,
3834 int pad_height, int filter_width, int filter_height,
3835 int32 output_activation_min, int32 output_activation_max,
3836 uint8* output_data, const Dims<4>& output_dims) {
3837 static_assert(Ac == FusedActivationFunctionType::kNone ||
3838 Ac == FusedActivationFunctionType::kRelu ||
3839 Ac == FusedActivationFunctionType::kRelu6 ||
3840 Ac == FusedActivationFunctionType::kRelu1,
3841 "");
3842 if (Ac == FusedActivationFunctionType::kNone) {
3843 TFLITE_DCHECK_EQ(output_activation_min, 0);
3844 TFLITE_DCHECK_EQ(output_activation_max, 255);
3845 }
3846 return AveragePool(input_data, input_dims, stride_width, stride_height,
3847 pad_width, pad_height, filter_width, filter_height,
3848 output_activation_min, output_activation_max, output_data,
3849 output_dims);
3850 }
3851
3852 // legacy, for compatibility with old checked-in code
3853 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3854 bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
3855 int pad_width, int pad_height, int filter_width,
3856 int filter_height, int32 output_activation_min,
3857 int32 output_activation_max, uint8* output_data,
3858 const Dims<4>& output_dims) {
3859 return AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width,
3860 pad_height, filter_width, filter_height,
3861 output_activation_min, output_activation_max,
3862 output_data, output_dims);
3863 }
3864
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)3865 inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
3866 int stride_width, int stride_height, int pad_width,
3867 int pad_height, int kwidth, int kheight,
3868 float output_activation_min, float output_activation_max,
3869 float* output_data, const Dims<4>& output_dims) {
3870 tflite::PoolParams params;
3871 params.stride_height = stride_height;
3872 params.stride_width = stride_width;
3873 params.filter_height = kheight;
3874 params.filter_width = kwidth;
3875 params.padding_values.height = pad_height;
3876 params.padding_values.width = pad_width;
3877 params.float_activation_min = output_activation_min;
3878 params.float_activation_max = output_activation_max;
3879 MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3880 output_data);
3881 }
3882
3883 // legacy, for compatibility with old checked-in code
3884 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)3885 void MaxPool(const float* input_data, const Dims<4>& input_dims,
3886 int stride_width, int stride_height, int pad_width, int pad_height,
3887 int kwidth, int kheight, float* output_data,
3888 const Dims<4>& output_dims) {
3889 float output_activation_min, output_activation_max;
3890 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3891 MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
3892 pad_height, kwidth, kheight, output_activation_min,
3893 output_activation_max, output_data, output_dims);
3894 }
3895
3896 // legacy, for compatibility with old checked-in code
3897 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3898 void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
3899 int pad_width, int pad_height, int filter_width, int filter_height,
3900 float* output_data, const Dims<4>& output_dims) {
3901 MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3902 filter_width, filter_height, output_data, output_dims);
3903 }
3904
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3905 inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
3906 int stride_width, int stride_height, int pad_width,
3907 int pad_height, int filter_width, int filter_height,
3908 int32 output_activation_min, int32 output_activation_max,
3909 uint8* output_data, const Dims<4>& output_dims) {
3910 PoolParams params;
3911 params.stride_height = stride_height;
3912 params.stride_width = stride_width;
3913 params.filter_height = filter_height;
3914 params.filter_width = filter_width;
3915 params.padding_values.height = pad_height;
3916 params.padding_values.width = pad_width;
3917 params.quantized_activation_min = output_activation_min;
3918 params.quantized_activation_max = output_activation_max;
3919 MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3920 output_data);
3921 }
3922
3923 // legacy, for compatibility with old checked-in code
3924 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3925 void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
3926 int stride_width, int stride_height, int pad_width, int pad_height,
3927 int filter_width, int filter_height, int32 output_activation_min,
3928 int32 output_activation_max, uint8* output_data,
3929 const Dims<4>& output_dims) {
3930 static_assert(Ac == FusedActivationFunctionType::kNone ||
3931 Ac == FusedActivationFunctionType::kRelu ||
3932 Ac == FusedActivationFunctionType::kRelu6 ||
3933 Ac == FusedActivationFunctionType::kRelu1,
3934 "");
3935 if (Ac == FusedActivationFunctionType::kNone) {
3936 TFLITE_DCHECK_EQ(output_activation_min, 0);
3937 TFLITE_DCHECK_EQ(output_activation_max, 255);
3938 }
3939 MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
3940 pad_height, filter_width, filter_height, output_activation_min,
3941 output_activation_max, output_data, output_dims);
3942 }
3943
3944 // legacy, for compatibility with old checked-in code
3945 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3946 void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
3947 int pad_width, int pad_height, int filter_width, int filter_height,
3948 int32 output_activation_min, int32 output_activation_max,
3949 uint8* output_data, const Dims<4>& output_dims) {
3950 MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3951 filter_width, filter_height, output_activation_min,
3952 output_activation_max, output_data, output_dims);
3953 }
3954
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)3955 inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
3956 int stride_width, int stride_height, int pad_width,
3957 int pad_height, int filter_width, int filter_height,
3958 float output_activation_min, float output_activation_max,
3959 float* output_data, const Dims<4>& output_dims) {
3960 PoolParams params;
3961 params.stride_height = stride_height;
3962 params.stride_width = stride_width;
3963 params.filter_height = filter_height;
3964 params.filter_width = filter_width;
3965 params.padding_values.height = pad_height;
3966 params.padding_values.width = pad_width;
3967 params.float_activation_min = output_activation_min;
3968 params.float_activation_max = output_activation_max;
3969 L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3970 output_data);
3971 }
3972
3973 // legacy, for compatibility with old checked-in code
3974 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3975 void L2Pool(const float* input_data, const Dims<4>& input_dims,
3976 int stride_width, int stride_height, int pad_width, int pad_height,
3977 int filter_width, int filter_height, float* output_data,
3978 const Dims<4>& output_dims) {
3979 float output_activation_min, output_activation_max;
3980 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3981 L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
3982 pad_height, filter_width, filter_height, output_activation_min,
3983 output_activation_max, output_data, output_dims);
3984 }
3985
3986 // legacy, for compatibility with old checked-in code
3987 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3988 void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
3989 int pad_width, int pad_height, int filter_width, int filter_height,
3990 float* output_data, const Dims<4>& output_dims) {
3991 L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3992 filter_width, filter_height, output_data, output_dims);
3993 }
3994
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3995 inline void Softmax(const SoftmaxParams& params,
3996 const RuntimeShape& input_shape, const uint8* input_data,
3997 const RuntimeShape& output_shape, uint8* output_data) {
3998 const int32 input_beta_multiplier = params.input_multiplier;
3999 const int32 input_beta_left_shift = params.input_left_shift;
4000 const int diff_min = params.diff_min;
4001 // The representation chosen for the input to the exp() function is Q5.26.
4002 // We need to leave extra space since values that we skip might be as large as
4003 // -32 before multiplying by input_beta_multiplier, and therefore as large as
4004 // -16 afterwards. Note that exp(-8) is definitely not insignificant to
4005 // accumulation, but exp(-16) definitely is.
4006 static const int kScaledDiffIntegerBits = 5;
4007 static const int kAccumulationIntegerBits = 12;
4008 using FixedPointScaledDiff =
4009 gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
4010 using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
4011 using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4012
4013 ruy::profiler::ScopeLabel label("Softmax/8bit");
4014 const int trailing_dim = input_shape.DimensionsCount() - 1;
4015 const int outer_size =
4016 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4017 const int depth =
4018 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4019
4020 for (int b = 0; b < outer_size; ++b) {
4021 const uint8* input_data_ptr = input_data + b * depth;
4022 uint8* output_data_ptr = output_data + b * depth;
4023
4024 // Determine the largest entry in the current row
4025 uint8 max_in_row = 0;
4026 {
4027 int c = 0;
4028 #ifdef USE_NEON
4029 uint8x16_t max16_0 = vdupq_n_u8(0);
4030 uint8x16_t max16_1 = vdupq_n_u8(0);
4031 for (; c <= depth - 32; c += 32) {
4032 max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
4033 max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
4034 }
4035 uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
4036 if (c <= depth - 16) {
4037 max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
4038 c += 16;
4039 }
4040 uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
4041 if (c <= depth - 8) {
4042 max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
4043 c += 8;
4044 }
4045 uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
4046 uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
4047 uint8x8_t max1 = vpmax_u8(max2, max2);
4048 max_in_row = vget_lane_u8(max1, 0);
4049 #endif
4050 for (; c < depth; ++c) {
4051 max_in_row = std::max(max_in_row, input_data_ptr[c]);
4052 }
4053 }
4054
4055 #ifdef USE_NEON
4056 using FixedPointAccumInt32x4 =
4057 gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
4058 using FixedPointScaledDiffInt32x4 =
4059 gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
4060 using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
4061 FixedPoint0Int32x4 input_beta_multiplier_f0 =
4062 FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
4063 int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
4064 #endif
4065
4066 // Compute the sum of exponentials of the differences of entries in the
4067 // current row from the largest entry in the current row.
4068 FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
4069 {
4070 int c = 0;
4071 #ifdef USE_NEON
4072 int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
4073 FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
4074 FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
4075 FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
4076 for (; c <= depth - 8; c += 8) {
4077 uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
4078 int16x8_t input_diff_s16 =
4079 vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
4080 int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
4081 int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
4082 int32x4_t mask_0 =
4083 gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
4084 int32x4_t mask_1 =
4085 gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
4086 FixedPointScaledDiffInt32x4 scaled_diff_0 =
4087 input_beta_multiplier_f0 *
4088 FixedPointScaledDiffInt32x4::FromRaw(
4089 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
4090 FixedPointScaledDiffInt32x4 scaled_diff_1 =
4091 input_beta_multiplier_f0 *
4092 FixedPointScaledDiffInt32x4::FromRaw(
4093 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
4094 FixedPointAccumInt32x4 exps_0 =
4095 gemmlowp::Rescale<kAccumulationIntegerBits>(
4096 exp_on_negative_values(scaled_diff_0));
4097 FixedPointAccumInt32x4 exps_1 =
4098 gemmlowp::Rescale<kAccumulationIntegerBits>(
4099 exp_on_negative_values(scaled_diff_1));
4100 FixedPointAccumInt32x4 masked_exps_0 =
4101 SelectUsingMask(mask_0, exps_0, zeros);
4102 FixedPointAccumInt32x4 masked_exps_1 =
4103 SelectUsingMask(mask_1, exps_1, zeros);
4104 sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
4105 sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
4106 }
4107 int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
4108 int32x2_t sum_of_exps_reduced_2 =
4109 vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
4110 vget_high_s32(sum_of_exps_reduced_4));
4111 int32x2_t sum_of_exps_reduced_1 =
4112 vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
4113 sum_of_exps =
4114 FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
4115 #endif
4116 for (; c < depth; ++c) {
4117 int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
4118 if (input_diff >= diff_min) {
4119 const int32 input_diff_rescaled =
4120 MultiplyByQuantizedMultiplierGreaterThanOne(
4121 input_diff, input_beta_multiplier, input_beta_left_shift);
4122 const FixedPointScaledDiff scaled_diff_f8 =
4123 FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4124 sum_of_exps =
4125 sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
4126 exp_on_negative_values(scaled_diff_f8));
4127 }
4128 }
4129 }
4130
4131 // Compute the fixed-point multiplier and shift that we need to apply to
4132 // perform a division by the above-computed sum-of-exponentials.
4133 int num_bits_over_unit = 0;
4134 FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
4135 sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
4136
4137 // Compute the quotients of exponentials of differences of entries in the
4138 // current row from the largest entry, over the previously-computed sum of
4139 // exponentials.
4140 {
4141 int c = 0;
4142 #ifdef USE_NEON
4143 int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
4144 for (; c <= depth - 8; c += 8) {
4145 uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
4146 int16x8_t input_diff_s16 =
4147 vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
4148 int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
4149 int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
4150 uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
4151 FixedPointScaledDiffInt32x4 scaled_diff_0 =
4152 input_beta_multiplier_f0 *
4153 FixedPointScaledDiffInt32x4::FromRaw(
4154 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
4155 FixedPointScaledDiffInt32x4 scaled_diff_1 =
4156 input_beta_multiplier_f0 *
4157 FixedPointScaledDiffInt32x4::FromRaw(
4158 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
4159 FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
4160 FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
4161 int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
4162 vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
4163 num_bits_over_unit + 31 - 8);
4164 int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
4165 vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
4166 num_bits_over_unit + 31 - 8);
4167 int16x8_t output_s16 =
4168 vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
4169 uint8x8_t output_u8 = vqmovun_s16(output_s16);
4170 uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
4171 vst1_u8(output_data_ptr + c, masked_output);
4172 }
4173 #endif
4174 for (; c < depth; ++c) {
4175 int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
4176 if (input_diff >= diff_min) {
4177 const int32 input_diff_rescaled =
4178 MultiplyByQuantizedMultiplierGreaterThanOne(
4179 input_diff, input_beta_multiplier, input_beta_left_shift);
4180 const FixedPointScaledDiff scaled_diff_f8 =
4181 FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4182
4183 FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
4184 int32 unsat_output = gemmlowp::RoundingDivideByPOT(
4185 (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
4186
4187 output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
4188
4189 } else {
4190 output_data_ptr[c] = 0;
4191 }
4192 }
4193 }
4194 }
4195 }
4196
Softmax(const float * input_data,const RuntimeShape & input_shape,float beta,float * output_data,const RuntimeShape & output_shape)4197 inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
4198 float beta, float* output_data,
4199 const RuntimeShape& output_shape) {
4200 SoftmaxParams params;
4201 params.beta = beta;
4202 Softmax(params, input_shape, input_data, output_shape, output_data);
4203 }
4204
Softmax(const float * input_data,const Dims<4> & input_dims,float beta,float * output_data,const Dims<4> & output_dims)4205 inline void Softmax(const float* input_data, const Dims<4>& input_dims,
4206 float beta, float* output_data,
4207 const Dims<4>& output_dims) {
4208 Softmax(input_data, DimsToShape(input_dims), beta, output_data,
4209 DimsToShape(output_dims));
4210 }
4211
Softmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)4212 inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
4213 int32 input_beta_multiplier, int32 input_beta_left_shift,
4214 int diff_min, uint8* output_data,
4215 const RuntimeShape& output_shape) {
4216 SoftmaxParams params;
4217 params.input_multiplier = input_beta_multiplier;
4218 params.input_left_shift = input_beta_left_shift;
4219 params.diff_min = diff_min;
4220 Softmax(params, input_shape, input_data, output_shape, output_data);
4221 }
Softmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)4222 inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
4223 int32 input_beta_multiplier, int32 input_beta_left_shift,
4224 int diff_min, uint8* output_data,
4225 const Dims<4>& output_dims) {
4226 Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
4227 input_beta_left_shift, diff_min, output_data,
4228 DimsToShape(output_dims));
4229 }
4230
LogSoftmax(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)4231 inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
4232 float* output_data, const RuntimeShape& output_shape) {
4233 SoftmaxParams params;
4234 // No params currently used for float LogSoftmax.
4235 LogSoftmax(params, input_shape, input_data, output_shape, output_data);
4236 }
4237
LogSoftmax(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4238 inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
4239 float* output_data, const Dims<4>& output_dims) {
4240 LogSoftmax(input_data, DimsToShape(input_dims), output_data,
4241 DimsToShape(output_dims));
4242 }
4243
LogSoftmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)4244 inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
4245 int32 input_multiplier, int32 input_left_shift,
4246 int32 reverse_scaling_divisor,
4247 int32 reverse_scaling_right_shift, int diff_min,
4248 uint8* output_data, const RuntimeShape& output_shape) {
4249 SoftmaxParams params;
4250 params.input_multiplier = input_multiplier;
4251 params.input_left_shift = input_left_shift;
4252 params.reverse_scaling_divisor = reverse_scaling_divisor;
4253 params.reverse_scaling_right_shift = reverse_scaling_right_shift;
4254 params.diff_min = diff_min;
4255 reference_ops::LogSoftmax(params, input_shape, input_data, output_shape,
4256 output_data);
4257 }
4258
LogSoftmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)4259 inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
4260 int32 input_multiplier, int32 input_left_shift,
4261 int32 reverse_scaling_divisor,
4262 int32 reverse_scaling_right_shift, int diff_min,
4263 uint8* output_data, const Dims<4>& output_dims) {
4264 reference_ops::LogSoftmax(
4265 input_data, DimsToShape(input_dims), input_multiplier, input_left_shift,
4266 reverse_scaling_divisor, reverse_scaling_right_shift, diff_min,
4267 output_data, DimsToShape(output_dims));
4268 }
4269
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4270 inline void Logistic(const LogisticParams& params,
4271 const RuntimeShape& input_shape, const uint8* input_data,
4272 const RuntimeShape& output_shape, uint8* output_data) {
4273 ruy::profiler::ScopeLabel label("Logistic/Uint8");
4274 const int32 input_zero_point = params.input_zero_point;
4275 const int32 input_range_radius = params.input_range_radius;
4276 const int32 input_multiplier = params.input_multiplier;
4277 const int input_left_shift = params.input_left_shift;
4278 const int size = MatchingFlatSize(input_shape, output_shape);
4279
4280 int c = 0;
4281 #ifdef USE_NEON
4282 // Handle 16 values at a time
4283 for (; c <= size - 16; c += 16) {
4284 // Read input uint8 values, cast to int16 and subtract input_zero_point
4285 uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
4286 int16x8_t input_val_centered_0 =
4287 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
4288 vdupq_n_s16(input_zero_point));
4289 int16x8_t input_val_centered_1 =
4290 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
4291 vdupq_n_s16(input_zero_point));
4292
4293 // Prepare the bit masks that we will use at the end to implement the logic
4294 // that was expressed in the scalar code with branching:
4295 // if (input_val_centered < -input_range_radius) {
4296 // output_val = 0;
4297 // } else if (input_val_centered > input_range_radius) {
4298 // output_val = 255;
4299 // } else {
4300 // ...
4301 uint16x8_t mask_rightclamp_0 =
4302 vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
4303 uint16x8_t mask_rightclamp_1 =
4304 vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
4305 uint16x8_t mask_leftclamp_0 =
4306 vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
4307 uint16x8_t mask_leftclamp_1 =
4308 vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
4309 uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
4310 vshrn_n_u16(mask_rightclamp_1, 8));
4311 uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
4312 vshrn_n_u16(mask_leftclamp_1, 8));
4313
4314 // This performs what is expressed in the scalar code as
4315 // const int32 input_val_rescaled =
4316 // MultiplyByQuantizedMultiplierGreaterThanOne(
4317 // input_val_centered, input_multiplier, input_left_shift);
4318 int32x4_t input_val_rescaled_0 =
4319 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
4320 vdupq_n_s32(input_left_shift));
4321 int32x4_t input_val_rescaled_1 =
4322 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
4323 vdupq_n_s32(input_left_shift));
4324 int32x4_t input_val_rescaled_2 =
4325 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
4326 vdupq_n_s32(input_left_shift));
4327 int32x4_t input_val_rescaled_3 =
4328 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
4329 vdupq_n_s32(input_left_shift));
4330 input_val_rescaled_0 =
4331 vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
4332 input_val_rescaled_1 =
4333 vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
4334 input_val_rescaled_2 =
4335 vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
4336 input_val_rescaled_3 =
4337 vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
4338
4339 // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t
4340 using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
4341 using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
4342 const FixedPoint4 input_val_f4_0 =
4343 FixedPoint4::FromRaw(input_val_rescaled_0);
4344 const FixedPoint4 input_val_f4_1 =
4345 FixedPoint4::FromRaw(input_val_rescaled_1);
4346 const FixedPoint4 input_val_f4_2 =
4347 FixedPoint4::FromRaw(input_val_rescaled_2);
4348 const FixedPoint4 input_val_f4_3 =
4349 FixedPoint4::FromRaw(input_val_rescaled_3);
4350 const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
4351 const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
4352 const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
4353 const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
4354
4355 // Divide by 2^23 as in the scalar code
4356 using gemmlowp::RoundingDivideByPOT;
4357 int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23);
4358 int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23);
4359 int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23);
4360 int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23);
4361
4362 // Cast output values to uint8, saturating
4363 int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
4364 vqmovn_s32(output_val_s32_1));
4365 int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
4366 vqmovn_s32(output_val_s32_3));
4367 uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
4368 vqmovun_s16(output_val_s16_1));
4369
4370 // Perform the bit-masking with the bit masks computed at the beginning,
4371 // see the comment there.
4372 output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
4373 output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
4374
4375 // Store back to memory
4376 vst1q_u8(output_data + c, output_val_u8);
4377 }
4378 #endif
4379 // Leftover loop: handle one value at a time with scalar code.
4380 for (; c < size; ++c) {
4381 const uint8 input_val_u8 = input_data[c];
4382 const int32 input_val_centered =
4383 static_cast<int32>(input_val_u8) - input_zero_point;
4384 uint8 output_val;
4385 if (input_val_centered < -input_range_radius) {
4386 output_val = 0;
4387 } else if (input_val_centered > input_range_radius) {
4388 output_val = 255;
4389 } else {
4390 const int32 input_val_rescaled =
4391 MultiplyByQuantizedMultiplierGreaterThanOne(
4392 input_val_centered, input_multiplier, input_left_shift);
4393 using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
4394 using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4395 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
4396 const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
4397 using gemmlowp::RoundingDivideByPOT;
4398 int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
4399 if (output_val_s32 == 256) {
4400 output_val_s32 = 255;
4401 }
4402 TFLITE_DCHECK_GE(output_val_s32, 0);
4403 TFLITE_DCHECK_LE(output_val_s32, 255);
4404 output_val = static_cast<uint8>(output_val_s32);
4405 }
4406 output_data[c] = output_val;
4407 }
4408 }
4409
Logistic(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)4410 inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
4411 int32 input_zero_point, int32 input_range_radius,
4412 int32 input_multiplier, int input_left_shift,
4413 uint8* output_data, const RuntimeShape& output_shape) {
4414 LogisticParams params;
4415 params.input_zero_point = input_zero_point;
4416 params.input_range_radius = input_range_radius;
4417 params.input_multiplier = input_multiplier;
4418 params.input_left_shift = input_left_shift;
4419 Logistic(params, input_shape, input_data, output_shape, output_data);
4420 }
4421
Logistic(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4422 inline void Logistic(const float* input_data, const Dims<4>& input_dims,
4423 float* output_data, const Dims<4>& output_dims) {
4424 Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4425 output_data);
4426 }
4427
Logistic(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)4428 inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
4429 int32 input_zero_point, int32 input_range_radius,
4430 int32 input_multiplier, int input_left_shift,
4431 uint8* output_data, const Dims<4>& output_dims) {
4432 Logistic(input_data, DimsToShape(input_dims), input_zero_point,
4433 input_range_radius, input_multiplier, input_left_shift, output_data,
4434 DimsToShape(output_dims));
4435 }
4436
Logistic(const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4437 inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
4438 const RuntimeShape& output_shape, int16* output_data) {
4439 LogisticParams params;
4440 // No params currently needed by int16 Logistic.
4441 Logistic(params, input_shape, input_data, output_shape, output_data);
4442 }
4443
Logistic(const int16 * input_data,const RuntimeShape & input_shape,int16 * output_data,const RuntimeShape & output_shape)4444 inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
4445 int16* output_data, const RuntimeShape& output_shape) {
4446 LogisticParams params;
4447 // No params currently needed by int16 Logistic.
4448 Logistic(params, input_shape, input_data, output_shape, output_data);
4449 }
4450
Logistic(const int16 * input_data,const Dims<4> & input_dims,int16 * output_data,const Dims<4> & output_dims)4451 inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
4452 int16* output_data, const Dims<4>& output_dims) {
4453 Logistic(input_data, DimsToShape(input_dims), output_data,
4454 DimsToShape(output_dims));
4455 }
4456
Tanh(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4457 inline void Tanh(const float* input_data, const Dims<4>& input_dims,
4458 float* output_data, const Dims<4>& output_dims) {
4459 Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4460 output_data);
4461 }
4462
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4463 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
4464 const uint8* input_data, const RuntimeShape& output_shape,
4465 uint8* output_data) {
4466 // Note that this is almost the exact same code as in Logistic().
4467 ruy::profiler::ScopeLabel label("Tanh");
4468 const int32 input_zero_point = params.input_zero_point;
4469 const int32 input_range_radius = params.input_range_radius;
4470 const int32 input_multiplier = params.input_multiplier;
4471 const int input_left_shift = params.input_left_shift;
4472 const int size = MatchingFlatSize(input_shape, output_shape);
4473
4474 int c = 0;
4475 int32_t output_zero_point = 128;
4476 #ifdef USE_NEON
4477 // Handle 16 values at a time
4478 for (; c <= size - 16; c += 16) {
4479 // Read input uint8 values, cast to int16 and subtract input_zero_point
4480 uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
4481 int16x8_t input_val_centered_0 =
4482 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
4483 vdupq_n_s16(input_zero_point));
4484 int16x8_t input_val_centered_1 =
4485 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
4486 vdupq_n_s16(input_zero_point));
4487
4488 // Prepare the bit masks that we will use at the end to implement the logic
4489 // that was expressed in the scalar code with branching:
4490 // if (input_val_centered < -input_range_radius) {
4491 // output_val = 0;
4492 // } else if (input_val_centered > input_range_radius) {
4493 // output_val = 255;
4494 // } else {
4495 // ...
4496 uint16x8_t mask_rightclamp_0 =
4497 vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
4498 uint16x8_t mask_rightclamp_1 =
4499 vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
4500 uint16x8_t mask_leftclamp_0 =
4501 vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
4502 uint16x8_t mask_leftclamp_1 =
4503 vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
4504 uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
4505 vshrn_n_u16(mask_rightclamp_1, 8));
4506 uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
4507 vshrn_n_u16(mask_leftclamp_1, 8));
4508
4509 // This performs what is expressed in the scalar code as
4510 // const int32 input_val_rescaled =
4511 // MultiplyByQuantizedMultiplierGreaterThanOne(
4512 // input_val_centered, input_multiplier, input_left_shift);
4513 int32x4_t input_val_rescaled_0 =
4514 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
4515 vdupq_n_s32(input_left_shift));
4516 int32x4_t input_val_rescaled_1 =
4517 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
4518 vdupq_n_s32(input_left_shift));
4519 int32x4_t input_val_rescaled_2 =
4520 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
4521 vdupq_n_s32(input_left_shift));
4522 int32x4_t input_val_rescaled_3 =
4523 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
4524 vdupq_n_s32(input_left_shift));
4525 input_val_rescaled_0 =
4526 vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
4527 input_val_rescaled_1 =
4528 vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
4529 input_val_rescaled_2 =
4530 vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
4531 input_val_rescaled_3 =
4532 vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
4533
4534 // Invoke gemmlowp::tanh on FixedPoint wrapping int32x4_t
4535 using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
4536 using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
4537 const FixedPoint4 input_val_f4_0 =
4538 FixedPoint4::FromRaw(input_val_rescaled_0);
4539 const FixedPoint4 input_val_f4_1 =
4540 FixedPoint4::FromRaw(input_val_rescaled_1);
4541 const FixedPoint4 input_val_f4_2 =
4542 FixedPoint4::FromRaw(input_val_rescaled_2);
4543 const FixedPoint4 input_val_f4_3 =
4544 FixedPoint4::FromRaw(input_val_rescaled_3);
4545 const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
4546 const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
4547 const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
4548 const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
4549
4550 // Divide by 2^24 as in the scalar code
4551 using gemmlowp::RoundingDivideByPOT;
4552 int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 24);
4553 int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 24);
4554 int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 24);
4555 int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 24);
4556
4557 // Add the output zero point
4558 int32x4_t output_zero_point_s32 = vdupq_n_s32(output_zero_point);
4559 output_val_s32_0 = vaddq_s32(output_val_s32_0, output_zero_point_s32);
4560 output_val_s32_1 = vaddq_s32(output_val_s32_1, output_zero_point_s32);
4561 output_val_s32_2 = vaddq_s32(output_val_s32_2, output_zero_point_s32);
4562 output_val_s32_3 = vaddq_s32(output_val_s32_3, output_zero_point_s32);
4563
4564 // Cast output values to uint8, saturating
4565 int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
4566 vqmovn_s32(output_val_s32_1));
4567 int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
4568 vqmovn_s32(output_val_s32_3));
4569 uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
4570 vqmovun_s16(output_val_s16_1));
4571
4572 // Perform the bit-masking with the bit masks computed at the beginning,
4573 // see the comment there.
4574 output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
4575 output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
4576
4577 // Store back to memory
4578 vst1q_u8(output_data + c, output_val_u8);
4579 }
4580 #endif
4581 // Leftover loop: handle one value at a time with scalar code.
4582 for (; c < size; ++c) {
4583 const uint8 input_val_u8 = input_data[c];
4584 const int32 input_val_centered =
4585 static_cast<int32>(input_val_u8) - input_zero_point;
4586 uint8 output_val;
4587 if (input_val_centered < -input_range_radius) {
4588 output_val = 0;
4589 } else if (input_val_centered > input_range_radius) {
4590 output_val = 255;
4591 } else {
4592 const int32 input_val_rescaled =
4593 MultiplyByQuantizedMultiplierGreaterThanOne(
4594 input_val_centered, input_multiplier, input_left_shift);
4595 using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
4596 using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4597 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
4598 const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
4599 using gemmlowp::RoundingDivideByPOT;
4600 int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24);
4601 output_val_s32 += output_zero_point;
4602 if (output_val_s32 == 256) {
4603 output_val_s32 = 255;
4604 }
4605 TFLITE_DCHECK_GE(output_val_s32, 0);
4606 TFLITE_DCHECK_LE(output_val_s32, 255);
4607 output_val = static_cast<uint8>(output_val_s32);
4608 }
4609 output_data[c] = output_val;
4610 }
4611 }
4612
Tanh(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)4613 inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
4614 int32 input_zero_point, int32 input_range_radius,
4615 int32 input_multiplier, int input_left_shift,
4616 uint8* output_data, const RuntimeShape& output_shape) {
4617 TanhParams params;
4618 params.input_zero_point = input_zero_point;
4619 params.input_range_radius = input_range_radius;
4620 params.input_multiplier = input_multiplier;
4621 params.input_left_shift = input_left_shift;
4622 Tanh(params, input_shape, input_data, output_shape, output_data);
4623 }
4624
Tanh(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)4625 inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
4626 int32 input_zero_point, int32 input_range_radius,
4627 int32 input_multiplier, int input_left_shift,
4628 uint8* output_data, const Dims<4>& output_dims) {
4629 Tanh(input_data, DimsToShape(input_dims), input_zero_point,
4630 input_range_radius, input_multiplier, input_left_shift, output_data,
4631 DimsToShape(output_dims));
4632 }
4633
Tanh(const int16 * input_data,const RuntimeShape & input_shape,int input_left_shift,int16 * output_data,const RuntimeShape & output_shape)4634 inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
4635 int input_left_shift, int16* output_data,
4636 const RuntimeShape& output_shape) {
4637 TanhParams params;
4638 params.input_left_shift = input_left_shift;
4639 Tanh(params, input_shape, input_data, output_shape, output_data);
4640 }
4641
Tanh(const int16 * input_data,const Dims<4> & input_dims,int input_left_shift,int16 * output_data,const Dims<4> & output_dims)4642 inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
4643 int input_left_shift, int16* output_data,
4644 const Dims<4>& output_dims) {
4645 Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
4646 DimsToShape(output_dims));
4647 }
4648
4649 template <typename T>
DepthToSpace(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)4650 inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
4651 int block_size, T* output_data,
4652 const Dims<4>& output_dims) {
4653 tflite::DepthToSpaceParams op_params;
4654 op_params.block_size = block_size;
4655
4656 DepthToSpace(op_params, DimsToShape(input_dims), input_data,
4657 DimsToShape(output_dims), output_data);
4658 }
4659
4660 template <typename T>
SpaceToDepth(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)4661 inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
4662 int block_size, T* output_data,
4663 const Dims<4>& output_dims) {
4664 tflite::SpaceToDepthParams op_params;
4665 op_params.block_size = block_size;
4666
4667 SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
4668 DimsToShape(output_dims), output_data);
4669 }
4670
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)4671 inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
4672 const float* input2_data, const Dims<4>& input2_dims,
4673 float output_activation_min, float output_activation_max,
4674 float* output_data, const Dims<4>& output_dims) {
4675 tflite::ArithmeticParams op_params;
4676 op_params.float_activation_min = output_activation_min;
4677 op_params.float_activation_max = output_activation_max;
4678
4679 Mul(op_params, DimsToShape(input1_dims), input1_data,
4680 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4681 output_data);
4682 }
4683
4684 template <FusedActivationFunctionType Ac>
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)4685 void Mul(const float* input1_data, const Dims<4>& input1_dims,
4686 const float* input2_data, const Dims<4>& input2_dims,
4687 float* output_data, const Dims<4>& output_dims) {
4688 float output_activation_min, output_activation_max;
4689 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
4690
4691 Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
4692 output_activation_max, output_data, output_dims);
4693 }
4694
Mul(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 output_activation_min,int32 output_activation_max,int32 * output_data,const Dims<4> & output_dims)4695 inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
4696 const int32* input2_data, const Dims<4>& input2_dims,
4697 int32 output_activation_min, int32 output_activation_max,
4698 int32* output_data, const Dims<4>& output_dims) {
4699 tflite::ArithmeticParams op_params;
4700 op_params.quantized_activation_min = output_activation_min;
4701 op_params.quantized_activation_max = output_activation_max;
4702
4703 Mul(op_params, DimsToShape(input1_dims), input1_data,
4704 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4705 output_data);
4706 }
4707
4708 template <FusedActivationFunctionType Ac>
Mul(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)4709 void Mul(const int32* input1_data, const Dims<4>& input1_dims,
4710 const int32* input2_data, const Dims<4>& input2_dims,
4711 int32* output_data, const Dims<4>& output_dims) {
4712 TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
4713 tflite::ArithmeticParams op_params;
4714 // No parameters needed.
4715
4716 MulNoActivation(op_params, DimsToShape(input1_dims), input1_data,
4717 DimsToShape(input2_dims), input2_data,
4718 DimsToShape(output_dims), output_data);
4719 }
4720
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int16 * output_data,const Dims<4> & output_dims)4721 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
4722 const int16* input2_data, const Dims<4>& input2_dims,
4723 int16* output_data, const Dims<4>& output_dims) {
4724 tflite::ArithmeticParams op_params;
4725 // No parameters needed.
4726
4727 Mul(op_params, DimsToShape(input1_dims), input1_data,
4728 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4729 output_data);
4730 }
4731
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int32 output_offset,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)4732 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
4733 const int16* input2_data, const Dims<4>& input2_dims,
4734 int32 output_offset, int32 output_activation_min,
4735 int32 output_activation_max, uint8* output_data,
4736 const Dims<4>& output_dims) {
4737 tflite::ArithmeticParams op_params;
4738 op_params.output_offset = output_offset;
4739 op_params.quantized_activation_min = output_activation_min;
4740 op_params.quantized_activation_max = output_activation_max;
4741
4742 Mul(op_params, DimsToShape(input1_dims), input1_data,
4743 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4744 output_data);
4745 }
4746
4747 template <typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)4748 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
4749 const T* input2_data, const Dims<4>& input2_dims,
4750 T output_activation_min, T output_activation_max,
4751 T* output_data, const Dims<4>& output_dims) {
4752 tflite::ArithmeticParams op_params;
4753 SetActivationParams(output_activation_min, output_activation_max, &op_params);
4754
4755 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
4756 DimsToShape(input2_dims), input2_data,
4757 DimsToShape(output_dims), output_data);
4758 }
4759
4760 // For compatibility with old checked-in code
4761 template <FusedActivationFunctionType Ac>
BroadcastMul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)4762 inline void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
4763 const float* input2_data, const Dims<4>& input2_dims,
4764 float* output_data, const Dims<4>& output_dims) {
4765 tflite::ArithmeticParams op_params;
4766 float float_activation_min;
4767 float float_activation_max;
4768 GetActivationMinMax(Ac, &float_activation_min, &float_activation_max);
4769 SetActivationParams(float_activation_min, float_activation_max, &op_params);
4770
4771 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
4772 DimsToShape(input2_dims), input2_data,
4773 DimsToShape(output_dims), output_data);
4774 }
4775
LocalResponseNormalization(const float * input_data,const Dims<4> & input_dims,int range,float bias,float alpha,float beta,float * output_data,const Dims<4> & output_dims)4776 inline void LocalResponseNormalization(const float* input_data,
4777 const Dims<4>& input_dims, int range,
4778 float bias, float alpha, float beta,
4779 float* output_data,
4780 const Dims<4>& output_dims) {
4781 tflite::LocalResponseNormalizationParams op_params;
4782 op_params.range = range;
4783 op_params.bias = bias;
4784 op_params.alpha = alpha;
4785 op_params.beta = beta;
4786
4787 LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
4788 DimsToShape(output_dims), output_data);
4789 }
4790
4791 template <typename SrcT, typename DstT>
Cast(const SrcT * input_data,const Dims<4> & input_dims,DstT * output_data,const Dims<4> & output_dims)4792 void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
4793 const Dims<4>& output_dims) {
4794 Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4795 output_data);
4796 }
4797
Floor(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4798 inline void Floor(const float* input_data, const Dims<4>& input_dims,
4799 float* output_data, const Dims<4>& output_dims) {
4800 Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4801 output_data);
4802 }
4803
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims,bool align_corners)4804 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
4805 const int32* output_size_data,
4806 const Dims<4>& output_size_dims, float* output_data,
4807 const Dims<4>& output_dims, bool align_corners) {
4808 tflite::ResizeBilinearParams op_params;
4809 op_params.align_corners = align_corners;
4810 op_params.half_pixel_centers = false;
4811 ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
4812 DimsToShape(output_size_dims), output_size_data,
4813 DimsToShape(output_dims), output_data);
4814 }
4815
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims,bool align_corners)4816 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
4817 const int32* output_size_data,
4818 const Dims<4>& output_size_dims, uint8* output_data,
4819 const Dims<4>& output_dims, bool align_corners) {
4820 tflite::ResizeBilinearParams op_params;
4821 op_params.align_corners = align_corners;
4822 op_params.half_pixel_centers = false;
4823 ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
4824 DimsToShape(output_size_dims), output_size_data,
4825 DimsToShape(output_dims), output_data);
4826 }
4827
4828 // legacy, for compatibility with old checked-in code
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims)4829 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
4830 const int32* output_size_data,
4831 const Dims<4>& output_size_dims, float* output_data,
4832 const Dims<4>& output_dims) {
4833 ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
4834 output_data, output_dims, /*align_corners=*/false);
4835 }
4836
4837 // legacy, for compatibility with old checked-in code
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims)4838 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
4839 const int32* output_size_data,
4840 const Dims<4>& output_size_dims, uint8* output_data,
4841 const Dims<4>& output_dims) {
4842 ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
4843 output_data, output_dims, /*align_corners=*/false);
4844 }
4845
4846 template <typename T>
BatchToSpaceND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * crops_data,const Dims<4> & crops_dims,T * output_data,const Dims<4> & output_dims)4847 inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
4848 const int32* block_shape_data,
4849 const Dims<4>& block_shape_dims,
4850 const int32* crops_data, const Dims<4>& crops_dims,
4851 T* output_data, const Dims<4>& output_dims) {
4852 BatchToSpaceND(DimsToShape(input_dims), input_data,
4853 DimsToShape(block_shape_dims), block_shape_data,
4854 DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
4855 output_data);
4856 }
4857
4858 // Legacy signature, function covered both Pad and PadV2.
4859 template <typename T>
PadV2(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const T pad_value)4860 inline void PadV2(const T* input_data, const Dims<4>& input_dims,
4861 const std::vector<int>& left_paddings,
4862 const std::vector<int>& right_paddings, T* output_data,
4863 const Dims<4>& output_dims, const T pad_value) {
4864 TFLITE_DCHECK_EQ(left_paddings.size(), 4);
4865 TFLITE_DCHECK_EQ(right_paddings.size(), 4);
4866 tflite::PadParams op_params;
4867 op_params.left_padding_count = 4;
4868 op_params.right_padding_count = 4;
4869 for (int i = 0; i < 4; ++i) {
4870 op_params.left_padding[i] = left_paddings[3 - i];
4871 op_params.right_padding[i] = right_paddings[3 - i];
4872 }
4873 const T pad_value_copy = pad_value;
4874
4875 Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
4876 DimsToShape(output_dims), output_data);
4877 }
4878
4879 // Old Pad that calls legacy PadV2.
4880 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)4881 inline void Pad(const T* input_data, const Dims<4>& input_dims,
4882 const std::vector<int>& left_paddings,
4883 const std::vector<int>& right_paddings, T* output_data,
4884 const Dims<4>& output_dims, const int32_t pad_value) {
4885 const T converted_pad_value = static_cast<T>(pad_value);
4886 PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
4887 output_dims, converted_pad_value);
4888 }
4889
4890 // Old Pad that only padded with 0.
4891 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims)4892 inline void Pad(const T* input_data, const Dims<4>& input_dims,
4893 const std::vector<int>& left_paddings,
4894 const std::vector<int>& right_paddings, T* output_data,
4895 const Dims<4>& output_dims) {
4896 const T pad_value = static_cast<T>(0);
4897 PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
4898 output_dims, pad_value);
4899 }
4900
4901 template <typename T>
Slice(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & begin,const std::vector<int> & size,T * output_data,const Dims<4> & output_dims)4902 inline void Slice(const T* input_data, const Dims<4>& input_dims,
4903 const std::vector<int>& begin, const std::vector<int>& size,
4904 T* output_data, const Dims<4>& output_dims) {
4905 tflite::SliceParams op_params;
4906 op_params.begin_count = 4;
4907 op_params.size_count = 4;
4908 for (int i = 0; i < 4; ++i) {
4909 op_params.begin[i] = begin[3 - i];
4910 op_params.size[i] = size[3 - i];
4911 }
4912
4913 Slice(op_params, DimsToShape(input_dims), input_data,
4914 DimsToShape(output_dims), output_data);
4915 }
4916
4917 template <typename T>
TensorFlowMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)4918 void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
4919 const T* input2_data, T* output_data,
4920 const Dims<4>& output_dims) {
4921 Minimum(DimsToShape(input1_dims), input1_data, input2_data,
4922 DimsToShape(output_dims), output_data);
4923 }
4924
4925 template <typename T>
TensorFlowMaximum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)4926 void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
4927 const T* input2_data, T* output_data,
4928 const Dims<4>& output_dims) {
4929 Maximum(DimsToShape(input1_dims), input1_data, input2_data,
4930 DimsToShape(output_dims), output_data);
4931 }
4932
Dequantize(const uint8 * input_data,const Dims<4> & input_dims,int32 zero_point,double scale,float * output_data,const Dims<4> & output_dims)4933 inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
4934 int32 zero_point, double scale, float* output_data,
4935 const Dims<4>& output_dims) {
4936 tflite::DequantizationParams op_params;
4937 op_params.zero_point = zero_point;
4938 op_params.scale = scale;
4939
4940 Dequantize(op_params, DimsToShape(input_dims), input_data,
4941 DimsToShape(output_dims), output_data);
4942 }
4943
4944 template <typename T>
Transpose(const T * input,const Dims<4> & input_dims,T * output,const Dims<4> & output_dims,const int * permuted_axes)4945 void Transpose(const T* input, const Dims<4>& input_dims, T* output,
4946 const Dims<4>& output_dims, const int* permuted_axes) {
4947 TransposeParams params;
4948 params.perm_count = 4;
4949 for (int i = 0; i < 4; ++i) {
4950 params.perm[i] = 3 - permuted_axes[3 - i];
4951 }
4952 Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
4953 output);
4954 }
4955
4956 template <typename T>
StridedSlice(const T * input_data,const Dims<4> & input_dims,int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides,T * output_data,const Dims<4> & output_dims)4957 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
4958 int begin_mask, int end_mask, int shrink_axis_mask,
4959 const std::vector<int>& start_indices,
4960 const std::vector<int>& stop_indices,
4961 const std::vector<int>& strides, T* output_data,
4962 const Dims<4>& output_dims) {
4963 TFLITE_DCHECK_EQ(start_indices.size(), 4);
4964 auto op_params = strided_slice::BuildStridedSliceParams(
4965 begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
4966 strides);
4967 reference_ops::StridedSliceReverseIndices(&op_params);
4968
4969 StridedSlice(op_params, DimsToShape(input_dims), input_data,
4970 DimsToShape(output_dims), output_data);
4971 }
4972
4973 template <typename T1, typename T2, typename T3>
ArgMax(const T3 * axis,const T1 * input_data,const tflite::Dims<4> & input_dims,T2 * output_data,const tflite::Dims<4> & output_dims)4974 void ArgMax(const T3* axis, const T1* input_data,
4975 const tflite::Dims<4>& input_dims, T2* output_data,
4976 const tflite::Dims<4>& output_dims) {
4977 // Assumes the input always has 4 dimensions, and therefore,
4978 // output always has three dimensions.
4979 auto output_shape = RuntimeShape(
4980 {output_dims.sizes[2], output_dims.sizes[1], output_dims.sizes[0]});
4981 // Another way to interpret this is that output_dims.sizes[4] is always 1.
4982 TFLITE_DCHECK_EQ(output_shape.FlatSize(),
4983 DimsToShape(output_dims).FlatSize());
4984 // Legacy path only supported this.
4985 TFLITE_DCHECK_EQ(axis[0], 3);
4986 ArgMinMax(DimsToShape(input_dims), input_data, axis, output_shape,
4987 output_data, /*is_arg_max=*/true);
4988 }
4989
4990 template <typename T1, typename T2, typename T3>
ArgMinMax(const T3 * axis,const T1 * input_data,const Dims<4> & input_dims,T2 * output_data,const Dims<4> & output_dims,const bool is_arg_max)4991 void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
4992 T2* output_data, const Dims<4>& output_dims,
4993 const bool is_arg_max) {
4994 ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4995 output_data, is_arg_max);
4996 }
4997
4998 } // namespace optimized_ops
4999 } // namespace tflite
5000 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
5001