1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
17
18 #include <assert.h>
19 #include <stdint.h>
20 #include <sys/types.h>
21
22 #include <algorithm>
23 #include <cmath>
24 #include <cstdint>
25 #include <limits>
26 #include <memory>
27 #include <tuple>
28 #include <type_traits>
29
30 #include "tensorflow/lite/kernels/internal/common.h"
31 #include "tensorflow/lite/kernels/internal/compatibility.h"
32 #include "tensorflow/lite/kernels/internal/reference/add.h"
33 #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
34
35 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
36 #include <Accelerate/Accelerate.h>
37 #endif
38
39 #include "Eigen/Core"
40 #include "unsupported/Eigen/CXX11/Tensor"
41 #include "fixedpoint/fixedpoint.h"
42 #include "ruy/profiler/instrumentation.h" // from @ruy
43 #include "tensorflow/lite/c/common.h"
44 #include "tensorflow/lite/kernels/cpu_backend_context.h"
45 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
46 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
47 #include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
48 #include "tensorflow/lite/kernels/internal/cppmath.h"
49 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
50 #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
51 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops_utils.h"
52 #include "tensorflow/lite/kernels/internal/quantization_util.h"
53 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
54 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
55 #include "tensorflow/lite/kernels/internal/tensor.h"
56 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
57 #include "tensorflow/lite/kernels/internal/transpose_utils.h"
58 #include "tensorflow/lite/kernels/internal/types.h"
59
60 #if __aarch64__ && __clang__
61 #define TFLITE_SOFTMAX_USE_UINT16_LUT
62 #endif
63
64 namespace tflite {
65 namespace optimized_ops {
66
67 // Unoptimized reference ops:
68 using reference_ops::Broadcast4DSlowGreater;
69 using reference_ops::Broadcast4DSlowGreaterEqual;
70 using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
71 using reference_ops::Broadcast4DSlowGreaterWithScaling;
72 using reference_ops::Broadcast4DSlowLess;
73 using reference_ops::Broadcast4DSlowLessEqual;
74 using reference_ops::Broadcast4DSlowLessEqualWithScaling;
75 using reference_ops::Broadcast4DSlowLessWithScaling;
76 using reference_ops::BroadcastAdd4DSlow;
77 using reference_ops::BroadcastMul4DSlow;
78 using reference_ops::BroadcastSub16POTSlow;
79 using reference_ops::BroadcastSubSlow;
80 using reference_ops::Concatenation;
81 using reference_ops::ConcatenationWithScaling;
82 using reference_ops::DepthConcatenation;
83 using reference_ops::Div;
84 using reference_ops::Elu;
85 using reference_ops::FakeQuant;
86 using reference_ops::Fill;
87 using reference_ops::Gather;
88 using reference_ops::Greater;
89 using reference_ops::GreaterEqual;
90 using reference_ops::GreaterEqualWithScaling;
91 using reference_ops::GreaterWithScaling;
92 using reference_ops::LeakyRelu;
93 using reference_ops::Less;
94 using reference_ops::LessEqual;
95 using reference_ops::LessEqualWithScaling;
96 using reference_ops::LessWithScaling;
97 using reference_ops::ProcessBroadcastShapes;
98 using reference_ops::RankOneSelect;
99 using reference_ops::Relu0To1; // NOLINT
100 using reference_ops::Relu1;
101 using reference_ops::Relu6;
102 using reference_ops::ReluX;
103 using reference_ops::Round;
104 using reference_ops::Select;
105 using reference_ops::SpaceToBatchND;
106 using reference_ops::Split;
107 using reference_ops::Sub16;
108
109 // TODO(b/80247582) Remove this constant.
110 // This will be phased out as the shifts are revised with more thought. Use of a
111 // constant enables us to track progress on this work.
112 //
113 // Used to convert from old-style shifts (right) to new-style (left).
114 static constexpr int kReverseShift = -1;
115
116 // Copied from tensorflow/core/framework/tensor_types.h
117 template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
118 struct TTypes {
119 // Rank-1 tensor (vector) of scalar type T.
120 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
121 Eigen::Aligned>
122 Flat;
123 typedef Eigen::TensorMap<
124 Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>>
125 UnalignedConstMatrix;
126 };
127
128 // TODO(b/62193649): this function is only needed as long
129 // as we have the --variable_batch hack.
130 template <typename Scalar>
MapAsMatrixWithGivenNumberOfRows(Scalar * data,const RuntimeShape & shape,int rows)131 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
132 const RuntimeShape& shape,
133 int rows) {
134 const int flatsize = shape.FlatSize();
135 TFLITE_DCHECK_EQ(flatsize % rows, 0);
136 const int cols = flatsize / rows;
137 return MatrixMap<Scalar>(data, rows, cols);
138 }
139
140 template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
BinaryBroadcastFiveFold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const T * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const T * unswitched_input2_data,const RuntimeShape & output_shape,T * output_data,ElementwiseF elementwise_f,ScalarBroadcastF scalar_broadcast_f)141 inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params,
142 const RuntimeShape& unswitched_input1_shape,
143 const T* unswitched_input1_data,
144 const RuntimeShape& unswitched_input2_shape,
145 const T* unswitched_input2_data,
146 const RuntimeShape& output_shape,
147 T* output_data, ElementwiseF elementwise_f,
148 ScalarBroadcastF scalar_broadcast_f) {
149 ArithmeticParams switched_params = unswitched_params;
150 switched_params.input1_offset = unswitched_params.input2_offset;
151 switched_params.input1_multiplier = unswitched_params.input2_multiplier;
152 switched_params.input1_shift = unswitched_params.input2_shift;
153 switched_params.input2_offset = unswitched_params.input1_offset;
154 switched_params.input2_multiplier = unswitched_params.input1_multiplier;
155 switched_params.input2_shift = unswitched_params.input1_shift;
156
157 const bool use_unswitched =
158 unswitched_params.broadcast_category ==
159 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
160
161 const ArithmeticParams& params =
162 use_unswitched ? unswitched_params : switched_params;
163 const T* input1_data =
164 use_unswitched ? unswitched_input1_data : unswitched_input2_data;
165 const T* input2_data =
166 use_unswitched ? unswitched_input2_data : unswitched_input1_data;
167
168 // Fivefold nested loops. The second input resets its position for each
169 // iteration of the second loop. The first input resets its position at the
170 // beginning of the fourth loop. The innermost loop is an elementwise add of
171 // sections of the arrays.
172 T* output_data_ptr = output_data;
173 const T* input1_data_ptr = input1_data;
174 const T* input2_data_reset = input2_data;
175 // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
176 // between input shapes. y3 for input 1 is always broadcast, and so the
177 // dimension there is 1, whereas optionally y1 might be broadcast for
178 // input 2. Put another way, input1.shape.FlatSize = y0 * y1 * y2 * y4,
179 // input2.shape.FlatSize = y0 * y2 * y3 * y4.
180 int y0 = params.broadcast_shape[0];
181 int y1 = params.broadcast_shape[1];
182 int y2 = params.broadcast_shape[2];
183 int y3 = params.broadcast_shape[3];
184 int y4 = params.broadcast_shape[4];
185 if (y4 > 1) {
186 // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
187 // dimension.
188 for (int i0 = 0; i0 < y0; ++i0) {
189 const T* input2_data_ptr = nullptr;
190 for (int i1 = 0; i1 < y1; ++i1) {
191 input2_data_ptr = input2_data_reset;
192 for (int i2 = 0; i2 < y2; ++i2) {
193 for (int i3 = 0; i3 < y3; ++i3) {
194 elementwise_f(y4, params, input1_data_ptr, input2_data_ptr,
195 output_data_ptr);
196 input2_data_ptr += y4;
197 output_data_ptr += y4;
198 }
199 // We have broadcast y4 of input1 data y3 times, and now move on.
200 input1_data_ptr += y4;
201 }
202 }
203 // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
204 input2_data_reset = input2_data_ptr;
205 }
206 } else if (input1_data_ptr != nullptr) {
207 // Special case of y4 == 1, in which the innermost loop is a single
208 // element and can be combined with the next (y3) as an inner broadcast.
209 //
210 // Note that this handles the case of pure scalar broadcast when
211 // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
212 // broadcast with batch (as y2 > 1).
213 //
214 // NOTE The process is the same as the above general case except
215 // simplified for y4 == 1 and the loop over y3 is contained within the
216 // AddScalarBroadcast function.
217 for (int i0 = 0; i0 < y0; ++i0) {
218 const T* input2_data_ptr = nullptr;
219 for (int i1 = 0; i1 < y1; ++i1) {
220 input2_data_ptr = input2_data_reset;
221 for (int i2 = 0; i2 < y2; ++i2) {
222 scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr,
223 output_data_ptr);
224 input2_data_ptr += y3;
225 output_data_ptr += y3;
226 input1_data_ptr += 1;
227 }
228 }
229 input2_data_reset = input2_data_ptr;
230 }
231 }
232 }
233
234 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
235
236 // Looks up each element of <indices> in <table>, returns them in a vector.
aarch64_lookup_vector(const uint8x16x4_t table[4],uint8x16_t indices)237 inline uint8x16_t aarch64_lookup_vector(const uint8x16x4_t table[4],
238 uint8x16_t indices) {
239 // Look up in 1st quarter of the table: top 2 bits of indices == 00
240 uint8x16_t output1 = vqtbl4q_u8(table[0], indices);
241 // Look up in 2nd quarter of the table: top 2 bits of indices == 01
242 uint8x16_t output2 =
243 vqtbl4q_u8(table[1], veorq_u8(indices, vdupq_n_u8(0x40)));
244 // Look up in 3rd quarter of the table: top 2 bits of indices == 10
245 uint8x16_t output3 =
246 vqtbl4q_u8(table[2], veorq_u8(indices, vdupq_n_u8(0x80)));
247 // Look up in 4th quarter of the table: top 2 bits of indices == 11
248 uint8x16_t output4 =
249 vqtbl4q_u8(table[3], veorq_u8(indices, vdupq_n_u8(0xc0)));
250
251 // Combine result of the 4 lookups.
252 return vorrq_u8(vorrq_u8(output1, output2), vorrq_u8(output3, output4));
253 }
254
255 #endif
256
AddBiasAndEvalActivationFunction(float output_activation_min,float output_activation_max,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & array_shape,float * array_data)257 inline void AddBiasAndEvalActivationFunction(float output_activation_min,
258 float output_activation_max,
259 const RuntimeShape& bias_shape,
260 const float* bias_data,
261 const RuntimeShape& array_shape,
262 float* array_data) {
263 BiasAndClamp(output_activation_min, output_activation_max,
264 bias_shape.FlatSize(), bias_data, array_shape.FlatSize(),
265 array_data);
266 }
267
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * optional_bias_data,const RuntimeShape & output_shape,float * output_data,CpuBackendContext * cpu_backend_context)268 inline void FullyConnected(
269 const FullyConnectedParams& params, const RuntimeShape& input_shape,
270 const float* input_data, const RuntimeShape& weights_shape,
271 const float* weights_data, const RuntimeShape& bias_shape,
272 const float* optional_bias_data, const RuntimeShape& output_shape,
273 float* output_data, CpuBackendContext* cpu_backend_context) {
274 ruy::profiler::ScopeLabel label("FullyConnected");
275 const int dims_count = weights_shape.DimensionsCount();
276 const int input_rows = weights_shape.Dims(dims_count - 1);
277 cpu_backend_gemm::MatrixParams<float> rhs_params;
278 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
279 rhs_params.rows = input_rows;
280 rhs_params.cols = input_shape.FlatSize() / input_rows;
281 rhs_params.cache_policy =
282 cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
283 TFLITE_DCHECK_EQ(input_shape.FlatSize(), rhs_params.rows * rhs_params.cols);
284 cpu_backend_gemm::MatrixParams<float> lhs_params;
285 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
286 lhs_params.cols = weights_shape.Dims(dims_count - 1);
287 lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
288 lhs_params.cache_policy =
289 cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
290 cpu_backend_gemm::MatrixParams<float> dst_params;
291 dst_params.order = cpu_backend_gemm::Order::kColMajor;
292 dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
293 dst_params.cols =
294 FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
295 cpu_backend_gemm::GemmParams<float, float> gemm_params;
296 gemm_params.bias = optional_bias_data;
297 gemm_params.clamp_min = params.float_activation_min;
298 gemm_params.clamp_max = params.float_activation_max;
299 cpu_backend_gemm::Gemm(lhs_params, weights_data, rhs_params, input_data,
300 dst_params, output_data, gemm_params,
301 cpu_backend_context);
302 }
303
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,CpuBackendContext * cpu_backend_context)304 inline void FullyConnected(
305 const FullyConnectedParams& params, const RuntimeShape& input_shape,
306 const uint8* input_data, const RuntimeShape& filter_shape,
307 const uint8* filter_data, const RuntimeShape& bias_shape,
308 const int32* bias_data, const RuntimeShape& output_shape,
309 uint8* output_data, CpuBackendContext* cpu_backend_context) {
310 ruy::profiler::ScopeLabel label("FullyConnected/8bit");
311 const int32 input_offset = params.input_offset;
312 const int32 filter_offset = params.weights_offset;
313 const int32 output_offset = params.output_offset;
314 const int32 output_multiplier = params.output_multiplier;
315 const int output_shift = params.output_shift;
316 const int32 output_activation_min = params.quantized_activation_min;
317 const int32 output_activation_max = params.quantized_activation_max;
318 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
319 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
320 // TODO(b/62193649): This really should be:
321 // const int batches = ArraySize(output_dims, 1);
322 // but the current --variable_batch hack consists in overwriting the 3rd
323 // dimension with the runtime batch size, as we don't keep track for each
324 // array of which dimension is the batch dimension in it.
325 const int output_dim_count = output_shape.DimensionsCount();
326 const int filter_dim_count = filter_shape.DimensionsCount();
327 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
328 const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
329 const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
330 TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
331 const int output_rows = output_shape.Dims(output_dim_count - 1);
332 TFLITE_DCHECK_EQ(output_rows, filter_rows);
333 if (bias_data) {
334 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
335 }
336
337 cpu_backend_gemm::MatrixParams<uint8> lhs_params;
338 lhs_params.rows = filter_rows;
339 lhs_params.cols = filter_cols;
340 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
341 lhs_params.zero_point = -filter_offset;
342 lhs_params.cache_policy =
343 cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
344 cpu_backend_gemm::MatrixParams<uint8> rhs_params;
345 rhs_params.rows = filter_cols;
346 rhs_params.cols = batches;
347 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
348 rhs_params.zero_point = -input_offset;
349 rhs_params.cache_policy =
350 cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
351 cpu_backend_gemm::MatrixParams<uint8> dst_params;
352 dst_params.rows = filter_rows;
353 dst_params.cols = batches;
354 dst_params.order = cpu_backend_gemm::Order::kColMajor;
355 dst_params.zero_point = output_offset;
356 cpu_backend_gemm::GemmParams<int32, uint8> gemm_params;
357 gemm_params.bias = bias_data;
358 gemm_params.clamp_min = output_activation_min;
359 gemm_params.clamp_max = output_activation_max;
360 gemm_params.multiplier_fixedpoint = output_multiplier;
361 gemm_params.multiplier_exponent = output_shift;
362 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
363 dst_params, output_data, gemm_params,
364 cpu_backend_context);
365 }
366
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data_int32,const RuntimeShape & output_shape,int16 * output_data,CpuBackendContext * cpu_backend_context)367 inline void FullyConnected(
368 const FullyConnectedParams& params, const RuntimeShape& input_shape,
369 const uint8* input_data, const RuntimeShape& filter_shape,
370 const uint8* filter_data, const RuntimeShape& bias_shape,
371 const int32* bias_data_int32, const RuntimeShape& output_shape,
372 int16* output_data, CpuBackendContext* cpu_backend_context) {
373 ruy::profiler::ScopeLabel label("FullyConnected/Uint8Int16");
374 const int32 input_offset = params.input_offset;
375 const int32 filter_offset = params.weights_offset;
376 const int32 output_offset = params.output_offset;
377 const int32 output_multiplier = params.output_multiplier;
378 const int output_shift = params.output_shift;
379 const int32 output_activation_min = params.quantized_activation_min;
380 const int32 output_activation_max = params.quantized_activation_max;
381 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
382 TFLITE_DCHECK_EQ(output_offset, 0);
383 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
384 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
385
386 // TODO(b/62193649): This really should be:
387 // const int batches = ArraySize(output_dims, 1);
388 // but the current --variable_batch hack consists in overwriting the 3rd
389 // dimension with the runtime batch size, as we don't keep track for each
390 // array of which dimension is the batch dimension in it.
391 const int output_dim_count = output_shape.DimensionsCount();
392 const int filter_dim_count = filter_shape.DimensionsCount();
393 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
394 const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
395 output_shape, output_dim_count - 1);
396 const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
397
398 cpu_backend_gemm::MatrixParams<uint8> lhs_params;
399 lhs_params.rows = output_depth;
400 lhs_params.cols = accum_depth;
401 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
402 lhs_params.zero_point = -filter_offset;
403 lhs_params.cache_policy =
404 cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
405 cpu_backend_gemm::MatrixParams<uint8> rhs_params;
406 rhs_params.rows = accum_depth;
407 rhs_params.cols = batches;
408 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
409 rhs_params.zero_point = -input_offset;
410 rhs_params.cache_policy =
411 cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
412 cpu_backend_gemm::MatrixParams<int16> dst_params;
413 dst_params.rows = output_depth;
414 dst_params.cols = batches;
415 dst_params.order = cpu_backend_gemm::Order::kColMajor;
416 dst_params.zero_point = 0;
417 cpu_backend_gemm::GemmParams<int32, int16> gemm_params;
418 gemm_params.bias = bias_data_int32;
419 gemm_params.clamp_min = output_activation_min;
420 gemm_params.clamp_max = output_activation_max;
421 gemm_params.multiplier_fixedpoint = output_multiplier;
422 gemm_params.multiplier_exponent = output_shift;
423 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
424 dst_params, output_data, gemm_params,
425 cpu_backend_context);
426 }
427
428 // Internal function doing the actual arithmetic work for
429 // ShuffledFullyConnected.
430 // May be called either directly by it (single-threaded case) or may be used
431 // as the 'task' for worker threads to run (multi-threaded case, see
432 // ShuffledFullyConnectedWorkerTask below).
ShuffledFullyConnectedWorkerImpl(const uint8 * shuffled_input_workspace_data,const int8 * shuffled_weights_data,int batches,int output_depth,int output_stride,int accum_depth,const int32 * bias_data,int32 output_multiplier,int output_shift,int16 * output_data)433 inline void ShuffledFullyConnectedWorkerImpl(
434 const uint8* shuffled_input_workspace_data,
435 const int8* shuffled_weights_data, int batches, int output_depth,
436 int output_stride, int accum_depth, const int32* bias_data,
437 int32 output_multiplier, int output_shift, int16* output_data) {
438 #if defined USE_NEON
439 const int8* shuffled_weights_ptr = shuffled_weights_data;
440 if (batches == 1) {
441 const int right_shift = output_shift > 0 ? 0 : -output_shift;
442 const int left_shift = output_shift > 0 ? output_shift : 0;
443 for (int c = 0; c < output_depth; c += 4) {
444 // Accumulation loop.
445 int32x4_t row_accum0 = vdupq_n_s32(0);
446 int32x4_t row_accum1 = vdupq_n_s32(0);
447 int32x4_t row_accum2 = vdupq_n_s32(0);
448 int32x4_t row_accum3 = vdupq_n_s32(0);
449 for (int d = 0; d < accum_depth; d += 16) {
450 int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
451 int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
452 int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
453 int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
454 shuffled_weights_ptr += 64;
455 int8x16_t input =
456 vreinterpretq_s8_u8(vld1q_u8(shuffled_input_workspace_data + d));
457 int16x8_t local_accum0 =
458 vmull_s8(vget_low_s8(weights0), vget_low_s8(input));
459 int16x8_t local_accum1 =
460 vmull_s8(vget_low_s8(weights1), vget_low_s8(input));
461 int16x8_t local_accum2 =
462 vmull_s8(vget_low_s8(weights2), vget_low_s8(input));
463 int16x8_t local_accum3 =
464 vmull_s8(vget_low_s8(weights3), vget_low_s8(input));
465 local_accum0 =
466 vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input));
467 local_accum1 =
468 vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input));
469 local_accum2 =
470 vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input));
471 local_accum3 =
472 vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input));
473 row_accum0 = vpadalq_s16(row_accum0, local_accum0);
474 row_accum1 = vpadalq_s16(row_accum1, local_accum1);
475 row_accum2 = vpadalq_s16(row_accum2, local_accum2);
476 row_accum3 = vpadalq_s16(row_accum3, local_accum3);
477 }
478 // Horizontally reduce accumulators
479 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
480 pairwise_reduced_acc_2, pairwise_reduced_acc_3;
481 pairwise_reduced_acc_0 =
482 vpadd_s32(vget_low_s32(row_accum0), vget_high_s32(row_accum0));
483 pairwise_reduced_acc_1 =
484 vpadd_s32(vget_low_s32(row_accum1), vget_high_s32(row_accum1));
485 pairwise_reduced_acc_2 =
486 vpadd_s32(vget_low_s32(row_accum2), vget_high_s32(row_accum2));
487 pairwise_reduced_acc_3 =
488 vpadd_s32(vget_low_s32(row_accum3), vget_high_s32(row_accum3));
489 const int32x2_t reduced_lo =
490 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
491 const int32x2_t reduced_hi =
492 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
493 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
494 // Add bias values.
495 int32x4_t bias_vec = vld1q_s32(bias_data + c);
496 reduced = vaddq_s32(reduced, bias_vec);
497 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
498 // Multiply by the fixed-point multiplier.
499 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
500 // Rounding-shift-right.
501 using gemmlowp::RoundingDivideByPOT;
502 reduced = RoundingDivideByPOT(reduced, right_shift);
503 // Narrow values down to 16 bit signed.
504 const int16x4_t res16 = vqmovn_s32(reduced);
505 vst1_s16(output_data + c, res16);
506 }
507 } else if (batches == 4) {
508 const int right_shift = output_shift > 0 ? 0 : -output_shift;
509 const int left_shift = output_shift > 0 ? output_shift : 0;
510 for (int c = 0; c < output_depth; c += 4) {
511 const int8* shuffled_input_ptr =
512 reinterpret_cast<const int8*>(shuffled_input_workspace_data);
513 // Accumulation loop.
514 int32x4_t row_accum00 = vdupq_n_s32(0);
515 int32x4_t row_accum10 = vdupq_n_s32(0);
516 int32x4_t row_accum20 = vdupq_n_s32(0);
517 int32x4_t row_accum30 = vdupq_n_s32(0);
518 int32x4_t row_accum01 = vdupq_n_s32(0);
519 int32x4_t row_accum11 = vdupq_n_s32(0);
520 int32x4_t row_accum21 = vdupq_n_s32(0);
521 int32x4_t row_accum31 = vdupq_n_s32(0);
522 int32x4_t row_accum02 = vdupq_n_s32(0);
523 int32x4_t row_accum12 = vdupq_n_s32(0);
524 int32x4_t row_accum22 = vdupq_n_s32(0);
525 int32x4_t row_accum32 = vdupq_n_s32(0);
526 int32x4_t row_accum03 = vdupq_n_s32(0);
527 int32x4_t row_accum13 = vdupq_n_s32(0);
528 int32x4_t row_accum23 = vdupq_n_s32(0);
529 int32x4_t row_accum33 = vdupq_n_s32(0);
530 for (int d = 0; d < accum_depth; d += 16) {
531 int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
532 int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
533 int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
534 int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
535 shuffled_weights_ptr += 64;
536 int8x16_t input0 = vld1q_s8(shuffled_input_ptr + 0);
537 int8x16_t input1 = vld1q_s8(shuffled_input_ptr + 16);
538 int8x16_t input2 = vld1q_s8(shuffled_input_ptr + 32);
539 int8x16_t input3 = vld1q_s8(shuffled_input_ptr + 48);
540 shuffled_input_ptr += 64;
541 int16x8_t local_accum0, local_accum1, local_accum2, local_accum3;
542 #define TFLITE_SHUFFLED_FC_ACCUM(B) \
543 local_accum0 = vmull_s8(vget_low_s8(weights0), vget_low_s8(input##B)); \
544 local_accum1 = vmull_s8(vget_low_s8(weights1), vget_low_s8(input##B)); \
545 local_accum2 = vmull_s8(vget_low_s8(weights2), vget_low_s8(input##B)); \
546 local_accum3 = vmull_s8(vget_low_s8(weights3), vget_low_s8(input##B)); \
547 local_accum0 = \
548 vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input##B)); \
549 local_accum1 = \
550 vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input##B)); \
551 local_accum2 = \
552 vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input##B)); \
553 local_accum3 = \
554 vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input##B)); \
555 row_accum0##B = vpadalq_s16(row_accum0##B, local_accum0); \
556 row_accum1##B = vpadalq_s16(row_accum1##B, local_accum1); \
557 row_accum2##B = vpadalq_s16(row_accum2##B, local_accum2); \
558 row_accum3##B = vpadalq_s16(row_accum3##B, local_accum3);
559
560 TFLITE_SHUFFLED_FC_ACCUM(0)
561 TFLITE_SHUFFLED_FC_ACCUM(1)
562 TFLITE_SHUFFLED_FC_ACCUM(2)
563 TFLITE_SHUFFLED_FC_ACCUM(3)
564
565 #undef TFLITE_SHUFFLED_FC_ACCUM
566 }
567 // Horizontally reduce accumulators
568
569 #define TFLITE_SHUFFLED_FC_STORE(B) \
570 { \
571 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1, \
572 pairwise_reduced_acc_2, pairwise_reduced_acc_3; \
573 pairwise_reduced_acc_0 = \
574 vpadd_s32(vget_low_s32(row_accum0##B), vget_high_s32(row_accum0##B)); \
575 pairwise_reduced_acc_1 = \
576 vpadd_s32(vget_low_s32(row_accum1##B), vget_high_s32(row_accum1##B)); \
577 pairwise_reduced_acc_2 = \
578 vpadd_s32(vget_low_s32(row_accum2##B), vget_high_s32(row_accum2##B)); \
579 pairwise_reduced_acc_3 = \
580 vpadd_s32(vget_low_s32(row_accum3##B), vget_high_s32(row_accum3##B)); \
581 const int32x2_t reduced_lo = \
582 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); \
583 const int32x2_t reduced_hi = \
584 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); \
585 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); \
586 int32x4_t bias_vec = vld1q_s32(bias_data + c); \
587 reduced = vaddq_s32(reduced, bias_vec); \
588 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift)); \
589 reduced = vqrdmulhq_n_s32(reduced, output_multiplier); \
590 using gemmlowp::RoundingDivideByPOT; \
591 reduced = RoundingDivideByPOT(reduced, right_shift); \
592 const int16x4_t res16 = vqmovn_s32(reduced); \
593 vst1_s16(output_data + c + B * output_stride, res16); \
594 }
595
596 TFLITE_SHUFFLED_FC_STORE(0);
597 TFLITE_SHUFFLED_FC_STORE(1);
598 TFLITE_SHUFFLED_FC_STORE(2);
599 TFLITE_SHUFFLED_FC_STORE(3);
600
601 #undef TFLITE_SHUFFLED_FC_STORE
602 }
603 } else {
604 TFLITE_DCHECK(false);
605 return;
606 }
607 #else
608 if (batches == 1) {
609 int16* output_ptr = output_data;
610 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
611 // so that just reinterpreting them as int8 values is equivalent to
612 // subtracting 128 from them, thus implementing for free the subtraction of
613 // the zero_point value 128.
614 const int8* shuffled_weights_ptr =
615 reinterpret_cast<const int8*>(shuffled_weights_data);
616 // Likewise, we preshuffled and pre-xored the input data above.
617 const int8* shuffled_input_data =
618 reinterpret_cast<const int8*>(shuffled_input_workspace_data);
619 for (int c = 0; c < output_depth; c += 4) {
620 // Internal accumulation.
621 // Initialize accumulator with the bias-value.
622 int32 accum[4] = {0};
623 // Accumulation loop.
624 for (int d = 0; d < accum_depth; d += 16) {
625 for (int i = 0; i < 4; i++) {
626 for (int j = 0; j < 16; j++) {
627 int8 input_val = shuffled_input_data[d + j];
628 int8 weights_val = *shuffled_weights_ptr++;
629 accum[i] += weights_val * input_val;
630 }
631 }
632 }
633 for (int i = 0; i < 4; i++) {
634 // Add bias value
635 int acc = accum[i] + bias_data[c + i];
636 // Down-scale the final int32 accumulator to the scale used by our
637 // (16-bit, typically 3 integer bits) fixed-point format. The quantized
638 // multiplier and shift here have been pre-computed offline
639 // (e.g. by toco).
640 acc =
641 MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
642 // Saturate, cast to int16, and store to output array.
643 acc = std::max(acc, -32768);
644 acc = std::min(acc, 32767);
645 output_ptr[c + i] = acc;
646 }
647 }
648 } else if (batches == 4) {
649 int16* output_ptr = output_data;
650 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
651 // so that just reinterpreting them as int8 values is equivalent to
652 // subtracting 128 from them, thus implementing for free the subtraction of
653 // the zero_point value 128.
654 const int8* shuffled_weights_ptr =
655 reinterpret_cast<const int8*>(shuffled_weights_data);
656 // Likewise, we preshuffled and pre-xored the input data above.
657 const int8* shuffled_input_data =
658 reinterpret_cast<const int8*>(shuffled_input_workspace_data);
659 for (int c = 0; c < output_depth; c += 4) {
660 const int8* shuffled_input_ptr = shuffled_input_data;
661 // Accumulation loop.
662 // Internal accumulation.
663 // Initialize accumulator with the bias-value.
664 int32 accum[4][4];
665 for (int i = 0; i < 4; i++) {
666 for (int b = 0; b < 4; b++) {
667 accum[i][b] = 0;
668 }
669 }
670 for (int d = 0; d < accum_depth; d += 16) {
671 for (int i = 0; i < 4; i++) {
672 for (int b = 0; b < 4; b++) {
673 for (int j = 0; j < 16; j++) {
674 int8 input_val = shuffled_input_ptr[16 * b + j];
675 int8 weights_val = shuffled_weights_ptr[16 * i + j];
676 accum[i][b] += weights_val * input_val;
677 }
678 }
679 }
680 shuffled_input_ptr += 64;
681 shuffled_weights_ptr += 64;
682 }
683 for (int i = 0; i < 4; i++) {
684 for (int b = 0; b < 4; b++) {
685 // Add bias value
686 int acc = accum[i][b] + bias_data[c + i];
687 // Down-scale the final int32 accumulator to the scale used by our
688 // (16-bit, typically 3 integer bits) fixed-point format. The
689 // quantized multiplier and shift here have been pre-computed offline
690 // (e.g. by toco).
691 acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
692 output_shift);
693 // Saturate, cast to int16, and store to output array.
694 acc = std::max(acc, -32768);
695 acc = std::min(acc, 32767);
696 output_ptr[b * output_stride + c + i] = acc;
697 }
698 }
699 }
700 } else {
701 TFLITE_DCHECK(false);
702 return;
703 }
704 #endif
705 }
706
707 // Wraps ShuffledFullyConnectedWorkerImpl into a Task class
708 // to allow using gemmlowp's threadpool.
709 struct ShuffledFullyConnectedWorkerTask : cpu_backend_threadpool::Task {
ShuffledFullyConnectedWorkerTaskShuffledFullyConnectedWorkerTask710 ShuffledFullyConnectedWorkerTask(const uint8* input_data,
711 const int8* shuffled_weights_data,
712 int batches, int output_depth,
713 int output_stride, int accum_depth,
714 const int32* bias_data,
715 int32 output_multiplier, int output_shift,
716 int16* output_data)
717 : input_data_(input_data),
718 shuffled_weights_data_(shuffled_weights_data),
719 batches_(batches),
720 output_depth_(output_depth),
721 output_stride_(output_stride),
722 accum_depth_(accum_depth),
723 bias_data_(bias_data),
724 output_multiplier_(output_multiplier),
725 output_shift_(output_shift),
726 output_data_(output_data) {}
727
RunShuffledFullyConnectedWorkerTask728 void Run() override {
729 ShuffledFullyConnectedWorkerImpl(
730 input_data_, shuffled_weights_data_, batches_, output_depth_,
731 output_stride_, accum_depth_, bias_data_, output_multiplier_,
732 output_shift_, output_data_);
733 }
734
735 const uint8* input_data_;
736 const int8* shuffled_weights_data_;
737 int batches_;
738 int output_depth_;
739 int output_stride_;
740 int accum_depth_;
741 const int32* bias_data_;
742 int32 output_multiplier_;
743 int output_shift_;
744 int16* output_data_;
745 };
746
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,CpuBackendContext * cpu_backend_context)747 inline void ShuffledFullyConnected(
748 const FullyConnectedParams& params, const RuntimeShape& input_shape,
749 const uint8* input_data, const RuntimeShape& weights_shape,
750 const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
751 const int32* bias_data, const RuntimeShape& output_shape,
752 int16* output_data, uint8* shuffled_input_workspace_data,
753 CpuBackendContext* cpu_backend_context) {
754 ruy::profiler::ScopeLabel label("ShuffledFullyConnected/8bit");
755 const int32 output_multiplier = params.output_multiplier;
756 const int output_shift = params.output_shift;
757 const int32 output_activation_min = params.quantized_activation_min;
758 const int32 output_activation_max = params.quantized_activation_max;
759 TFLITE_DCHECK_EQ(output_activation_min, -32768);
760 TFLITE_DCHECK_EQ(output_activation_max, 32767);
761 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
762 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
763 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
764 // TODO(b/62193649): This really should be:
765 // const int batches = ArraySize(output_dims, 1);
766 // but the current --variable_batch hack consists in overwriting the 3rd
767 // dimension with the runtime batch size, as we don't keep track for each
768 // array of which dimension is the batch dimension in it.
769 const int output_dim_count = output_shape.DimensionsCount();
770 const int weights_dim_count = weights_shape.DimensionsCount();
771 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
772 const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
773 output_shape, output_dim_count - 1);
774 const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
775 TFLITE_DCHECK((accum_depth % 16) == 0);
776 TFLITE_DCHECK((output_depth % 4) == 0);
777 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
778 // so that just reinterpreting them as int8 values is equivalent to
779 // subtracting 128 from them, thus implementing for free the subtraction of
780 // the zero_point value 128.
781 const int8* int8_shuffled_weights_data =
782 reinterpret_cast<const int8*>(shuffled_weights_data);
783
784 // Shuffling and xoring of input activations into the workspace buffer
785 if (batches == 1) {
786 #ifdef USE_NEON
787 const uint8x16_t signbit = vdupq_n_u8(0x80);
788 for (int i = 0; i < accum_depth; i += 16) {
789 uint8x16_t val = vld1q_u8(input_data + i);
790 val = veorq_u8(val, signbit);
791 vst1q_u8(shuffled_input_workspace_data + i, val);
792 }
793 #else
794 for (int i = 0; i < accum_depth; i++) {
795 shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
796 }
797 #endif
798 } else if (batches == 4) {
799 uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
800 int c = 0;
801 #ifdef USE_NEON
802 const uint8x16_t signbit = vdupq_n_u8(0x80);
803 for (c = 0; c < accum_depth; c += 16) {
804 const uint8* src_data_ptr = input_data + c;
805 uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth);
806 uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth);
807 uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth);
808 uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth);
809 val0 = veorq_u8(val0, signbit);
810 val1 = veorq_u8(val1, signbit);
811 val2 = veorq_u8(val2, signbit);
812 val3 = veorq_u8(val3, signbit);
813 vst1q_u8(shuffled_input_workspace_ptr + 0, val0);
814 vst1q_u8(shuffled_input_workspace_ptr + 16, val1);
815 vst1q_u8(shuffled_input_workspace_ptr + 32, val2);
816 vst1q_u8(shuffled_input_workspace_ptr + 48, val3);
817 shuffled_input_workspace_ptr += 64;
818 }
819 #else
820 for (c = 0; c < accum_depth; c += 16) {
821 for (int b = 0; b < 4; b++) {
822 const uint8* src_data_ptr = input_data + b * accum_depth + c;
823 for (int j = 0; j < 16; j++) {
824 uint8 src_val = *src_data_ptr++;
825 // Flip the sign bit, so that the kernel will only need to
826 // reinterpret these uint8 values as int8, getting for free the
827 // subtraction of the zero_point value 128.
828 uint8 dst_val = src_val ^ 0x80;
829 *shuffled_input_workspace_ptr++ = dst_val;
830 }
831 }
832 }
833 #endif
834 } else {
835 TFLITE_DCHECK(false);
836 return;
837 }
838
839 static constexpr int kKernelRows = 4;
840 const int thread_count =
841 LegacyHowManyThreads<kKernelRows>(cpu_backend_context->max_num_threads(),
842 output_depth, batches, accum_depth);
843 if (thread_count == 1) {
844 // Single-thread case: do the computation on the current thread, don't
845 // use a threadpool
846 ShuffledFullyConnectedWorkerImpl(
847 shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
848 output_depth, output_depth, accum_depth, bias_data, output_multiplier,
849 output_shift, output_data);
850 return;
851 }
852
853 // Multi-threaded case: use the gemmlowp context's threadpool.
854 TFLITE_DCHECK_GT(thread_count, 1);
855 std::vector<ShuffledFullyConnectedWorkerTask> tasks;
856 // TODO(b/131746020) don't create new heap allocations every time.
857 // At least we make it a single heap allocation by using reserve().
858 tasks.reserve(thread_count);
859 const int kRowsPerWorker =
860 RoundUp<kKernelRows>(CeilQuotient(output_depth, thread_count));
861 int row_start = 0;
862 for (int i = 0; i < thread_count; i++) {
863 int row_end = std::min(output_depth, row_start + kRowsPerWorker);
864 tasks.emplace_back(shuffled_input_workspace_data,
865 int8_shuffled_weights_data + row_start * accum_depth,
866 batches, row_end - row_start, output_depth, accum_depth,
867 bias_data + row_start, output_multiplier, output_shift,
868 output_data + row_start);
869 row_start = row_end;
870 }
871 TFLITE_DCHECK_EQ(row_start, output_depth);
872 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
873 cpu_backend_context);
874 }
875
876 #ifdef USE_NEON
877
RoundToNearest(const float32x4_t input)878 inline int32x4_t RoundToNearest(const float32x4_t input) {
879 #if defined(__aarch64__) || defined(__SSSE3__)
880 // Note: vcvtnq_s32_f32 is not available in ARMv7
881 return vcvtnq_s32_f32(input);
882 #else
883 static const float32x4_t zero_val_dup = vdupq_n_f32(0.0f);
884 static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
885 static const float32x4_t minus_point5_val_dup = vdupq_n_f32(-0.5f);
886
887 const uint32x4_t mask = vcltq_f32(input, zero_val_dup);
888 const float32x4_t round =
889 vbslq_f32(mask, minus_point5_val_dup, point5_val_dup);
890 return vcvtq_s32_f32(vaddq_f32(input, round));
891 #endif // defined(__aarch64__) || defined(__SSSE3__)
892 }
893
RoundToNearestUnsigned(const float32x4_t input)894 inline uint32x4_t RoundToNearestUnsigned(const float32x4_t input) {
895 #if defined(__aarch64__)
896 // Note that vcvtnq_u32_f32 is not available in ARMv7 or in arm_neon_sse.h.
897 return vcvtnq_u32_f32(input);
898 #else
899 static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
900
901 return vcvtq_u32_f32(vaddq_f32(input, point5_val_dup));
902 #endif // defined(__aarch64__)
903 }
904
905 #endif // USE_NEON
906
Conv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data,CpuBackendContext * cpu_backend_context)907 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
908 const float* input_data, const RuntimeShape& filter_shape,
909 const float* filter_data, const RuntimeShape& bias_shape,
910 const float* bias_data, const RuntimeShape& output_shape,
911 float* output_data, const RuntimeShape& im2col_shape,
912 float* im2col_data, CpuBackendContext* cpu_backend_context) {
913 const int stride_width = params.stride_width;
914 const int stride_height = params.stride_height;
915 const int dilation_width_factor = params.dilation_width_factor;
916 const int dilation_height_factor = params.dilation_height_factor;
917 const float output_activation_min = params.float_activation_min;
918 const float output_activation_max = params.float_activation_max;
919 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
920 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
921 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
922
923 ruy::profiler::ScopeLabel label("Conv");
924
925 // NB: the float 0.0f value is represented by all zero bytes.
926 const uint8 float_zero_byte = 0x00;
927 const float* gemm_input_data = nullptr;
928 const RuntimeShape* gemm_input_shape = nullptr;
929 const int filter_width = filter_shape.Dims(2);
930 const int filter_height = filter_shape.Dims(1);
931 const bool need_dilated_im2col =
932 dilation_width_factor != 1 || dilation_height_factor != 1;
933 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
934 filter_width != 1 || filter_height != 1;
935 if (need_dilated_im2col) {
936 DilatedIm2col(params, float_zero_byte, input_shape, input_data,
937 filter_shape, output_shape, im2col_data);
938 gemm_input_data = im2col_data;
939 gemm_input_shape = &im2col_shape;
940 } else if (need_im2col) {
941 TFLITE_DCHECK(im2col_data);
942 Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
943 input_data, im2col_shape, im2col_data);
944 gemm_input_data = im2col_data;
945 gemm_input_shape = &im2col_shape;
946 } else {
947 TFLITE_DCHECK(!im2col_data);
948 gemm_input_data = input_data;
949 gemm_input_shape = &input_shape;
950 }
951
952 const int gemm_input_dims = gemm_input_shape->DimensionsCount();
953 int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
954 int n = output_shape.Dims(3);
955 int k = gemm_input_shape->Dims(gemm_input_dims - 1);
956
957 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
958 // The following code computes matrix multiplication c = a * transponse(b)
959 // with CBLAS, where:
960 // * `a` is a matrix with dimensions (m, k).
961 // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
962 // * `c` is a matrix with dimensions (m, n).
963 // The naming of variables are aligned with CBLAS specification here.
964 const float* a = gemm_input_data;
965 const float* b = filter_data;
966 float* c = output_data;
967 // The stride of matrix a, b and c respectively.
968 int stride_a = k;
969 int stride_b = k;
970 int stride_c = n;
971
972 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
973 stride_a, b, stride_b, 0.0f, c, stride_c);
974 optimized_ops::AddBiasAndEvalActivationFunction(
975 output_activation_min, output_activation_max, bias_shape, bias_data,
976 output_shape, output_data);
977 #else
978 // When an optimized CBLAS implementation is not available, fall back
979 // to using cpu_backend_gemm.
980 cpu_backend_gemm::MatrixParams<float> lhs_params;
981 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
982 lhs_params.rows = n;
983 lhs_params.cols = k;
984 cpu_backend_gemm::MatrixParams<float> rhs_params;
985 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
986 rhs_params.rows = k;
987 rhs_params.cols = m;
988 cpu_backend_gemm::MatrixParams<float> dst_params;
989 dst_params.order = cpu_backend_gemm::Order::kColMajor;
990 dst_params.rows = n;
991 dst_params.cols = m;
992 cpu_backend_gemm::GemmParams<float, float> gemm_params;
993 gemm_params.bias = bias_data;
994 gemm_params.clamp_min = output_activation_min;
995 gemm_params.clamp_max = output_activation_max;
996 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
997 dst_params, output_data, gemm_params,
998 cpu_backend_context);
999 #endif // defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
1000 }
1001
HybridConv(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & accum_scratch_shape,int32_t * accum_scratch,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data,CpuBackendContext * context)1002 inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
1003 const RuntimeShape& input_shape,
1004 const int8_t* input_data,
1005 const RuntimeShape& filter_shape,
1006 const int8_t* filter_data,
1007 const RuntimeShape& bias_shape, const float* bias_data,
1008 const RuntimeShape& accum_scratch_shape,
1009 int32_t* accum_scratch, const RuntimeShape& output_shape,
1010 float* output_data, const RuntimeShape& im2col_shape,
1011 int8_t* im2col_data, CpuBackendContext* context) {
1012 const int stride_width = params.stride_width;
1013 const int stride_height = params.stride_height;
1014 const int dilation_width_factor = params.dilation_width_factor;
1015 const int dilation_height_factor = params.dilation_height_factor;
1016 const float output_activation_min = params.float_activation_min;
1017 const float output_activation_max = params.float_activation_max;
1018 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1019 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1020 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1021
1022 const int batch_size = input_shape.Dims(0);
1023 const int filter_width = filter_shape.Dims(2);
1024 const int filter_height = filter_shape.Dims(1);
1025
1026 const int input_zero_point = 0;
1027 const int8_t* gemm_input_data = nullptr;
1028 int num_input;
1029 const bool need_dilated_im2col =
1030 dilation_width_factor != 1 || dilation_height_factor != 1;
1031 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1032 filter_width != 1 || filter_height != 1;
1033
1034 if (need_dilated_im2col) {
1035 DilatedIm2col(params, input_zero_point, input_shape, input_data,
1036 filter_shape, output_shape, im2col_data);
1037 gemm_input_data = im2col_data;
1038 num_input = im2col_shape.FlatSize();
1039 } else if (need_im2col) {
1040 TFLITE_DCHECK(im2col_data);
1041 // symmetric quantization assumes zero point of 0.
1042
1043 Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
1044 input_data, im2col_shape, im2col_data);
1045 gemm_input_data = im2col_data;
1046 num_input = im2col_shape.FlatSize();
1047 } else {
1048 TFLITE_DCHECK(!im2col_data);
1049 gemm_input_data = input_data;
1050 num_input = input_shape.FlatSize();
1051 }
1052
1053 // Flatten 4D matrices into 2D matrices for matrix multiplication.
1054
1055 // Flatten so that each filter has its own row.
1056 const int filter_rows = filter_shape.Dims(0);
1057 const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1058
1059 // In MatrixBatchVectorMultiplyAccumulate, each output value is the
1060 // dot product of one row of the first matrix with one row of the second
1061 // matrix. Therefore, the number of cols in each matrix are equivalent.
1062 //
1063 // After Im2Col, each input patch becomes a row.
1064 const int gemm_input_cols = filter_cols;
1065 const int gemm_input_rows = num_input / gemm_input_cols;
1066
1067 const int output_cols = output_shape.Dims(3);
1068 const int output_rows = FlatSizeSkipDim(output_shape, 3);
1069 TFLITE_DCHECK_EQ(output_cols, filter_rows);
1070 TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
1071 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_cols);
1072
1073 // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
1074 // input matrix has its own scale factor. This code duplicates the scale
1075 // factors for each row in the same batch.
1076 const int rows_per_batch = gemm_input_rows / batch_size;
1077 for (int i = gemm_input_rows - 1; i >= 0; --i) {
1078 scaling_factors_ptr[i] = scaling_factors_ptr[i / rows_per_batch];
1079 }
1080
1081 std::fill_n(output_data, output_rows * output_cols, 0.0f);
1082
1083 // The scratch buffer must have the same size as the output.
1084 TFLITE_DCHECK_EQ(accum_scratch_shape.FlatSize(), output_shape.FlatSize());
1085 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
1086 filter_data, filter_rows, filter_cols, gemm_input_data,
1087 scaling_factors_ptr, /*n_batch=*/gemm_input_rows, accum_scratch,
1088 output_data, context);
1089 AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
1090 bias_shape, bias_data, output_shape,
1091 output_data);
1092 }
1093
HybridConvPerChannel(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data,const float * per_channel_scale,int32_t * input_offset,const RuntimeShape & scratch_shape,int32_t * scratch,int32_t * row_sums,bool * compute_row_sums,CpuBackendContext * cpu_backend_context)1094 inline void HybridConvPerChannel(
1095 const ConvParams& params, float* scaling_factors_ptr,
1096 const RuntimeShape& input_shape, const int8_t* input_data,
1097 const RuntimeShape& filter_shape, const int8_t* filter_data,
1098 const RuntimeShape& bias_shape, const float* bias_data,
1099 const RuntimeShape& output_shape, float* output_data,
1100 const RuntimeShape& im2col_shape, int8_t* im2col_data,
1101 const float* per_channel_scale, int32_t* input_offset,
1102 const RuntimeShape& scratch_shape, int32_t* scratch, int32_t* row_sums,
1103 bool* compute_row_sums, CpuBackendContext* cpu_backend_context) {
1104 ruy::profiler::ScopeLabel label("ConvHybridPerChannel");
1105 const int stride_width = params.stride_width;
1106 const int stride_height = params.stride_height;
1107 const int dilation_width_factor = params.dilation_width_factor;
1108 const int dilation_height_factor = params.dilation_height_factor;
1109 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1110 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1111 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1112
1113 const int8* gemm_input_data = nullptr;
1114 const RuntimeShape* gemm_input_shape = nullptr;
1115 const int filter_width = filter_shape.Dims(2);
1116 const int filter_height = filter_shape.Dims(1);
1117 const bool need_dilated_im2col =
1118 dilation_width_factor != 1 || dilation_height_factor != 1;
1119 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1120 filter_width != 1 || filter_height != 1;
1121
1122 const int batch_size = input_shape.Dims(0);
1123
1124 if (need_dilated_im2col) {
1125 TFLITE_DCHECK(im2col_data);
1126 optimized_ops::DilatedIm2col(params, input_shape, input_data, filter_shape,
1127 output_shape, im2col_data, input_offset,
1128 batch_size);
1129 gemm_input_data = im2col_data;
1130 gemm_input_shape = &im2col_shape;
1131 } else if (need_im2col) {
1132 Im2col(params, filter_height, filter_width, input_offset, batch_size,
1133 input_shape, input_data, im2col_shape, im2col_data);
1134 gemm_input_data = im2col_data;
1135 gemm_input_shape = &im2col_shape;
1136 } else {
1137 TFLITE_DCHECK(!im2col_data);
1138 gemm_input_data = input_data;
1139 gemm_input_shape = &input_shape;
1140 }
1141
1142 const int filter_rows = filter_shape.Dims(0);
1143 const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1144
1145 const int gemm_input_rows = gemm_input_shape->Dims(3);
1146 const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
1147 const int output_rows = output_shape.Dims(3);
1148 const int output_cols =
1149 output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
1150
1151 TFLITE_DCHECK_EQ(output_rows, filter_rows);
1152 TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
1153 TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
1154 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1155 TFLITE_DCHECK_EQ(scratch_shape.FlatSize(), output_shape.FlatSize());
1156 if (!compute_row_sums || *compute_row_sums) {
1157 tensor_utils::ReductionSumVector(filter_data, row_sums, filter_rows,
1158 filter_cols);
1159 if (compute_row_sums) {
1160 *compute_row_sums = false;
1161 }
1162 }
1163
1164 cpu_backend_gemm::MatrixParams<int8> lhs_params;
1165 lhs_params.rows = filter_rows;
1166 lhs_params.cols = filter_cols;
1167 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1168
1169 cpu_backend_gemm::MatrixParams<int8> rhs_params;
1170 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1171 rhs_params.rows = gemm_input_rows;
1172 rhs_params.cols = gemm_input_cols;
1173
1174 cpu_backend_gemm::MatrixParams<int32> dst_params;
1175 dst_params.order = cpu_backend_gemm::Order::kColMajor;
1176 dst_params.rows = output_rows;
1177 dst_params.cols = output_cols;
1178
1179 // TODO(b/149003801): Use hybrid gemm once supported in Ruy.
1180 cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
1181 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1182 dst_params, scratch, gemm_params, cpu_backend_context);
1183
1184 MatrixMap<float> out_mat(output_data, filter_rows, output_cols);
1185 MatrixMap<int32_t> in_mat(scratch, filter_rows, output_cols);
1186 VectorMap<const float> bias_data_vec(bias_data, filter_rows, 1);
1187 VectorMap<int32_t> row_sums_vec(row_sums, filter_rows, 1);
1188 VectorMap<const float> per_channel_scale_vec(per_channel_scale, filter_rows,
1189 1);
1190 const int cols_per_batch = output_cols / batch_size;
1191 for (int c = 0; c < output_cols; c++) {
1192 const int b = c / cols_per_batch;
1193 const float input_scale = scaling_factors_ptr[b];
1194 const int32_t zero_point = input_offset[b];
1195 out_mat.col(c) =
1196 (((in_mat.col(c) - (row_sums_vec * zero_point))
1197 .cast<float>()
1198 .cwiseProduct((per_channel_scale_vec * input_scale))) +
1199 bias_data_vec)
1200 .cwiseMin(params.float_activation_max)
1201 .cwiseMax(params.float_activation_min);
1202 }
1203 }
1204
Conv(const ConvParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,const RuntimeShape & im2col_shape,uint8 * im2col_data,CpuBackendContext * cpu_backend_context)1205 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
1206 const uint8* input_data, const RuntimeShape& filter_shape,
1207 const uint8* filter_data, const RuntimeShape& bias_shape,
1208 const int32* bias_data, const RuntimeShape& output_shape,
1209 uint8* output_data, const RuntimeShape& im2col_shape,
1210 uint8* im2col_data, CpuBackendContext* cpu_backend_context) {
1211 ruy::profiler::ScopeLabel label("Conv/8bit");
1212
1213 const int stride_width = params.stride_width;
1214 const int stride_height = params.stride_height;
1215 const int dilation_width_factor = params.dilation_width_factor;
1216 const int dilation_height_factor = params.dilation_height_factor;
1217 const int32 input_offset = params.input_offset;
1218 const int32 filter_offset = params.weights_offset;
1219 const int32 output_offset = params.output_offset;
1220 const int32 output_multiplier = params.output_multiplier;
1221 const int output_shift = params.output_shift;
1222 const int32 output_activation_min = params.quantized_activation_min;
1223 const int32 output_activation_max = params.quantized_activation_max;
1224 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1225 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1226 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1227
1228 const uint8* gemm_input_data = nullptr;
1229 const RuntimeShape* gemm_input_shape = nullptr;
1230 const int filter_width = filter_shape.Dims(2);
1231 const int filter_height = filter_shape.Dims(1);
1232 const bool need_dilated_im2col =
1233 dilation_width_factor != 1 || dilation_height_factor != 1;
1234 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1235 filter_width != 1 || filter_height != 1;
1236 if (need_dilated_im2col) {
1237 TFLITE_DCHECK(im2col_data);
1238 const int input_zero_point = -input_offset;
1239 TFLITE_DCHECK_GE(input_zero_point, 0);
1240 TFLITE_DCHECK_LE(input_zero_point, 255);
1241 DilatedIm2col(params, input_zero_point, input_shape, input_data,
1242 filter_shape, output_shape, im2col_data);
1243 gemm_input_data = im2col_data;
1244 gemm_input_shape = &im2col_shape;
1245 } else if (need_im2col) {
1246 TFLITE_DCHECK(im2col_data);
1247 const int input_zero_point = -input_offset;
1248 TFLITE_DCHECK_GE(input_zero_point, 0);
1249 TFLITE_DCHECK_LE(input_zero_point, 255);
1250 Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
1251 input_data, im2col_shape, im2col_data);
1252 gemm_input_data = im2col_data;
1253 gemm_input_shape = &im2col_shape;
1254 } else {
1255 TFLITE_DCHECK(!im2col_data);
1256 gemm_input_data = input_data;
1257 gemm_input_shape = &input_shape;
1258 }
1259
1260 const int gemm_input_rows = gemm_input_shape->Dims(3);
1261 // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
1262 // The root cause has not yet been identified though. Same applies below for
1263 // the other calls commented out. This is a partial rollback of cl/196819423.
1264 // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
1265 const int gemm_input_cols = gemm_input_shape->Dims(0) *
1266 gemm_input_shape->Dims(1) *
1267 gemm_input_shape->Dims(2);
1268 const int filter_rows = filter_shape.Dims(0);
1269 // See b/79927784.
1270 // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1271 const int filter_cols =
1272 filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
1273 const int output_rows = output_shape.Dims(3);
1274 // See b/79927784.
1275 // const int output_cols = FlatSizeSkipDim(output_shape, 3);
1276 const int output_cols =
1277 output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
1278 TFLITE_DCHECK_EQ(output_rows, filter_rows);
1279 TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
1280 TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
1281 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1282
1283 cpu_backend_gemm::MatrixParams<uint8> lhs_params;
1284 lhs_params.rows = filter_rows;
1285 lhs_params.cols = filter_cols;
1286 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1287 lhs_params.zero_point = -filter_offset;
1288 cpu_backend_gemm::MatrixParams<uint8> rhs_params;
1289 rhs_params.rows = gemm_input_rows;
1290 rhs_params.cols = gemm_input_cols;
1291 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1292 rhs_params.zero_point = -input_offset;
1293 cpu_backend_gemm::MatrixParams<uint8> dst_params;
1294 dst_params.rows = output_rows;
1295 dst_params.cols = output_cols;
1296 dst_params.order = cpu_backend_gemm::Order::kColMajor;
1297 dst_params.zero_point = output_offset;
1298 cpu_backend_gemm::GemmParams<int32, uint8> gemm_params;
1299 gemm_params.bias = bias_data;
1300 gemm_params.clamp_min = output_activation_min;
1301 gemm_params.clamp_max = output_activation_max;
1302 gemm_params.multiplier_fixedpoint = output_multiplier;
1303 gemm_params.multiplier_exponent = output_shift;
1304 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1305 dst_params, output_data, gemm_params,
1306 cpu_backend_context);
1307 }
1308
1309 template <typename T>
DepthToSpace(const tflite::DepthToSpaceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)1310 inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
1311 const RuntimeShape& unextended_input_shape,
1312 const T* input_data,
1313 const RuntimeShape& unextended_output_shape,
1314 T* output_data) {
1315 ruy::profiler::ScopeLabel label("DepthToSpace");
1316
1317 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1318 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1319 const RuntimeShape input_shape =
1320 RuntimeShape::ExtendedShape(4, unextended_input_shape);
1321 const RuntimeShape output_shape =
1322 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1323
1324 const int input_depth = input_shape.Dims(3);
1325 const int input_width = input_shape.Dims(2);
1326 const int input_height = input_shape.Dims(1);
1327
1328 const int output_depth = output_shape.Dims(3);
1329 const int batch_size = output_shape.Dims(0);
1330
1331 // Number of continuous values that we can copy in one interation.
1332 const int stride = op_params.block_size * output_depth;
1333
1334 for (int batch = 0; batch < batch_size; ++batch) {
1335 for (int in_h = 0; in_h < input_height; ++in_h) {
1336 const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0);
1337 for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
1338 const T* src = input_ptr;
1339 for (int in_w = 0; in_w < input_width; ++in_w) {
1340 memcpy(output_data, src, stride * sizeof(T));
1341 output_data += stride;
1342 src += input_depth;
1343 }
1344 input_ptr += stride;
1345 }
1346 }
1347 }
1348 }
1349
1350 template <typename T>
SpaceToDepth(const tflite::SpaceToDepthParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)1351 inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
1352 const RuntimeShape& unextended_input_shape,
1353 const T* input_data,
1354 const RuntimeShape& unextended_output_shape,
1355 T* output_data) {
1356 ruy::profiler::ScopeLabel label("SpaceToDepth");
1357
1358 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1359 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1360 const RuntimeShape input_shape =
1361 RuntimeShape::ExtendedShape(4, unextended_input_shape);
1362 const RuntimeShape output_shape =
1363 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1364
1365 const int output_depth = output_shape.Dims(3);
1366 const int output_width = output_shape.Dims(2);
1367 const int output_height = output_shape.Dims(1);
1368
1369 const int input_depth = input_shape.Dims(3);
1370 const int batch_size = input_shape.Dims(0);
1371
1372 // Number of continuous values that we can copy in one interation.
1373 const int stride = op_params.block_size * input_depth;
1374
1375 for (int batch = 0; batch < batch_size; ++batch) {
1376 for (int out_h = 0; out_h < output_height; ++out_h) {
1377 T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0);
1378 for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
1379 T* dst = output_ptr;
1380 for (int out_w = 0; out_w < output_width; ++out_w) {
1381 memcpy(dst, input_data, stride * sizeof(T));
1382 input_data += stride;
1383 dst += output_depth;
1384 }
1385 output_ptr += stride;
1386 }
1387 }
1388 }
1389 }
1390
Relu(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)1391 inline void Relu(const RuntimeShape& input_shape, const float* input_data,
1392 const RuntimeShape& output_shape, float* output_data) {
1393 ruy::profiler::ScopeLabel label("Relu (not fused)");
1394
1395 const auto input = MapAsVector(input_data, input_shape);
1396 auto output = MapAsVector(output_data, output_shape);
1397 output = input.cwiseMax(0.0f);
1398 }
1399
1400 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
1401 const RuntimeShape& input_shape,
1402 const float* input_data,
1403 const RuntimeShape& output_shape,
1404 float* output_data, float epsilon = 1e-6) {
1405 ruy::profiler::ScopeLabel label("L2Normalization");
1406 const int trailing_dim = input_shape.DimensionsCount() - 1;
1407 const int outer_size =
1408 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1409 const int depth =
1410 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1411 for (int i = 0; i < outer_size; ++i) {
1412 float squared_l2_norm = 0;
1413 for (int c = 0; c < depth; ++c) {
1414 const float val = input_data[c];
1415 squared_l2_norm += val * val;
1416 }
1417 float l2_norm = std::sqrt(squared_l2_norm);
1418 l2_norm = std::max(l2_norm, epsilon);
1419 for (int c = 0; c < depth; ++c) {
1420 *output_data = *input_data / l2_norm;
1421 ++output_data;
1422 ++input_data;
1423 }
1424 }
1425 }
1426
L2Normalization(const tflite::L2NormalizationParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)1427 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
1428 const RuntimeShape& input_shape,
1429 const uint8* input_data,
1430 const RuntimeShape& output_shape,
1431 uint8* output_data) {
1432 ruy::profiler::ScopeLabel label("L2Normalization/8bit");
1433 const int trailing_dim = input_shape.DimensionsCount() - 1;
1434 const int depth =
1435 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1436 const int outer_size =
1437 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1438 const int32 input_zero_point = op_params.input_zero_point;
1439 for (int i = 0; i < outer_size; ++i) {
1440 int32 square_l2_norm = 0;
1441 for (int c = 0; c < depth; c++) {
1442 // Note that input_data advances by depth in the second pass below.
1443 int32 diff = input_data[c] - input_zero_point;
1444 square_l2_norm += diff * diff;
1445 }
1446 // TODO(b/29395854): add clamping to TOCO and TF Lite kernel
1447 // for all zero tensors in the input_data
1448 int32 inv_l2norm_multiplier;
1449 int inv_l2norm_shift;
1450 GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift,
1451 &inv_l2norm_multiplier, &inv_l2norm_shift);
1452
1453 for (int c = 0; c < depth; c++) {
1454 int32 diff = *input_data - input_zero_point;
1455 int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
1456 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
1457 int32 unclamped_output_val = 128 + rescaled_diff;
1458 int32 output_val = std::min(255, std::max(0, unclamped_output_val));
1459 *output_data = static_cast<uint8>(output_val);
1460 ++input_data;
1461 ++output_data;
1462 }
1463 }
1464 }
1465
AddElementwise(int size,const ArithmeticParams & params,const float * input1_data,const float * input2_data,float * output_data)1466 inline void AddElementwise(int size, const ArithmeticParams& params,
1467 const float* input1_data, const float* input2_data,
1468 float* output_data) {
1469 int i = 0;
1470
1471 #ifdef USE_NEON
1472 const auto activation_min = vdupq_n_f32(params.float_activation_min);
1473 const auto activation_max = vdupq_n_f32(params.float_activation_max);
1474 for (; i <= size - 16; i += 16) {
1475 auto a10 = vld1q_f32(input1_data + i);
1476 auto a11 = vld1q_f32(input1_data + i + 4);
1477 auto a12 = vld1q_f32(input1_data + i + 8);
1478 auto a13 = vld1q_f32(input1_data + i + 12);
1479 auto a20 = vld1q_f32(input2_data + i);
1480 auto a21 = vld1q_f32(input2_data + i + 4);
1481 auto a22 = vld1q_f32(input2_data + i + 8);
1482 auto a23 = vld1q_f32(input2_data + i + 12);
1483 auto x0 = vaddq_f32(a10, a20);
1484 auto x1 = vaddq_f32(a11, a21);
1485 auto x2 = vaddq_f32(a12, a22);
1486 auto x3 = vaddq_f32(a13, a23);
1487 x0 = vmaxq_f32(activation_min, x0);
1488 x1 = vmaxq_f32(activation_min, x1);
1489 x2 = vmaxq_f32(activation_min, x2);
1490 x3 = vmaxq_f32(activation_min, x3);
1491 x0 = vminq_f32(activation_max, x0);
1492 x1 = vminq_f32(activation_max, x1);
1493 x2 = vminq_f32(activation_max, x2);
1494 x3 = vminq_f32(activation_max, x3);
1495 vst1q_f32(output_data + i, x0);
1496 vst1q_f32(output_data + i + 4, x1);
1497 vst1q_f32(output_data + i + 8, x2);
1498 vst1q_f32(output_data + i + 12, x3);
1499 }
1500 for (; i <= size - 4; i += 4) {
1501 auto a1 = vld1q_f32(input1_data + i);
1502 auto a2 = vld1q_f32(input2_data + i);
1503 auto x = vaddq_f32(a1, a2);
1504 x = vmaxq_f32(activation_min, x);
1505 x = vminq_f32(activation_max, x);
1506 vst1q_f32(output_data + i, x);
1507 }
1508 #endif // NEON
1509
1510 for (; i < size; i++) {
1511 auto x = input1_data[i] + input2_data[i];
1512 output_data[i] = ActivationFunctionWithMinMax(
1513 x, params.float_activation_min, params.float_activation_max);
1514 }
1515 }
1516
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)1517 inline void Add(const ArithmeticParams& params,
1518 const RuntimeShape& input1_shape, const float* input1_data,
1519 const RuntimeShape& input2_shape, const float* input2_data,
1520 const RuntimeShape& output_shape, float* output_data) {
1521 ruy::profiler::ScopeLabel label("Add");
1522 const int flat_size =
1523 MatchingElementsSize(input1_shape, input2_shape, output_shape);
1524 AddElementwise(flat_size, params, input1_data, input2_data, output_data);
1525 }
1526
1527 // Element-wise add that can often be used for inner loop of broadcast add as
1528 // well as the non-broadcast add.
AddElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)1529 inline void AddElementwise(int size, const ArithmeticParams& params,
1530 const uint8* input1_data, const uint8* input2_data,
1531 uint8* output_data) {
1532 ruy::profiler::ScopeLabel label("AddElementwise/8bit");
1533 int i = 0;
1534 TFLITE_DCHECK_GT(params.input1_offset, -256);
1535 TFLITE_DCHECK_GT(params.input2_offset, -256);
1536 TFLITE_DCHECK_LT(params.input1_offset, 256);
1537 TFLITE_DCHECK_LT(params.input2_offset, 256);
1538 #ifdef USE_NEON
1539 const uint8x8_t output_activation_min_vector =
1540 vdup_n_u8(params.quantized_activation_min);
1541 const uint8x8_t output_activation_max_vector =
1542 vdup_n_u8(params.quantized_activation_max);
1543 for (; i <= size - 8; i += 8) {
1544 const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
1545 const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
1546 const int16x8_t input1_val_s16 =
1547 vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
1548 const int16x8_t input2_val_s16 =
1549 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
1550 const int16x8_t input1_val =
1551 vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
1552 const int16x8_t input2_val =
1553 vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
1554 const int16x4_t input1_val_high = vget_high_s16(input1_val);
1555 const int16x4_t input1_val_low = vget_low_s16(input1_val);
1556 const int16x4_t input2_val_high = vget_high_s16(input2_val);
1557 const int16x4_t input2_val_low = vget_low_s16(input2_val);
1558 int32x4_t x11 = vmovl_s16(input1_val_low);
1559 int32x4_t x12 = vmovl_s16(input1_val_high);
1560 int32x4_t x21 = vmovl_s16(input2_val_low);
1561 int32x4_t x22 = vmovl_s16(input2_val_high);
1562 const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
1563 x11 = vshlq_s32(x11, left_shift_dup);
1564 x12 = vshlq_s32(x12, left_shift_dup);
1565 x21 = vshlq_s32(x21, left_shift_dup);
1566 x22 = vshlq_s32(x22, left_shift_dup);
1567 x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
1568 x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
1569 x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
1570 x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
1571 const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
1572 const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
1573 x11 = vshlq_s32(x11, input1_shift_dup);
1574 x12 = vshlq_s32(x12, input1_shift_dup);
1575 x21 = vshlq_s32(x21, input2_shift_dup);
1576 x22 = vshlq_s32(x22, input2_shift_dup);
1577 int32x4_t s1 = vaddq_s32(x11, x21);
1578 int32x4_t s2 = vaddq_s32(x12, x22);
1579 s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
1580 s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
1581 using gemmlowp::RoundingDivideByPOT;
1582 s1 = RoundingDivideByPOT(s1, -params.output_shift);
1583 s2 = RoundingDivideByPOT(s2, -params.output_shift);
1584 const int16x4_t s1_narrowed = vmovn_s32(s1);
1585 const int16x4_t s2_narrowed = vmovn_s32(s2);
1586 const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
1587 vdupq_n_s16(params.output_offset));
1588 const uint8x8_t clamped =
1589 vmax_u8(output_activation_min_vector,
1590 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
1591 vst1_u8(output_data + i, clamped);
1592 }
1593 #endif // NEON
1594
1595 for (; i < size; ++i) {
1596 const int32 input1_val = params.input1_offset + input1_data[i];
1597 const int32 input2_val = params.input2_offset + input2_data[i];
1598 const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
1599 const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
1600 const int32 scaled_input1_val =
1601 MultiplyByQuantizedMultiplierSmallerThanOneExp(
1602 shifted_input1_val, params.input1_multiplier, params.input1_shift);
1603 const int32 scaled_input2_val =
1604 MultiplyByQuantizedMultiplierSmallerThanOneExp(
1605 shifted_input2_val, params.input2_multiplier, params.input2_shift);
1606 const int32 raw_sum = scaled_input1_val + scaled_input2_val;
1607 const int32 raw_output =
1608 MultiplyByQuantizedMultiplierSmallerThanOneExp(
1609 raw_sum, params.output_multiplier, params.output_shift) +
1610 params.output_offset;
1611 const int32 clamped_output =
1612 std::min(params.quantized_activation_max,
1613 std::max(params.quantized_activation_min, raw_output));
1614 output_data[i] = static_cast<uint8>(clamped_output);
1615 }
1616 }
1617
1618 // Scalar-broadcast add that can be used for inner loop of more general
1619 // broadcast add, so that, for example, scalar-broadcast with batch will still
1620 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,uint8 input1_data,const uint8 * input2_data,uint8 * output_data)1621 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
1622 uint8 input1_data, const uint8* input2_data,
1623 uint8* output_data) {
1624 using gemmlowp::RoundingDivideByPOT;
1625
1626 ruy::profiler::ScopeLabel label("AddScalarBroadcast/8bit");
1627 TFLITE_DCHECK_GT(params.input1_offset, -256);
1628 TFLITE_DCHECK_GT(params.input2_offset, -256);
1629 TFLITE_DCHECK_LT(params.input1_offset, 256);
1630 TFLITE_DCHECK_LT(params.input2_offset, 256);
1631
1632 int i = 0;
1633
1634 #ifdef USE_NEON
1635 const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
1636 const uint8x8_t output_activation_min_vector =
1637 vdup_n_u8(params.quantized_activation_min);
1638 const uint8x8_t output_activation_max_vector =
1639 vdup_n_u8(params.quantized_activation_max);
1640
1641 // Process broadcast scalar.
1642 const uint8x8_t input1_val_original = vdup_n_u8(input1_data);
1643 const int16x8_t input1_val_s16 =
1644 vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
1645 const int16x8_t input1_val =
1646 vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
1647 const int16x4_t input1_val_high = vget_high_s16(input1_val);
1648 const int16x4_t input1_val_low = vget_low_s16(input1_val);
1649 int32x4_t x11 = vmovl_s16(input1_val_low);
1650 int32x4_t x12 = vmovl_s16(input1_val_high);
1651 x11 = vshlq_s32(x11, left_shift_dup);
1652 x12 = vshlq_s32(x12, left_shift_dup);
1653 x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
1654 x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
1655 const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
1656 x11 = vshlq_s32(x11, input1_shift_dup);
1657 x12 = vshlq_s32(x12, input1_shift_dup);
1658
1659 for (; i <= size - 8; i += 8) {
1660 const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
1661 const int16x8_t input2_val_s16 =
1662 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
1663 const int16x8_t input2_val =
1664 vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
1665 const int16x4_t input2_val_high = vget_high_s16(input2_val);
1666 const int16x4_t input2_val_low = vget_low_s16(input2_val);
1667 int32x4_t x21 = vmovl_s16(input2_val_low);
1668 int32x4_t x22 = vmovl_s16(input2_val_high);
1669 x21 = vshlq_s32(x21, left_shift_dup);
1670 x22 = vshlq_s32(x22, left_shift_dup);
1671 x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
1672 x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
1673 const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
1674 x21 = vshlq_s32(x21, input2_shift_dup);
1675 x22 = vshlq_s32(x22, input2_shift_dup);
1676 int32x4_t s1 = vaddq_s32(x11, x21);
1677 int32x4_t s2 = vaddq_s32(x12, x22);
1678 s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
1679 s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
1680 s1 = RoundingDivideByPOT(s1, -params.output_shift);
1681 s2 = RoundingDivideByPOT(s2, -params.output_shift);
1682 const int16x4_t s1_narrowed = vmovn_s32(s1);
1683 const int16x4_t s2_narrowed = vmovn_s32(s2);
1684 const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
1685 vdupq_n_s16(params.output_offset));
1686 const uint8x8_t clamped =
1687 vmax_u8(output_activation_min_vector,
1688 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
1689 vst1_u8(output_data + i, clamped);
1690 }
1691 #endif // NEON
1692
1693 if (i < size) {
1694 // Process broadcast scalar.
1695 const int32 input1_val = params.input1_offset + input1_data;
1696 const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
1697 const int32 scaled_input1_val =
1698 MultiplyByQuantizedMultiplierSmallerThanOneExp(
1699 shifted_input1_val, params.input1_multiplier, params.input1_shift);
1700
1701 for (; i < size; ++i) {
1702 const int32 input2_val = params.input2_offset + input2_data[i];
1703 const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
1704 const int32 scaled_input2_val =
1705 MultiplyByQuantizedMultiplierSmallerThanOneExp(
1706 shifted_input2_val, params.input2_multiplier,
1707 params.input2_shift);
1708 const int32 raw_sum = scaled_input1_val + scaled_input2_val;
1709 const int32 raw_output =
1710 MultiplyByQuantizedMultiplierSmallerThanOneExp(
1711 raw_sum, params.output_multiplier, params.output_shift) +
1712 params.output_offset;
1713 const int32 clamped_output =
1714 std::min(params.quantized_activation_max,
1715 std::max(params.quantized_activation_min, raw_output));
1716 output_data[i] = static_cast<uint8>(clamped_output);
1717 }
1718 }
1719 }
1720
1721 // Scalar-broadcast add that can be used for inner loop of more general
1722 // broadcast add, so that, for example, scalar-broadcast with batch will still
1723 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,float broadcast_value,const float * input2_data,float * output_data)1724 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
1725 float broadcast_value, const float* input2_data,
1726 float* output_data) {
1727 int i = 0;
1728 #ifdef USE_NEON
1729 const float32x4_t output_activation_min_vector =
1730 vdupq_n_f32(params.float_activation_min);
1731 const float32x4_t output_activation_max_vector =
1732 vdupq_n_f32(params.float_activation_max);
1733 const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
1734 for (; i <= size - 4; i += 4) {
1735 const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
1736
1737 const float32x4_t output =
1738 vaddq_f32(input2_val_original, broadcast_value_dup);
1739
1740 const float32x4_t clamped =
1741 vmaxq_f32(output_activation_min_vector,
1742 vminq_f32(output_activation_max_vector, output));
1743 vst1q_f32(output_data + i, clamped);
1744 }
1745 #endif // NEON
1746
1747 for (; i < size; ++i) {
1748 auto x = broadcast_value + input2_data[i];
1749 output_data[i] = ActivationFunctionWithMinMax(
1750 x, params.float_activation_min, params.float_activation_max);
1751 }
1752 }
1753
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)1754 inline void Add(const ArithmeticParams& params,
1755 const RuntimeShape& input1_shape, const uint8* input1_data,
1756 const RuntimeShape& input2_shape, const uint8* input2_data,
1757 const RuntimeShape& output_shape, uint8* output_data) {
1758 TFLITE_DCHECK_LE(params.quantized_activation_min,
1759 params.quantized_activation_max);
1760 ruy::profiler::ScopeLabel label("Add/8bit");
1761 const int flat_size =
1762 MatchingElementsSize(input1_shape, input2_shape, output_shape);
1763
1764 TFLITE_DCHECK_GT(params.input1_offset, -256);
1765 TFLITE_DCHECK_GT(params.input2_offset, -256);
1766 TFLITE_DCHECK_LT(params.input1_offset, 256);
1767 TFLITE_DCHECK_LT(params.input2_offset, 256);
1768 AddElementwise(flat_size, params, input1_data, input2_data, output_data);
1769 }
1770
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)1771 inline void Add(const ArithmeticParams& params,
1772 const RuntimeShape& input1_shape, const int16* input1_data,
1773 const RuntimeShape& input2_shape, const int16* input2_data,
1774 const RuntimeShape& output_shape, int16* output_data) {
1775 ruy::profiler::ScopeLabel label("Add/Int16");
1776 TFLITE_DCHECK_LE(params.quantized_activation_min,
1777 params.quantized_activation_max);
1778
1779 const int input1_shift = params.input1_shift;
1780 const int flat_size =
1781 MatchingElementsSize(input1_shape, input2_shape, output_shape);
1782 const int16 output_activation_min = params.quantized_activation_min;
1783 const int16 output_activation_max = params.quantized_activation_max;
1784
1785 TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
1786 TFLITE_DCHECK_LE(input1_shift, 0);
1787 TFLITE_DCHECK_LE(params.input2_shift, 0);
1788 const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
1789 const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
1790 const int input_right_shift =
1791 input1_shift == 0 ? -params.input2_shift : -input1_shift;
1792
1793 for (int i = 0; i < flat_size; i++) {
1794 // F0 uses 0 integer bits, range [-1, 1].
1795 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
1796
1797 F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
1798 F0 scaled_input = F0::FromRaw(
1799 gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
1800 F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
1801 const int16 raw_output = result.raw();
1802 const int16 clamped_output = std::min(
1803 output_activation_max, std::max(output_activation_min, raw_output));
1804 output_data[i] = clamped_output;
1805 }
1806 }
1807
1808 template <typename T>
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1809 inline typename std::enable_if<is_int32_or_int64<T>::value, void>::type Add(
1810 const ArithmeticParams& params, const RuntimeShape& input1_shape,
1811 const T* input1_data, const RuntimeShape& input2_shape,
1812 const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
1813 ruy::profiler::ScopeLabel label("Add/int32or64");
1814
1815 T activation_min, activation_max;
1816 GetActivationParams(params, &activation_min, &activation_max);
1817
1818 auto input1_map = MapAsVector(input1_data, input1_shape);
1819 auto input2_map = MapAsVector(input2_data, input2_shape);
1820 auto output_map = MapAsVector(output_data, output_shape);
1821 if (input1_shape == input2_shape) {
1822 output_map.array() = (input1_map.array() + input2_map.array())
1823 .cwiseMax(activation_min)
1824 .cwiseMin(activation_max);
1825 } else if (input2_shape.FlatSize() == 1) {
1826 auto scalar = input2_data[0];
1827 output_map.array() = (input1_map.array() + scalar)
1828 .cwiseMax(activation_min)
1829 .cwiseMin(activation_max);
1830 } else if (input1_shape.FlatSize() == 1) {
1831 auto scalar = input1_data[0];
1832 output_map.array() = (scalar + input2_map.array())
1833 .cwiseMax(activation_min)
1834 .cwiseMin(activation_max);
1835 } else {
1836 reference_ops::BroadcastAdd4DSlow<T>(params, input1_shape, input1_data,
1837 input2_shape, input2_data,
1838 output_shape, output_data);
1839 }
1840 }
1841
1842 template <typename T>
BroadcastAddDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1843 inline void BroadcastAddDispatch(
1844 const ArithmeticParams& params, const RuntimeShape& input1_shape,
1845 const T* input1_data, const RuntimeShape& input2_shape,
1846 const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
1847 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
1848 return BroadcastAdd4DSlow(params, input1_shape, input1_data, input2_shape,
1849 input2_data, output_shape, output_data);
1850 }
1851
1852 BinaryBroadcastFiveFold(
1853 params, input1_shape, input1_data, input2_shape, input2_data,
1854 output_shape, output_data,
1855 static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
1856 T*)>(AddElementwise),
1857 static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
1858 AddScalarBroadcast));
1859 }
1860
BroadcastAddFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)1861 inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
1862 const RuntimeShape& unswitched_input1_shape,
1863 const uint8* unswitched_input1_data,
1864 const RuntimeShape& unswitched_input2_shape,
1865 const uint8* unswitched_input2_data,
1866 const RuntimeShape& output_shape,
1867 uint8* output_data) {
1868 BroadcastAddDispatch(unswitched_params, unswitched_input1_shape,
1869 unswitched_input1_data, unswitched_input2_shape,
1870 unswitched_input2_data, output_shape, output_data);
1871 }
1872
BroadcastAddFivefold(const ArithmeticParams & params,const RuntimeShape & unswitched_input1_shape,const float * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const float * unswitched_input2_data,const RuntimeShape & output_shape,float * output_data)1873 inline void BroadcastAddFivefold(const ArithmeticParams& params,
1874 const RuntimeShape& unswitched_input1_shape,
1875 const float* unswitched_input1_data,
1876 const RuntimeShape& unswitched_input2_shape,
1877 const float* unswitched_input2_data,
1878 const RuntimeShape& output_shape,
1879 float* output_data) {
1880 BroadcastAddDispatch(params, unswitched_input1_shape, unswitched_input1_data,
1881 unswitched_input2_shape, unswitched_input2_data,
1882 output_shape, output_data);
1883 }
1884
MulElementwise(int size,const ArithmeticParams & params,const float * input1_data,const float * input2_data,float * output_data)1885 inline void MulElementwise(int size, const ArithmeticParams& params,
1886 const float* input1_data, const float* input2_data,
1887 float* output_data) {
1888 const float output_activation_min = params.float_activation_min;
1889 const float output_activation_max = params.float_activation_max;
1890
1891 int i = 0;
1892 #ifdef USE_NEON
1893 const auto activation_min = vdupq_n_f32(output_activation_min);
1894 const auto activation_max = vdupq_n_f32(output_activation_max);
1895 for (; i <= size - 16; i += 16) {
1896 auto a10 = vld1q_f32(input1_data + i);
1897 auto a11 = vld1q_f32(input1_data + i + 4);
1898 auto a12 = vld1q_f32(input1_data + i + 8);
1899 auto a13 = vld1q_f32(input1_data + i + 12);
1900 auto a20 = vld1q_f32(input2_data + i);
1901 auto a21 = vld1q_f32(input2_data + i + 4);
1902 auto a22 = vld1q_f32(input2_data + i + 8);
1903 auto a23 = vld1q_f32(input2_data + i + 12);
1904 auto x0 = vmulq_f32(a10, a20);
1905 auto x1 = vmulq_f32(a11, a21);
1906 auto x2 = vmulq_f32(a12, a22);
1907 auto x3 = vmulq_f32(a13, a23);
1908
1909 x0 = vmaxq_f32(activation_min, x0);
1910 x1 = vmaxq_f32(activation_min, x1);
1911 x2 = vmaxq_f32(activation_min, x2);
1912 x3 = vmaxq_f32(activation_min, x3);
1913 x0 = vminq_f32(activation_max, x0);
1914 x1 = vminq_f32(activation_max, x1);
1915 x2 = vminq_f32(activation_max, x2);
1916 x3 = vminq_f32(activation_max, x3);
1917
1918 vst1q_f32(output_data + i, x0);
1919 vst1q_f32(output_data + i + 4, x1);
1920 vst1q_f32(output_data + i + 8, x2);
1921 vst1q_f32(output_data + i + 12, x3);
1922 }
1923 for (; i <= size - 4; i += 4) {
1924 auto a1 = vld1q_f32(input1_data + i);
1925 auto a2 = vld1q_f32(input2_data + i);
1926 auto x = vmulq_f32(a1, a2);
1927
1928 x = vmaxq_f32(activation_min, x);
1929 x = vminq_f32(activation_max, x);
1930
1931 vst1q_f32(output_data + i, x);
1932 }
1933 #endif // NEON
1934
1935 for (; i < size; i++) {
1936 auto x = input1_data[i] * input2_data[i];
1937 output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
1938 output_activation_max);
1939 }
1940 }
1941
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)1942 inline void Mul(const ArithmeticParams& params,
1943 const RuntimeShape& input1_shape, const float* input1_data,
1944 const RuntimeShape& input2_shape, const float* input2_data,
1945 const RuntimeShape& output_shape, float* output_data) {
1946 ruy::profiler::ScopeLabel label("Mul");
1947
1948 const int flat_size =
1949 MatchingElementsSize(input1_shape, input2_shape, output_shape);
1950 MulElementwise(flat_size, params, input1_data, input2_data, output_data);
1951 }
1952
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)1953 inline void Mul(const ArithmeticParams& params,
1954 const RuntimeShape& input1_shape, const int32* input1_data,
1955 const RuntimeShape& input2_shape, const int32* input2_data,
1956 const RuntimeShape& output_shape, int32* output_data) {
1957 ruy::profiler::ScopeLabel label("Mul/int32/activation");
1958
1959 const int flat_size =
1960 MatchingElementsSize(input1_shape, input2_shape, output_shape);
1961 const int32 output_activation_min = params.quantized_activation_min;
1962 const int32 output_activation_max = params.quantized_activation_max;
1963 for (int i = 0; i < flat_size; ++i) {
1964 output_data[i] = ActivationFunctionWithMinMax(
1965 input1_data[i] * input2_data[i], output_activation_min,
1966 output_activation_max);
1967 }
1968 }
1969
MulNoActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)1970 inline void MulNoActivation(const ArithmeticParams& params,
1971 const RuntimeShape& input1_shape,
1972 const int32* input1_data,
1973 const RuntimeShape& input2_shape,
1974 const int32* input2_data,
1975 const RuntimeShape& output_shape,
1976 int32* output_data) {
1977 ruy::profiler::ScopeLabel label("Mul/int32");
1978
1979 auto input1_map = MapAsVector(input1_data, input1_shape);
1980 auto input2_map = MapAsVector(input2_data, input2_shape);
1981 auto output_map = MapAsVector(output_data, output_shape);
1982 if (input1_shape == input2_shape) {
1983 output_map.array() = input1_map.array() * input2_map.array();
1984 } else if (input2_shape.FlatSize() == 1) {
1985 auto scalar = input2_data[0];
1986 output_map.array() = input1_map.array() * scalar;
1987 } else if (input1_shape.FlatSize() == 1) {
1988 auto scalar = input1_data[0];
1989 output_map.array() = scalar * input2_map.array();
1990 } else {
1991 reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data,
1992 input2_shape, input2_data, output_shape,
1993 output_data);
1994 }
1995 }
1996
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)1997 inline void Mul(const ArithmeticParams& params,
1998 const RuntimeShape& input1_shape, const int16* input1_data,
1999 const RuntimeShape& input2_shape, const int16* input2_data,
2000 const RuntimeShape& output_shape, int16* output_data) {
2001 ruy::profiler::ScopeLabel label("Mul/Int16/NoActivation");
2002 // This is a copy of the reference implementation. We do not currently have a
2003 // properly optimized version.
2004
2005 const int flat_size =
2006 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2007
2008 for (int i = 0; i < flat_size; i++) {
2009 // F0 uses 0 integer bits, range [-1, 1].
2010 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2011
2012 F0 unclamped_result =
2013 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
2014 output_data[i] = unclamped_result.raw();
2015 }
2016 }
2017
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2018 inline void Mul(const ArithmeticParams& params,
2019 const RuntimeShape& input1_shape, const int16* input1_data,
2020 const RuntimeShape& input2_shape, const int16* input2_data,
2021 const RuntimeShape& output_shape, uint8* output_data) {
2022 ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
2023 // This is a copy of the reference implementation. We do not currently have a
2024 // properly optimized version.
2025 const int32 output_activation_min = params.quantized_activation_min;
2026 const int32 output_activation_max = params.quantized_activation_max;
2027 const int32 output_offset = params.output_offset;
2028 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
2029
2030 const int flat_size =
2031 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2032
2033 for (int i = 0; i < flat_size; i++) {
2034 // F0 uses 0 integer bits, range [-1, 1].
2035 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2036
2037 F0 unclamped_result =
2038 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
2039 int16 rescaled_result =
2040 gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
2041 int16 clamped_result =
2042 std::min<int16>(output_activation_max - output_offset, rescaled_result);
2043 clamped_result =
2044 std::max<int16>(output_activation_min - output_offset, clamped_result);
2045 output_data[i] = output_offset + clamped_result;
2046 }
2047 }
2048
2049 // Element-wise mul that can often be used for inner loop of broadcast Mul as
2050 // well as the non-broadcast Mul.
MulElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)2051 inline void MulElementwise(int size, const ArithmeticParams& params,
2052 const uint8* input1_data, const uint8* input2_data,
2053 uint8* output_data) {
2054 int i = 0;
2055 TFLITE_DCHECK_GT(params.input1_offset, -256);
2056 TFLITE_DCHECK_LT(params.input1_offset, 256);
2057 TFLITE_DCHECK_GT(params.input2_offset, -256);
2058 TFLITE_DCHECK_LT(params.input2_offset, 256);
2059 TFLITE_DCHECK_GT(params.output_offset, -256);
2060 TFLITE_DCHECK_LT(params.output_offset, 256);
2061 #ifdef USE_NEON
2062 const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
2063 const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
2064 const auto output_offset_vector = vdupq_n_s16(params.output_offset);
2065 const auto output_activation_min_vector =
2066 vdup_n_u8(params.quantized_activation_min);
2067 const auto output_activation_max_vector =
2068 vdup_n_u8(params.quantized_activation_max);
2069 const int left_shift = std::max(0, params.output_shift);
2070 const int right_shift = std::max(0, -params.output_shift);
2071 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
2072 for (; i <= size - 8; i += 8) {
2073 // We load / store 8 at a time, multiplying as two sets of 4 int32s.
2074 const auto input1_val_original = vld1_u8(input1_data + i);
2075 const auto input2_val_original = vld1_u8(input2_data + i);
2076 const auto input1_val_s16 =
2077 vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
2078 const auto input2_val_s16 =
2079 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2080 const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
2081 const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
2082
2083 const auto input1_val_low = vget_low_s16(input1_val);
2084 const auto input1_val_high = vget_high_s16(input1_val);
2085 const auto input2_val_low = vget_low_s16(input2_val);
2086 const auto input2_val_high = vget_high_s16(input2_val);
2087
2088 auto p1 = vmull_s16(input2_val_low, input1_val_low);
2089 auto p2 = vmull_s16(input2_val_high, input1_val_high);
2090
2091 p1 = vshlq_s32(p1, left_shift_vec);
2092 p2 = vshlq_s32(p2, left_shift_vec);
2093 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
2094 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
2095 using gemmlowp::RoundingDivideByPOT;
2096 p1 = RoundingDivideByPOT(p1, right_shift);
2097 p2 = RoundingDivideByPOT(p2, right_shift);
2098
2099 const auto p1_narrowed = vqmovn_s32(p1);
2100 const auto p2_narrowed = vqmovn_s32(p2);
2101 const auto p =
2102 vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
2103 const auto clamped =
2104 vmax_u8(output_activation_min_vector,
2105 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
2106 vst1_u8(output_data + i, clamped);
2107 }
2108 #endif // NEON
2109
2110 for (; i < size; ++i) {
2111 const int32 input1_val = params.input1_offset + input1_data[i];
2112 const int32 input2_val = params.input2_offset + input2_data[i];
2113 const int32 unclamped_result =
2114 params.output_offset +
2115 MultiplyByQuantizedMultiplier(input1_val * input2_val,
2116 params.output_multiplier,
2117 params.output_shift);
2118 const int32 clamped_output =
2119 std::min(params.quantized_activation_max,
2120 std::max(params.quantized_activation_min, unclamped_result));
2121 output_data[i] = static_cast<uint8>(clamped_output);
2122 }
2123 }
2124
2125 // Broadcast mul that can often be used for inner loop of broadcast Mul.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const uint8 broadcast_value,const uint8 * input2_data,uint8 * output_data)2126 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
2127 const uint8 broadcast_value,
2128 const uint8* input2_data, uint8* output_data) {
2129 const int16 input1_val = params.input1_offset + broadcast_value;
2130
2131 int i = 0;
2132 TFLITE_DCHECK_GT(params.input1_offset, -256);
2133 TFLITE_DCHECK_LT(params.input1_offset, 256);
2134 TFLITE_DCHECK_GT(params.input2_offset, -256);
2135 TFLITE_DCHECK_LT(params.input2_offset, 256);
2136 TFLITE_DCHECK_GT(params.output_offset, -256);
2137 TFLITE_DCHECK_LT(params.output_offset, 256);
2138 #ifdef USE_NEON
2139 const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
2140 const auto output_offset_vector = vdupq_n_s16(params.output_offset);
2141 const auto output_activation_min_vector =
2142 vdup_n_u8(params.quantized_activation_min);
2143 const auto output_activation_max_vector =
2144 vdup_n_u8(params.quantized_activation_max);
2145 const int left_shift = std::max(0, params.output_shift);
2146 const int right_shift = std::max(0, -params.output_shift);
2147 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
2148 for (; i <= size - 8; i += 8) {
2149 // We load / store 8 at a time, multiplying as two sets of 4 int32s.
2150 const auto input2_val_original = vld1_u8(input2_data + i);
2151 const auto input2_val_s16 =
2152 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2153 const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
2154
2155 const auto input2_val_low = vget_low_s16(input2_val);
2156 const auto input2_val_high = vget_high_s16(input2_val);
2157
2158 auto p1 = vmull_n_s16(input2_val_low, input1_val);
2159 auto p2 = vmull_n_s16(input2_val_high, input1_val);
2160
2161 p1 = vshlq_s32(p1, left_shift_vec);
2162 p2 = vshlq_s32(p2, left_shift_vec);
2163 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
2164 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
2165 using gemmlowp::RoundingDivideByPOT;
2166 p1 = RoundingDivideByPOT(p1, right_shift);
2167 p2 = RoundingDivideByPOT(p2, right_shift);
2168
2169 const auto p1_narrowed = vmovn_s32(p1);
2170 const auto p2_narrowed = vmovn_s32(p2);
2171 const auto p =
2172 vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
2173 const auto clamped =
2174 vmax_u8(output_activation_min_vector,
2175 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
2176 vst1_u8(output_data + i, clamped);
2177 }
2178 #endif // NEON
2179
2180 for (; i < size; ++i) {
2181 const int32 input2_val = params.input2_offset + input2_data[i];
2182 const int32 unclamped_result =
2183 params.output_offset +
2184 MultiplyByQuantizedMultiplier(input1_val * input2_val,
2185 params.output_multiplier,
2186 params.output_shift);
2187 const int32 clamped_output =
2188 std::min(params.quantized_activation_max,
2189 std::max(params.quantized_activation_min, unclamped_result));
2190 output_data[i] = static_cast<uint8>(clamped_output);
2191 }
2192 }
2193
2194 // Broadcast mul that can often be used for inner loop of broadcast Mul.
2195 // This function will handle scalar_value (LHS) * vector_values (RHS).
2196 // Since it's a float function, input params does not matter here.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const float broadcast_value,const float * input2_data,float * output_data)2197 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
2198 const float broadcast_value,
2199 const float* input2_data, float* output_data) {
2200 int i = 0;
2201 #ifdef USE_NEON
2202 const float32x4_t output_activation_min_vector =
2203 vdupq_n_f32(params.float_activation_min);
2204 const float32x4_t output_activation_max_vector =
2205 vdupq_n_f32(params.float_activation_max);
2206 const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
2207 for (; i <= size - 4; i += 4) {
2208 const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
2209
2210 const float32x4_t output =
2211 vmulq_f32(input2_val_original, broadcast_value_dup);
2212
2213 const float32x4_t clamped =
2214 vmaxq_f32(output_activation_min_vector,
2215 vminq_f32(output_activation_max_vector, output));
2216 vst1q_f32(output_data + i, clamped);
2217 }
2218 #endif // NEON
2219
2220 for (; i < size; ++i) {
2221 float x = broadcast_value * input2_data[i];
2222 output_data[i] = ActivationFunctionWithMinMax(
2223 x, params.float_activation_min, params.float_activation_max);
2224 }
2225 }
2226
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2227 inline void Mul(const ArithmeticParams& params,
2228 const RuntimeShape& input1_shape, const uint8* input1_data,
2229 const RuntimeShape& input2_shape, const uint8* input2_data,
2230 const RuntimeShape& output_shape, uint8* output_data) {
2231 TFLITE_DCHECK_LE(params.quantized_activation_min,
2232 params.quantized_activation_max);
2233 ruy::profiler::ScopeLabel label("Mul/8bit");
2234 const int flat_size =
2235 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2236
2237 MulElementwise(flat_size, params, input1_data, input2_data, output_data);
2238 }
2239
2240 template <typename T>
BroadcastMulDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2241 inline void BroadcastMulDispatch(
2242 const ArithmeticParams& params, const RuntimeShape& input1_shape,
2243 const T* input1_data, const RuntimeShape& input2_shape,
2244 const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2245 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
2246 return BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
2247 input2_data, output_shape, output_data);
2248 }
2249
2250 BinaryBroadcastFiveFold(
2251 params, input1_shape, input1_data, input2_shape, input2_data,
2252 output_shape, output_data,
2253 static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
2254 T*)>(MulElementwise),
2255 static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
2256 MulSimpleBroadcast));
2257 }
2258
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)2259 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
2260 const RuntimeShape& unswitched_input1_shape,
2261 const uint8* unswitched_input1_data,
2262 const RuntimeShape& unswitched_input2_shape,
2263 const uint8* unswitched_input2_data,
2264 const RuntimeShape& output_shape,
2265 uint8* output_data) {
2266 BroadcastMulDispatch(unswitched_params, unswitched_input1_shape,
2267 unswitched_input1_data, unswitched_input2_shape,
2268 unswitched_input2_data, output_shape, output_data);
2269 }
2270
BroadcastMulFivefold(const ArithmeticParams & params,const RuntimeShape & unswitched_input1_shape,const float * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const float * unswitched_input2_data,const RuntimeShape & output_shape,float * output_data)2271 inline void BroadcastMulFivefold(const ArithmeticParams& params,
2272 const RuntimeShape& unswitched_input1_shape,
2273 const float* unswitched_input1_data,
2274 const RuntimeShape& unswitched_input2_shape,
2275 const float* unswitched_input2_data,
2276 const RuntimeShape& output_shape,
2277 float* output_data) {
2278 BroadcastMulDispatch(params, unswitched_input1_shape, unswitched_input1_data,
2279 unswitched_input2_shape, unswitched_input2_data,
2280 output_shape, output_data);
2281 }
2282
2283 // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
2284 // dimensionality if the runtime code does a single loop over one dimension
2285 // that handles broadcasting as the base case. The code generator would then
2286 // generate max(D1, D2) nested for loops.
2287 // TODO(benoitjacob): BroadcastDiv is intentionally duplicated from
2288 // reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
2289 // is no longer referenced in this file, move NdArrayDesc<T> from types.h to
2290 // reference_ops.h.
2291 template <typename T, int N = 5>
BroadcastDivSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)2292 void BroadcastDivSlow(const ArithmeticParams& params,
2293 const RuntimeShape& unextended_input1_shape,
2294 const T* input1_data,
2295 const RuntimeShape& unextended_input2_shape,
2296 const T* input2_data,
2297 const RuntimeShape& unextended_output_shape,
2298 T* output_data) {
2299 ruy::profiler::ScopeLabel label("BroadcastDivSlow");
2300 T output_activation_min;
2301 T output_activation_max;
2302 GetActivationParams(params, &output_activation_min, &output_activation_max);
2303
2304 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
2305 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
2306 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
2307
2308 NdArrayDesc<N> desc1;
2309 NdArrayDesc<N> desc2;
2310 NdArrayDesc<N> output_desc;
2311 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
2312 unextended_input2_shape, &desc1, &desc2);
2313 CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
2314 &output_desc);
2315
2316 // In Tensorflow, the dimensions are canonically named (batch_number, row,
2317 // col, channel), with extents (batches, height, width, depth), with the
2318 // trailing dimension changing most rapidly (channels has the smallest stride,
2319 // typically 1 element).
2320 //
2321 // In generated C code, we store arrays with the dimensions reversed. The
2322 // first dimension has smallest stride.
2323 //
2324 // We name our variables by their Tensorflow convention, but generate C code
2325 // nesting loops such that the innermost loop has the smallest stride for the
2326 // best cache behavior.
2327 auto div_func = [&](int indexes[N]) {
2328 output_data[SubscriptToIndex(output_desc, indexes)] =
2329 ActivationFunctionWithMinMax(
2330 input1_data[SubscriptToIndex(desc1, indexes)] /
2331 input2_data[SubscriptToIndex(desc2, indexes)],
2332 output_activation_min, output_activation_max);
2333 };
2334 NDOpsHelper<N>(output_desc, div_func);
2335 }
2336
2337 // BroadcastDiv is intentionally duplicated from reference_ops.h.
2338 // For more details see the comment above the generic version of
2339 // BroadcastDivSlow.
2340 template <int N = 5>
BroadcastDivSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const uint8 * input1_data,const RuntimeShape & unextended_input2_shape,const uint8 * input2_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)2341 inline void BroadcastDivSlow(const ArithmeticParams& params,
2342 const RuntimeShape& unextended_input1_shape,
2343 const uint8* input1_data,
2344 const RuntimeShape& unextended_input2_shape,
2345 const uint8* input2_data,
2346 const RuntimeShape& unextended_output_shape,
2347 uint8* output_data) {
2348 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
2349 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
2350 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
2351
2352 NdArrayDesc<N> desc1;
2353 NdArrayDesc<N> desc2;
2354 NdArrayDesc<N> output_desc;
2355 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
2356 unextended_input2_shape, &desc1, &desc2);
2357 CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
2358 &output_desc);
2359
2360 TFLITE_DCHECK_GT(params.input1_offset, -256);
2361 TFLITE_DCHECK_LT(params.input1_offset, 256);
2362 TFLITE_DCHECK_GT(params.input2_offset, -256);
2363 TFLITE_DCHECK_LT(params.input2_offset, 256);
2364 TFLITE_DCHECK_GT(params.output_offset, -256);
2365 TFLITE_DCHECK_LT(params.output_offset, 256);
2366
2367 auto div_func = [&](int indexes[N]) {
2368 int32 input1_val =
2369 params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)];
2370 int32 input2_val =
2371 params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)];
2372 TFLITE_DCHECK_NE(input2_val, 0);
2373 if (input2_val < 0) {
2374 // Invert signs to avoid a negative input2_val as input2_inv needs to be
2375 // positive to be used as multiplier of MultiplyByQuantizedMultiplier.
2376 input1_val = -input1_val;
2377 input2_val = -input2_val;
2378 }
2379 int recip_shift;
2380 const int32 input2_inv = GetReciprocal(input2_val, 31, &recip_shift);
2381 const int headroom = CountLeadingSignBits(input1_val);
2382 const int32 unscaled_quotient = MultiplyByQuantizedMultiplierGreaterThanOne(
2383 input1_val, input2_inv, headroom);
2384 const int total_shift = params.output_shift - recip_shift - headroom;
2385 const int32 unclamped_result =
2386 params.output_offset +
2387 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2388 unscaled_quotient, params.output_multiplier, total_shift);
2389 const int32 clamped_output =
2390 std::min(params.quantized_activation_max,
2391 std::max(params.quantized_activation_min, unclamped_result));
2392 output_data[SubscriptToIndex(output_desc, indexes)] =
2393 static_cast<uint8>(clamped_output);
2394 };
2395 NDOpsHelper<N>(output_desc, div_func);
2396 }
2397
2398 template <typename T>
SubWithActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2399 inline void SubWithActivation(
2400 const ArithmeticParams& params, const RuntimeShape& input1_shape,
2401 const T* input1_data, const RuntimeShape& input2_shape,
2402 const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2403 ruy::profiler::ScopeLabel label("SubWithActivation_optimized");
2404 TFLITE_DCHECK_EQ(input1_shape.FlatSize(), input2_shape.FlatSize());
2405 auto input1_map = MapAsVector(input1_data, input1_shape);
2406 auto input2_map = MapAsVector(input2_data, input2_shape);
2407 auto output_map = MapAsVector(output_data, output_shape);
2408 T activation_min, activation_max;
2409 GetActivationParams(params, &activation_min, &activation_max);
2410 output_map.array() = (input1_map.array() - input2_map.array())
2411 .cwiseMin(activation_max)
2412 .cwiseMax(activation_min);
2413 }
2414
SubNonBroadcast(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)2415 inline void SubNonBroadcast(const ArithmeticParams& params,
2416 const RuntimeShape& input1_shape,
2417 const float* input1_data,
2418 const RuntimeShape& input2_shape,
2419 const float* input2_data,
2420 const RuntimeShape& output_shape,
2421 float* output_data) {
2422 ruy::profiler::ScopeLabel label("SubNonBroadcast");
2423 SubWithActivation<float>(params, input1_shape, input1_data, input2_shape,
2424 input2_data, output_shape, output_data);
2425 }
2426
2427 template <typename T>
Sub(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2428 void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
2429 const T* input1_data, const RuntimeShape& input2_shape,
2430 const T* input2_data, const RuntimeShape& output_shape,
2431 T* output_data) {
2432 ruy::profiler::ScopeLabel label("Sub");
2433
2434 auto input1_map = MapAsVector(input1_data, input1_shape);
2435 auto input2_map = MapAsVector(input2_data, input2_shape);
2436 auto output_map = MapAsVector(output_data, output_shape);
2437 if (input1_shape == input2_shape) {
2438 output_map.array() = input1_map.array() - input2_map.array();
2439 } else if (input1_shape.FlatSize() == 1) {
2440 auto scalar = input1_data[0];
2441 output_map.array() = scalar - input2_map.array();
2442 } else if (input2_shape.FlatSize() == 1) {
2443 auto scalar = input2_data[0];
2444 output_map.array() = input1_map.array() - scalar;
2445 } else {
2446 BroadcastSubSlow(params, input1_shape, input1_data, input2_shape,
2447 input2_data, output_shape, output_data);
2448 }
2449 }
2450
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data,CpuBackendContext * cpu_backend_context)2451 inline void LstmCell(
2452 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2453 const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
2454 const float* prev_activ_data, const RuntimeShape& weights_shape,
2455 const float* weights_data, const RuntimeShape& unextended_bias_shape,
2456 const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
2457 const float* prev_state_data,
2458 const RuntimeShape& unextended_output_state_shape, float* output_state_data,
2459 const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
2460 const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
2461 const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data,
2462 CpuBackendContext* cpu_backend_context) {
2463 ruy::profiler::ScopeLabel label("LstmCell");
2464 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2465 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2466 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2467 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2468 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2469 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2470 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2471 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2472 const RuntimeShape input_shape =
2473 RuntimeShape::ExtendedShape(4, unextended_input_shape);
2474 const RuntimeShape prev_activ_shape =
2475 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
2476 const RuntimeShape bias_shape =
2477 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
2478 const RuntimeShape prev_state_shape =
2479 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
2480 const RuntimeShape output_state_shape =
2481 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
2482 const RuntimeShape output_activ_shape =
2483 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
2484 const RuntimeShape concat_temp_shape =
2485 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
2486 const RuntimeShape activ_temp_shape =
2487 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
2488 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2489
2490 const int weights_dim_count = weights_shape.DimensionsCount();
2491 MatchingDim( // batches
2492 input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
2493 output_state_shape, 0, output_activ_shape, 0);
2494 MatchingDim( // height
2495 input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
2496 output_state_shape, 1, output_activ_shape, 1);
2497 MatchingDim( // width
2498 input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
2499 output_state_shape, 2, output_activ_shape, 2);
2500 const int input_depth = input_shape.Dims(3);
2501 const int prev_activ_depth = prev_activ_shape.Dims(3);
2502 const int total_input_depth = prev_activ_depth + input_depth;
2503 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
2504 total_input_depth);
2505 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
2506 const int intern_activ_depth =
2507 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
2508 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
2509 intern_activ_depth * total_input_depth);
2510 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
2511 const int output_depth =
2512 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
2513 3, output_activ_shape, 3);
2514 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
2515
2516 // Concatenate prev_activ and input data together
2517 std::vector<float const*> concat_input_arrays_data;
2518 std::vector<RuntimeShape const*> concat_input_arrays_shapes;
2519 concat_input_arrays_data.push_back(input_data);
2520 concat_input_arrays_data.push_back(prev_activ_data);
2521 concat_input_arrays_shapes.push_back(&input_shape);
2522 concat_input_arrays_shapes.push_back(&prev_activ_shape);
2523 tflite::ConcatenationParams concat_params;
2524 concat_params.axis = 3;
2525 concat_params.inputs_count = concat_input_arrays_data.size();
2526 Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
2527 &(concat_input_arrays_data[0]), concat_temp_shape,
2528 concat_temp_data);
2529
2530 // Fully connected
2531 tflite::FullyConnectedParams fc_params;
2532 fc_params.float_activation_min = std::numeric_limits<float>::lowest();
2533 fc_params.float_activation_max = std::numeric_limits<float>::max();
2534 fc_params.lhs_cacheable = false;
2535 fc_params.rhs_cacheable = false;
2536 FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
2537 weights_data, bias_shape, bias_data, activ_temp_shape,
2538 activ_temp_data, cpu_backend_context);
2539
2540 // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
2541 // operations.
2542 ArrayMap<float> activ_temp_map =
2543 MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
2544 auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
2545 activ_temp_map.cols());
2546 auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
2547 activ_temp_map.cols());
2548 auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
2549 activ_temp_map.cols());
2550 auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
2551 activ_temp_map.cols());
2552 ArrayMap<const float> prev_state_map =
2553 MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
2554 ArrayMap<float> output_state_map =
2555 MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
2556 ArrayMap<float> output_activ_map =
2557 MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
2558
2559 // Combined memory state and final output calculation
2560 ruy::profiler::ScopeLabel label2("MemoryStateAndFinalOutput");
2561 output_state_map =
2562 input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2563 new_input_sm.tanh() +
2564 forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2565 prev_state_map;
2566 output_activ_map =
2567 output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2568 output_state_map.tanh();
2569 }
2570
2571 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,CpuBackendContext * cpu_backend_context)2572 inline void LstmCell(
2573 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2574 const uint8* input_data_uint8,
2575 const RuntimeShape& unextended_prev_activ_shape,
2576 const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
2577 const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
2578 const int32* bias_data_int32,
2579 const RuntimeShape& unextended_prev_state_shape,
2580 const int16* prev_state_data_int16,
2581 const RuntimeShape& unextended_output_state_shape,
2582 int16* output_state_data_int16,
2583 const RuntimeShape& unextended_output_activ_shape,
2584 uint8* output_activ_data_uint8,
2585 const RuntimeShape& unextended_concat_temp_shape,
2586 uint8* concat_temp_data_uint8,
2587 const RuntimeShape& unextended_activ_temp_shape,
2588 int16* activ_temp_data_int16, CpuBackendContext* cpu_backend_context) {
2589 ruy::profiler::ScopeLabel label(
2590 "LstmCell/quantized (8bit external, 16bit internal)");
2591 int32 weights_zero_point = params.weights_zero_point;
2592 int32 accum_multiplier = params.accum_multiplier;
2593 int accum_shift = params.accum_shift;
2594 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2595 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2596 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2597 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2598 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2599 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2600 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2601 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2602 const RuntimeShape input_shape =
2603 RuntimeShape::ExtendedShape(4, unextended_input_shape);
2604 const RuntimeShape prev_activ_shape =
2605 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
2606 const RuntimeShape bias_shape =
2607 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
2608 const RuntimeShape prev_state_shape =
2609 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
2610 const RuntimeShape output_state_shape =
2611 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
2612 const RuntimeShape output_activ_shape =
2613 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
2614 const RuntimeShape concat_temp_shape =
2615 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
2616 const RuntimeShape activ_temp_shape =
2617 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
2618 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2619
2620 // Gather dimensions information, and perform consistency checks.
2621 const int weights_dim_count = weights_shape.DimensionsCount();
2622 const int outer_size = MatchingFlatSizeSkipDim(
2623 input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
2624 output_activ_shape);
2625 const int input_depth = input_shape.Dims(3);
2626 const int prev_activ_depth = prev_activ_shape.Dims(3);
2627 const int total_input_depth = prev_activ_depth + input_depth;
2628 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
2629 total_input_depth);
2630 const int intern_activ_depth =
2631 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
2632 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
2633 intern_activ_depth * total_input_depth);
2634 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
2635 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
2636 const int output_depth =
2637 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
2638 3, output_activ_shape, 3);
2639 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
2640 const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
2641 const int fc_output_depth =
2642 MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
2643 const int fc_accum_depth = total_input_depth;
2644 TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
2645
2646 // Depth-concatenate prev_activ and input data together.
2647 uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
2648 prev_activ_data_uint8};
2649 const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
2650 &prev_activ_shape};
2651 tflite::ConcatenationParams concat_params;
2652 concat_params.axis = 3;
2653 concat_params.inputs_count = 2;
2654 Concatenation(concat_params, concat_input_arrays_shapes,
2655 concat_input_arrays_data, concat_temp_shape,
2656 concat_temp_data_uint8);
2657
2658 // Implementation of the fully connected node inside the LSTM cell.
2659 // The operands are 8-bit integers, the accumulators are internally 32bit
2660 // integers, and the output is 16-bit fixed-point with 3 integer bits so
2661 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
2662 // is explained in the function comment above.
2663 cpu_backend_gemm::MatrixParams<uint8> lhs_params;
2664 lhs_params.rows = fc_output_depth;
2665 lhs_params.cols = fc_accum_depth;
2666 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
2667 lhs_params.zero_point = weights_zero_point;
2668 cpu_backend_gemm::MatrixParams<uint8> rhs_params;
2669 rhs_params.rows = fc_accum_depth;
2670 rhs_params.cols = fc_batches;
2671 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
2672 rhs_params.zero_point = 128;
2673 cpu_backend_gemm::MatrixParams<int16> dst_params;
2674 dst_params.rows = fc_output_depth;
2675 dst_params.cols = fc_batches;
2676 dst_params.order = cpu_backend_gemm::Order::kColMajor;
2677 dst_params.zero_point = 0;
2678 cpu_backend_gemm::GemmParams<int32, int16> gemm_params;
2679 gemm_params.bias = bias_data_int32;
2680 gemm_params.multiplier_fixedpoint = accum_multiplier;
2681 gemm_params.multiplier_exponent = accum_shift;
2682 cpu_backend_gemm::Gemm(
2683 lhs_params, weights_data_uint8, rhs_params, concat_temp_data_uint8,
2684 dst_params, activ_temp_data_int16, gemm_params, cpu_backend_context);
2685
2686 // Rest of the LSTM cell: tanh and logistic math functions, and some adds
2687 // and muls, all done in 16-bit fixed-point.
2688 const int16* input_gate_input_ptr = activ_temp_data_int16;
2689 const int16* input_modulation_gate_input_ptr =
2690 activ_temp_data_int16 + output_depth;
2691 const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth;
2692 const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth;
2693 const int16* prev_state_ptr = prev_state_data_int16;
2694 int16* output_state_data_ptr = output_state_data_int16;
2695 uint8* output_activ_data_ptr = output_activ_data_uint8;
2696
2697 for (int b = 0; b < outer_size; ++b) {
2698 int c = 0;
2699 #ifdef GEMMLOWP_NEON
2700 for (; c <= output_depth - 8; c += 8) {
2701 // Define the fixed-point data types that we will use here. All use
2702 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
2703 // They only differ by the number of integral vs. fractional bits,
2704 // determining the range of values that they can represent.
2705 //
2706 // F0 uses 0 integer bits, range [-1, 1].
2707 // This is the return type of math functions such as tanh, logistic,
2708 // whose range is in [-1, 1].
2709 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
2710 // F3 uses 3 integer bits, range [-8, 8].
2711 // This is the range of the previous fully-connected node's output,
2712 // which is our input here.
2713 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
2714 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
2715 // 2^StateIntegerBits]. It's used to represent the internal state, whose
2716 // number of integer bits is currently dictated by the model. See comment
2717 // on the StateIntegerBits template parameter above.
2718 using FS = gemmlowp::FixedPoint<int16x8_t, StateIntegerBits>;
2719 // Implementation of input gate, using fixed-point logistic function.
2720 F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr));
2721 input_gate_input_ptr += 8;
2722 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
2723 // Implementation of input modulation gate, using fixed-point tanh
2724 // function.
2725 F3 input_modulation_gate_input =
2726 F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr));
2727 input_modulation_gate_input_ptr += 8;
2728 F0 input_modulation_gate_output =
2729 gemmlowp::tanh(input_modulation_gate_input);
2730 // Implementation of forget gate, using fixed-point logistic function.
2731 F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr));
2732 forget_gate_input_ptr += 8;
2733 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
2734 // Implementation of output gate, using fixed-point logistic function.
2735 F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr));
2736 output_gate_input_ptr += 8;
2737 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
2738 // Implementation of internal multiplication nodes, still in fixed-point.
2739 F0 input_times_input_modulation =
2740 input_gate_output * input_modulation_gate_output;
2741 FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr));
2742 prev_state_ptr += 8;
2743 FS prev_state_times_forget_state = forget_gate_output * prev_state;
2744 // Implementation of internal addition node, saturating.
2745 FS new_state = gemmlowp::SaturatingAdd(
2746 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
2747 prev_state_times_forget_state);
2748 // Implementation of last internal Tanh node, still in fixed-point.
2749 // Since a Tanh fixed-point implementation is specialized for a given
2750 // number or integer bits, and each specialization can have a substantial
2751 // code size, and we already used above a Tanh on an input with 3 integer
2752 // bits, and per the table in the above function comment there is no
2753 // significant accuracy to be lost by clamping to [-8, +8] for a
2754 // 3-integer-bits representation, let us just do that. This helps people
2755 // porting this to targets where code footprint must be minimized.
2756 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
2757 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
2758 // Store the new internal state back to memory, as 16-bit integers.
2759 // Note: here we store the original value with StateIntegerBits, not
2760 // the rescaled 3-integer-bits value fed to tanh.
2761 vst1q_s16(output_state_data_ptr, new_state.raw());
2762 output_state_data_ptr += 8;
2763 // Down-scale the output activations to 8-bit integers, saturating,
2764 // and store back to memory.
2765 int16x8_t rescaled_output_activ =
2766 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
2767 int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ);
2768 uint8x8_t uint8_output_activ =
2769 vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ));
2770 vst1_u8(output_activ_data_ptr, uint8_output_activ);
2771 output_activ_data_ptr += 8;
2772 }
2773 #endif
2774 for (; c < output_depth; ++c) {
2775 // Define the fixed-point data types that we will use here. All use
2776 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
2777 // They only differ by the number of integral vs. fractional bits,
2778 // determining the range of values that they can represent.
2779 //
2780 // F0 uses 0 integer bits, range [-1, 1].
2781 // This is the return type of math functions such as tanh, logistic,
2782 // whose range is in [-1, 1].
2783 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2784 // F3 uses 3 integer bits, range [-8, 8].
2785 // This is the range of the previous fully-connected node's output,
2786 // which is our input here.
2787 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
2788 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
2789 // 2^StateIntegerBits]. It's used to represent the internal state, whose
2790 // number of integer bits is currently dictated by the model. See comment
2791 // on the StateIntegerBits template parameter above.
2792 using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
2793 // Implementation of input gate, using fixed-point logistic function.
2794 F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++);
2795 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
2796 // Implementation of input modulation gate, using fixed-point tanh
2797 // function.
2798 F3 input_modulation_gate_input =
2799 F3::FromRaw(*input_modulation_gate_input_ptr++);
2800 F0 input_modulation_gate_output =
2801 gemmlowp::tanh(input_modulation_gate_input);
2802 // Implementation of forget gate, using fixed-point logistic function.
2803 F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++);
2804 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
2805 // Implementation of output gate, using fixed-point logistic function.
2806 F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++);
2807 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
2808 // Implementation of internal multiplication nodes, still in fixed-point.
2809 F0 input_times_input_modulation =
2810 input_gate_output * input_modulation_gate_output;
2811 FS prev_state = FS::FromRaw(*prev_state_ptr++);
2812 FS prev_state_times_forget_state = forget_gate_output * prev_state;
2813 // Implementation of internal addition node, saturating.
2814 FS new_state = gemmlowp::SaturatingAdd(
2815 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
2816 prev_state_times_forget_state);
2817 // Implementation of last internal Tanh node, still in fixed-point.
2818 // Since a Tanh fixed-point implementation is specialized for a given
2819 // number or integer bits, and each specialization can have a substantial
2820 // code size, and we already used above a Tanh on an input with 3 integer
2821 // bits, and per the table in the above function comment there is no
2822 // significant accuracy to be lost by clamping to [-8, +8] for a
2823 // 3-integer-bits representation, let us just do that. This helps people
2824 // porting this to targets where code footprint must be minimized.
2825 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
2826 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
2827 // Store the new internal state back to memory, as 16-bit integers.
2828 // Note: here we store the original value with StateIntegerBits, not
2829 // the rescaled 3-integer-bits value fed to tanh.
2830 *output_state_data_ptr++ = new_state.raw();
2831 // Down-scale the output activations to 8-bit integers, saturating,
2832 // and store back to memory.
2833 int16 rescaled_output_activ =
2834 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
2835 int16 clamped_output_activ =
2836 std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
2837 *output_activ_data_ptr++ = 128 + clamped_output_activ;
2838 }
2839 input_gate_input_ptr += 3 * output_depth;
2840 input_modulation_gate_input_ptr += 3 * output_depth;
2841 forget_gate_input_ptr += 3 * output_depth;
2842 output_gate_input_ptr += 3 * output_depth;
2843 }
2844 }
2845
NodeOffset(int b,int h,int w,int height,int width)2846 inline int NodeOffset(int b, int h, int w, int height, int width) {
2847 return (b * height + h) * width + w;
2848 }
2849
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)2850 inline bool AveragePool(const PoolParams& params,
2851 const RuntimeShape& input_shape,
2852 const float* input_data,
2853 const RuntimeShape& output_shape, float* output_data) {
2854 ruy::profiler::ScopeLabel label("AveragePool");
2855 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2856 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2857 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
2858 const int input_height = input_shape.Dims(1);
2859 const int input_width = input_shape.Dims(2);
2860 const int output_height = output_shape.Dims(1);
2861 const int output_width = output_shape.Dims(2);
2862 const int stride_height = params.stride_height;
2863 const int stride_width = params.stride_width;
2864
2865 if (stride_height == 0) return false;
2866 if (stride_width == 0) return false;
2867
2868 // TODO(benoitjacob) make this a proper reference impl without Eigen!
2869 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
2870 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
2871 // TODO(benoitjacob) get rid of the dynamic memory allocation here!
2872 Eigen::VectorXf out_count(out_mat.cols());
2873 out_count.setZero();
2874 // Prefill the output to 0.
2875 out_mat.setZero();
2876 for (int b = 0; b < batches; ++b) {
2877 for (int h = 0; h < input_height; ++h) {
2878 for (int w = 0; w < input_width; ++w) {
2879 // (h_start, h_end) * (w_start, w_end) is the range that the input
2880 // vector projects to.
2881 int hpad = h + params.padding_values.height;
2882 int wpad = w + params.padding_values.width;
2883 int h_start = (hpad < params.filter_height)
2884 ? 0
2885 : (hpad - params.filter_height) / stride_height + 1;
2886 int h_end = std::min(hpad / stride_height + 1, output_height);
2887 int w_start = (wpad < params.filter_width)
2888 ? 0
2889 : (wpad - params.filter_width) / stride_width + 1;
2890 int w_end = std::min(wpad / stride_width + 1, output_width);
2891 // compute elementwise sum
2892 for (int ph = h_start; ph < h_end; ++ph) {
2893 for (int pw = w_start; pw < w_end; ++pw) {
2894 int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
2895 out_mat.col(out_offset) +=
2896 in_mat.col(NodeOffset(b, h, w, input_height, input_width));
2897 out_count(out_offset)++;
2898 }
2899 }
2900 }
2901 }
2902 }
2903 // Divide the output by the actual number of elements being averaged over
2904 TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
2905 out_mat.array().rowwise() /= out_count.transpose().array();
2906
2907 const int flat_size = output_shape.FlatSize();
2908 for (int i = 0; i < flat_size; ++i) {
2909 output_data[i] = ActivationFunctionWithMinMax(output_data[i],
2910 params.float_activation_min,
2911 params.float_activation_max);
2912 }
2913
2914 return true;
2915 }
2916
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)2917 inline bool AveragePool(const PoolParams& params,
2918 const RuntimeShape& input_shape,
2919 const uint8* input_data,
2920 const RuntimeShape& output_shape, uint8* output_data) {
2921 ruy::profiler::ScopeLabel label("AveragePool/8bit");
2922
2923 // Here, and in other pooling ops, in order to maintain locality of reference,
2924 // to minimize some recalculations, and to load into NEON vector registers, we
2925 // use an inner loop down the depth. Since depths can be large and hence we
2926 // would need arbitrarily large temporary storage, we divide the work up into
2927 // depth tranches just within the batch loop.
2928 static constexpr int kPoolingAccTrancheSize = 256;
2929
2930 TFLITE_DCHECK_LE(params.quantized_activation_min,
2931 params.quantized_activation_max);
2932 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2933 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2934 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
2935 const int depth = MatchingDim(input_shape, 3, output_shape, 3);
2936 const int input_height = input_shape.Dims(1);
2937 const int input_width = input_shape.Dims(2);
2938 const int output_height = output_shape.Dims(1);
2939 const int output_width = output_shape.Dims(2);
2940 const int stride_height = params.stride_height;
2941 const int stride_width = params.stride_width;
2942
2943 uint32 acc[kPoolingAccTrancheSize];
2944 for (int batch = 0; batch < batches; ++batch) {
2945 // We proceed through the depth in tranches (see comment above). The
2946 // depth_base is the depth at the beginning of the tranche. The
2947 // tranche_depth is the depth dimension of the tranche.
2948 for (int depth_base = 0; depth_base < depth;
2949 depth_base += kPoolingAccTrancheSize) {
2950 const int tranche_depth =
2951 std::min(depth - depth_base, kPoolingAccTrancheSize);
2952 for (int out_y = 0; out_y < output_height; ++out_y) {
2953 for (int out_x = 0; out_x < output_width; ++out_x) {
2954 const int in_x_origin =
2955 (out_x * stride_width) - params.padding_values.width;
2956 const int in_y_origin =
2957 (out_y * stride_height) - params.padding_values.height;
2958 const int filter_x_start = std::max(0, -in_x_origin);
2959 const int filter_x_end =
2960 std::min(params.filter_width, input_width - in_x_origin);
2961 const int filter_y_start = std::max(0, -in_y_origin);
2962 const int filter_y_end =
2963 std::min(params.filter_height, input_height - in_y_origin);
2964 const int filter_count =
2965 (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
2966 if (filter_count == 0) return false;
2967 memset(acc, 0, tranche_depth * sizeof(acc[0]));
2968 const uint8* input_ptr =
2969 input_data + depth_base +
2970 depth * (in_x_origin +
2971 input_width * (in_y_origin + input_height * batch));
2972 for (int fy = filter_y_start; fy < filter_y_end; fy++) {
2973 const uint8* input_row_ptr =
2974 input_ptr + depth * (fy * input_width + filter_x_start);
2975 for (int fx = filter_x_start; fx < filter_x_end; fx++) {
2976 const uint8* input_channel_ptr = input_row_ptr;
2977 int channel = 0;
2978 #ifdef USE_NEON
2979 for (; channel <= tranche_depth - 16; channel += 16) {
2980 uint16x4_t acc_reg[4];
2981 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
2982 input_channel_ptr += 16;
2983 acc_reg[0] = vget_low_u16(vmovl_u8(vget_low_u8(input_reg)));
2984 acc_reg[1] = vget_high_u16(vmovl_u8(vget_low_u8(input_reg)));
2985 acc_reg[2] = vget_low_u16(vmovl_u8(vget_high_u8(input_reg)));
2986 acc_reg[3] = vget_high_u16(vmovl_u8(vget_high_u8(input_reg)));
2987 for (int i = 0; i < 4; i++) {
2988 vst1q_u32(
2989 acc + channel + 4 * i,
2990 vaddw_u16(vld1q_u32(acc + channel + 4 * i), acc_reg[i]));
2991 }
2992 }
2993 for (; channel <= tranche_depth - 8; channel += 8) {
2994 uint16x4_t acc_reg[2];
2995 uint16x8_t input_reg = vmovl_u8(vld1_u8(input_channel_ptr));
2996 input_channel_ptr += 8;
2997 acc_reg[0] = vget_low_u16(input_reg);
2998 acc_reg[1] = vget_high_u16(input_reg);
2999 for (int i = 0; i < 2; i++) {
3000 vst1q_u32(
3001 acc + channel + 4 * i,
3002 vaddw_u16(vld1q_u32(acc + channel + 4 * i), acc_reg[i]));
3003 }
3004 }
3005 #endif
3006 for (; channel < tranche_depth; ++channel) {
3007 acc[channel] += *input_channel_ptr++;
3008 }
3009 input_row_ptr += depth;
3010 }
3011 }
3012 uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
3013 out_x, depth_base);
3014 int channel = 0;
3015 #ifdef USE_NEON
3016 #define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
3017 if (filter_count == FILTER_COUNT) { \
3018 for (; channel <= tranche_depth - 8; channel += 8) { \
3019 uint16 buf[8]; \
3020 for (int i = 0; i < 8; i++) { \
3021 buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
3022 } \
3023 uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \
3024 buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); \
3025 buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); \
3026 vst1_u8(output_ptr + channel, buf8); \
3027 } \
3028 }
3029 AVGPOOL_DIVIDING_BY(9)
3030 AVGPOOL_DIVIDING_BY(15)
3031 #undef AVGPOOL_DIVIDING_BY
3032 for (; channel <= tranche_depth - 8; channel += 8) {
3033 uint16 buf[8];
3034 for (int i = 0; i < 8; i++) {
3035 buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
3036 }
3037 uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
3038 buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max));
3039 buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min));
3040 vst1_u8(output_ptr + channel, buf8);
3041 }
3042 #endif
3043 for (; channel < tranche_depth; ++channel) {
3044 uint16 a = (acc[channel] + filter_count / 2) / filter_count;
3045 a = std::max<uint16>(a, params.quantized_activation_min);
3046 a = std::min<uint16>(a, params.quantized_activation_max);
3047 output_ptr[channel] = static_cast<uint8>(a);
3048 }
3049 }
3050 }
3051 }
3052 }
3053 return true;
3054 }
3055
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3056 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
3057 const float* input_data, const RuntimeShape& output_shape,
3058 float* output_data) {
3059 ruy::profiler::ScopeLabel label("MaxPool");
3060 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3061 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3062 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3063 const int input_height = input_shape.Dims(1);
3064 const int input_width = input_shape.Dims(2);
3065 const int output_height = output_shape.Dims(1);
3066 const int output_width = output_shape.Dims(2);
3067 const int stride_height = params.stride_height;
3068 const int stride_width = params.stride_width;
3069
3070 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3071 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3072 // Prefill the output to minimum representable float value
3073 out_mat.setConstant(std::numeric_limits<float>::lowest());
3074 for (int b = 0; b < batches; ++b) {
3075 for (int h = 0; h < input_height; ++h) {
3076 for (int w = 0; w < input_width; ++w) {
3077 // (h_start, h_end) * (w_start, w_end) is the range that the input
3078 // vector projects to.
3079 int hpad = h + params.padding_values.height;
3080 int wpad = w + params.padding_values.width;
3081 int h_start = (hpad < params.filter_height)
3082 ? 0
3083 : (hpad - params.filter_height) / stride_height + 1;
3084 int h_end = std::min(hpad / stride_height + 1, output_height);
3085 int w_start = (wpad < params.filter_width)
3086 ? 0
3087 : (wpad - params.filter_width) / stride_width + 1;
3088 int w_end = std::min(wpad / stride_width + 1, output_width);
3089 // compute elementwise sum
3090 for (int ph = h_start; ph < h_end; ++ph) {
3091 for (int pw = w_start; pw < w_end; ++pw) {
3092 int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
3093 out_mat.col(out_offset) =
3094 out_mat.col(out_offset)
3095 .cwiseMax(in_mat.col(
3096 NodeOffset(b, h, w, input_height, input_width)));
3097 }
3098 }
3099 }
3100 }
3101 }
3102 const int flat_size = output_shape.FlatSize();
3103 for (int i = 0; i < flat_size; ++i) {
3104 output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3105 params.float_activation_min,
3106 params.float_activation_max);
3107 }
3108 }
3109
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3110 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
3111 const uint8* input_data, const RuntimeShape& output_shape,
3112 uint8* output_data) {
3113 ruy::profiler::ScopeLabel label("MaxPool/8bit");
3114
3115 // Here, and in other pooling ops, in order to maintain locality of reference,
3116 // to minimize some recalculations, and to load into NEON vector registers, we
3117 // use an inner loop down the depth. Since depths can be large and hence we
3118 // would need arbitrarily large temporary storage, we divide the work up into
3119 // depth tranches just within the batch loop.
3120 static constexpr int kPoolingAccTrancheSize = 256;
3121
3122 TFLITE_DCHECK_LE(params.quantized_activation_min,
3123 params.quantized_activation_max);
3124 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3125 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3126 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3127 const int depth = MatchingDim(input_shape, 3, output_shape, 3);
3128 const int input_height = input_shape.Dims(1);
3129 const int input_width = input_shape.Dims(2);
3130 const int output_height = output_shape.Dims(1);
3131 const int output_width = output_shape.Dims(2);
3132 const int stride_height = params.stride_height;
3133 const int stride_width = params.stride_width;
3134
3135 uint8 acc[kPoolingAccTrancheSize];
3136 for (int batch = 0; batch < batches; ++batch) {
3137 // We proceed through the depth in tranches (see comment above). The
3138 // depth_base is the depth at the beginning of the tranche. The
3139 // tranche_depth is the depth dimension of the tranche.
3140 for (int depth_base = 0; depth_base < depth;
3141 depth_base += kPoolingAccTrancheSize) {
3142 const int tranche_depth =
3143 std::min(depth - depth_base, kPoolingAccTrancheSize);
3144 for (int out_y = 0; out_y < output_height; ++out_y) {
3145 for (int out_x = 0; out_x < output_width; ++out_x) {
3146 const int in_x_origin =
3147 (out_x * stride_width) - params.padding_values.width;
3148 const int in_y_origin =
3149 (out_y * stride_height) - params.padding_values.height;
3150 const int filter_x_start = std::max(0, -in_x_origin);
3151 const int filter_x_end =
3152 std::min(params.filter_width, input_width - in_x_origin);
3153 const int filter_y_start = std::max(0, -in_y_origin);
3154 const int filter_y_end =
3155 std::min(params.filter_height, input_height - in_y_origin);
3156 memset(acc, 0, tranche_depth * sizeof(acc[0]));
3157 const uint8* input_ptr =
3158 input_data + depth_base +
3159 depth * (in_x_origin +
3160 input_width * (in_y_origin + input_height * batch));
3161 for (int fy = filter_y_start; fy < filter_y_end; fy++) {
3162 const uint8* input_row_ptr =
3163 input_ptr + depth * (fy * input_width + filter_x_start);
3164 for (int fx = filter_x_start; fx < filter_x_end; fx++) {
3165 const uint8* input_channel_ptr = input_row_ptr;
3166 int channel = 0;
3167 #ifdef USE_NEON
3168 for (; channel <= tranche_depth - 16; channel += 16) {
3169 uint8x16_t acc_reg = vld1q_u8(acc + channel);
3170 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
3171 input_channel_ptr += 16;
3172 acc_reg = vmaxq_u8(acc_reg, input_reg);
3173 vst1q_u8(acc + channel, acc_reg);
3174 }
3175
3176 for (; channel <= tranche_depth - 8; channel += 8) {
3177 uint8x8_t acc_reg = vld1_u8(acc + channel);
3178 uint8x8_t input_reg = vld1_u8(input_channel_ptr);
3179 input_channel_ptr += 8;
3180 acc_reg = vmax_u8(acc_reg, input_reg);
3181 vst1_u8(acc + channel, acc_reg);
3182 }
3183 #endif
3184 for (; channel < tranche_depth; ++channel) {
3185 acc[channel] = std::max(acc[channel], *input_channel_ptr++);
3186 }
3187 input_row_ptr += depth;
3188 }
3189 }
3190 uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
3191 out_x, depth_base);
3192 int channel = 0;
3193 #ifdef USE_NEON
3194 for (; channel <= tranche_depth - 16; channel += 16) {
3195 uint8x16_t a = vld1q_u8(acc + channel);
3196 a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max));
3197 a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min));
3198 vst1q_u8(output_ptr + channel, a);
3199 }
3200 for (; channel <= tranche_depth - 8; channel += 8) {
3201 uint8x8_t a = vld1_u8(acc + channel);
3202 a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max));
3203 a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min));
3204 vst1_u8(output_ptr + channel, a);
3205 }
3206 #endif
3207 for (; channel < tranche_depth; ++channel) {
3208 uint8 a = acc[channel];
3209 a = std::max<uint8>(a, params.quantized_activation_min);
3210 a = std::min<uint8>(a, params.quantized_activation_max);
3211 output_ptr[channel] = static_cast<uint8>(a);
3212 }
3213 }
3214 }
3215 }
3216 }
3217 }
3218
L2Pool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3219 inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
3220 const float* input_data, const RuntimeShape& output_shape,
3221 float* output_data) {
3222 ruy::profiler::ScopeLabel label("L2Pool");
3223 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3224 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3225 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3226 const int input_height = input_shape.Dims(1);
3227 const int input_width = input_shape.Dims(2);
3228 const int output_height = output_shape.Dims(1);
3229 const int output_width = output_shape.Dims(2);
3230 const int stride_height = params.stride_height;
3231 const int stride_width = params.stride_width;
3232 // Actually carry out L2 Pool. Code is written in forward mode: we go through
3233 // the input values once, and write to all the pooled regions that it maps to.
3234 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3235 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3236 Eigen::VectorXf in_square(in_mat.rows());
3237 Eigen::VectorXf out_count(out_mat.cols());
3238 out_count.setZero();
3239 // Prefill the output to 0.
3240 out_mat.setZero();
3241 for (int b = 0; b < batches; ++b) {
3242 for (int h = 0; h < input_height; ++h) {
3243 for (int w = 0; w < input_width; ++w) {
3244 // (h_start, h_end) * (w_start, w_end) is the range that the input
3245 // vector projects to.
3246 const int hpad = h + params.padding_values.height;
3247 const int wpad = w + params.padding_values.width;
3248 const int h_start =
3249 (hpad < params.filter_height)
3250 ? 0
3251 : (hpad - params.filter_height) / stride_height + 1;
3252 const int h_end = std::min(hpad / stride_height + 1, output_height);
3253 const int w_start =
3254 (wpad < params.filter_width)
3255 ? 0
3256 : (wpad - params.filter_width) / stride_width + 1;
3257 const int w_end = std::min(wpad / stride_width + 1, output_width);
3258 // pre-compute square
3259 const int in_offset = w + input_width * (h + input_height * b);
3260 in_square =
3261 in_mat.col(in_offset).array() * in_mat.col(in_offset).array();
3262 // compute elementwise sum of squares
3263 for (int ph = h_start; ph < h_end; ++ph) {
3264 for (int pw = w_start; pw < w_end; ++pw) {
3265 const int out_offset = pw + output_width * (ph + output_height * b);
3266 out_mat.col(out_offset) += in_square;
3267 out_count(out_offset)++;
3268 }
3269 }
3270 }
3271 }
3272 }
3273
3274 out_count = out_count.array().inverse();
3275 out_mat =
3276 (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
3277
3278 const int flat_size = output_shape.FlatSize();
3279 for (int i = 0; i < flat_size; ++i) {
3280 output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3281 params.float_activation_min,
3282 params.float_activation_max);
3283 }
3284 }
3285
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3286 inline void LocalResponseNormalization(
3287 const tflite::LocalResponseNormalizationParams& op_params,
3288 const RuntimeShape& input_shape, const float* input_data,
3289 const RuntimeShape& output_shape, float* output_data) {
3290 ruy::profiler::ScopeLabel label("LocalResponseNormalization");
3291 MatchingFlatSize(input_shape, output_shape);
3292
3293 const auto data_in = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3294 auto data_out = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3295
3296 // Carry out local response normalization, vector by vector.
3297 // Since the data are stored column major, making row-wise operation
3298 // probably not memory efficient anyway, we do an explicit for loop over
3299 // the columns.
3300 const int double_range = op_params.range * 2;
3301 Eigen::VectorXf padded_square(data_in.rows() + double_range);
3302 padded_square.setZero();
3303 const float bias = op_params.bias;
3304 for (int r = 0; r < data_in.cols(); ++r) {
3305 // Do local response normalization for data_in(:, r)
3306 // first, compute the square and store them in buffer for repeated use
3307 padded_square.block(op_params.range, 0, data_in.rows(), 1) =
3308 data_in.col(r).cwiseProduct(data_in.col(r)) * op_params.alpha;
3309 // Then, compute the scale and writes them to data_out
3310 float accumulated_scale = 0;
3311 for (int i = 0; i < double_range; ++i) {
3312 accumulated_scale += padded_square(i);
3313 }
3314 for (int i = 0; i < data_in.rows(); ++i) {
3315 accumulated_scale += padded_square(i + double_range);
3316 data_out(i, r) = bias + accumulated_scale;
3317 accumulated_scale -= padded_square(i);
3318 }
3319 }
3320
3321 // In a few cases, the pow computation could benefit from speedups.
3322 if (op_params.beta == 1) {
3323 data_out.array() = data_in.array() * data_out.array().inverse();
3324 } else if (op_params.beta == 0.5f) {
3325 data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
3326 } else {
3327 data_out.array() = data_in.array() * data_out.array().pow(-op_params.beta);
3328 }
3329 }
3330
SoftmaxImpl(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data,int start_batch,int end_batch)3331 inline void SoftmaxImpl(const SoftmaxParams& params,
3332 const RuntimeShape& input_shape,
3333 const float* input_data,
3334 const RuntimeShape& output_shape, float* output_data,
3335 int start_batch, int end_batch) {
3336 ruy::profiler::ScopeLabel label("Softmax/Impl");
3337 MatchingFlatSize(input_shape, output_shape);
3338
3339 const int logit_size = input_shape.Dims(input_shape.DimensionsCount() - 1);
3340 const MatrixMap<const float> in_mat(input_data + logit_size * start_batch,
3341 logit_size, end_batch - start_batch);
3342 MatrixMap<float> out_mat(output_data + logit_size * start_batch, logit_size,
3343 end_batch - start_batch);
3344 // Compute the exponential first, removing the max coefficient for numerical
3345 // stability.
3346 out_mat =
3347 (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
3348 // We are separating out the exp function so that exp can be vectorized.
3349 out_mat = out_mat.array().exp();
3350 // Normalize to get the activations.
3351 Eigen::Array<float, 1, Eigen::Dynamic> scale =
3352 out_mat.array().colwise().sum().inverse();
3353 out_mat.array().rowwise() *= scale;
3354 }
3355
3356 struct SoftmaxWorkerTask : cpu_backend_threadpool::Task {
SoftmaxWorkerTaskSoftmaxWorkerTask3357 SoftmaxWorkerTask(const SoftmaxParams& params,
3358 const RuntimeShape& input_shape, const float* input_data,
3359 const RuntimeShape& output_shape, float* output_data,
3360 int start_batch, int end_batch)
3361 : params(params),
3362 input_shape(input_shape),
3363 input_data(input_data),
3364 output_shape(output_shape),
3365 output_data(output_data),
3366 start_batch(start_batch),
3367 end_batch(end_batch) {}
RunSoftmaxWorkerTask3368 void Run() override {
3369 SoftmaxImpl(params, input_shape, input_data, output_shape, output_data,
3370 start_batch, end_batch);
3371 }
3372
3373 private:
3374 const tflite::SoftmaxParams& params;
3375 const RuntimeShape& input_shape;
3376 const float* input_data;
3377 const RuntimeShape& output_shape;
3378 float* output_data;
3379 int start_batch;
3380 int end_batch;
3381 };
3382
3383 inline void Softmax(const SoftmaxParams& params,
3384 const RuntimeShape& input_shape, const float* input_data,
3385 const RuntimeShape& output_shape, float* output_data,
3386 CpuBackendContext* cpu_backend_context = nullptr) {
3387 ruy::profiler::ScopeLabel label("Softmax");
3388
3389 // We picture softmax input as a 2-D matrix while the last dim is the logit
3390 // dim, and the rest dims will be the batch dim for the 2-D matrix.
3391 const int batch_size =
3392 FlatSizeSkipDim(input_shape, input_shape.DimensionsCount() - 1);
3393 constexpr int kMinBatchPerThread = 8;
3394 int thread_count = batch_size / kMinBatchPerThread;
3395 thread_count = thread_count > 0 ? thread_count : 1;
3396 const int capped_thread_count =
3397 cpu_backend_context == nullptr
3398 ? 1
3399 : std::min(thread_count, cpu_backend_context->max_num_threads());
3400 if (capped_thread_count == 1) {
3401 SoftmaxImpl(params, input_shape, input_data, output_shape, output_data, 0,
3402 batch_size);
3403 } else {
3404 std::vector<SoftmaxWorkerTask> tasks;
3405 // TODO(b/131746020) don't create new heap allocations every time.
3406 // At least we make it a single heap allocation by using reserve().
3407 tasks.reserve(capped_thread_count);
3408 int batch_start = 0;
3409 for (int i = 0; i < capped_thread_count; ++i) {
3410 // Try to distribute the tasks as even as possible.
3411 int batch_end =
3412 batch_start + (batch_size - batch_start) / (capped_thread_count - i);
3413 tasks.emplace_back(params, input_shape, input_data, output_shape,
3414 output_data, batch_start, batch_end);
3415 batch_start = batch_end;
3416 }
3417 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
3418 cpu_backend_context);
3419 }
3420 }
3421
3422 template <typename T>
QuantizeSoftmaxOutput(float prob_rescaled,int32_t zero_point)3423 inline int32_t QuantizeSoftmaxOutput(float prob_rescaled, int32_t zero_point) {
3424 const int32_t prob_rnd = static_cast<int32_t>(std::round(prob_rescaled));
3425 return prob_rnd + zero_point;
3426 }
3427
3428 #if !__aarch64__
3429 // With ARM64, rounding is faster than add + truncation.
3430 template <>
3431 inline int32_t QuantizeSoftmaxOutput<uint8_t>(float prob_rescaled,
3432 int32_t zero_point) {
3433 return static_cast<int32_t>(prob_rescaled + 0.5f);
3434 }
3435 #endif
3436
PopulateSoftmaxLookupTable(SoftmaxParams * data,float input_scale,float beta)3437 inline void PopulateSoftmaxLookupTable(SoftmaxParams* data, float input_scale,
3438 float beta) {
3439 const float scale = -input_scale * beta;
3440 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3441 for (int32_t val = 0; val <= max_uint8; ++val) {
3442 data->table[max_uint8 - val] = expf(scale * val);
3443 }
3444 }
3445
3446 template <typename In, typename Out>
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const In * input_data,const RuntimeShape & output_shape,Out * output_data)3447 inline void Softmax(const SoftmaxParams& params,
3448 const RuntimeShape& input_shape, const In* input_data,
3449 const RuntimeShape& output_shape, Out* output_data) {
3450 const int trailing_dim = input_shape.DimensionsCount() - 1;
3451 const int excluding_last_dim =
3452 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3453 const int last_dim =
3454 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3455
3456 const int32_t clamp_max = std::numeric_limits<Out>::max();
3457 const int32_t clamp_min = std::numeric_limits<Out>::min();
3458 for (int i = 0; i < excluding_last_dim; ++i) {
3459 int32_t max_val = std::numeric_limits<In>::min();
3460 // Find max quantized value.
3461 for (int j = 0; j < last_dim; ++j) {
3462 max_val = std::max(max_val, static_cast<int32_t>(input_data[j]));
3463 }
3464
3465 float sum_exp = 0.0f;
3466 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3467 const float* table_offset = ¶ms.table[max_uint8 - max_val];
3468 // Calculate normalizer sum(exp(x)).
3469 for (int j = 0; j < last_dim; ++j) {
3470 sum_exp += table_offset[input_data[j]];
3471 }
3472
3473 const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
3474 // Normalize and quantize probabilities.
3475 for (int j = 0; j < last_dim; ++j) {
3476 const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
3477 const int32_t prob_quantized =
3478 QuantizeSoftmaxOutput<Out>(prob_rescaled, params.zero_point);
3479 output_data[j] = static_cast<Out>(
3480 std::max(std::min(clamp_max, prob_quantized), clamp_min));
3481 }
3482 input_data += last_dim;
3483 output_data += last_dim;
3484 }
3485 }
3486
3487 // Here's the softmax LUT optimization strategy:
3488 // For softmax, we can do some mathmetically equivalent transformation:
3489 //
3490 // softmax(x) = e^x / sum(e^x, 0...n) ===> equals to
3491 // softmax(x) = e^(x - CONST) / sum(e^(x - CONST), 0...n)
3492 //
3493 // For quantization, `x` in our case is (input_q - input_zp) * input_s
3494 // For uint8 case (int8 can be handled similarly), the range is [0, 255]
3495 //
3496 // so if we let
3497 // CONST = (255 - input_zp) * input_s
3498 // then we will have:
3499 // softmax(x) = e^((input_q - 255) * input_s) --------- (1)
3500 // /
3501 // sum(e^(input_q - 255) * input_s, 0...n) -------- (2)
3502 //
3503 // the good thing about (1) is it's within the range of (0, 1), so we can
3504 // approximate its result with uint16.
3505 // (1) = uint8_out * 1 / 2^16.
3506 //
3507 // so (1) is lookup_uint8_table(input_zp) * 1 / 2^16.
3508 // then (2) is essentially the following:
3509 // sum(lookup_uint8_table(input_zp), 0...n) / 2^16.
3510 //
3511 // since (output_q - output_zp) * output_s = softmax(x)
3512 // output_q = lookup_uint8_table(input_zp)
3513 // /
3514 // (sum(lookup_uint8_table(input_zp), 0...n) * output_s)
3515 // +
3516 // output_zp
3517 //
3518 // We can actually further improve the performance by using uint8 instead of
3519 // uint16. But that we may lose some accuracy, so we need to pay attention
3520 // to that.
PopulateSoftmaxUInt8LookupTable(SoftmaxParams * data,float input_scale,float beta)3521 inline void PopulateSoftmaxUInt8LookupTable(SoftmaxParams* data,
3522 float input_scale, float beta) {
3523 const float scale = input_scale * beta;
3524 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3525 const int32_t max_uint16 = std::numeric_limits<uint16_t>::max();
3526
3527 for (int32_t val = 0; val <= max_uint8; ++val) {
3528 float input_to_exp = scale * (val - max_uint8);
3529 int32_t temp = static_cast<int>(expf(input_to_exp) * max_uint16 + 0.5);
3530 temp = std::min(max_uint16, temp);
3531 uint8_t part1 = temp >> 8;
3532 uint8_t part2 = temp & 0xff;
3533 data->uint8_table1[val] = static_cast<uint8_t>(part1);
3534 data->uint8_table2[val] = static_cast<uint8_t>(part2);
3535 }
3536 }
3537
FindMaxValue(int size,const uint8_t * input_data,uint8_t offset)3538 inline int FindMaxValue(int size, const uint8_t* input_data, uint8_t offset) {
3539 int32_t max_val = std::numeric_limits<uint8_t>::min();
3540 int j = 0;
3541 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3542 uint8x16_t max_val_dup = vdupq_n_u8(max_val);
3543 uint8x16_t offset_dup = vdupq_n_u8(offset);
3544 for (; j <= size - 16; j += 16) {
3545 uint8x16_t input_value = vld1q_u8(input_data + j);
3546 input_value = veorq_u8(input_value, offset_dup);
3547 max_val_dup = vmaxq_u8(input_value, max_val_dup);
3548 }
3549 max_val = std::max(max_val, static_cast<int32>(vmaxvq_u8(max_val_dup)));
3550 #endif
3551
3552 for (; j < size; ++j) {
3553 max_val = std::max(max_val, static_cast<int32_t>(input_data[j] ^ offset));
3554 }
3555 return max_val;
3556 }
3557
3558 #ifdef USE_NEON
3559 // Value_to_store layout:
3560 // [high_high, high_low, low_high, low_low].
StoreValue(int32x4x4_t value_to_store,int8_t * output)3561 inline void StoreValue(int32x4x4_t value_to_store, int8_t* output) {
3562 const int16x8_t result_1 = vcombine_s16(vqmovn_s32(value_to_store.val[1]),
3563 vqmovn_s32(value_to_store.val[0]));
3564 const int16x8_t result_2 = vcombine_s16(vqmovn_s32(value_to_store.val[3]),
3565 vqmovn_s32(value_to_store.val[2]));
3566 const int8x16_t result =
3567 vcombine_s8(vqmovn_s16(result_2), vqmovn_s16(result_1));
3568 vst1q_s8(output, result);
3569 }
3570
3571 // Value_to_store layout:
3572 // [high_high, high_low, low_high, low_low].
StoreValue(int32x4x4_t value_to_store,uint8_t * output)3573 inline void StoreValue(int32x4x4_t value_to_store, uint8_t* output) {
3574 const uint16x8_t result_1 =
3575 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[1])),
3576 vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[0])));
3577 const uint16x8_t result_2 =
3578 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[3])),
3579 vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[2])));
3580 const uint8x16_t result =
3581 vcombine_u8(vqmovn_u16(result_2), vqmovn_u16(result_1));
3582 vst1q_u8(output, result);
3583 }
3584
3585 #endif
3586
3587 template <typename In, typename Out>
SoftmaxInt8LUT(const SoftmaxParams & params,const RuntimeShape & input_shape,const In * input_data,const RuntimeShape & output_shape,Out * output_data)3588 inline void SoftmaxInt8LUT(const SoftmaxParams& params,
3589 const RuntimeShape& input_shape,
3590 const In* input_data,
3591 const RuntimeShape& output_shape, Out* output_data) {
3592 ruy::profiler::ScopeLabel label("SoftmaxInt8LUT");
3593
3594 const int trailing_dim = input_shape.DimensionsCount() - 1;
3595 const int excluding_last_dim =
3596 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3597 const int last_dim =
3598 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3599
3600 const int32_t clamp_max = std::numeric_limits<Out>::max();
3601 const int32_t clamp_min = std::numeric_limits<Out>::min();
3602
3603 // Offset is used to interpret the input data "correctly".
3604 // If the input is uint8, the data will be unchanged.
3605 // If the input is int8, since it will be reinterpret as uint8.
3606 // e.g.,
3607 // int8 127 will be applied "offset" to become 255 in uint8.
3608 uint8_t offset = 0;
3609 if (std::is_same<In, int8>::value) {
3610 offset = 0x80;
3611 }
3612
3613 const uint8_t* input_data_uint = reinterpret_cast<const uint8_t*>(input_data);
3614
3615 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3616 // This code uses ARM64-only instructions.
3617 // TODO(b/143709993): Port to ARMv7
3618
3619 // Load the tables into registers. (4*4 128-bit registers)
3620 uint8x16x4_t table1[4];
3621 table1[0] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 0);
3622 table1[1] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 1);
3623 table1[2] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 2);
3624 table1[3] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 3);
3625
3626 uint8x16x4_t table2[4];
3627 table2[0] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 0);
3628 table2[1] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 1);
3629 table2[2] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 2);
3630 table2[3] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 3);
3631 #endif
3632
3633 for (int i = 0; i < excluding_last_dim; ++i) {
3634 // Find max quantized value.
3635 int32_t max_val = FindMaxValue(last_dim, input_data_uint, offset);
3636
3637 int32 sum_exp = 0;
3638 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3639 const uint8_t table_offset = max_uint8 - max_val;
3640
3641 // Calculate normalizer sum(exp(x)).
3642 int sum_j = 0;
3643 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3644 uint8x16_t table_offset_dup = vdupq_n_u8(table_offset);
3645 uint8x16_t offset_dup = vdupq_n_u8(offset);
3646 uint32x4_t sum_4 = vdupq_n_u32(0);
3647 const int multiplier_shift = 8;
3648 for (; sum_j <= last_dim - 16; sum_j += 16) {
3649 uint8x16_t input_value = vld1q_u8(input_data_uint + sum_j);
3650 input_value = veorq_u8(input_value, offset_dup);
3651 input_value = vaddq_u8(input_value, table_offset_dup);
3652
3653 const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
3654 const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
3655
3656 uint16x8_t exp_value1 =
3657 vshll_n_u8(vget_high_u8(output1), multiplier_shift);
3658 uint16x8_t exp_value2 =
3659 vshll_n_u8(vget_low_u8(output1), multiplier_shift);
3660
3661 exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
3662 exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
3663
3664 sum_4 = vpadalq_u16(sum_4, exp_value1);
3665 sum_4 = vpadalq_u16(sum_4, exp_value2);
3666 }
3667 int temp = vgetq_lane_u32(sum_4, 0) + vgetq_lane_u32(sum_4, 1) +
3668 vgetq_lane_u32(sum_4, 2) + vgetq_lane_u32(sum_4, 3);
3669 sum_exp += temp;
3670
3671 #endif
3672 for (; sum_j < last_dim; ++sum_j) {
3673 const uint8_t index = (input_data_uint[sum_j] ^ offset) + table_offset;
3674
3675 uint8_t part1 = params.uint8_table1[index];
3676 uint8_t part2 = params.uint8_table2[index];
3677 sum_exp += ((part1 << 8) + part2);
3678 }
3679
3680 const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
3681
3682 int32 multiplier, shift;
3683 QuantizeMultiplier(inv_sum_exp, &multiplier, &shift);
3684
3685 // Normalize and quantize probabilities.
3686 int j = 0;
3687 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3688 const int32x4_t output_zp_dup = vdupq_n_s32(params.zero_point);
3689 const int32x4_t max_val_dup = vdupq_n_s32(clamp_max);
3690 const int32x4_t min_val_dup = vdupq_n_s32(clamp_min);
3691
3692 for (; j <= last_dim - 16; j += 16) {
3693 uint8x16_t input_value = vld1q_u8(input_data_uint + j);
3694 input_value = veorq_u8(input_value, offset_dup);
3695 input_value = vaddq_u8(input_value, table_offset_dup);
3696
3697 const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
3698 const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
3699
3700 uint16x8_t exp_value1 =
3701 vshll_n_u8(vget_high_u8(output1), multiplier_shift);
3702 uint16x8_t exp_value2 =
3703 vshll_n_u8(vget_low_u8(output1), multiplier_shift);
3704
3705 exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
3706 exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
3707
3708 int32x4x4_t output_value;
3709 output_value.val[0] =
3710 vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value1)));
3711 output_value.val[1] =
3712 vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value1)));
3713 output_value.val[2] =
3714 vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value2)));
3715 output_value.val[3] =
3716 vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value2)));
3717
3718 int32x4x4_t temp_val =
3719 MultiplyByQuantizedMultiplier4Rows(output_value, multiplier, shift);
3720
3721 temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
3722 temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
3723 temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
3724 temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
3725
3726 temp_val.val[0] =
3727 vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
3728 temp_val.val[1] =
3729 vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
3730 temp_val.val[2] =
3731 vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
3732 temp_val.val[3] =
3733 vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
3734
3735 StoreValue(temp_val, output_data + j);
3736 }
3737 #endif
3738 for (; j < last_dim; ++j) {
3739 const uint8_t index = (input_data_uint[j] ^ offset) + table_offset;
3740 const uint8_t part1 = params.uint8_table1[index];
3741 const uint8_t part2 = params.uint8_table2[index];
3742 const int32_t exp_value = (part1 << 8) + part2;
3743 const int32_t output_value =
3744 MultiplyByQuantizedMultiplier(exp_value, multiplier, shift);
3745
3746 output_data[j] = static_cast<Out>(std::max(
3747 std::min(clamp_max, output_value + params.zero_point), clamp_min));
3748 }
3749 input_data_uint += last_dim;
3750 output_data += last_dim;
3751 }
3752 }
3753
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3754 inline void LogSoftmax(const SoftmaxParams& params,
3755 const RuntimeShape& input_shape, const float* input_data,
3756 const RuntimeShape& output_shape, float* output_data) {
3757 ruy::profiler::ScopeLabel label("LogSoftmax");
3758 const int trailing_dim = input_shape.DimensionsCount() - 1;
3759 const int outer_size =
3760 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3761 const int depth =
3762 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3763
3764 for (int i = 0; i < outer_size; ++i) {
3765 VectorMap<const float> block_input(input_data + i * depth, depth, 1);
3766 VectorMap<float> block_output(output_data + i * depth, depth, 1);
3767 // Find max element value which we'll use to ensure numerical stability
3768 // taking advantage of the following equality:
3769 // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
3770 const float max = block_input.maxCoeff();
3771 const float log_sum = std::log((block_input.array() - max).exp().sum());
3772 block_output = block_input.array() - max - log_sum;
3773 }
3774 }
3775
3776 // Backwards compatibility. Less optimized than below version.
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3777 inline void LogSoftmax(const SoftmaxParams& params,
3778 const RuntimeShape& input_shape, const uint8* input_data,
3779 const RuntimeShape& output_shape, uint8* output_data) {
3780 reference_ops::LogSoftmax(params, input_shape, input_data, output_shape,
3781 output_data);
3782 }
3783
3784 // Compute LogSoftmax as (x - x_max) - ln(sum(e^(x_i - x_max)...)
3785 // as done in tf.nn.log_softmax to prevent underflow and overflow.
3786 // This is in contrast to just log(softmax(x))
3787 //
3788 // To handle quantization, first dequantize the inputs (from doing
3789 // e^(input scale * val) where we ignore the zero point since it cancels
3790 // out during subtraction due to the ln) and do a rescale at the end to int8.
3791 //
3792 // Notably this makes use of float and is intended as the optimized
3793 // form for quantized execution on CPU. For a fully integer version,
3794 // see the reference op.
3795 //
3796 // TODO(tflite): notes for optimization:
3797 // 1) See if e^ is also bottleneck in the reference fully-integer
3798 // version and apply lookup there and compare.
3799 template <typename T>
LogSoftmax(const SoftmaxParams & params,float input_scale,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)3800 inline void LogSoftmax(const SoftmaxParams& params, float input_scale,
3801 const RuntimeShape& input_shape, const T* input_data,
3802 const RuntimeShape& output_shape, T* output_data) {
3803 ruy::profiler::ScopeLabel label("LogSoftmax");
3804 const int trailing_dim = input_shape.DimensionsCount() - 1;
3805 const int excluding_last_dim =
3806 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3807 const int last_dim =
3808 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3809
3810 const int32_t clamp_max = std::numeric_limits<T>::max();
3811 const int32_t clamp_min = std::numeric_limits<T>::min();
3812
3813 for (int i = 0; i < excluding_last_dim; ++i) {
3814 T max_val = std::numeric_limits<T>::min();
3815 // Find max quantized value.
3816 for (int j = 0; j < last_dim; ++j) {
3817 max_val = std::max(max_val, input_data[j]);
3818 }
3819
3820 float sum_exp = 0.0f;
3821 const int32_t max_uint8 = std::numeric_limits<uint8>::max();
3822 // Offset into table to compute exp(scale*(x - xmax)) instead of
3823 // exp(scale*(x)) to prevent overflow.
3824 const float* table_offset = ¶ms.table[max_uint8 - max_val];
3825 // Calculate sum(exp(scale*(x - x_max))).
3826 for (int j = 0; j < last_dim; ++j) {
3827 sum_exp += table_offset[input_data[j]];
3828 }
3829 const float log_sum_exp = std::log(sum_exp);
3830
3831 // params.scale is the output scale.
3832 const float scale = input_scale / params.scale;
3833 const float precomputed =
3834 (input_scale * max_val + log_sum_exp) / params.scale;
3835 for (int j = 0; j < last_dim; ++j) {
3836 // Equivalent to (input_scale * (input_data[j] - max_val) - log_sum_exp) /
3837 // output_scale.
3838 const float log_prob = scale * input_data[j] - precomputed;
3839
3840 // TODO(tflite): look into better solution.
3841 // Use std::rint over std::round (which is used in
3842 // FakeQuant) since it's multiple times faster on tested arm32.
3843 const int32_t prob_quantized = std::rint(log_prob) + params.zero_point;
3844 output_data[j] = static_cast<T>(
3845 std::max(std::min(clamp_max, prob_quantized), clamp_min));
3846 }
3847 input_data += last_dim;
3848 output_data += last_dim;
3849 }
3850 }
3851
Logistic(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3852 inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
3853 const RuntimeShape& output_shape, float* output_data) {
3854 ruy::profiler::ScopeLabel label("Logistic");
3855 auto input_map = MapAsVector(input_data, input_shape);
3856 auto output_map = MapAsVector(output_data, output_shape);
3857 output_map.array() =
3858 input_map.array().unaryExpr(Eigen::internal::scalar_logistic_op<float>());
3859 }
3860
3861 // Convenience version that allows, for example, generated-code calls to be
3862 // uniform between data types.
Logistic(const LogisticParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3863 inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
3864 const float* input_data, const RuntimeShape& output_shape,
3865 float* output_data) {
3866 // Drop params: not needed.
3867 Logistic(input_shape, input_data, output_shape, output_data);
3868 }
3869
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)3870 inline void Logistic(const LogisticParams& params,
3871 const RuntimeShape& input_shape, const int16* input_data,
3872 const RuntimeShape& output_shape, int16* output_data) {
3873 ruy::profiler::ScopeLabel label("Logistic/Int16");
3874 const int flat_size = MatchingFlatSize(input_shape, output_shape);
3875
3876 for (int i = 0; i < flat_size; i++) {
3877 }
3878
3879 int c = 0;
3880 const int16* input_data_ptr = input_data;
3881 int16* output_data_ptr = output_data;
3882 #ifdef GEMMLOWP_NEON
3883 {
3884 // F0 uses 0 integer bits, range [-1, 1].
3885 // This is the return type of math functions such as tanh, logistic,
3886 // whose range is in [-1, 1].
3887 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
3888 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
3889 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
3890
3891 for (; c <= flat_size - 16; c += 16) {
3892 F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
3893 F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
3894 F0 output0 = gemmlowp::logistic(input0);
3895 F0 output1 = gemmlowp::logistic(input1);
3896 vst1q_s16(output_data_ptr, output0.raw());
3897 vst1q_s16(output_data_ptr + 8, output1.raw());
3898
3899 input_data_ptr += 16;
3900 output_data_ptr += 16;
3901 }
3902 for (; c <= flat_size - 8; c += 8) {
3903 F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
3904 F0 output = gemmlowp::logistic(input);
3905 vst1q_s16(output_data_ptr, output.raw());
3906
3907 input_data_ptr += 8;
3908 output_data_ptr += 8;
3909 }
3910 }
3911 #endif
3912 #ifdef GEMMLOWP_SSE4
3913 {
3914 // F0 uses 0 integer bits, range [-1, 1].
3915 // This is the return type of math functions such as tanh, logistic,
3916 // whose range is in [-1, 1].
3917 using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
3918 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
3919 using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
3920
3921 for (; c <= flat_size - 16; c += 16) {
3922 F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
3923 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
3924 F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
3925 reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
3926 F0 output0 = gemmlowp::logistic(input0);
3927 F0 output1 = gemmlowp::logistic(input1);
3928 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
3929 output0.raw().v);
3930 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
3931 output1.raw().v);
3932 input_data_ptr += 16;
3933 output_data_ptr += 16;
3934 }
3935 for (; c <= flat_size - 8; c += 8) {
3936 F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
3937 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
3938 F0 output = gemmlowp::logistic(input);
3939 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
3940 output.raw().v);
3941 input_data_ptr += 8;
3942 output_data_ptr += 8;
3943 }
3944 }
3945 #endif
3946
3947 {
3948 // F0 uses 0 integer bits, range [-1, 1].
3949 // This is the return type of math functions such as tanh, logistic,
3950 // whose range is in [-1, 1].
3951 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3952 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
3953 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
3954
3955 for (; c < flat_size; ++c) {
3956 F3 input = F3::FromRaw(*input_data_ptr);
3957 F0 output = gemmlowp::logistic(input);
3958 *output_data_ptr = output.raw();
3959
3960 ++input_data_ptr;
3961 ++output_data_ptr;
3962 }
3963 }
3964 }
3965
Tanh(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3966 inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
3967 const RuntimeShape& output_shape, float* output_data) {
3968 ruy::profiler::ScopeLabel label("Tanh");
3969 auto input_map = MapAsVector(input_data, input_shape);
3970 auto output_map = MapAsVector(output_data, output_shape);
3971 output_map.array() = input_map.array().tanh();
3972 }
3973
3974 // Convenience version that allows, for example, generated-code calls to be
3975 // uniform between data types.
Tanh(const TanhParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3976 inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
3977 const float* input_data, const RuntimeShape& output_shape,
3978 float* output_data) {
3979 // Drop params: not needed.
3980 Tanh(input_shape, input_data, output_shape, output_data);
3981 }
3982
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)3983 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
3984 const int16* input_data, const RuntimeShape& output_shape,
3985 int16* output_data) {
3986 ruy::profiler::ScopeLabel label("Tanh/Int16");
3987 const int input_left_shift = params.input_left_shift;
3988 // Support for shifts is limited until we have a parameterized version of
3989 // SaturatingRoundingMultiplyByPOT().
3990 TFLITE_DCHECK_GE(input_left_shift, 0);
3991 TFLITE_DCHECK_LE(input_left_shift, 1);
3992
3993 const int flat_size = MatchingFlatSize(input_shape, output_shape);
3994
3995 int c = 0;
3996 const int16* input_data_ptr = input_data;
3997 int16* output_data_ptr = output_data;
3998 #ifdef GEMMLOWP_NEON
3999 {
4000 // F0 uses 0 integer bits, range [-1, 1].
4001 // This is the return type of math functions such as tanh, logistic,
4002 // whose range is in [-1, 1].
4003 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
4004 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4005 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
4006
4007 if (input_left_shift == 0) {
4008 for (; c <= flat_size - 16; c += 16) {
4009 F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
4010 F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
4011 F0 output0 = gemmlowp::tanh(input0);
4012 F0 output1 = gemmlowp::tanh(input1);
4013 vst1q_s16(output_data_ptr, output0.raw());
4014 vst1q_s16(output_data_ptr + 8, output1.raw());
4015
4016 input_data_ptr += 16;
4017 output_data_ptr += 16;
4018 }
4019 for (; c <= flat_size - 8; c += 8) {
4020 F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
4021 F0 output = gemmlowp::tanh(input);
4022 vst1q_s16(output_data_ptr, output.raw());
4023
4024 input_data_ptr += 8;
4025 output_data_ptr += 8;
4026 }
4027 } else {
4028 for (; c <= flat_size - 16; c += 16) {
4029 F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4030 vld1q_s16(input_data_ptr)));
4031 F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4032 vld1q_s16(input_data_ptr + 8)));
4033 F0 output0 = gemmlowp::tanh(input0);
4034 F0 output1 = gemmlowp::tanh(input1);
4035 vst1q_s16(output_data_ptr, output0.raw());
4036 vst1q_s16(output_data_ptr + 8, output1.raw());
4037
4038 input_data_ptr += 16;
4039 output_data_ptr += 16;
4040 }
4041 for (; c <= flat_size - 8; c += 8) {
4042 F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4043 vld1q_s16(input_data_ptr)));
4044 F0 output = gemmlowp::tanh(input);
4045 vst1q_s16(output_data_ptr, output.raw());
4046
4047 input_data_ptr += 8;
4048 output_data_ptr += 8;
4049 }
4050 }
4051 }
4052 #endif
4053 #ifdef GEMMLOWP_SSE4
4054 {
4055 // F0 uses 0 integer bits, range [-1, 1].
4056 // This is the return type of math functions such as tanh, logistic,
4057 // whose range is in [-1, 1].
4058 using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
4059 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4060 using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
4061
4062 if (input_left_shift == 0) {
4063 for (; c <= flat_size - 16; c += 16) {
4064 F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4065 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4066 F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4067 reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
4068 F0 output0 = gemmlowp::tanh(input0);
4069 F0 output1 = gemmlowp::tanh(input1);
4070 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4071 output0.raw().v);
4072 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4073 output1.raw().v);
4074
4075 input_data_ptr += 16;
4076 output_data_ptr += 16;
4077 }
4078 for (; c <= flat_size - 8; c += 8) {
4079 F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4080 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4081 F0 output = gemmlowp::tanh(input);
4082 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4083 output.raw().v);
4084 input_data_ptr += 8;
4085 output_data_ptr += 8;
4086 }
4087 } else {
4088 for (; c <= flat_size - 16; c += 16) {
4089 F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4090 gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4091 reinterpret_cast<const __m128i*>(input_data_ptr)))));
4092 F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4093 gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4094 reinterpret_cast<const __m128i*>(input_data_ptr + 8)))));
4095 F0 output0 = gemmlowp::tanh(input0);
4096 F0 output1 = gemmlowp::tanh(input1);
4097 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4098 output0.raw().v);
4099 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4100 output1.raw().v);
4101
4102 input_data_ptr += 16;
4103 output_data_ptr += 16;
4104 }
4105 for (; c <= flat_size - 8; c += 8) {
4106 F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4107 gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4108 reinterpret_cast<const __m128i*>(input_data_ptr)))));
4109 F0 output = gemmlowp::tanh(input);
4110 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4111 output.raw().v);
4112 input_data_ptr += 8;
4113 output_data_ptr += 8;
4114 }
4115 }
4116 }
4117 #endif
4118
4119 {
4120 // F0 uses 0 integer bits, range [-1, 1].
4121 // This is the return type of math functions such as tanh, logistic,
4122 // whose range is in [-1, 1].
4123 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
4124 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4125 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
4126
4127 if (input_left_shift == 0) {
4128 for (; c < flat_size; ++c) {
4129 F3 input = F3::FromRaw(*input_data_ptr);
4130 F0 output = gemmlowp::tanh(input);
4131 *output_data_ptr = output.raw();
4132
4133 ++input_data_ptr;
4134 ++output_data_ptr;
4135 }
4136 } else {
4137 for (; c < flat_size; ++c) {
4138 F3 input = F3::FromRaw(
4139 gemmlowp::SaturatingRoundingMultiplyByPOT<1>(*input_data_ptr));
4140 F0 output = gemmlowp::tanh(input);
4141 *output_data_ptr = output.raw();
4142
4143 ++input_data_ptr;
4144 ++output_data_ptr;
4145 }
4146 }
4147 }
4148 }
4149
4150 template <typename SrcT, typename DstT>
Cast(const RuntimeShape & input_shape,const SrcT * input_data,const RuntimeShape & output_shape,DstT * output_data)4151 inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
4152 const RuntimeShape& output_shape, DstT* output_data) {
4153 ruy::profiler::ScopeLabel label("Cast");
4154 auto input_map = MapAsVector(input_data, input_shape);
4155 auto output_map = MapAsVector(output_data, output_shape);
4156 output_map.array() = input_map.array().template cast<DstT>();
4157 }
4158
Floor(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4159 inline void Floor(const RuntimeShape& input_shape, const float* input_data,
4160 const RuntimeShape& output_shape, float* output_data) {
4161 ruy::profiler::ScopeLabel label("Floor");
4162 auto input_map = MapAsVector(input_data, input_shape);
4163 auto output_map = MapAsVector(output_data, output_shape);
4164 output_map.array() = Eigen::floor(input_map.array());
4165 }
4166
Ceil(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4167 inline void Ceil(const RuntimeShape& input_shape, const float* input_data,
4168 const RuntimeShape& output_shape, float* output_data) {
4169 ruy::profiler::ScopeLabel label("Ceil");
4170 auto input_map = MapAsVector(input_data, input_shape);
4171 auto output_map = MapAsVector(output_data, output_shape);
4172 output_map.array() = Eigen::ceil(input_map.array());
4173 }
4174
4175 // Helper methods for BatchToSpaceND.
4176 // `spatial_index_dim` specifies post-crop offset index in this spatial
4177 // dimension, i.e. spatial offset introduced by flattening batch to spatial
4178 // dimension minus the crop size at beginning. `block_shape_dim` is the block
4179 // size in current dimension. `input_dim` and `output_dim` are input and output
4180 // size of BatchToSpaceND operation in current dimension.
4181 // Output start index is inclusive and end index is exclusive.
GetIndexRange(int spatial_index_dim,int block_shape_dim,int input_dim,int output_dim,int * start_index,int * end_index)4182 inline void GetIndexRange(int spatial_index_dim, int block_shape_dim,
4183 int input_dim, int output_dim, int* start_index,
4184 int* end_index) {
4185 // (*start_index) * block_shape_dim is effectively rounded up to the next
4186 // multiple of block_shape_dim by the integer division.
4187 *start_index =
4188 std::max(0, (-spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
4189 // Similarly, (*end_index) * block_shape_dim is rounded up too (note that
4190 // end_index is exclusive).
4191 *end_index = std::min(
4192 input_dim,
4193 (output_dim - spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
4194 }
4195
4196 template <typename T>
BatchToSpaceND(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const int32 * block_shape_data,const RuntimeShape & unextended_input3_shape,const int32 * crops_data,const RuntimeShape & unextended_output_shape,T * output_data)4197 inline void BatchToSpaceND(
4198 const RuntimeShape& unextended_input1_shape, const T* input1_data,
4199 const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
4200 const RuntimeShape& unextended_input3_shape, const int32* crops_data,
4201 const RuntimeShape& unextended_output_shape, T* output_data) {
4202 ruy::profiler::ScopeLabel label("BatchToSpaceND");
4203
4204 TFLITE_DCHECK_GE(unextended_input1_shape.DimensionsCount(), 3);
4205 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
4206 TFLITE_DCHECK_EQ(unextended_input1_shape.DimensionsCount(),
4207 unextended_output_shape.DimensionsCount());
4208
4209 // Extends the input/output shape from 3D to 4D if needed, NHC -> NH1C.
4210 auto extend_shape = [](const RuntimeShape& shape) {
4211 if (shape.DimensionsCount() == 4) {
4212 return shape;
4213 }
4214 RuntimeShape new_shape(4, 1);
4215 new_shape.SetDim(0, shape.Dims(0));
4216 new_shape.SetDim(1, shape.Dims(1));
4217 new_shape.SetDim(3, shape.Dims(2));
4218 return new_shape;
4219 };
4220 const RuntimeShape input1_shape = extend_shape(unextended_input1_shape);
4221 const RuntimeShape output_shape = extend_shape(unextended_output_shape);
4222
4223 const int output_width = output_shape.Dims(2);
4224 const int output_height = output_shape.Dims(1);
4225 const int output_batch_size = output_shape.Dims(0);
4226
4227 const int depth = input1_shape.Dims(3);
4228 const int input_width = input1_shape.Dims(2);
4229 const int input_height = input1_shape.Dims(1);
4230 const int input_batch_size = input1_shape.Dims(0);
4231
4232 const int block_shape_height = block_shape_data[0];
4233 const int block_shape_width =
4234 unextended_input1_shape.DimensionsCount() == 4 ? block_shape_data[1] : 1;
4235 const int crops_top = crops_data[0];
4236 const int crops_left =
4237 unextended_input1_shape.DimensionsCount() == 4 ? crops_data[2] : 0;
4238
4239 for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
4240 const int out_batch = in_batch % output_batch_size;
4241 const int spatial_offset = in_batch / output_batch_size;
4242
4243 int in_h_start = 0;
4244 int in_h_end = 0;
4245 // GetIndexRange ensures start and end indices are in [0, output_height).
4246 GetIndexRange(spatial_offset / block_shape_width - crops_top,
4247 block_shape_height, input_height, output_height, &in_h_start,
4248 &in_h_end);
4249
4250 for (int in_h = in_h_start; in_h < in_h_end; ++in_h) {
4251 const int out_h = in_h * block_shape_height +
4252 spatial_offset / block_shape_width - crops_top;
4253 TFLITE_DCHECK_GE(out_h, 0);
4254 TFLITE_DCHECK_LT(out_h, output_height);
4255
4256 int in_w_start = 0;
4257 int in_w_end = 0;
4258 // GetIndexRange ensures start and end indices are in [0, output_width).
4259 GetIndexRange(spatial_offset % block_shape_width - crops_left,
4260 block_shape_width, input_width, output_width, &in_w_start,
4261 &in_w_end);
4262
4263 for (int in_w = in_w_start; in_w < in_w_end; ++in_w) {
4264 const int out_w = in_w * block_shape_width +
4265 spatial_offset % block_shape_width - crops_left;
4266 TFLITE_DCHECK_GE(out_w, 0);
4267 TFLITE_DCHECK_LT(out_w, output_width);
4268 T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
4269 const T* in =
4270 input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
4271 memcpy(out, in, depth * sizeof(T));
4272 }
4273 }
4274 }
4275 }
4276
4277 template <typename T>
TypedMemset(void * ptr,T value,size_t num)4278 void TypedMemset(void* ptr, T value, size_t num) {
4279 // Optimization for common cases where memset() will suffice.
4280 if (value == 0 || std::is_same<T, uint8_t>::value) {
4281 memset(ptr, value, num * sizeof(T));
4282 } else {
4283 // Default implementation for cases where memset() will not preserve the
4284 // bytes, e.g., typically when sizeof(T) > sizeof(uint8_t).
4285 char* pos = static_cast<char*>(ptr);
4286 for (size_t i = 0; i < num; ++i) {
4287 memcpy(pos, &value, sizeof(T));
4288 pos = pos + sizeof(T);
4289 }
4290 }
4291 }
4292
4293 // This makes heavy use of Offset, along with conditional branches. There may be
4294 // opportunities for improvement.
4295 //
4296 // There are two versions of pad: Pad and PadV2. In PadV2 there is a second
4297 // scalar input that provides the padding value. Therefore pad_value_ptr can be
4298 // equivalent to a simple input1_data. For Pad, it should point to a zero
4299 // value.
4300 //
4301 // Note that two typenames are required, so that T=P=int32 is considered a
4302 // specialization distinct from P=int32.
4303 template <typename T, typename P>
PadImpl(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4304 inline void PadImpl(const tflite::PadParams& op_params,
4305 const RuntimeShape& input_shape, const T* input_data,
4306 const P* pad_value_ptr, const RuntimeShape& output_shape,
4307 T* output_data) {
4308 ruy::profiler::ScopeLabel label("PadImpl");
4309 const int max_supported_dims = 5;
4310 const RuntimeShape ext_input_shape =
4311 RuntimeShape::ExtendedShape(max_supported_dims, input_shape);
4312 const RuntimeShape ext_output_shape =
4313 RuntimeShape::ExtendedShape(max_supported_dims, output_shape);
4314 TFLITE_DCHECK_LE(op_params.left_padding_count, max_supported_dims);
4315 TFLITE_DCHECK_LE(op_params.right_padding_count, max_supported_dims);
4316
4317 // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
4318 // to 4 dims (yes, we are "padding the padding").
4319 std::vector<int> left_padding_copy(max_supported_dims, 0);
4320 const int left_padding_extend =
4321 max_supported_dims - op_params.left_padding_count;
4322 for (int i = 0; i < op_params.left_padding_count; ++i) {
4323 left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
4324 }
4325 std::vector<int> right_padding_copy(max_supported_dims, 0);
4326 const int right_padding_extend =
4327 max_supported_dims - op_params.right_padding_count;
4328 for (int i = 0; i < op_params.right_padding_count; ++i) {
4329 right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
4330 }
4331
4332 const int output_batch = ext_output_shape.Dims(0);
4333 const int output_spatial_dim1 = ext_output_shape.Dims(1);
4334 const int output_spatial_dim2 = ext_output_shape.Dims(2);
4335 const int output_spatial_dim3 = ext_output_shape.Dims(3);
4336 const int output_channel = ext_output_shape.Dims(4);
4337
4338 const int left_b_padding = left_padding_copy[0];
4339 const int left_s1_padding = left_padding_copy[1];
4340 const int left_s2_padding = left_padding_copy[2];
4341 const int left_s3_padding = left_padding_copy[3];
4342 const int left_c_padding = left_padding_copy[4];
4343
4344 const int right_b_padding = right_padding_copy[0];
4345 const int right_s1_padding = right_padding_copy[1];
4346 const int right_s2_padding = right_padding_copy[2];
4347 const int right_s3_padding = right_padding_copy[3];
4348 const int right_c_padding = right_padding_copy[4];
4349
4350 const int input_depth = ext_input_shape.Dims(4);
4351 const T pad_value = *pad_value_ptr;
4352
4353 if (left_b_padding != 0) {
4354 TypedMemset<T>(output_data, pad_value,
4355 left_b_padding * output_spatial_dim1 * output_spatial_dim2 *
4356 output_spatial_dim3 * output_channel);
4357 }
4358 for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
4359 ++out_b) {
4360 if (left_s1_padding != 0) {
4361 TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, 0, 0, 0, 0),
4362 pad_value,
4363 left_s1_padding * output_spatial_dim2 *
4364 output_spatial_dim3 * output_channel);
4365 }
4366 for (int out_p = left_s1_padding;
4367 out_p < output_spatial_dim1 - right_s1_padding; ++out_p) {
4368 if (left_s2_padding != 0) {
4369 TypedMemset<T>(
4370 output_data + Offset(ext_output_shape, out_b, out_p, 0, 0, 0),
4371 pad_value, left_s2_padding * output_spatial_dim3 * output_channel);
4372 }
4373 for (int out_h = left_s2_padding;
4374 out_h < output_spatial_dim2 - right_s2_padding; ++out_h) {
4375 if (left_s3_padding != 0) {
4376 TypedMemset<T>(
4377 output_data + Offset(ext_output_shape, out_b, out_p, out_h, 0, 0),
4378 pad_value, left_s3_padding * output_channel);
4379 }
4380 for (int out_w = left_s3_padding;
4381 out_w < output_spatial_dim3 - right_s3_padding; ++out_w) {
4382 if (left_c_padding != 0) {
4383 TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, out_p,
4384 out_h, out_w, 0),
4385 pad_value, left_c_padding);
4386 }
4387
4388 T* out = output_data + Offset(ext_output_shape, out_b, out_p, out_h,
4389 out_w, left_c_padding);
4390 const T* in = input_data +
4391 Offset(ext_input_shape, out_b - left_b_padding,
4392 out_p - left_s1_padding, out_h - left_s2_padding,
4393 out_w - left_s3_padding, 0);
4394 memcpy(out, in, input_depth * sizeof(T));
4395
4396 if (right_c_padding != 0) {
4397 TypedMemset<T>(
4398 output_data + Offset(ext_output_shape, out_b, out_p, out_h,
4399 out_w, output_channel - right_c_padding),
4400 pad_value, right_c_padding);
4401 }
4402 }
4403 if (right_s3_padding != 0) {
4404 TypedMemset<T>(
4405 output_data + Offset(ext_output_shape, out_b, out_p, out_h,
4406 output_spatial_dim3 - right_s3_padding, 0),
4407 pad_value, right_s3_padding * output_channel);
4408 }
4409 }
4410 if (right_s2_padding != 0) {
4411 TypedMemset<T>(
4412 output_data + Offset(ext_output_shape, out_b, out_p,
4413 output_spatial_dim2 - right_s2_padding, 0, 0),
4414 pad_value, right_s2_padding * output_spatial_dim3 * output_channel);
4415 }
4416 }
4417 if (right_s1_padding != 0) {
4418 TypedMemset<T>(
4419 output_data + Offset(ext_output_shape, out_b,
4420 output_spatial_dim1 - right_s1_padding, 0, 0, 0),
4421 pad_value,
4422 right_s1_padding * output_spatial_dim2 * output_spatial_dim3 *
4423 output_channel);
4424 }
4425 }
4426 if (right_b_padding != 0) {
4427 TypedMemset<T>(
4428 output_data + Offset(ext_output_shape, output_batch - right_b_padding,
4429 0, 0, 0, 0),
4430 pad_value,
4431 right_b_padding * output_spatial_dim1 * output_spatial_dim2 *
4432 output_spatial_dim3 * output_channel);
4433 }
4434 }
4435
4436 template <typename T, typename P>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4437 inline void Pad(const tflite::PadParams& op_params,
4438 const RuntimeShape& input_shape, const T* input_data,
4439 const P* pad_value_ptr, const RuntimeShape& output_shape,
4440 T* output_data) {
4441 PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
4442 output_data);
4443 }
4444
4445 // The second (pad-value) input can be int32 when, say, the first is uint8.
4446 template <typename T>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4447 inline void Pad(const tflite::PadParams& op_params,
4448 const RuntimeShape& input_shape, const T* input_data,
4449 const int32* pad_value_ptr, const RuntimeShape& output_shape,
4450 T* output_data) {
4451 const T converted_pad_value = static_cast<T>(*pad_value_ptr);
4452 PadImpl(op_params, input_shape, input_data, &converted_pad_value,
4453 output_shape, output_data);
4454 }
4455
4456 // This version avoids conflicting template matching.
4457 template <>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const int32 * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,int32 * output_data)4458 inline void Pad(const tflite::PadParams& op_params,
4459 const RuntimeShape& input_shape, const int32* input_data,
4460 const int32* pad_value_ptr, const RuntimeShape& output_shape,
4461 int32* output_data) {
4462 PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
4463 output_data);
4464 }
4465
4466 // TODO(b/117643175): Optimize. (This is an introductory copy of standard Pad.)
4467 //
4468 // This pad requires that (a) left and right paddings are in the 4D patterns
4469 // {0, h_pad, w_pad, 0}, and (b) memset can be used: *pad_value_ptr == 0 and/or
4470 // T is uint8.
4471 //
4472 // There are two versions of pad: Pad and PadV2. In PadV2 there is a second
4473 // scalar input that provides the padding value. Therefore pad_value_ptr can be
4474 // equivalent to a simple input1_data. For Pad, it should point to a zero
4475 // value.
4476 //
4477 // Note that two typenames are required, so that T=P=int32 is considered a
4478 // specialization distinct from P=int32.
4479 template <typename T, typename P>
PadImageStyleMemset(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4480 inline void PadImageStyleMemset(const tflite::PadParams& op_params,
4481 const RuntimeShape& input_shape,
4482 const T* input_data, const P* pad_value_ptr,
4483 const RuntimeShape& output_shape,
4484 T* output_data) {
4485 ruy::profiler::ScopeLabel label("PadImageStyle");
4486 const RuntimeShape ext_input_shape =
4487 RuntimeShape::ExtendedShape(4, input_shape);
4488 const RuntimeShape ext_output_shape =
4489 RuntimeShape::ExtendedShape(4, output_shape);
4490 TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
4491 TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
4492
4493 // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
4494 // to 4 dims (yes, we are "padding the padding").
4495 std::vector<int> left_padding_copy(4, 0);
4496 const int left_padding_extend = 4 - op_params.left_padding_count;
4497 for (int i = 0; i < op_params.left_padding_count; ++i) {
4498 left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
4499 }
4500 std::vector<int> right_padding_copy(4, 0);
4501 const int right_padding_extend = 4 - op_params.right_padding_count;
4502 for (int i = 0; i < op_params.right_padding_count; ++i) {
4503 right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
4504 }
4505 // The following padding restrictions are contractual requirements, and
4506 // embody what it means for a padding op to be "image-style".
4507 TFLITE_DCHECK_EQ(left_padding_copy[0], 0);
4508 TFLITE_DCHECK_EQ(left_padding_copy[3], 0);
4509 TFLITE_DCHECK_EQ(right_padding_copy[0], 0);
4510 TFLITE_DCHECK_EQ(right_padding_copy[3], 0);
4511
4512 const int batch = MatchingDim(ext_input_shape, 0, ext_output_shape, 0);
4513 const int output_height = ext_output_shape.Dims(1);
4514 const int output_width = ext_output_shape.Dims(2);
4515 const int input_height = ext_input_shape.Dims(1);
4516 const int input_width = ext_input_shape.Dims(2);
4517 const int depth = MatchingDim(ext_input_shape, 3, ext_output_shape, 3);
4518
4519 const int left_h_padding = left_padding_copy[1];
4520 const int left_w_padding = left_padding_copy[2];
4521 const int right_h_padding = right_padding_copy[1];
4522 const int right_w_padding = right_padding_copy[2];
4523
4524 TFLITE_DCHECK_EQ(output_height,
4525 input_height + left_h_padding + right_h_padding);
4526 TFLITE_DCHECK_EQ(output_width,
4527 input_width + left_w_padding + right_w_padding);
4528
4529 const T pad_value = *pad_value_ptr;
4530 const int top_block_size = left_h_padding * output_width * depth;
4531 const size_t num_top_block_bytes = top_block_size * sizeof(T);
4532 const int bottom_block_size = right_h_padding * output_width * depth;
4533 const size_t num_bottom_block_bytes = bottom_block_size * sizeof(T);
4534 const int left_blocks_size = left_w_padding * depth;
4535 const size_t num_left_block_bytes = left_blocks_size * sizeof(T);
4536 const int right_blocks_size = right_w_padding * depth;
4537 const size_t num_right_block_bytes = right_blocks_size * sizeof(T);
4538 const int inner_line_size = input_width * depth;
4539 const size_t num_inner_line_bytes = inner_line_size * sizeof(T);
4540
4541 if (input_height == 0) {
4542 memset(output_data, pad_value,
4543 num_top_block_bytes + num_bottom_block_bytes);
4544 } else {
4545 for (int i = 0; i < batch; ++i) {
4546 // For each image in the batch, apply the top padding, then iterate
4547 // through rows, then apply the bottom padding.
4548 //
4549 // By unwinding one iteration, we can combine the first left-margin
4550 // padding with the top padding, and the last right-margin padding with
4551 // the bottom padding.
4552 memset(output_data, pad_value,
4553 num_top_block_bytes + num_left_block_bytes);
4554 output_data += top_block_size + left_blocks_size;
4555 memcpy(output_data, input_data, num_inner_line_bytes);
4556 input_data += inner_line_size;
4557 output_data += inner_line_size;
4558 // One iteration unwound.
4559 // Unwinding this loop affords the opportunity to reorder the loop work
4560 // and hence combine memset() calls.
4561 //
4562 // Before unwinding:
4563 // for (int j = 0; j < input_height; ++j) {
4564 // // Pad on left, copy central data, pad on right.
4565 // memset(output_data, pad_value, num_left_block_bytes);
4566 // output_data += left_blocks_size;
4567 // memcpy(output_data, input_data, num_inner_line_bytes);
4568 // input_data += inner_line_size;
4569 // output_data += inner_line_size;
4570 // memset(output_data, pad_value, num_right_block_bytes);
4571 // output_data += right_blocks_size;
4572 // }
4573 for (int j = 1; j < input_height; ++j) {
4574 memset(output_data, pad_value,
4575 num_right_block_bytes + num_left_block_bytes);
4576 output_data += right_blocks_size + left_blocks_size;
4577 memcpy(output_data, input_data, num_inner_line_bytes);
4578 input_data += inner_line_size;
4579 output_data += inner_line_size;
4580 }
4581 memset(output_data, pad_value,
4582 num_right_block_bytes + num_bottom_block_bytes);
4583 output_data += right_blocks_size + bottom_block_size;
4584 }
4585 }
4586 }
4587
4588 template <typename T, typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4589 inline void PadImageStyle(const tflite::PadParams& op_params,
4590 const RuntimeShape& input_shape, const T* input_data,
4591 const P* pad_value_ptr,
4592 const RuntimeShape& output_shape, T* output_data) {
4593 reference_ops::PadImageStyle(op_params, input_shape, input_data,
4594 pad_value_ptr, output_shape, output_data);
4595 }
4596
4597 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,uint8 * output_data)4598 inline void PadImageStyle(const tflite::PadParams& op_params,
4599 const RuntimeShape& input_shape,
4600 const uint8* input_data, const P* pad_value_ptr,
4601 const RuntimeShape& output_shape,
4602 uint8* output_data) {
4603 PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
4604 output_shape, output_data);
4605 }
4606
4607 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const float * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,float * output_data)4608 inline void PadImageStyle(const tflite::PadParams& op_params,
4609 const RuntimeShape& input_shape,
4610 const float* input_data, const P* pad_value_ptr,
4611 const RuntimeShape& output_shape,
4612 float* output_data) {
4613 const float converted_pad_value = static_cast<float>(*pad_value_ptr);
4614 if (converted_pad_value == 0.0f) {
4615 PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
4616 output_shape, output_data);
4617 } else {
4618 PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
4619 output_data);
4620 }
4621 }
4622
4623 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,SequentialTensorWriter<T> * writer)4624 inline void Slice(const tflite::SliceParams& op_params,
4625 const RuntimeShape& input_shape,
4626 const RuntimeShape& output_shape,
4627 SequentialTensorWriter<T>* writer) {
4628 ruy::profiler::ScopeLabel label("Slice");
4629 const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
4630 TFLITE_DCHECK_LE(op_params.begin_count, 5);
4631 TFLITE_DCHECK_LE(op_params.size_count, 5);
4632 const int begin_count = op_params.begin_count;
4633 const int size_count = op_params.size_count;
4634 // We front-pad the begin and size vectors.
4635 std::array<int, 5> start;
4636 std::array<int, 5> stop;
4637 for (int i = 0; i < 5; ++i) {
4638 int padded_i = 5 - i;
4639 start[i] =
4640 begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
4641 stop[i] =
4642 (size_count < padded_i || op_params.size[size_count - padded_i] == -1)
4643 ? ext_shape.Dims(i)
4644 : start[i] + op_params.size[size_count - padded_i];
4645 }
4646
4647 for (int i0 = start[0]; i0 < stop[0]; ++i0) {
4648 for (int i1 = start[1]; i1 < stop[1]; ++i1) {
4649 for (int i2 = start[2]; i2 < stop[2]; ++i2) {
4650 for (int i3 = start[3]; i3 < stop[3]; ++i3) {
4651 const int len = stop[4] - start[4];
4652 if (len > 0)
4653 writer->WriteN(Offset(ext_shape, i0, i1, i2, i3, start[4]), len);
4654 }
4655 }
4656 }
4657 }
4658 }
4659
4660 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)4661 inline void Slice(const tflite::SliceParams& op_params,
4662 const RuntimeShape& input_shape, const T* input_data,
4663 const RuntimeShape& output_shape, T* output_data) {
4664 SequentialTensorWriter<T> writer(input_data, output_data);
4665 return Slice(op_params, input_shape, output_shape, &writer);
4666 }
4667
4668 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const TfLiteTensor * input,const RuntimeShape & output_shape,TfLiteTensor * output)4669 inline void Slice(const tflite::SliceParams& op_params,
4670 const RuntimeShape& input_shape, const TfLiteTensor* input,
4671 const RuntimeShape& output_shape, TfLiteTensor* output) {
4672 SequentialTensorWriter<T> writer(input, output);
4673 return Slice(op_params, input_shape, output_shape, &writer);
4674 }
4675
4676 // Note: This implementation is only optimized for the case where the inner
4677 // stride == 1.
4678 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const RuntimeShape & unextended_output_shape,SequentialTensorWriter<T> * writer)4679 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
4680 const RuntimeShape& unextended_input_shape,
4681 const RuntimeShape& unextended_output_shape,
4682 SequentialTensorWriter<T>* writer) {
4683 using strided_slice::LoopCondition;
4684 using strided_slice::StartForAxis;
4685 using strided_slice::StopForAxis;
4686
4687 ruy::profiler::ScopeLabel label("StridedSlice");
4688
4689 // Note that the output_shape is not used herein.
4690 tflite::StridedSliceParams params_copy = op_params;
4691
4692 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
4693 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
4694 const RuntimeShape input_shape =
4695 RuntimeShape::ExtendedShape(5, unextended_input_shape);
4696 const RuntimeShape output_shape =
4697 RuntimeShape::ExtendedShape(5, unextended_output_shape);
4698
4699 // Reverse and pad to 5 dimensions because that is what the runtime code
4700 // requires (ie. all shapes must be 5D and are given backwards).
4701 strided_slice::StridedSlicePadIndices(¶ms_copy, 5);
4702
4703 const int start_0 = StartForAxis(params_copy, input_shape, 0);
4704 const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
4705 const int start_1 = StartForAxis(params_copy, input_shape, 1);
4706 const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
4707 const int start_2 = StartForAxis(params_copy, input_shape, 2);
4708 const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
4709 const int start_3 = StartForAxis(params_copy, input_shape, 3);
4710 const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
4711 const int start_4 = StartForAxis(params_copy, input_shape, 4);
4712 const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
4713 const bool inner_stride_is_1 = params_copy.strides[4] == 1;
4714
4715 for (int offset_0 = start_0 * input_shape.Dims(1),
4716 end_0 = stop_0 * input_shape.Dims(1),
4717 step_0 = params_copy.strides[0] * input_shape.Dims(1);
4718 !LoopCondition(offset_0, end_0, params_copy.strides[0]);
4719 offset_0 += step_0) {
4720 for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
4721 end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
4722 step_1 = params_copy.strides[1] * input_shape.Dims(2);
4723 !LoopCondition(offset_1, end_1, params_copy.strides[1]);
4724 offset_1 += step_1) {
4725 for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
4726 end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
4727 step_2 = params_copy.strides[2] * input_shape.Dims(3);
4728 !LoopCondition(offset_2, end_2, params_copy.strides[2]);
4729 offset_2 += step_2) {
4730 for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
4731 end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
4732 step_3 = params_copy.strides[3] * input_shape.Dims(4);
4733 !LoopCondition(offset_3, end_3, params_copy.strides[3]);
4734 offset_3 += step_3) {
4735 // When the stride is 1, the inner loop is equivalent to the
4736 // optimized slice inner loop. Otherwise, it is identical to the
4737 // strided_slice reference implementation inner loop.
4738 if (inner_stride_is_1) {
4739 const int len = stop_4 - start_4;
4740 if (len > 0) {
4741 writer->WriteN(offset_3 + start_4, len);
4742 }
4743 } else {
4744 for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
4745 !LoopCondition(offset_4, end_4, params_copy.strides[4]);
4746 offset_4 += params_copy.strides[4]) {
4747 writer->Write(offset_4);
4748 }
4749 }
4750 }
4751 }
4752 }
4753 }
4754 }
4755
4756 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)4757 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
4758 const RuntimeShape& unextended_input_shape,
4759 const T* input_data,
4760 const RuntimeShape& unextended_output_shape,
4761 T* output_data) {
4762 SequentialTensorWriter<T> writer(input_data, output_data);
4763 StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
4764 &writer);
4765 }
4766
4767 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const TfLiteTensor * input,const RuntimeShape & unextended_output_shape,TfLiteTensor * output)4768 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
4769 const RuntimeShape& unextended_input_shape,
4770 const TfLiteTensor* input,
4771 const RuntimeShape& unextended_output_shape,
4772 TfLiteTensor* output) {
4773 SequentialTensorWriter<T> writer(input, output);
4774 StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
4775 &writer);
4776 }
4777
4778 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)4779 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
4780 const T* input2_data, const RuntimeShape& output_shape,
4781 T* output_data) {
4782 ruy::profiler::ScopeLabel label("TensorFlowMinimum");
4783 auto input1_map = MapAsVector(input1_data, input1_shape);
4784 auto output_map = MapAsVector(output_data, output_shape);
4785 auto min_value = input2_data[0];
4786 output_map.array() = input1_map.array().min(min_value);
4787 }
4788
4789 // Convenience version that allows, for example, generated-code calls to be
4790 // the same as other binary ops.
4791 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)4792 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
4793 const RuntimeShape&, const T* input2_data,
4794 const RuntimeShape& output_shape, T* output_data) {
4795 // Drop shape of second input: not needed.
4796 Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
4797 }
4798
4799 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)4800 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
4801 const T* input2_data, const RuntimeShape& output_shape,
4802 T* output_data) {
4803 ruy::profiler::ScopeLabel label("TensorFlowMaximum");
4804 auto input1_map = MapAsVector(input1_data, input1_shape);
4805 auto output_map = MapAsVector(output_data, output_shape);
4806 auto max_value = input2_data[0];
4807 output_map.array() = input1_map.array().max(max_value);
4808 }
4809
4810 // Convenience version that allows, for example, generated-code calls to be
4811 // the same as other binary ops.
4812 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)4813 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
4814 const RuntimeShape&, const T* input2_data,
4815 const RuntimeShape& output_shape, T* output_data) {
4816 // Drop shape of second input: not needed.
4817 Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
4818 }
4819
4820 template <typename T>
TransposeIm2col(const ConvParams & params,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data)4821 void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
4822 const RuntimeShape& input_shape, const T* input_data,
4823 const RuntimeShape& filter_shape,
4824 const RuntimeShape& output_shape, T* im2col_data) {
4825 ruy::profiler::ScopeLabel label("TransposeIm2col");
4826 const int stride_width = params.stride_width;
4827 const int stride_height = params.stride_height;
4828 const int pad_width = params.padding_values.width;
4829 const int pad_height = params.padding_values.height;
4830 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4831 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
4832 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
4833 TFLITE_DCHECK(im2col_data);
4834
4835 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
4836 const int input_height = input_shape.Dims(1);
4837 const int input_width = input_shape.Dims(2);
4838 const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
4839 const int filter_height = filter_shape.Dims(1);
4840 const int filter_width = filter_shape.Dims(2);
4841 const int output_height = output_shape.Dims(1);
4842 const int output_width = output_shape.Dims(2);
4843 MatchingDim(output_shape, 3, filter_shape, 0); // output_depth
4844
4845 // Construct the MxN sized im2col matrix.
4846 // The rows M, are sub-ordered B x H x W
4847 const RuntimeShape row_shape({1, batches, output_height, output_width});
4848 // The columns, N, are sub-ordered Kh x Kw x Din
4849 const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
4850 // Use dimensions M and N to construct dims for indexing directly into im2col
4851 const RuntimeShape im2col_shape(
4852 {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
4853
4854 // Build the im2col matrix by looping through all the input pixels,
4855 // computing their influence on the output, rather than looping through all
4856 // the output pixels. We therefore must initialize the im2col array to zero.
4857 // This is potentially inefficient because we subsequently overwrite bytes
4858 // set here. However, in practice memset is very fast and costs negligible.
4859 memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
4860
4861 // Loop through the output batches
4862 for (int batch = 0; batch < batches; ++batch) {
4863 // Loop through input pixels one at a time.
4864 for (int in_y = 0; in_y < input_height; ++in_y) {
4865 for (int in_x = 0; in_x < input_width; ++in_x) {
4866 // Loop through the output pixels it will influence
4867 const int out_x_origin = (in_x * stride_width) - pad_width;
4868 const int out_y_origin = (in_y * stride_height) - pad_height;
4869 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
4870 const int out_y = out_y_origin + filter_y;
4871 // Is output pixel within height bounds?
4872 if ((out_y >= 0) && (out_y < output_height)) {
4873 for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
4874 const int out_x = out_x_origin + filter_x;
4875 // Is output pixel within width bounds?
4876 if ((out_x >= 0) && (out_x < output_width)) {
4877 // Copy the input elements of this pixel
4878 T const* src =
4879 input_data + Offset(input_shape, batch, in_y, in_x, 0);
4880 int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
4881 int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
4882 T* dst = im2col_data +
4883 Offset(im2col_shape, 0, 0, row_offset, col_offset);
4884 memcpy(dst, src, input_depth * sizeof(T));
4885 }
4886 }
4887 }
4888 }
4889 }
4890 }
4891 }
4892 }
4893
4894 // Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
4895 // order (height, width, depth), constructed from patches in 'col_data', which
4896 // is required to be in storage order (out_height * out_width, filter_height,
4897 // filter_width, in_depth). Implementation by Yangqing Jia (jiayq).
4898 // Copied from //tensorflow/core/kernels/conv_grad_input_ops.cc
4899 template <typename T>
Col2im(const T * col_data,const int depth,const int height,const int width,const int filter_h,const int filter_w,const int pad_t,const int pad_l,const int pad_b,const int pad_r,const int stride_h,const int stride_w,T * im_data)4900 void Col2im(const T* col_data, const int depth, const int height,
4901 const int width, const int filter_h, const int filter_w,
4902 const int pad_t, const int pad_l, const int pad_b, const int pad_r,
4903 const int stride_h, const int stride_w, T* im_data) {
4904 ruy::profiler::ScopeLabel label("Col2im");
4905 int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
4906 int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
4907 int h_pad = -pad_t;
4908 for (int h = 0; h < height_col; ++h) {
4909 int w_pad = -pad_l;
4910 for (int w = 0; w < width_col; ++w) {
4911 T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
4912 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
4913 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
4914 if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
4915 // TODO(andydavis) Vectorize this loop (if compiler does not).
4916 for (int i = 0; i < depth; ++i) {
4917 im_patch_data[i] += col_data[i];
4918 }
4919 }
4920 im_patch_data += depth;
4921 col_data += depth;
4922 }
4923 // Jump over remaining number of depth.
4924 im_patch_data += depth * (width - filter_w);
4925 }
4926 w_pad += stride_w;
4927 }
4928 h_pad += stride_h;
4929 }
4930 }
4931
4932 // TODO(b/188008864) Optimize this function by combining outer loops.
4933 template <typename T>
BiasAdd(T * im_data,const T * bias_data,const int batch_size,const int height,const int width,const int depth)4934 void BiasAdd(T* im_data, const T* bias_data, const int batch_size,
4935 const int height, const int width, const int depth) {
4936 if (bias_data) {
4937 for (int n = 0; n < batch_size; ++n) {
4938 for (int h = 0; h < height; ++h) {
4939 for (int w = 0; w < width; ++w) {
4940 for (int d = 0; d < depth; ++d) {
4941 im_data[d] += bias_data[d];
4942 }
4943 im_data += depth;
4944 }
4945 }
4946 }
4947 }
4948 }
4949
4950 // TransposeConvV2 expect the weights in HWOI order.
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const float * hwoi_ordered_filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * const output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)4951 inline void TransposeConvV2(
4952 const ConvParams& params, const RuntimeShape& input_shape,
4953 const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
4954 const float* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
4955 const float* bias_data, const RuntimeShape& output_shape,
4956 float* const output_data, const RuntimeShape& col2im_shape,
4957 float* col2im_data, CpuBackendContext* cpu_backend_context) {
4958 ruy::profiler::ScopeLabel label("TransposeConvV2/float");
4959 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4960 TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
4961 TFLITE_DCHECK(col2im_data);
4962 TFLITE_DCHECK(hwoi_ordered_filter_data);
4963
4964 const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
4965 const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
4966 const int output_height = output_shape.Dims(1);
4967 const int output_width = output_shape.Dims(2);
4968 const int output_image_size = output_height * output_width;
4969 const int input_depth =
4970 MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
4971 const int output_depth =
4972 MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
4973 const int input_offset = input_image_size * input_depth;
4974 const int output_offset = output_image_size * output_depth;
4975
4976 const int filter_height = hwoi_ordered_filter_shape.Dims(0);
4977 const int filter_width = hwoi_ordered_filter_shape.Dims(1);
4978 const int padding_top = params.padding_values.height;
4979 const int padding_bottom =
4980 params.padding_values.height + params.padding_values.height_offset;
4981 const int padding_left = params.padding_values.width;
4982 const int padding_right =
4983 params.padding_values.width + params.padding_values.width_offset;
4984 const int stride_height = params.stride_height;
4985 const int stride_width = params.stride_width;
4986
4987 const int hwoi_ordered_filter_total_size =
4988 filter_height * filter_width * output_depth;
4989
4990 cpu_backend_gemm::MatrixParams<float> lhs_params;
4991 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
4992 lhs_params.rows = hwoi_ordered_filter_total_size;
4993 lhs_params.cols = input_depth;
4994 float* output_data_p = output_data;
4995 std::fill_n(output_data, output_offset * batch_size, 0.0f);
4996 for (int i = 0; i < batch_size; ++i) {
4997 cpu_backend_gemm::MatrixParams<float> rhs_params;
4998 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
4999 rhs_params.rows = input_depth;
5000 rhs_params.cols = input_image_size;
5001 cpu_backend_gemm::MatrixParams<float> dst_params;
5002 dst_params.order = cpu_backend_gemm::Order::kColMajor;
5003 dst_params.rows = hwoi_ordered_filter_total_size;
5004 dst_params.cols = input_image_size;
5005 cpu_backend_gemm::GemmParams<float, float> gemm_params;
5006 cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
5007 input_data + input_offset * i, dst_params,
5008 col2im_data, gemm_params, cpu_backend_context);
5009
5010 Col2im(col2im_data, output_depth, output_height, output_width,
5011 filter_height, filter_width, padding_top, padding_left,
5012 padding_bottom, padding_right, stride_height, stride_width,
5013 output_data_p);
5014 output_data_p += output_offset;
5015 }
5016 output_data_p = output_data;
5017 BiasAdd(output_data_p, bias_data, batch_size, output_height, output_width,
5018 output_depth);
5019 }
5020
Quantize(int32_t multiplier,int32_t shift,int32_t total_size,int32_t output_zp,int32_t * scratch,uint8_t * output)5021 inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size,
5022 int32_t output_zp, int32_t* scratch, uint8_t* output) {
5023 ruy::profiler::ScopeLabel label("Quantize/uint8");
5024 int i = 0;
5025 const int32_t output_min = std::numeric_limits<uint8_t>::min();
5026 const int32_t output_max = std::numeric_limits<uint8_t>::max();
5027
5028 #ifdef USE_NEON
5029 const int32x4_t output_zp_dup = vdupq_n_s32(output_zp);
5030 const int32x4_t max_val_dup = vdupq_n_s32(output_max);
5031 const int32x4_t min_val_dup = vdupq_n_s32(output_min);
5032
5033 using gemmlowp::RoundingDivideByPOT;
5034 using gemmlowp::SaturatingRoundingDoublingHighMul;
5035
5036 for (; i <= total_size - 16; i += 16) {
5037 int32x4x4_t scratch_val;
5038 scratch_val.val[0] = vld1q_s32(scratch + i);
5039 scratch_val.val[1] = vld1q_s32(scratch + i + 4);
5040 scratch_val.val[2] = vld1q_s32(scratch + i + 8);
5041 scratch_val.val[3] = vld1q_s32(scratch + i + 12);
5042
5043 int32x4x4_t temp_val =
5044 MultiplyByQuantizedMultiplier4Rows(scratch_val, multiplier, shift);
5045
5046 temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
5047 temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
5048 temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
5049 temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
5050
5051 temp_val.val[0] =
5052 vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
5053 temp_val.val[1] =
5054 vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
5055 temp_val.val[2] =
5056 vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
5057 temp_val.val[3] =
5058 vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
5059
5060 const uint16x8_t result_1 =
5061 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[0])),
5062 vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[1])));
5063 const uint16x8_t result_2 =
5064 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[2])),
5065 vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[3])));
5066 const uint8x16_t result =
5067 vcombine_u8(vqmovn_u16(result_1), vqmovn_u16(result_2));
5068 vst1q_u8(output + i, result);
5069 }
5070 #endif
5071 for (; i < total_size; ++i) {
5072 int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift);
5073 temp += output_zp;
5074 if (temp > output_max) {
5075 temp = output_max;
5076 }
5077 if (temp < output_min) {
5078 temp = output_min;
5079 }
5080 output[i] = static_cast<uint8_t>(temp);
5081 }
5082 }
5083
5084 // Single-rounding MultiplyByQuantizedMultiplier
5085 #if TFLITE_SINGLE_ROUNDING
Quantize(const int32_t * multiplier,const int32_t * shift,int32_t channel_size,int32_t total_size,int32_t output_zp,int32_t output_min,int32_t output_max,int32_t * scratch,int8_t * output)5086 inline void Quantize(const int32_t* multiplier, const int32_t* shift,
5087 int32_t channel_size, int32_t total_size,
5088 int32_t output_zp, int32_t output_min, int32_t output_max,
5089 int32_t* scratch, int8_t* output) {
5090 ruy::profiler::ScopeLabel label("Quantize/int8");
5091
5092 // Here we're trying to quantize the raw accumulators:
5093 // output_channels
5094 // data data data data data
5095 // rows data data data data data
5096 // data data data data data
5097 // ....
5098 //
5099 // In order to minimize the reload of the multipliers & shifts, once we load
5100 // the multipliers & shifts, we load & quantize the raw accumulators for every
5101 // row.
5102 #ifdef USE_NEON
5103 const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
5104 const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
5105 const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
5106 const int32x4_t minus_ones = vdupq_n_s32(-1);
5107 #endif
5108
5109 TFLITE_DCHECK_EQ(total_size % channel_size, 0);
5110 const int32_t rows = total_size / channel_size;
5111
5112 int c = 0;
5113
5114 #ifdef USE_NEON
5115 for (; c <= channel_size - 8; c += 8) {
5116 int32x4_t out_shift_1 = vld1q_s32(shift + c);
5117 int32x4_t out_shift_2 = vld1q_s32(shift + c + 4);
5118
5119 int32x4_t right_shift_1 = vminq_s32(out_shift_1, minus_ones);
5120 int32x4_t right_shift_2 = vminq_s32(out_shift_2, minus_ones);
5121
5122 int32x4_t left_shift_1 = vsubq_s32(out_shift_1, right_shift_1);
5123 int32x4_t left_shift_2 = vsubq_s32(out_shift_2, right_shift_2);
5124
5125 int32x4_t out_mul_1 = vld1q_s32(multiplier + c);
5126 int32x4_t out_mul_2 = vld1q_s32(multiplier + c + 4);
5127 for (int n = 0; n < rows; ++n) {
5128 int loc = n * channel_size + c;
5129 int32x4_t acc_1 = vld1q_s32(scratch + loc);
5130 int32x4_t acc_2 = vld1q_s32(scratch + loc + 4);
5131
5132 // Saturating Doubling High Mul.
5133 acc_1 = vshlq_s32(acc_1, left_shift_1);
5134 acc_1 = vqdmulhq_s32(acc_1, out_mul_1);
5135 acc_2 = vshlq_s32(acc_2, left_shift_2);
5136 acc_2 = vqdmulhq_s32(acc_2, out_mul_2);
5137
5138 // Rounding Dividing By POT.
5139 acc_1 = vrshlq_s32(acc_1, right_shift_1);
5140 acc_2 = vrshlq_s32(acc_2, right_shift_2);
5141
5142 // Add the output offset.
5143 acc_1 = vaddq_s32(acc_1, output_offset_vec);
5144 acc_2 = vaddq_s32(acc_2, output_offset_vec);
5145
5146 // Apply the activation function.
5147 acc_1 = vmaxq_s32(acc_1, output_activation_min_vec);
5148 acc_1 = vminq_s32(acc_1, output_activation_max_vec);
5149 acc_2 = vmaxq_s32(acc_2, output_activation_min_vec);
5150 acc_2 = vminq_s32(acc_2, output_activation_max_vec);
5151
5152 // Saturating cast to int8 and store to destination.
5153 const int16x4_t acc_s16_1 = vqmovn_s32(acc_1);
5154 const int16x4_t acc_s16_2 = vqmovn_s32(acc_2);
5155 const int16x8_t res_s16 = vcombine_s16(acc_s16_1, acc_s16_2);
5156 const int8x8_t res_s8 = vqmovn_s16(res_s16);
5157 vst1_s8(output + loc, res_s8);
5158 }
5159 }
5160
5161 #endif // USE_NEON
5162 // Handle leftover values, one by one. This is very slow.
5163 for (; c < channel_size; c++) {
5164 for (int n = 0; n < rows; ++n) {
5165 int loc = n * channel_size + c;
5166 int32 acc = scratch[loc];
5167 acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
5168 acc += output_zp;
5169 acc = std::max(acc, output_min);
5170 acc = std::min(acc, output_max);
5171 output[loc] = static_cast<int8>(acc);
5172 }
5173 }
5174 }
5175
Quantize(const int32_t * multiplier,const int32_t * shift,int32_t channel_size,int32_t total_size,int32_t output_zp,int32_t output_min,int32_t output_max,int32_t * scratch,int16_t * output)5176 inline void Quantize(const int32_t* multiplier, const int32_t* shift,
5177 int32_t channel_size, int32_t total_size,
5178 int32_t output_zp, int32_t output_min, int32_t output_max,
5179 int32_t* scratch, int16_t* output) {
5180 ruy::profiler::ScopeLabel label("Quantize(Single-rounding)/int16");
5181
5182 // Here we're trying to quantize the raw accumulators:
5183 // output_channels
5184 // data data data data data
5185 // rows data data data data data
5186 // data data data data data
5187 // ....
5188 //
5189 // In order to minimize the reload of the multipliers & shifts, once we load
5190 // the multipliers & shifts, we load & quantize the raw accumulators for every
5191 // row.
5192 #ifdef USE_NEON
5193 const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
5194 const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
5195 const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
5196 const int32x4_t minus_ones = vdupq_n_s32(-1);
5197 #endif
5198
5199 TFLITE_DCHECK_EQ(total_size % channel_size, 0);
5200 const int32_t rows = total_size / channel_size;
5201
5202 int c = 0;
5203
5204 #ifdef USE_NEON
5205 for (; c <= channel_size - 8; c += 8) {
5206 int32x4_t out_shift_1 = vld1q_s32(shift + c);
5207 int32x4_t out_shift_2 = vld1q_s32(shift + c + 4);
5208
5209 int32x4_t right_shift_1 = vminq_s32(out_shift_1, minus_ones);
5210 int32x4_t right_shift_2 = vminq_s32(out_shift_2, minus_ones);
5211
5212 int32x4_t left_shift_1 = vsubq_s32(out_shift_1, right_shift_1);
5213 int32x4_t left_shift_2 = vsubq_s32(out_shift_2, right_shift_2);
5214
5215 int32x4_t out_mul_1 = vld1q_s32(multiplier + c);
5216 int32x4_t out_mul_2 = vld1q_s32(multiplier + c + 4);
5217 for (int n = 0; n < rows; ++n) {
5218 int loc = n * channel_size + c;
5219 int32x4_t acc_1 = vld1q_s32(scratch + loc);
5220 int32x4_t acc_2 = vld1q_s32(scratch + loc + 4);
5221
5222 // Saturating Doubling High Mul.
5223 acc_1 = vshlq_s32(acc_1, left_shift_1);
5224 acc_1 = vqdmulhq_s32(acc_1, out_mul_1);
5225 acc_2 = vshlq_s32(acc_2, left_shift_2);
5226 acc_2 = vqdmulhq_s32(acc_2, out_mul_2);
5227
5228 // Rounding Dividing By POT.
5229 acc_1 = vrshlq_s32(acc_1, right_shift_1);
5230 acc_2 = vrshlq_s32(acc_2, right_shift_2);
5231
5232 // Add the output offset.
5233 acc_1 = vaddq_s32(acc_1, output_offset_vec);
5234 acc_2 = vaddq_s32(acc_2, output_offset_vec);
5235
5236 // Apply the activation function.
5237 acc_1 = vmaxq_s32(acc_1, output_activation_min_vec);
5238 acc_1 = vminq_s32(acc_1, output_activation_max_vec);
5239 acc_2 = vmaxq_s32(acc_2, output_activation_min_vec);
5240 acc_2 = vminq_s32(acc_2, output_activation_max_vec);
5241
5242 // Saturating cast to int16 and store to destination.
5243 const int16x4_t acc_s16_1 = vqmovn_s32(acc_1);
5244 const int16x4_t acc_s16_2 = vqmovn_s32(acc_2);
5245 vst1_s16(reinterpret_cast<int16_t*>(output) + loc, acc_s16_1);
5246 vst1_s16(reinterpret_cast<int16_t*>(output) + loc + 4, acc_s16_2);
5247 }
5248 }
5249
5250 #endif // USE_NEON
5251 // Handle leftover values, one by one. This is very slow.
5252 for (; c < channel_size; c++) {
5253 for (int n = 0; n < rows; ++n) {
5254 int loc = n * channel_size + c;
5255 int32 acc = scratch[loc];
5256 acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
5257 acc += output_zp;
5258 acc = std::max(acc, output_min);
5259 acc = std::min(acc, output_max);
5260 output[loc] = static_cast<int16>(acc);
5261 }
5262 }
5263 }
5264 // Double-rounding MultiplyByQuantizedMultiplier
5265 #else
Quantize(const int32_t * multiplier,const int32_t * shift,int32_t channel_size,int32_t total_size,int32_t output_zp,int32_t output_min,int32_t output_max,int32_t * scratch,int8_t * output)5266 inline void Quantize(const int32_t* multiplier, const int32_t* shift,
5267 int32_t channel_size, int32_t total_size,
5268 int32_t output_zp, int32_t output_min, int32_t output_max,
5269 int32_t* scratch, int8_t* output) {
5270 ruy::profiler::ScopeLabel label("Quantize/int8");
5271
5272 // Here we're trying to quantize the raw accumulators:
5273 // output_channels
5274 // data data data data data
5275 // rows data data data data data
5276 // data data data data data
5277 // ....
5278 //
5279 // In order to minimize the reload of the multipliers & shifts, once we load
5280 // the multipliers & shifts, we load & quantize the raw accumulators for every
5281 // row.
5282 #ifdef USE_NEON
5283 const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
5284 const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
5285 const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
5286 const int32x4_t zeros = vdupq_n_s32(0);
5287 #endif
5288
5289 TFLITE_DCHECK_EQ(total_size % channel_size, 0);
5290 const int32_t rows = total_size / channel_size;
5291
5292 int c = 0;
5293
5294 #ifdef USE_NEON
5295 using gemmlowp::RoundingDivideByPOT;
5296 for (; c <= channel_size - 8; c += 8) {
5297 int32x4_t out_shift_1 = vld1q_s32(shift + c);
5298 int32x4_t out_shift_2 = vld1q_s32(shift + c + 4);
5299 int32x4_t left_shift_1 = vmaxq_s32(out_shift_1, zeros);
5300 int32x4_t left_shift_2 = vmaxq_s32(out_shift_2, zeros);
5301
5302 // Right shift will be performed as left shift with negative values.
5303 int32x4_t right_shift_1 = vminq_s32(out_shift_1, zeros);
5304 int32x4_t right_shift_2 = vminq_s32(out_shift_2, zeros);
5305
5306 int32x4_t out_mul_1 = vld1q_s32(multiplier + c);
5307 int32x4_t out_mul_2 = vld1q_s32(multiplier + c + 4);
5308 for (int n = 0; n < rows; ++n) {
5309 int loc = n * channel_size + c;
5310 int32x4_t acc_1 = vld1q_s32(scratch + loc);
5311 int32x4_t acc_2 = vld1q_s32(scratch + loc + 4);
5312
5313 // Saturating Rounding Doubling High Mul.
5314 acc_1 = vshlq_s32(acc_1, left_shift_1);
5315 acc_1 = vqrdmulhq_s32(acc_1, out_mul_1);
5316 acc_2 = vshlq_s32(acc_2, left_shift_2);
5317 acc_2 = vqrdmulhq_s32(acc_2, out_mul_2);
5318
5319 // Rounding Dividing By POT.
5320 acc_1 = vrshlq_s32(acc_1, right_shift_1);
5321 acc_2 = vrshlq_s32(acc_2, right_shift_2);
5322
5323 // Add the output offset.
5324 acc_1 = vaddq_s32(acc_1, output_offset_vec);
5325 acc_2 = vaddq_s32(acc_2, output_offset_vec);
5326
5327 // Apply the activation function.
5328 acc_1 = vmaxq_s32(acc_1, output_activation_min_vec);
5329 acc_1 = vminq_s32(acc_1, output_activation_max_vec);
5330 acc_2 = vmaxq_s32(acc_2, output_activation_min_vec);
5331 acc_2 = vminq_s32(acc_2, output_activation_max_vec);
5332
5333 // Saturating cast to int8 and store to destination.
5334 const int16x4_t acc_s16_1 = vqmovn_s32(acc_1);
5335 const int16x4_t acc_s16_2 = vqmovn_s32(acc_2);
5336 const int16x8_t res_s16 = vcombine_s16(acc_s16_1, acc_s16_2);
5337 const int8x8_t res_s8 = vqmovn_s16(res_s16);
5338 vst1_s8(output + loc, res_s8);
5339 }
5340 }
5341
5342 #endif // USE_NEON
5343 // Handle leftover values, one by one. This is very slow.
5344 for (; c < channel_size; c++) {
5345 for (int n = 0; n < rows; ++n) {
5346 int loc = n * channel_size + c;
5347 int32 acc = scratch[loc];
5348 acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
5349 acc += output_zp;
5350 acc = std::max(acc, output_min);
5351 acc = std::min(acc, output_max);
5352 output[loc] = static_cast<int8>(acc);
5353 }
5354 }
5355 }
5356
Quantize(const int32_t * multiplier,const int32_t * shift,int32_t channel_size,int32_t total_size,int32_t output_zp,int32_t output_min,int32_t output_max,int32_t * scratch,int16_t * output)5357 inline void Quantize(const int32_t* multiplier, const int32_t* shift,
5358 int32_t channel_size, int32_t total_size,
5359 int32_t output_zp, int32_t output_min, int32_t output_max,
5360 int32_t* scratch, int16_t* output) {
5361 ruy::profiler::ScopeLabel label("Quantize(Double-rounding)/int16");
5362
5363 // Here we're trying to quantize the raw accumulators:
5364 // output_channels
5365 // data data data data data
5366 // rows data data data data data
5367 // data data data data data
5368 // ....
5369 //
5370 // In order to minimize the reload of the multipliers & shifts, once we load
5371 // the multipliers & shifts, we load & quantize the raw accumulators for every
5372 // row.
5373 #ifdef USE_NEON
5374 const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
5375 const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
5376 const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
5377 const int32x4_t zeros = vdupq_n_s32(0);
5378 #endif
5379
5380 TFLITE_DCHECK_EQ(total_size % channel_size, 0);
5381 const int32_t rows = total_size / channel_size;
5382
5383 int c = 0;
5384
5385 #ifdef USE_NEON
5386 using gemmlowp::RoundingDivideByPOT;
5387 for (; c <= channel_size - 8; c += 8) {
5388 int32x4_t out_shift_1 = vld1q_s32(shift + c);
5389 int32x4_t out_shift_2 = vld1q_s32(shift + c + 4);
5390 int32x4_t left_shift_1 = vmaxq_s32(out_shift_1, zeros);
5391 int32x4_t left_shift_2 = vmaxq_s32(out_shift_2, zeros);
5392
5393 // Right shift will be performed as left shift with negative values.
5394 int32x4_t right_shift_1 = vminq_s32(out_shift_1, zeros);
5395 int32x4_t right_shift_2 = vminq_s32(out_shift_2, zeros);
5396
5397 int32x4_t out_mul_1 = vld1q_s32(multiplier + c);
5398 int32x4_t out_mul_2 = vld1q_s32(multiplier + c + 4);
5399 for (int n = 0; n < rows; ++n) {
5400 int loc = n * channel_size + c;
5401 int32x4_t acc_1 = vld1q_s32(scratch + loc);
5402 int32x4_t acc_2 = vld1q_s32(scratch + loc + 4);
5403
5404 // Saturating Rounding Doubling High Mul.
5405 acc_1 = vshlq_s32(acc_1, left_shift_1);
5406 acc_1 = vqrdmulhq_s32(acc_1, out_mul_1);
5407 acc_2 = vshlq_s32(acc_2, left_shift_2);
5408 acc_2 = vqrdmulhq_s32(acc_2, out_mul_2);
5409
5410 // Rounding Dividing By POT.
5411 acc_1 = vrshlq_s32(acc_1, right_shift_1);
5412 acc_2 = vrshlq_s32(acc_2, right_shift_2);
5413
5414 // Add the output offset.
5415 acc_1 = vaddq_s32(acc_1, output_offset_vec);
5416 acc_2 = vaddq_s32(acc_2, output_offset_vec);
5417
5418 // Apply the activation function.
5419 acc_1 = vmaxq_s32(acc_1, output_activation_min_vec);
5420 acc_1 = vminq_s32(acc_1, output_activation_max_vec);
5421 acc_2 = vmaxq_s32(acc_2, output_activation_min_vec);
5422 acc_2 = vminq_s32(acc_2, output_activation_max_vec);
5423
5424 // Saturating cast to int16 and store to destination.
5425 const int16x4_t acc_s16_1 = vqmovn_s32(acc_1);
5426 const int16x4_t acc_s16_2 = vqmovn_s32(acc_2);
5427 vst1_s16(reinterpret_cast<int16_t*>(output) + loc, acc_s16_1);
5428 vst1_s16(reinterpret_cast<int16_t*>(output) + loc + 4, acc_s16_2);
5429 }
5430 }
5431
5432 #endif // USE_NEON
5433 // Handle leftover values, one by one. This is very slow.
5434 for (; c < channel_size; c++) {
5435 for (int n = 0; n < rows; ++n) {
5436 int loc = n * channel_size + c;
5437 int32 acc = scratch[loc];
5438 acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
5439 acc += output_zp;
5440 acc = std::max(acc, output_min);
5441 acc = std::min(acc, output_max);
5442 output[loc] = static_cast<int16>(acc);
5443 }
5444 }
5445 }
5446 #endif // TFLITE_SINGLE_ROUNDING
5447
5448 // TransposeConvV2 expect the weights in HWOI order.
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const uint8_t * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const uint8_t * hwoi_ordered_filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8_t * output_data,const RuntimeShape & col2im_shape,int32_t * col2im_data,int32_t * scratch_data,CpuBackendContext * cpu_backend_context)5449 inline void TransposeConvV2(
5450 const ConvParams& params, const RuntimeShape& input_shape,
5451 const uint8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
5452 const uint8_t* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
5453 const int32* bias_data, const RuntimeShape& output_shape,
5454 uint8_t* output_data, const RuntimeShape& col2im_shape,
5455 int32_t* col2im_data, int32_t* scratch_data,
5456 CpuBackendContext* cpu_backend_context) {
5457 ruy::profiler::ScopeLabel label("TransposeConvV2/uint8");
5458 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5459 TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
5460 TFLITE_DCHECK(col2im_data);
5461 TFLITE_DCHECK(hwoi_ordered_filter_data);
5462
5463 const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
5464 const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
5465 const int output_height = output_shape.Dims(1);
5466 const int output_width = output_shape.Dims(2);
5467 const int output_image_size = output_height * output_width;
5468 const int input_depth =
5469 MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
5470 const int output_depth =
5471 MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
5472 const int input_offset = input_image_size * input_depth;
5473 const int output_offset = output_image_size * output_depth;
5474
5475 const int filter_height = hwoi_ordered_filter_shape.Dims(0);
5476 const int filter_width = hwoi_ordered_filter_shape.Dims(1);
5477 const int padding_top = params.padding_values.height;
5478 const int padding_bottom =
5479 params.padding_values.height + params.padding_values.height_offset;
5480 const int padding_left = params.padding_values.width;
5481 const int padding_right =
5482 params.padding_values.width + params.padding_values.width_offset;
5483 const int stride_height = params.stride_height;
5484 const int stride_width = params.stride_width;
5485
5486 const int hwoi_ordered_filter_total_size =
5487 filter_height * filter_width * output_depth;
5488
5489 cpu_backend_gemm::MatrixParams<uint8_t> lhs_params;
5490 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
5491 lhs_params.rows = hwoi_ordered_filter_total_size;
5492 lhs_params.cols = input_depth;
5493 lhs_params.zero_point = -params.weights_offset;
5494
5495 int32_t* scratch_data_p = scratch_data;
5496 std::fill_n(scratch_data, output_offset * batch_size, static_cast<int32>(0));
5497 for (int i = 0; i < batch_size; ++i) {
5498 cpu_backend_gemm::MatrixParams<uint8_t> rhs_params;
5499 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
5500 rhs_params.rows = input_depth;
5501 rhs_params.cols = input_image_size;
5502 rhs_params.zero_point = -params.input_offset;
5503
5504 cpu_backend_gemm::MatrixParams<int32_t> dst_params;
5505 dst_params.order = cpu_backend_gemm::Order::kColMajor;
5506 dst_params.rows = hwoi_ordered_filter_total_size;
5507 dst_params.cols = input_image_size;
5508
5509 cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
5510 cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
5511 input_data + input_offset * i, dst_params,
5512 col2im_data, gemm_params, cpu_backend_context);
5513
5514 Col2im(col2im_data, output_depth, output_height, output_width,
5515 filter_height, filter_width, padding_top, padding_left,
5516 padding_bottom, padding_right, stride_height, stride_width,
5517 scratch_data_p);
5518
5519 scratch_data_p += output_offset;
5520 }
5521 scratch_data_p = scratch_data;
5522 BiasAdd(scratch_data_p, bias_data, batch_size, output_height, output_width,
5523 output_depth);
5524
5525 Quantize(params.output_multiplier, params.output_shift,
5526 output_shape.FlatSize(), params.output_offset, scratch_data,
5527 output_data);
5528 }
5529
5530 // Integer-only version of ResizeNearestNeighbor. Since scales are represented
5531 // in fixed-point and thus approximated, |in_x| or |in_y| may differ from the
5532 // reference version. Debug checks are in place to test if this occurs.
5533 // NOTE: If align_corners or half_pixel_centers is true, we use the reference
5534 // version.
ResizeNearestNeighbor(const tflite::ResizeNearestNeighborParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)5535 inline void ResizeNearestNeighbor(
5536 const tflite::ResizeNearestNeighborParams& op_params,
5537 const RuntimeShape& unextended_input_shape, const uint8* input_data,
5538 const RuntimeShape& output_size_shape, const int32* output_size_data,
5539 const RuntimeShape& unextended_output_shape, uint8* output_data) {
5540 if (op_params.align_corners || op_params.half_pixel_centers) {
5541 // TODO(b/149823713): Add support for align_corners & half_pixel_centers in
5542 // this kernel.
5543 reference_ops::ResizeNearestNeighbor(
5544 op_params, unextended_input_shape, input_data, output_size_shape,
5545 output_size_data, unextended_output_shape, output_data);
5546 return;
5547 }
5548 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
5549 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
5550
5551 const RuntimeShape input_shape =
5552 RuntimeShape::ExtendedShape(4, unextended_input_shape);
5553 const RuntimeShape output_shape =
5554 RuntimeShape::ExtendedShape(4, unextended_output_shape);
5555
5556 int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
5557 int32 input_height = input_shape.Dims(1);
5558 int32 input_width = input_shape.Dims(2);
5559 int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
5560
5561 // The Tensorflow version of this op allows resize on the width and height
5562 // axis only.
5563 TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
5564 int32 output_height = output_size_data[0];
5565 int32 output_width = output_size_data[1];
5566
5567 // Convert scales to fixed-point with 16 fractional bits. We add 1 as an
5568 // error factor and to avoid zero scales. For example, with input_height = 1,
5569 // output_height = 3, the float scaling factor would be non-zero at 1/3.
5570 // With fixed-point, this is zero.
5571 int32 height_scale = (input_height << 16) / output_height + 1;
5572 int32 width_scale = (input_width << 16) / output_width + 1;
5573
5574 const int col_offset = input_shape.Dims(3);
5575 const int row_offset = input_shape.Dims(2) * col_offset;
5576 const int batch_offset = input_shape.Dims(1) * row_offset;
5577
5578 const uint8* input_ptr = input_data;
5579 uint8* output_ptr = output_data;
5580 for (int b = 0; b < batches; ++b) {
5581 for (int y = 0; y < output_height; ++y) {
5582 int32 in_y = std::min((y * height_scale) >> 16, input_height - 1);
5583 // Check offset calculation is the same as the reference version. See
5584 // function comment for details. We check using a non-float version of:
5585 // TFLITE_DCHECK_EQ(in_y, std::floor(y * (static_cast<float>(input_height)
5586 // / output_height)));
5587 TFLITE_DCHECK_LT(y * input_height, output_height + in_y * output_height);
5588 TFLITE_DCHECK_GE(y * input_height, in_y * output_height);
5589 const uint8* y_input_ptr = input_ptr + in_y * row_offset;
5590 for (int x = 0; x < output_width; ++x) {
5591 int32 in_x = std::min((x * width_scale) >> 16, input_width - 1);
5592 // Check offset calculation is the same as the reference version. See
5593 // function comment for details. We check using a non-float version of:
5594 // TFLITE_DCHECK_EQ(in_y,
5595 // std::floor(y * (static_cast<float>(input_width)
5596 // / output_width)));
5597 TFLITE_DCHECK_LT(x * input_width, output_width + in_x * output_width);
5598 TFLITE_DCHECK_GE(x * input_width, in_x * output_width);
5599 const uint8* x_input_ptr = y_input_ptr + in_x * col_offset;
5600 memcpy(output_ptr, x_input_ptr, depth);
5601 output_ptr += depth;
5602 }
5603 }
5604 input_ptr += batch_offset;
5605 }
5606 }
5607
5608 template <typename input_type, typename output_type>
Requantize(const input_type * input_data,int32_t size,int32_t effective_scale_multiplier,int32_t effective_scale_shift,int32_t input_zeropoint,int32_t output_zeropoint,output_type * output_data)5609 inline void Requantize(const input_type* input_data, int32_t size,
5610 int32_t effective_scale_multiplier,
5611 int32_t effective_scale_shift, int32_t input_zeropoint,
5612 int32_t output_zeropoint, output_type* output_data) {
5613 reference_ops::Requantize(input_data, size, effective_scale_multiplier,
5614 effective_scale_shift, input_zeropoint,
5615 output_zeropoint, output_data);
5616 }
5617
5618 template <>
5619 inline void Requantize<int8_t, uint8_t>(const int8_t* input_data, int32_t size,
5620 int32_t effective_scale_multiplier,
5621 int32_t effective_scale_shift,
5622 int32_t input_zeropoint,
5623 int32_t output_zeropoint,
5624 uint8_t* output_data) {
5625 ruy::profiler::ScopeLabel label("Requantize/Int8ToUint8");
5626
5627 static constexpr int32_t kMinOutput = std::numeric_limits<uint8_t>::min();
5628 static constexpr int32_t kMaxOutput = std::numeric_limits<uint8_t>::max();
5629
5630 int i = 0;
5631 #ifdef USE_NEON
5632 // Constants.
5633 const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5634 const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5635 const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5636 const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5637
5638 for (; i <= size - 16; i += 16) {
5639 const int8x16_t input_vec = vld1q_s8(input_data + i);
5640 const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
5641 const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
5642 int32x4x4_t input;
5643 input.val[0] = vmovl_s16(vget_low_s16(first_half));
5644 input.val[1] = vmovl_s16(vget_high_s16(first_half));
5645 input.val[2] = vmovl_s16(vget_low_s16(second_half));
5646 input.val[3] = vmovl_s16(vget_high_s16(second_half));
5647 input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5648 input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5649 input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5650 input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5651
5652 int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5653 input, effective_scale_multiplier, effective_scale_shift);
5654
5655 result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5656 result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5657 result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5658 result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5659 result.val[0] =
5660 vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5661 result.val[1] =
5662 vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5663 result.val[2] =
5664 vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5665 result.val[3] =
5666 vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5667
5668 const uint32x4_t result_val_1_unsigned =
5669 vreinterpretq_u32_s32(result.val[0]);
5670 const uint32x4_t result_val_2_unsigned =
5671 vreinterpretq_u32_s32(result.val[1]);
5672 const uint32x4_t result_val_3_unsigned =
5673 vreinterpretq_u32_s32(result.val[2]);
5674 const uint32x4_t result_val_4_unsigned =
5675 vreinterpretq_u32_s32(result.val[3]);
5676
5677 const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
5678 const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
5679 const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
5680 const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
5681 const uint16x8_t output_first_half =
5682 vcombine_u16(narrowed_val_1, narrowed_val_2);
5683 const uint16x8_t output_second_half =
5684 vcombine_u16(narrowed_val_3, narrowed_val_4);
5685 const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
5686 const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
5687 const uint8x16_t narrowed_result =
5688 vcombine_u8(narrowed_first_half, narrowed_second_half);
5689 vst1q_u8(output_data + i, narrowed_result);
5690 }
5691
5692 #endif
5693 for (; i < size; ++i) {
5694 const int32_t input = input_data[i] - input_zeropoint;
5695 const int32_t output =
5696 MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5697 effective_scale_shift) +
5698 output_zeropoint;
5699 const int32_t clamped_output =
5700 std::max(std::min(output, kMaxOutput), kMinOutput);
5701 output_data[i] = static_cast<uint8_t>(clamped_output);
5702 }
5703 }
5704
5705 template <>
5706 inline void Requantize<uint8_t, int8_t>(const uint8_t* input_data, int32_t size,
5707 int32_t effective_scale_multiplier,
5708 int32_t effective_scale_shift,
5709 int32_t input_zeropoint,
5710 int32_t output_zeropoint,
5711 int8_t* output_data) {
5712 ruy::profiler::ScopeLabel label("Requantize/Uint8ToInt8");
5713
5714 static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
5715 static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
5716
5717 int i = 0;
5718 #ifdef USE_NEON
5719 // Constants.
5720 const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5721 const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5722 const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5723 const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5724
5725 for (; i <= size - 16; i += 16) {
5726 const uint8x16_t input_vec = vld1q_u8(input_data + i);
5727 const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
5728 const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
5729 int32x4x4_t input;
5730 input.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
5731 input.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(first_half)));
5732 input.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(second_half)));
5733 input.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(second_half)));
5734 input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5735 input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5736 input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5737 input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5738
5739 int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5740 input, effective_scale_multiplier, effective_scale_shift);
5741
5742 result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5743 result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5744 result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5745 result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5746 result.val[0] =
5747 vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5748 result.val[1] =
5749 vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5750 result.val[2] =
5751 vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5752 result.val[3] =
5753 vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5754
5755 const int16x4_t narrowed_val_1 = vqmovn_s32(result.val[0]);
5756 const int16x4_t narrowed_val_2 = vqmovn_s32(result.val[1]);
5757 const int16x4_t narrowed_val_3 = vqmovn_s32(result.val[2]);
5758 const int16x4_t narrowed_val_4 = vqmovn_s32(result.val[3]);
5759 const int16x8_t output_first_half =
5760 vcombine_s16(narrowed_val_1, narrowed_val_2);
5761 const int16x8_t output_second_half =
5762 vcombine_s16(narrowed_val_3, narrowed_val_4);
5763 const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
5764 const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
5765 const int8x16_t narrowed_result =
5766 vcombine_s8(narrowed_first_half, narrowed_second_half);
5767 vst1q_s8(output_data + i, narrowed_result);
5768 }
5769
5770 #endif
5771 for (; i < size; ++i) {
5772 const int32_t input = input_data[i] - input_zeropoint;
5773 const int32_t output =
5774 MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5775 effective_scale_shift) +
5776 output_zeropoint;
5777 const int32_t clamped_output =
5778 std::max(std::min(output, kMaxOutput), kMinOutput);
5779 output_data[i] = static_cast<int8_t>(clamped_output);
5780 }
5781 }
5782
5783 template <>
5784 inline void Requantize<int8_t, int8_t>(const int8_t* input_data, int32_t size,
5785 int32_t effective_scale_multiplier,
5786 int32_t effective_scale_shift,
5787 int32_t input_zeropoint,
5788 int32_t output_zeropoint,
5789 int8_t* output_data) {
5790 ruy::profiler::ScopeLabel label("Requantize/Int8ToInt8");
5791
5792 static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
5793 static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
5794
5795 int i = 0;
5796 #ifdef USE_NEON
5797 // Constants.
5798 const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5799 const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5800 const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5801 const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5802
5803 for (; i <= size - 16; i += 16) {
5804 const int8x16_t input_vec = vld1q_s8(input_data + i);
5805 const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
5806 const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
5807 int32x4x4_t input;
5808 input.val[0] = vmovl_s16(vget_low_s16(first_half));
5809 input.val[1] = vmovl_s16(vget_high_s16(first_half));
5810 input.val[2] = vmovl_s16(vget_low_s16(second_half));
5811 input.val[3] = vmovl_s16(vget_high_s16(second_half));
5812
5813 input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5814 input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5815 input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5816 input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5817
5818 int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5819 input, effective_scale_multiplier, effective_scale_shift);
5820
5821 result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5822 result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5823 result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5824 result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5825 result.val[0] =
5826 vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5827 result.val[1] =
5828 vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5829 result.val[2] =
5830 vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5831 result.val[3] =
5832 vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5833
5834 const int16x4_t narrowed_val_1 = vqmovn_s32(result.val[0]);
5835 const int16x4_t narrowed_val_2 = vqmovn_s32(result.val[1]);
5836 const int16x4_t narrowed_val_3 = vqmovn_s32(result.val[2]);
5837 const int16x4_t narrowed_val_4 = vqmovn_s32(result.val[3]);
5838 const int16x8_t output_first_half =
5839 vcombine_s16(narrowed_val_1, narrowed_val_2);
5840 const int16x8_t output_second_half =
5841 vcombine_s16(narrowed_val_3, narrowed_val_4);
5842 const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
5843 const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
5844 const int8x16_t narrowed_result =
5845 vcombine_s8(narrowed_first_half, narrowed_second_half);
5846 vst1q_s8(output_data + i, narrowed_result);
5847 }
5848
5849 #endif
5850 for (; i < size; ++i) {
5851 const int32_t input = input_data[i] - input_zeropoint;
5852 const int32_t output =
5853 MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5854 effective_scale_shift) +
5855 output_zeropoint;
5856 const int32_t clamped_output =
5857 std::max(std::min(output, kMaxOutput), kMinOutput);
5858 output_data[i] = static_cast<int8_t>(clamped_output);
5859 }
5860 }
5861
5862 template <>
5863 inline void Requantize<uint8_t, uint8_t>(
5864 const uint8_t* input_data, int32_t size, int32_t effective_scale_multiplier,
5865 int32_t effective_scale_shift, int32_t input_zeropoint,
5866 int32_t output_zeropoint, uint8_t* output_data) {
5867 ruy::profiler::ScopeLabel label("Requantize/Uint8ToUint8");
5868
5869 static constexpr int32_t kMinOutput = std::numeric_limits<uint8_t>::min();
5870 static constexpr int32_t kMaxOutput = std::numeric_limits<uint8_t>::max();
5871
5872 int i = 0;
5873 #ifdef USE_NEON
5874 // Constants.
5875 const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5876 const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5877 const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5878 const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5879
5880 for (; i <= size - 16; i += 16) {
5881 const uint8x16_t input_vec = vld1q_u8(input_data + i);
5882 const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
5883 const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
5884 int32x4x4_t input;
5885 input.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
5886 input.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(first_half)));
5887 input.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(second_half)));
5888 input.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(second_half)));
5889 input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5890 input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5891 input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5892 input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5893
5894 int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5895 input, effective_scale_multiplier, effective_scale_shift);
5896
5897 result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5898 result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5899 result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5900 result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5901 result.val[0] =
5902 vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5903 result.val[1] =
5904 vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5905 result.val[2] =
5906 vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5907 result.val[3] =
5908 vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5909
5910 const uint32x4_t result_val_1_unsigned =
5911 vreinterpretq_u32_s32(result.val[0]);
5912 const uint32x4_t result_val_2_unsigned =
5913 vreinterpretq_u32_s32(result.val[1]);
5914 const uint32x4_t result_val_3_unsigned =
5915 vreinterpretq_u32_s32(result.val[2]);
5916 const uint32x4_t result_val_4_unsigned =
5917 vreinterpretq_u32_s32(result.val[3]);
5918
5919 const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
5920 const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
5921 const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
5922 const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
5923 const uint16x8_t output_first_half =
5924 vcombine_u16(narrowed_val_1, narrowed_val_2);
5925 const uint16x8_t output_second_half =
5926 vcombine_u16(narrowed_val_3, narrowed_val_4);
5927 const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
5928 const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
5929 const uint8x16_t narrowed_result =
5930 vcombine_u8(narrowed_first_half, narrowed_second_half);
5931 vst1q_u8(output_data + i, narrowed_result);
5932 }
5933
5934 #endif
5935 for (; i < size; ++i) {
5936 const int32_t input = input_data[i] - input_zeropoint;
5937 const int32_t output =
5938 MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5939 effective_scale_shift) +
5940 output_zeropoint;
5941 const int32_t clamped_output =
5942 std::max(std::min(output, kMaxOutput), kMinOutput);
5943 output_data[i] = static_cast<uint8_t>(clamped_output);
5944 }
5945 }
5946
HardSwish(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5947 inline void HardSwish(const RuntimeShape& input_shape, const float* input_data,
5948 const RuntimeShape& output_shape, float* output_data) {
5949 ruy::profiler::ScopeLabel label("HardSwish/Float");
5950 auto size = MatchingFlatSize(input_shape, output_shape);
5951 int i = 0;
5952 #ifdef USE_NEON
5953 const float32x4_t zero = vdupq_n_f32(0.0f);
5954 const float32x4_t three = vdupq_n_f32(3.0f);
5955 const float32x4_t six = vdupq_n_f32(6.0f);
5956 const float32x4_t one_sixth = vdupq_n_f32(1.0f / 6.0f);
5957
5958 for (; i <= size - 16; i += 16) {
5959 // 4x partially unrolled version of the loop below. Refer to its comments.
5960 const float32x4_t in_0 = vld1q_f32(input_data + i + 0);
5961 const float32x4_t in_1 = vld1q_f32(input_data + i + 4);
5962 const float32x4_t in_2 = vld1q_f32(input_data + i + 8);
5963 const float32x4_t in_3 = vld1q_f32(input_data + i + 12);
5964 const float32x4_t in_scaled_0 = vmulq_f32(in_0, one_sixth);
5965 const float32x4_t in_scaled_1 = vmulq_f32(in_1, one_sixth);
5966 const float32x4_t in_scaled_2 = vmulq_f32(in_2, one_sixth);
5967 const float32x4_t in_scaled_3 = vmulq_f32(in_3, one_sixth);
5968 const float32x4_t in_reluish_0 =
5969 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_0, three)));
5970 const float32x4_t in_reluish_1 =
5971 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_1, three)));
5972 const float32x4_t in_reluish_2 =
5973 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_2, three)));
5974 const float32x4_t in_reluish_3 =
5975 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_3, three)));
5976 const float32x4_t product_0 = vmulq_f32(in_scaled_0, in_reluish_0);
5977 const float32x4_t product_1 = vmulq_f32(in_scaled_1, in_reluish_1);
5978 const float32x4_t product_2 = vmulq_f32(in_scaled_2, in_reluish_2);
5979 const float32x4_t product_3 = vmulq_f32(in_scaled_3, in_reluish_3);
5980 vst1q_f32(output_data + i + 0, product_0);
5981 vst1q_f32(output_data + i + 4, product_1);
5982 vst1q_f32(output_data + i + 8, product_2);
5983 vst1q_f32(output_data + i + 12, product_3);
5984 }
5985 for (; i <= size - 4; i += 4) {
5986 // The expression to be computed is:
5987 // out = one_sixth * in * min(six, max(zero, (in + three)))
5988 // We structure the AST to have two roughly balanced, independent branches:
5989 // - Multiplication: in_scaled = one_sixth * in.
5990 // - Addition and clamping: in_reluish = min(six, max(zero, (in + three))).
5991 // Then the remaining multiplication at the root of the tree.
5992 const float32x4_t in = vld1q_f32(input_data + i);
5993 const float32x4_t in_scaled = vmulq_f32(in, one_sixth);
5994 const float32x4_t in_reluish =
5995 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in, three)));
5996 const float32x4_t product = vmulq_f32(in_scaled, in_reluish);
5997 vst1q_f32(output_data + i, product);
5998 }
5999 #endif
6000 for (; i < size; i++) {
6001 const float in = input_data[i];
6002 output_data[i] =
6003 in * std::min(6.0f, std::max(0.0f, in + 3.0f)) * (1.0f / 6.0f);
6004 }
6005 }
6006
6007 #ifdef USE_NEON
SaturateAndStore(int16x8_t src,std::uint8_t * dst)6008 inline void SaturateAndStore(int16x8_t src, std::uint8_t* dst) {
6009 // Narrow values down to 8 bit unsigned, saturating.
6010 uint8x8_t res8 = vqmovun_s16(src);
6011 // Store results to destination.
6012 vst1_u8(dst, res8);
6013 }
6014
SaturateAndStore(int16x8_t src,std::int8_t * dst)6015 inline void SaturateAndStore(int16x8_t src, std::int8_t* dst) {
6016 // Narrow values down to 8 bit unsigned, saturating.
6017 int8x8_t res8 = vqmovn_s16(src);
6018 // Store results to destination.
6019 vst1_s8(dst, res8);
6020 }
6021 #endif
6022
6023 template <typename T>
HardSwish(const HardSwishParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)6024 inline void HardSwish(const HardSwishParams& params,
6025 const RuntimeShape& input_shape, const T* input_data,
6026 const RuntimeShape& output_shape, T* output_data) {
6027 ruy::profiler::ScopeLabel label("HardSwish/Quantized");
6028
6029 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6030
6031 int i = 0;
6032 // This code heavily uses NEON saturating left shifts (vqshl*) with shift
6033 // amounts that can be zero, in which case we rely on the correct behavior
6034 // of a left shift by zero returning just its first operand unmodified.
6035 // Unfortunately, the Intel arm_neon_sse.h implementation of vqshl* is
6036 // buggy in the case of zero shift amounts, see b/137199585. That is why
6037 // this NEON code path is restricted to true ARM NEON, excluding
6038 // arm_neon_sse.h. Anyway, the arm_neon_sse.h implementation of saturating
6039 // left shifts is slow scalar code, so there may not be much benefit in
6040 // running that over just plain reference code.
6041 //
6042 // TODO(b/137199585): revisit when this is fixed.
6043 #ifdef __ARM_NEON
6044 const int16x8_t positive_reluish_multiplier_exponent_minus_one =
6045 vdupq_n_s16(std::max(0, params.reluish_multiplier_exponent - 1));
6046 const int16x8_t positive_reluish_multiplier_exponent_last_bit =
6047 vdupq_n_s16(params.reluish_multiplier_exponent > 0 ? 1 : 0);
6048 const int16x8_t negative_reluish_multiplier_exponent =
6049 vdupq_n_s16(std::min(0, params.reluish_multiplier_exponent));
6050 const int16x8_t constant_32767 = vdupq_n_s16(32767);
6051 const int16x8_t output_multiplier_exponent =
6052 vdupq_n_s16(params.output_multiplier_exponent);
6053 const int16x8_t output_zero_point = vdupq_n_s16(params.output_zero_point);
6054 // 4x unrolled version of the below NEON loop. Read that first.
6055 for (; i <= flat_size - 32; i += 32) {
6056 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6057 const int16x8x2_t input_value_0_1 =
6058 Load16AndSubtractZeroPoint(input_data + i, params.input_zero_point);
6059 const int16x8x2_t input_value_2_3 = Load16AndSubtractZeroPoint(
6060 input_data + i + 16, params.input_zero_point);
6061 const int16x8_t input_value_on_hires_input_scale_0 =
6062 vshlq_n_s16(input_value_0_1.val[0], 7);
6063 const int16x8_t input_value_on_hires_input_scale_1 =
6064 vshlq_n_s16(input_value_0_1.val[1], 7);
6065 const int16x8_t input_value_on_hires_input_scale_2 =
6066 vshlq_n_s16(input_value_2_3.val[0], 7);
6067 const int16x8_t input_value_on_hires_input_scale_3 =
6068 vshlq_n_s16(input_value_2_3.val[1], 7);
6069 const int16x8_t input_value_on_preshift_output_scale_0 =
6070 vqrdmulhq_n_s16(input_value_on_hires_input_scale_0,
6071 params.output_multiplier_fixedpoint_int16);
6072 const int16x8_t input_value_on_preshift_output_scale_1 =
6073 vqrdmulhq_n_s16(input_value_on_hires_input_scale_1,
6074 params.output_multiplier_fixedpoint_int16);
6075 const int16x8_t input_value_on_preshift_output_scale_2 =
6076 vqrdmulhq_n_s16(input_value_on_hires_input_scale_2,
6077 params.output_multiplier_fixedpoint_int16);
6078 const int16x8_t input_value_on_preshift_output_scale_3 =
6079 vqrdmulhq_n_s16(input_value_on_hires_input_scale_3,
6080 params.output_multiplier_fixedpoint_int16);
6081 int16x8_t reluish_value_0 = input_value_on_hires_input_scale_0;
6082 int16x8_t reluish_value_1 = input_value_on_hires_input_scale_1;
6083 int16x8_t reluish_value_2 = input_value_on_hires_input_scale_2;
6084 int16x8_t reluish_value_3 = input_value_on_hires_input_scale_3;
6085 reluish_value_0 = vqshlq_s16(
6086 reluish_value_0, positive_reluish_multiplier_exponent_minus_one);
6087 reluish_value_1 = vqshlq_s16(
6088 reluish_value_1, positive_reluish_multiplier_exponent_minus_one);
6089 reluish_value_2 = vqshlq_s16(
6090 reluish_value_2, positive_reluish_multiplier_exponent_minus_one);
6091 reluish_value_3 = vqshlq_s16(
6092 reluish_value_3, positive_reluish_multiplier_exponent_minus_one);
6093 reluish_value_0 = vqrdmulhq_n_s16(
6094 reluish_value_0, params.reluish_multiplier_fixedpoint_int16);
6095 reluish_value_1 = vqrdmulhq_n_s16(
6096 reluish_value_1, params.reluish_multiplier_fixedpoint_int16);
6097 reluish_value_2 = vqrdmulhq_n_s16(
6098 reluish_value_2, params.reluish_multiplier_fixedpoint_int16);
6099 reluish_value_3 = vqrdmulhq_n_s16(
6100 reluish_value_3, params.reluish_multiplier_fixedpoint_int16);
6101 reluish_value_0 = vqshlq_s16(reluish_value_0,
6102 positive_reluish_multiplier_exponent_last_bit);
6103 reluish_value_1 = vqshlq_s16(reluish_value_1,
6104 positive_reluish_multiplier_exponent_last_bit);
6105 reluish_value_2 = vqshlq_s16(reluish_value_2,
6106 positive_reluish_multiplier_exponent_last_bit);
6107 reluish_value_3 = vqshlq_s16(reluish_value_3,
6108 positive_reluish_multiplier_exponent_last_bit);
6109 reluish_value_0 =
6110 vrshlq_s16(reluish_value_0, negative_reluish_multiplier_exponent);
6111 reluish_value_1 =
6112 vrshlq_s16(reluish_value_1, negative_reluish_multiplier_exponent);
6113 reluish_value_2 =
6114 vrshlq_s16(reluish_value_2, negative_reluish_multiplier_exponent);
6115 reluish_value_3 =
6116 vrshlq_s16(reluish_value_3, negative_reluish_multiplier_exponent);
6117 reluish_value_0 = vrhaddq_s16(reluish_value_0, constant_32767);
6118 reluish_value_1 = vrhaddq_s16(reluish_value_1, constant_32767);
6119 reluish_value_2 = vrhaddq_s16(reluish_value_2, constant_32767);
6120 reluish_value_3 = vrhaddq_s16(reluish_value_3, constant_32767);
6121 const int16x8_t preshift_output_value_0 =
6122 vqdmulhq_s16(reluish_value_0, input_value_on_preshift_output_scale_0);
6123 const int16x8_t preshift_output_value_1 =
6124 vqdmulhq_s16(reluish_value_1, input_value_on_preshift_output_scale_1);
6125 const int16x8_t preshift_output_value_2 =
6126 vqdmulhq_s16(reluish_value_2, input_value_on_preshift_output_scale_2);
6127 const int16x8_t preshift_output_value_3 =
6128 vqdmulhq_s16(reluish_value_3, input_value_on_preshift_output_scale_3);
6129 int16x8_t output_value_0 =
6130 vrshlq_s16(preshift_output_value_0, output_multiplier_exponent);
6131 int16x8_t output_value_1 =
6132 vrshlq_s16(preshift_output_value_1, output_multiplier_exponent);
6133 int16x8_t output_value_2 =
6134 vrshlq_s16(preshift_output_value_2, output_multiplier_exponent);
6135 int16x8_t output_value_3 =
6136 vrshlq_s16(preshift_output_value_3, output_multiplier_exponent);
6137 output_value_0 = vaddq_s16(output_value_0, output_zero_point);
6138 output_value_1 = vaddq_s16(output_value_1, output_zero_point);
6139 output_value_2 = vaddq_s16(output_value_2, output_zero_point);
6140 output_value_3 = vaddq_s16(output_value_3, output_zero_point);
6141 SaturateAndStore(output_value_0, output_data + i);
6142 SaturateAndStore(output_value_1, output_data + i + 8);
6143 SaturateAndStore(output_value_2, output_data + i + 16);
6144 SaturateAndStore(output_value_3, output_data + i + 24);
6145 }
6146 // NEON version of reference_ops::HardSwish. Read that first.
6147 for (; i <= flat_size - 8; i += 8) {
6148 using cpu_backend_gemm::detail::Load8AndSubtractZeroPoint;
6149 const int16x8_t input_value =
6150 Load8AndSubtractZeroPoint(input_data + i, params.input_zero_point);
6151 const int16x8_t input_value_on_hires_input_scale =
6152 vshlq_n_s16(input_value, 7);
6153 const int16x8_t input_value_on_preshift_output_scale =
6154 vqrdmulhq_n_s16(input_value_on_hires_input_scale,
6155 params.output_multiplier_fixedpoint_int16);
6156 int16x8_t reluish_value = input_value_on_hires_input_scale;
6157 reluish_value = vqshlq_s16(reluish_value,
6158 positive_reluish_multiplier_exponent_minus_one);
6159 reluish_value = vqrdmulhq_n_s16(reluish_value,
6160 params.reluish_multiplier_fixedpoint_int16);
6161 reluish_value = vqshlq_s16(reluish_value,
6162 positive_reluish_multiplier_exponent_last_bit);
6163 reluish_value =
6164 vrshlq_s16(reluish_value, negative_reluish_multiplier_exponent);
6165 reluish_value = vrhaddq_s16(reluish_value, constant_32767);
6166 const int16x8_t preshift_output_value =
6167 vqdmulhq_s16(reluish_value, input_value_on_preshift_output_scale);
6168 int16x8_t output_value =
6169 vrshlq_s16(preshift_output_value, output_multiplier_exponent);
6170 output_value = vaddq_s16(output_value, output_zero_point);
6171 SaturateAndStore(output_value, output_data + i);
6172 }
6173 #endif
6174 // TODO(b/137208495): revisit when unit tests cover reference code.
6175 // Fall back to reference_ops::HardSwish. In general we have preferred
6176 // to duplicate such scalar code rather than call reference code to handle
6177 // leftovers, thinking that code duplication was not a big concern.
6178 // However, most of our unit tests happen to test only optimized code,
6179 // and the quantized HardSwish implementation is nontrivial enough that
6180 // I really want test coverage for the reference code.
6181 if (i < flat_size) {
6182 const RuntimeShape leftover_shape{flat_size - i};
6183 reference_ops::HardSwish(params, leftover_shape, input_data + i,
6184 leftover_shape, output_data + i);
6185 }
6186 }
6187
6188 template <typename T>
IntegerExponentPow(const ArithmeticParams & params,const RuntimeShape & unextended_base_shape,const T * base_data,const int exponent,const RuntimeShape & unextended_output_shape,T * output_data)6189 inline void IntegerExponentPow(const ArithmeticParams& params,
6190 const RuntimeShape& unextended_base_shape,
6191 const T* base_data, const int exponent,
6192 const RuntimeShape& unextended_output_shape,
6193 T* output_data) {
6194 TFLITE_DCHECK_GE(exponent, 1);
6195 if (exponent == 1) {
6196 // copy data over.
6197 std::memcpy(output_data, base_data,
6198 unextended_base_shape.FlatSize() * sizeof(T));
6199 } else {
6200 IntegerExponentPow(params, unextended_base_shape, base_data, exponent / 2,
6201 unextended_output_shape, output_data);
6202 Mul(params, unextended_base_shape, output_data, unextended_base_shape,
6203 output_data, unextended_output_shape, output_data);
6204 if (exponent % 2 == 1) {
6205 Mul(params, unextended_base_shape, base_data, unextended_base_shape,
6206 output_data, unextended_output_shape, output_data);
6207 }
6208 }
6209 }
6210
6211 template <typename T>
BroadcastPow4D(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)6212 inline void BroadcastPow4D(const RuntimeShape& unextended_input1_shape,
6213 const T* input1_data,
6214 const RuntimeShape& unextended_input2_shape,
6215 const T* input2_data,
6216 const RuntimeShape& unextended_output_shape,
6217 T* output_data) {
6218 ruy::profiler::ScopeLabel label("PowBroadcast");
6219
6220 if (unextended_input2_shape.FlatSize() == 1) {
6221 static const float epsilon = 1e-5;
6222 const T exponent = input2_data[0];
6223 const int int_exponent = static_cast<int>(std::round(exponent));
6224 if ((std::abs(input2_data[0] - int_exponent) < epsilon) &&
6225 (int_exponent >= 1)) {
6226 ArithmeticParams params;
6227 if (std::is_same<T, float>::value) {
6228 params.float_activation_max = std::numeric_limits<float>::max();
6229 params.float_activation_min = std::numeric_limits<float>::lowest();
6230 } else if (std::is_same<T, int>::value) {
6231 params.quantized_activation_max = std::numeric_limits<int>::max();
6232 params.quantized_activation_min = std::numeric_limits<int>::lowest();
6233 }
6234 IntegerExponentPow(params, unextended_input1_shape, input1_data,
6235 int_exponent, unextended_output_shape, output_data);
6236 return;
6237 }
6238 }
6239 reference_ops::BroadcastPow4DSlow(unextended_input1_shape, input1_data,
6240 unextended_input2_shape, input2_data,
6241 unextended_output_shape, output_data);
6242 }
6243
6244 #ifdef USE_NEON
6245
ScaleWithNewZeroPoint(const int32x4_t input,const float32x4_t scale_dup,const float32x4_t zero_times_scale_dup,float32x4_t * output)6246 inline void ScaleWithNewZeroPoint(const int32x4_t input,
6247 const float32x4_t scale_dup,
6248 const float32x4_t zero_times_scale_dup,
6249 float32x4_t* output) {
6250 #ifdef __ARM_FEATURE_FMA
6251 *output = vfmaq_f32(zero_times_scale_dup, vcvtq_f32_s32(input), scale_dup);
6252 #else
6253 *output = vaddq_f32(vmulq_f32(vcvtq_f32_s32(input), scale_dup),
6254 zero_times_scale_dup);
6255 #endif
6256 }
6257
6258 #endif // USE_NEON
6259
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const uint8_t * input_data,const RuntimeShape & output_shape,float * output_data)6260 inline void Dequantize(const tflite::DequantizationParams& op_params,
6261 const RuntimeShape& input_shape,
6262 const uint8_t* input_data,
6263 const RuntimeShape& output_shape, float* output_data) {
6264 ruy::profiler::ScopeLabel label("Dequantize/Uint8");
6265 const int32 zero_point = op_params.zero_point;
6266 const double scale = op_params.scale;
6267 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6268
6269 int i = 0;
6270 #ifdef USE_NEON
6271 const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6272 const float32x4_t zero_times_scale_dup =
6273 vdupq_n_f32(static_cast<float>(-zero_point * scale));
6274 for (; i <= flat_size - 8; i += 8) {
6275 const uint8x8_t input_u8 = vld1_u8(input_data + i);
6276 const uint16x8_t input_u16 = vmovl_u8(input_u8);
6277 const int16x8_t input_s16 = vreinterpretq_s16_u16(input_u16);
6278 const int16x4_t input_s16_low = vget_low_s16(input_s16);
6279 const int16x4_t input_s16_high = vget_high_s16(input_s16);
6280 const int32x4_t val_low = vmovl_s16(input_s16_low);
6281 const int32x4_t val_high = vmovl_s16(input_s16_high);
6282
6283 float32x4_t result_low, result_high;
6284 ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6285 &result_low);
6286 ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6287 &result_high);
6288
6289 vst1q_f32(output_data + i, result_low);
6290 vst1q_f32(output_data + i + 4, result_high);
6291 }
6292 #endif // NEON
6293 for (; i < flat_size; ++i) {
6294 const int32 val = input_data[i];
6295 const float result = static_cast<float>(scale * (val - zero_point));
6296 output_data[i] = result;
6297 }
6298 }
6299
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & output_shape,float * output_data)6300 inline void Dequantize(const tflite::DequantizationParams& op_params,
6301 const RuntimeShape& input_shape,
6302 const int8_t* input_data,
6303 const RuntimeShape& output_shape, float* output_data) {
6304 ruy::profiler::ScopeLabel label("Dequantize/Int8");
6305 const int32 zero_point = op_params.zero_point;
6306 const double scale = op_params.scale;
6307 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6308
6309 int i = 0;
6310 #ifdef USE_NEON
6311 const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6312 const float32x4_t zero_times_scale_dup =
6313 vdupq_n_f32(static_cast<float>(-zero_point * scale));
6314 for (; i <= flat_size - 8; i += 8) {
6315 const int8x8_t input_s8 = vld1_s8(input_data + i);
6316 const int16x8_t input_s16 = vmovl_s8(input_s8);
6317 const int16x4_t input_s16_low = vget_low_s16(input_s16);
6318 const int16x4_t input_s16_high = vget_high_s16(input_s16);
6319 const int32x4_t val_low = vmovl_s16(input_s16_low);
6320 const int32x4_t val_high = vmovl_s16(input_s16_high);
6321
6322 float32x4_t result_low, result_high;
6323 ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6324 &result_low);
6325 ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6326 &result_high);
6327
6328 vst1q_f32(output_data + i, result_low);
6329 vst1q_f32(output_data + i + 4, result_high);
6330 }
6331 #endif // NEON
6332 for (; i < flat_size; ++i) {
6333 const int32 val = input_data[i];
6334 const float result = static_cast<float>(scale * (val - zero_point));
6335 output_data[i] = result;
6336 }
6337 }
6338
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const int16_t * input_data,const RuntimeShape & output_shape,float * output_data)6339 inline void Dequantize(const tflite::DequantizationParams& op_params,
6340 const RuntimeShape& input_shape,
6341 const int16_t* input_data,
6342 const RuntimeShape& output_shape, float* output_data) {
6343 ruy::profiler::ScopeLabel label("Dequantize/Int16");
6344 const int32 zero_point = op_params.zero_point;
6345 const double scale = op_params.scale;
6346 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6347
6348 int i = 0;
6349 #ifdef USE_NEON
6350 const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6351 const float32x4_t zero_times_scale_dup =
6352 vdupq_n_f32(static_cast<float>(-zero_point * scale));
6353 for (; i <= flat_size - 8; i += 8) {
6354 const int16x4_t input_s16_low = vld1_s16(input_data + i);
6355 const int16x4_t input_s16_high = vld1_s16(input_data + i + 4);
6356 const int32x4_t val_low = vmovl_s16(input_s16_low);
6357 const int32x4_t val_high = vmovl_s16(input_s16_high);
6358
6359 float32x4_t result_low, result_high;
6360 ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6361 &result_low);
6362 ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6363 &result_high);
6364
6365 vst1q_f32(output_data + i, result_low);
6366 vst1q_f32(output_data + i + 4, result_high);
6367 }
6368 #endif // NEON
6369 for (; i < flat_size; ++i) {
6370 const int32 val = input_data[i];
6371 const float result = static_cast<float>(scale * (val - zero_point));
6372 output_data[i] = result;
6373 }
6374 }
6375
Dequantize(const RuntimeShape & input_shape,const Eigen::half * input_data,const RuntimeShape & output_shape,float * output_data)6376 inline void Dequantize(const RuntimeShape& input_shape,
6377 const Eigen::half* input_data,
6378 const RuntimeShape& output_shape, float* output_data) {
6379 reference_ops::Dequantize(input_shape, input_data, output_shape, output_data);
6380 }
6381
6382 template <typename T>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,T * output_data)6383 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6384 const RuntimeShape& input_shape,
6385 const float* input_data,
6386 const RuntimeShape& output_shape, T* output_data) {
6387 reference_ops::AffineQuantize(op_params, input_shape, input_data,
6388 output_shape, output_data);
6389 }
6390
6391 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,int8_t * output_data)6392 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6393 const RuntimeShape& input_shape,
6394 const float* input_data,
6395 const RuntimeShape& output_shape,
6396 int8_t* output_data) {
6397 ruy::profiler::ScopeLabel label("Quantize/Int8");
6398 const int32 zero_point = op_params.zero_point;
6399 const double scale = static_cast<double>(op_params.scale);
6400 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6401 static constexpr int32 min_val = std::numeric_limits<int8_t>::min();
6402 static constexpr int32 max_val = std::numeric_limits<int8_t>::max();
6403
6404 int i = 0;
6405 #ifdef USE_NEON
6406 const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6407 const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6408 const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6409 const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6410
6411 for (; i <= flat_size - 8; i += 8) {
6412 const float* src_data_ptr = input_data + i;
6413 float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6414 float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6415
6416 input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6417 input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6418
6419 int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6420 int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6421
6422 casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6423 casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6424
6425 // Clamp the values to fit the target type's range.
6426 casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6427 casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6428 casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6429 casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6430
6431 const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
6432 const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
6433 const int16x8_t combined_val = vcombine_s16(narrowed_val_0, narrowed_val_1);
6434 const int8x8_t combined_val_narrowed = vmovn_s16(combined_val);
6435 vst1_s8(output_data + i, combined_val_narrowed);
6436 }
6437 #endif // NEON
6438
6439 for (; i < flat_size; ++i) {
6440 const float val = input_data[i];
6441 const int32 unclamped =
6442 static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6443 const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6444 output_data[i] = clamped;
6445 }
6446 }
6447
6448 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,uint8_t * output_data)6449 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6450 const RuntimeShape& input_shape,
6451 const float* input_data,
6452 const RuntimeShape& output_shape,
6453 uint8_t* output_data) {
6454 ruy::profiler::ScopeLabel label("Quantize/Uint8");
6455 const int32 zero_point = op_params.zero_point;
6456 const double scale = static_cast<double>(op_params.scale);
6457 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6458 static constexpr int32 min_val = std::numeric_limits<uint8_t>::min();
6459 static constexpr int32 max_val = std::numeric_limits<uint8_t>::max();
6460
6461 int i = 0;
6462 #ifdef USE_NEON
6463 const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6464 const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6465 const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6466 const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6467
6468 for (; i <= flat_size - 8; i += 8) {
6469 const float* src_data_ptr = input_data + i;
6470 float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6471 float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6472
6473 input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6474 input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6475
6476 int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6477 int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6478
6479 casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6480 casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6481
6482 // Clamp the values to fit the target type's range.
6483 casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6484 casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6485 casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6486 casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6487
6488 const uint16x4_t narrowed_val_0 = vqmovun_s32(casted_val_0);
6489 const uint16x4_t narrowed_val_1 = vqmovun_s32(casted_val_1);
6490 const uint16x8_t combined_val =
6491 vcombine_u16(narrowed_val_0, narrowed_val_1);
6492 const uint8x8_t combined_val_narrowed = vmovn_u16(combined_val);
6493 vst1_u8(output_data + i, combined_val_narrowed);
6494 }
6495 #endif // NEON
6496
6497 for (; i < flat_size; ++i) {
6498 const float val = input_data[i];
6499 const int32 unclamped =
6500 static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6501 const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6502 output_data[i] = clamped;
6503 }
6504 }
6505
6506 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,int16_t * output_data)6507 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6508 const RuntimeShape& input_shape,
6509 const float* input_data,
6510 const RuntimeShape& output_shape,
6511 int16_t* output_data) {
6512 ruy::profiler::ScopeLabel label("Quantize/Int16");
6513 const int32 zero_point = op_params.zero_point;
6514 const double scale = static_cast<double>(op_params.scale);
6515 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6516 static constexpr int32 min_val = std::numeric_limits<int16_t>::min();
6517 static constexpr int32 max_val = std::numeric_limits<int16_t>::max();
6518
6519 int i = 0;
6520 #ifdef USE_NEON
6521 const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6522 const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6523 const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6524 const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6525
6526 for (; i <= flat_size - 8; i += 8) {
6527 const float* src_data_ptr = input_data + i;
6528 float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6529 float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6530
6531 input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6532 input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6533
6534 int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6535 int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6536
6537 casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6538 casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6539
6540 // Clamp the values to fit the target type's range.
6541 casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6542 casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6543 casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6544 casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6545
6546 const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
6547 const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
6548 vst1_s16(output_data + i, narrowed_val_0);
6549 vst1_s16(output_data + i + 4, narrowed_val_1);
6550 }
6551 #endif // NEON
6552
6553 for (; i < flat_size; ++i) {
6554 const float val = input_data[i];
6555 const int32 unclamped =
6556 static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6557 const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6558 output_data[i] = clamped;
6559 }
6560 }
6561
6562 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6563 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6564 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6565 #ifdef GEMMLOWP_NEON
6566
SaturatingRounding(int16x8_t input_val_0,int16x8_t input_val_1,int16x8_t input_val_2,int16x8_t input_val_3,int input_left_shift,int input_multiplier)6567 inline int16x8x4_t SaturatingRounding(
6568 int16x8_t input_val_0, int16x8_t input_val_1, int16x8_t input_val_2,
6569 int16x8_t input_val_3, int input_left_shift, int input_multiplier) {
6570 // This performs what is expressed in the scalar code as
6571 // const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6572 // static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6573 // static_cast<int16>(input_multiplier));
6574 const int16x8_t left_shift_dup = vdupq_n_s16(input_left_shift);
6575 const int16x8_t input_val_shifted_0 = vshlq_s16(input_val_0, left_shift_dup);
6576 const int16x8_t input_val_shifted_1 = vshlq_s16(input_val_1, left_shift_dup);
6577 const int16x8_t input_val_shifted_2 = vshlq_s16(input_val_2, left_shift_dup);
6578 const int16x8_t input_val_shifted_3 = vshlq_s16(input_val_3, left_shift_dup);
6579 int16x8x4_t result;
6580 result.val[0] = vqrdmulhq_n_s16(input_val_shifted_0, input_multiplier);
6581 result.val[1] = vqrdmulhq_n_s16(input_val_shifted_1, input_multiplier);
6582 result.val[2] = vqrdmulhq_n_s16(input_val_shifted_2, input_multiplier);
6583 result.val[3] = vqrdmulhq_n_s16(input_val_shifted_3, input_multiplier);
6584 return result;
6585 }
6586
6587 // 4-bit fixed point is enough for tanh since tanh(16) is almost same with one,
6588 // considering 7 digits under zero.
FixedPoint4Logistic(int16x8x4_t input_val)6589 inline int16x8x4_t FixedPoint4Logistic(int16x8x4_t input_val) {
6590 // Invoke gemmlowp::logistic on FixedPoint wrapping int16x8_t
6591 using FixedPoint4 = gemmlowp::FixedPoint<int16x8_t, 4>;
6592 using FixedPoint0 = gemmlowp::FixedPoint<int16x8_t, 0>;
6593 const FixedPoint4 input_val_f4_0 = FixedPoint4::FromRaw(input_val.val[0]);
6594 const FixedPoint4 input_val_f4_1 = FixedPoint4::FromRaw(input_val.val[1]);
6595 const FixedPoint4 input_val_f4_2 = FixedPoint4::FromRaw(input_val.val[2]);
6596 const FixedPoint4 input_val_f4_3 = FixedPoint4::FromRaw(input_val.val[3]);
6597
6598 // TODO(b/134622898) Implement a low accuracy version of logistic. In this
6599 // method, gemmlowp::tanh spends about 80% of the execution times. The
6600 // current implementation is rougly 12-bit accurate in the 16-bit fixed
6601 // point case. Until reaching to error bounds, there are rooms for
6602 // improvements.
6603 const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
6604 const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
6605 const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
6606 const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
6607
6608 // Divide by 2^7 as in the scalar code
6609 int16x8x4_t result;
6610 result.val[0] = vrshrq_n_s16(output_val_f0_0.raw(), 7);
6611 result.val[1] = vrshrq_n_s16(output_val_f0_1.raw(), 7);
6612 result.val[2] = vrshrq_n_s16(output_val_f0_2.raw(), 7);
6613 result.val[3] = vrshrq_n_s16(output_val_f0_3.raw(), 7);
6614 return result;
6615 }
6616
6617 // 4-bit fixed point is enough for tanh since tanh(16) is almost same with one,
6618 // considering 11 digits under zero at least.
FixedPoint4Tanh(int16x8x4_t input_val)6619 inline int16x8x4_t FixedPoint4Tanh(int16x8x4_t input_val) {
6620 // Invoke gemmlowp::logistic on FixedPoint wrapping int16x8_t
6621 using FixedPoint4 = gemmlowp::FixedPoint<int16x8_t, 4>;
6622 using FixedPoint0 = gemmlowp::FixedPoint<int16x8_t, 0>;
6623 const FixedPoint4 input_val_f4_0 = FixedPoint4::FromRaw(input_val.val[0]);
6624 const FixedPoint4 input_val_f4_1 = FixedPoint4::FromRaw(input_val.val[1]);
6625 const FixedPoint4 input_val_f4_2 = FixedPoint4::FromRaw(input_val.val[2]);
6626 const FixedPoint4 input_val_f4_3 = FixedPoint4::FromRaw(input_val.val[3]);
6627
6628 // TODO(b/134622898) Implement a low accuracy version of logistic. In this
6629 // method, gemmlowp::tanh spends about 80% of the execution times. The
6630 // current implementation is rougly 12-bit accurate in the 16-bit fixed
6631 // point case. Until reaching to error bounds, there are rooms for
6632 // improvements.
6633 const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
6634 const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
6635 const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
6636 const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
6637
6638 // Divide by 2^7 as in the scalar code
6639 int16x8x4_t result;
6640 result.val[0] = vrshrq_n_s16(output_val_f0_0.raw(), 8);
6641 result.val[1] = vrshrq_n_s16(output_val_f0_1.raw(), 8);
6642 result.val[2] = vrshrq_n_s16(output_val_f0_2.raw(), 8);
6643 result.val[3] = vrshrq_n_s16(output_val_f0_3.raw(), 8);
6644 return result;
6645 }
6646
CalculateUnsignedClampingWithRangeBitMasks(int16x8x2_t input_val,int16x8_t range_radius_dup,int16x8_t neg_range_radius_dup)6647 inline uint8x16x2_t CalculateUnsignedClampingWithRangeBitMasks(
6648 int16x8x2_t input_val, int16x8_t range_radius_dup,
6649 int16x8_t neg_range_radius_dup) {
6650 const uint16x8_t mask_rightclamp_0 =
6651 vcgtq_s16(input_val.val[0], range_radius_dup);
6652 const uint16x8_t mask_rightclamp_1 =
6653 vcgtq_s16(input_val.val[1], range_radius_dup);
6654
6655 const uint16x8_t mask_leftclamp_0 =
6656 vcgeq_s16(input_val.val[0], neg_range_radius_dup);
6657 const uint16x8_t mask_leftclamp_1 =
6658 vcgeq_s16(input_val.val[1], neg_range_radius_dup);
6659
6660 uint8x16x2_t result;
6661 result.val[0] = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
6662 vshrn_n_u16(mask_leftclamp_1, 8));
6663 result.val[1] = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
6664 vshrn_n_u16(mask_rightclamp_1, 8));
6665 return result;
6666 }
6667
CalculateSignedClampingWithRangeBitMasks(int16x8x2_t input_val,int16x8_t range_radius_dup,int16x8_t neg_range_radius_dup)6668 inline uint8x16x2_t CalculateSignedClampingWithRangeBitMasks(
6669 int16x8x2_t input_val, int16x8_t range_radius_dup,
6670 int16x8_t neg_range_radius_dup) {
6671 const uint16x8_t mask_rightclamp_0 =
6672 vcgtq_s16(input_val.val[0], range_radius_dup);
6673 const uint16x8_t mask_rightclamp_1 =
6674 vcgtq_s16(input_val.val[1], range_radius_dup);
6675
6676 const uint16x8_t mask_leftclamp_0 =
6677 vcltq_s16(input_val.val[0], neg_range_radius_dup);
6678 const uint16x8_t mask_leftclamp_1 =
6679 vcltq_s16(input_val.val[1], neg_range_radius_dup);
6680
6681 uint8x16x2_t result;
6682 result.val[0] = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
6683 vshrn_n_u16(mask_leftclamp_1, 8));
6684 result.val[1] = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
6685 vshrn_n_u16(mask_rightclamp_1, 8));
6686 return result;
6687 }
6688
ClampWithRangeAndStore(uint8_t * output_dst,uint8x16_t input_val,uint8x16x2_t masks_clamp)6689 inline void ClampWithRangeAndStore(uint8_t* output_dst, uint8x16_t input_val,
6690 uint8x16x2_t masks_clamp) {
6691 // Store back to memory
6692 vst1q_u8(output_dst, vandq_u8(vorrq_u8(input_val, masks_clamp.val[1]),
6693 masks_clamp.val[0]));
6694 }
6695
ClampWithRangeAndStore(int8_t * output_dst,int8x16_t input_val,uint8x16x2_t masks_clamp)6696 inline void ClampWithRangeAndStore(int8_t* output_dst, int8x16_t input_val,
6697 uint8x16x2_t masks_clamp) {
6698 static const int8x16_t max_dup = vdupq_n_s8(127);
6699 static const int8x16_t min_dup = vdupq_n_s8(-128);
6700 // Store back to memory
6701 vst1q_s8(output_dst,
6702 vbslq_s8(masks_clamp.val[1], max_dup,
6703 vbslq_s8(masks_clamp.val[0], min_dup, input_val)));
6704 }
6705
6706 #endif // GEMMLOWP_NEON
6707
Tanh16bitPrecision(const TanhParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)6708 inline void Tanh16bitPrecision(const TanhParams& params,
6709 const RuntimeShape& input_shape,
6710 const uint8* input_data,
6711 const RuntimeShape& output_shape,
6712 uint8* output_data) {
6713 // Note that this is almost the exact same code as in Logistic().
6714 ruy::profiler::ScopeLabel label("Tanh/Uint8");
6715 const int32 input_zero_point = params.input_zero_point;
6716 const int32 input_range_radius = params.input_range_radius;
6717 const int16 input_multiplier = static_cast<int16>(params.input_multiplier);
6718 const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
6719 const int size = MatchingFlatSize(input_shape, output_shape);
6720
6721 int c = 0;
6722 int16_t output_zero_point = 128;
6723
6724 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6725 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6726 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6727 #ifdef GEMMLOWP_NEON
6728 const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
6729 const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
6730 const int16x8_t output_zero_point_s16 = vdupq_n_s16(output_zero_point);
6731
6732 // Handle 32 values at a time
6733 for (; c <= size - 32; c += 32) {
6734 // Read input uint8 values, cast to int16 and subtract input_zero_point
6735 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6736 const int16x8x2_t input_val_centered_0_1 =
6737 Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
6738 const int16x8x2_t input_val_centered_2_3 =
6739 Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
6740
6741 // Prepare the bit masks that we will use at the end to implement the logic
6742 // that was expressed in the scalar code with branching:
6743 // if (input_val_centered < -input_range_radius) {
6744 // output_val = 0;
6745 // } else if (input_val_centered > input_range_radius) {
6746 // output_val = 255;
6747 // } else {
6748 // ...
6749 uint8x16x2_t masks_clamp_0_1 = CalculateUnsignedClampingWithRangeBitMasks(
6750 input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
6751 uint8x16x2_t masks_clamp_2_3 = CalculateUnsignedClampingWithRangeBitMasks(
6752 input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
6753
6754 int16x8x4_t input_val_rescaled = SaturatingRounding(
6755 input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
6756 input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
6757 input_left_shift, input_multiplier);
6758
6759 int16x8x4_t output_val_s16 = FixedPoint4Tanh(input_val_rescaled);
6760
6761 // Add the output zero point
6762 output_val_s16.val[0] =
6763 vaddq_s16(output_val_s16.val[0], output_zero_point_s16);
6764 output_val_s16.val[1] =
6765 vaddq_s16(output_val_s16.val[1], output_zero_point_s16);
6766 output_val_s16.val[2] =
6767 vaddq_s16(output_val_s16.val[2], output_zero_point_s16);
6768 output_val_s16.val[3] =
6769 vaddq_s16(output_val_s16.val[3], output_zero_point_s16);
6770
6771 // Cast output values to uint8, saturating
6772 uint8x16_t output_val_u8_0_1 = vcombine_u8(
6773 vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1]));
6774 uint8x16_t output_val_u8_2_3 = vcombine_u8(
6775 vqmovun_s16(output_val_s16.val[2]), vqmovun_s16(output_val_s16.val[3]));
6776
6777 ClampWithRangeAndStore(output_data + c, output_val_u8_0_1, masks_clamp_0_1);
6778 ClampWithRangeAndStore(output_data + c + 16, output_val_u8_2_3,
6779 masks_clamp_2_3);
6780 }
6781 #endif // GEMMLOWP_NEON
6782 // Leftover loop: handle one value at a time with scalar code.
6783 for (; c < size; ++c) {
6784 const uint8 input_val_u8 = input_data[c];
6785 const int16 input_val_centered =
6786 static_cast<int16>(input_val_u8) - input_zero_point;
6787 uint8 output_val;
6788 if (input_val_centered < -input_range_radius) {
6789 output_val = 0;
6790 } else if (input_val_centered > input_range_radius) {
6791 output_val = 255;
6792 } else {
6793 using gemmlowp::SaturatingRoundingDoublingHighMul;
6794 const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6795 static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6796 static_cast<int16>(input_multiplier));
6797 using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
6798 using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
6799 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
6800 const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
6801 using gemmlowp::RoundingDivideByPOT;
6802 int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8);
6803 output_val_s16 += output_zero_point;
6804 if (output_val_s16 == 256) {
6805 output_val_s16 = 255;
6806 }
6807 TFLITE_DCHECK_GE(output_val_s16, 0);
6808 TFLITE_DCHECK_LE(output_val_s16, 255);
6809 output_val = static_cast<uint8>(output_val_s16);
6810 }
6811 output_data[c] = output_val;
6812 }
6813 }
6814
Tanh16bitPrecision(const TanhParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & output_shape,int8 * output_data)6815 inline void Tanh16bitPrecision(const TanhParams& params,
6816 const RuntimeShape& input_shape,
6817 const int8* input_data,
6818 const RuntimeShape& output_shape,
6819 int8* output_data) {
6820 // Note that this is almost the exact same code as in Logistic().
6821 ruy::profiler::ScopeLabel label("Tanh/Int8");
6822 const int32 input_zero_point = params.input_zero_point;
6823 const int32 input_range_radius = params.input_range_radius;
6824 const int16 input_multiplier = static_cast<int16>(params.input_multiplier);
6825 const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
6826 const int size = MatchingFlatSize(input_shape, output_shape);
6827
6828 int c = 0;
6829 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6830 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6831 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6832 #ifdef GEMMLOWP_NEON
6833 const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
6834 const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
6835
6836 // Handle 32 values at a time
6837 for (; c <= size - 32; c += 32) {
6838 // Read input int8 values, cast to int16 and subtract input_zero_point
6839 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6840 const int16x8x2_t input_val_centered_0_1 =
6841 Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
6842 const int16x8x2_t input_val_centered_2_3 =
6843 Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
6844
6845 // Prepare the bit masks that we will use at the end to implement the logic
6846 // that was expressed in the scalar code with branching:
6847 // if (input_val_centered < -input_range_radius) {
6848 // output_val = -128;
6849 // } else if (input_val_centered > input_range_radius) {
6850 // output_val = 127;
6851 // } else {
6852 // ...
6853 uint8x16x2_t masks_clamp_0_1 = CalculateSignedClampingWithRangeBitMasks(
6854 input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
6855 uint8x16x2_t masks_clamp_2_3 = CalculateSignedClampingWithRangeBitMasks(
6856 input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
6857
6858 int16x8x4_t input_val_rescaled = SaturatingRounding(
6859 input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
6860 input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
6861 input_left_shift, input_multiplier);
6862
6863 int16x8x4_t output_val_s16 = FixedPoint4Tanh(input_val_rescaled);
6864
6865 // Cast output values to uint8, saturating
6866 int8x16_t output_val_s8_0_1 = vcombine_s8(
6867 vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1]));
6868 int8x16_t output_val_s8_2_3 = vcombine_s8(
6869 vqmovn_s16(output_val_s16.val[2]), vqmovn_s16(output_val_s16.val[3]));
6870
6871 ClampWithRangeAndStore(output_data + c, output_val_s8_0_1, masks_clamp_0_1);
6872 ClampWithRangeAndStore(output_data + c + 16, output_val_s8_2_3,
6873 masks_clamp_2_3);
6874 }
6875 #endif // GEMMLOWP_NEON
6876 // Leftover loop: handle one value at a time with scalar code.
6877 for (; c < size; ++c) {
6878 const int8 input_val_s8 = input_data[c];
6879 const int16 input_val_centered =
6880 static_cast<int16>(input_val_s8) - input_zero_point;
6881 int8 output_val;
6882 if (input_val_centered <= -input_range_radius) {
6883 output_val = -128;
6884 } else if (input_val_centered >= input_range_radius) {
6885 output_val = 127;
6886 } else {
6887 using gemmlowp::SaturatingRoundingDoublingHighMul;
6888 const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6889 static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6890 static_cast<int16>(input_multiplier));
6891 using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
6892 using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
6893 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
6894 const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
6895 using gemmlowp::RoundingDivideByPOT;
6896 int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8);
6897 if (output_val_s16 == 128) {
6898 output_val_s16 = 127;
6899 }
6900 TFLITE_DCHECK_GE(output_val_s16, -128);
6901 TFLITE_DCHECK_LE(output_val_s16, 127);
6902 output_val = static_cast<int8>(output_val_s16);
6903 }
6904 output_data[c] = output_val;
6905 }
6906 }
6907
Logistic16bitPrecision(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)6908 inline void Logistic16bitPrecision(const LogisticParams& params,
6909 const RuntimeShape& input_shape,
6910 const uint8* input_data,
6911 const RuntimeShape& output_shape,
6912 uint8* output_data) {
6913 ruy::profiler::ScopeLabel label("Logistic/Uint8");
6914 const int32 input_zero_point = params.input_zero_point;
6915 const int32 input_range_radius = params.input_range_radius;
6916 const int32 input_multiplier = params.input_multiplier;
6917 const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
6918 const int size = MatchingFlatSize(input_shape, output_shape);
6919
6920 int c = 0;
6921 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6922 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6923 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6924 #ifdef GEMMLOWP_NEON
6925 const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
6926 const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
6927
6928 // Handle 32 values at a time
6929 for (; c <= size - 32; c += 32) {
6930 // Read input uint8 values, cast to int16 and subtract input_zero_point
6931 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6932 const int16x8x2_t input_val_centered_0_1 =
6933 Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
6934 const int16x8x2_t input_val_centered_2_3 =
6935 Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
6936
6937 // Prepare the bit masks that we will use at the end to implement the logic
6938 // that was expressed in the scalar code with branching:
6939 // if (input_val_centered < -input_range_radius) {
6940 // output_val = 0;
6941 // } else if (input_val_centered > input_range_radius) {
6942 // output_val = 255;
6943 // } else {
6944 // ...
6945 uint8x16x2_t masks_clamp_0_1 = CalculateUnsignedClampingWithRangeBitMasks(
6946 input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
6947 uint8x16x2_t masks_clamp_2_3 = CalculateUnsignedClampingWithRangeBitMasks(
6948 input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
6949
6950 int16x8x4_t input_val_rescaled = SaturatingRounding(
6951 input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
6952 input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
6953 input_left_shift, input_multiplier);
6954
6955 int16x8x4_t output_val_s16 = FixedPoint4Logistic(input_val_rescaled);
6956
6957 // Cast output values to uint8, saturating
6958 uint8x16_t output_val_u8_0_1 = vcombine_u8(
6959 vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1]));
6960 uint8x16_t output_val_u8_2_3 = vcombine_u8(
6961 vqmovun_s16(output_val_s16.val[2]), vqmovun_s16(output_val_s16.val[3]));
6962
6963 ClampWithRangeAndStore(output_data + c, output_val_u8_0_1, masks_clamp_0_1);
6964 ClampWithRangeAndStore(output_data + c + 16, output_val_u8_2_3,
6965 masks_clamp_2_3);
6966 }
6967 #endif // GEMMLOWP_NEON
6968 // Leftover loop: handle one value at a time with scalar code.
6969 for (; c < size; ++c) {
6970 const uint8 input_val_u8 = input_data[c];
6971 const int16 input_val_centered =
6972 static_cast<int16>(input_val_u8) - input_zero_point;
6973 uint8 output_val;
6974 if (input_val_centered < -input_range_radius) {
6975 output_val = 0;
6976 } else if (input_val_centered > input_range_radius) {
6977 output_val = 255;
6978 } else {
6979 using gemmlowp::SaturatingRoundingDoublingHighMul;
6980 const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6981 static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6982 static_cast<int16>(input_multiplier));
6983 using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
6984 using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
6985 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
6986 const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
6987 using gemmlowp::RoundingDivideByPOT;
6988 int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7);
6989 if (output_val_s16 == 256) {
6990 output_val_s16 = 255;
6991 }
6992 TFLITE_DCHECK_GE(output_val_s16, 0);
6993 TFLITE_DCHECK_LE(output_val_s16, 255);
6994 output_val = static_cast<uint8>(output_val_s16);
6995 }
6996 output_data[c] = output_val;
6997 }
6998 }
6999
Logistic16bitPrecision(const LogisticParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & output_shape,int8 * output_data)7000 inline void Logistic16bitPrecision(const LogisticParams& params,
7001 const RuntimeShape& input_shape,
7002 const int8* input_data,
7003 const RuntimeShape& output_shape,
7004 int8* output_data) {
7005 ruy::profiler::ScopeLabel label("Logistic/Int8");
7006 const int32 input_zero_point = params.input_zero_point;
7007 const int32 input_range_radius = params.input_range_radius;
7008 const int32 input_multiplier = params.input_multiplier;
7009 const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
7010 const int size = MatchingFlatSize(input_shape, output_shape);
7011
7012 int c = 0;
7013 const int16 output_zero_point = 128;
7014 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7015 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7016 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7017 #ifdef GEMMLOWP_NEON
7018 const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
7019 const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
7020 const int16x8_t output_zero_point_dup = vdupq_n_s16(output_zero_point);
7021
7022 // Handle 32 values at a time
7023 for (; c <= size - 32; c += 32) {
7024 // Read input int8 values, cast to int16 and subtract input_zero_point
7025 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
7026 const int16x8x2_t input_val_centered_0_1 =
7027 Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
7028 const int16x8x2_t input_val_centered_2_3 =
7029 Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
7030
7031 // Prepare the bit masks that we will use at the end to implement the logic
7032 // that was expressed in the scalar code with branching:
7033 // if (input_val_centered < -input_range_radius) {
7034 // output_val = -128;
7035 // } else if (input_val_centered > input_range_radius) {
7036 // output_val = 127;
7037 // } else {
7038 // ...
7039 uint8x16x2_t masks_clamp_0_1 = CalculateSignedClampingWithRangeBitMasks(
7040 input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
7041 uint8x16x2_t masks_clamp_2_3 = CalculateSignedClampingWithRangeBitMasks(
7042 input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
7043
7044 int16x8x4_t input_val_rescaled = SaturatingRounding(
7045 input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7046 input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7047 input_left_shift, input_multiplier);
7048
7049 int16x8x4_t output_val_s16 = FixedPoint4Logistic(input_val_rescaled);
7050
7051 // Substract output zero point.
7052 output_val_s16.val[0] =
7053 vsubq_s16(output_val_s16.val[0], output_zero_point_dup);
7054 output_val_s16.val[1] =
7055 vsubq_s16(output_val_s16.val[1], output_zero_point_dup);
7056 output_val_s16.val[2] =
7057 vsubq_s16(output_val_s16.val[2], output_zero_point_dup);
7058 output_val_s16.val[3] =
7059 vsubq_s16(output_val_s16.val[3], output_zero_point_dup);
7060
7061 // Cast output values to int8, saturating
7062 int8x16_t output_val_s8_0_1 = vcombine_s8(
7063 vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1]));
7064 int8x16_t output_val_s8_2_3 = vcombine_s8(
7065 vqmovn_s16(output_val_s16.val[2]), vqmovn_s16(output_val_s16.val[3]));
7066
7067 ClampWithRangeAndStore(output_data + c, output_val_s8_0_1, masks_clamp_0_1);
7068 ClampWithRangeAndStore(output_data + c + 16, output_val_s8_2_3,
7069 masks_clamp_2_3);
7070 }
7071 #endif // GEMMLOWP_NEON
7072 // Leftover loop: handle one value at a time with scalar code.
7073 for (; c < size; ++c) {
7074 const int8 input_val_s8 = input_data[c];
7075 const int16 input_val_centered =
7076 static_cast<int16>(input_val_s8) - input_zero_point;
7077 int8 output_val;
7078 if (input_val_centered < -input_range_radius) {
7079 output_val = -128;
7080 } else if (input_val_centered > input_range_radius) {
7081 output_val = 127;
7082 } else {
7083 using gemmlowp::SaturatingRoundingDoublingHighMul;
7084 const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7085 static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7086 static_cast<int16>(input_multiplier));
7087 using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7088 using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7089 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7090 const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
7091 using gemmlowp::RoundingDivideByPOT;
7092 int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7);
7093 output_val_s16 -= output_zero_point;
7094 if (output_val_s16 == 128) {
7095 output_val_s16 = 127;
7096 }
7097 TFLITE_DCHECK_GE(output_val_s16, -128);
7098 TFLITE_DCHECK_LE(output_val_s16, 127);
7099 output_val = static_cast<int8>(output_val_s16);
7100 }
7101 output_data[c] = output_val;
7102 }
7103 }
7104
7105 // Transpose2D only deals with typical 2D matrix transpose ops.
7106 // Perform transpose by transposing 4x4 blocks of the input, proceeding from
7107 // left to right (down the rows) of the input, and then from top to bottom.
7108 template <typename T>
Transpose2D(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7109 inline void Transpose2D(const RuntimeShape& input_shape, const T* input_data,
7110 const RuntimeShape& output_shape, T* output_data) {
7111 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7112 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
7113
7114 const int d0 = input_shape.DimsData()[0];
7115 const int d1 = input_shape.DimsData()[1];
7116 const int kLines = 4;
7117 const int kSkipSize = (kLines - 1) * d1;
7118
7119 const T* input = input_data;
7120
7121 int i = 0;
7122 for (; i <= d0 - kLines; i += kLines) {
7123 T* output = output_data + i;
7124
7125 const T* input_ptr = input;
7126 optimized_ops_preload_l1_keep(input_ptr);
7127 input_ptr += d1;
7128 optimized_ops_preload_l1_keep(input_ptr);
7129 input_ptr += d1;
7130 optimized_ops_preload_l1_keep(input_ptr);
7131 input_ptr += d1;
7132 optimized_ops_preload_l1_keep(input_ptr);
7133
7134 int j = 0;
7135 for (; j <= d1 - kLines; j += kLines) {
7136 input_ptr = input;
7137 const T a00 = input_ptr[0];
7138 const T a01 = input_ptr[1];
7139 const T a02 = input_ptr[2];
7140 const T a03 = input_ptr[3];
7141 input_ptr += d1;
7142 const T a10 = input_ptr[0];
7143 const T a11 = input_ptr[1];
7144 const T a12 = input_ptr[2];
7145 const T a13 = input_ptr[3];
7146 input_ptr += d1;
7147 const T a20 = input_ptr[0];
7148 const T a21 = input_ptr[1];
7149 const T a22 = input_ptr[2];
7150 const T a23 = input_ptr[3];
7151 input_ptr += d1;
7152 const T a30 = input_ptr[0];
7153 const T a31 = input_ptr[1];
7154 const T a32 = input_ptr[2];
7155 const T a33 = input_ptr[3];
7156
7157 output[0] = a00;
7158 output[1] = a10;
7159 output[2] = a20;
7160 output[3] = a30;
7161 output += d0;
7162
7163 output[0] = a01;
7164 output[1] = a11;
7165 output[2] = a21;
7166 output[3] = a31;
7167 output += d0;
7168
7169 output[0] = a02;
7170 output[1] = a12;
7171 output[2] = a22;
7172 output[3] = a32;
7173 output += d0;
7174
7175 output[0] = a03;
7176 output[1] = a13;
7177 output[2] = a23;
7178 output[3] = a33;
7179 output += d0;
7180
7181 input += kLines;
7182 }
7183 if (j == d1) {
7184 input += kSkipSize;
7185 } else {
7186 for (int p = 0; p < kLines; ++p) {
7187 for (int q = 0; q < d1 - j; ++q) {
7188 *(output + q * d0 + p) = *(input + p * d1 + q);
7189 }
7190 }
7191 input += (d1 - j) + kSkipSize;
7192 }
7193 }
7194 for (; i < d0; ++i) {
7195 T* output = output_data + i;
7196 for (int j = 0; j < d1; ++j) {
7197 *output = *input;
7198 output += d0;
7199 ++input;
7200 }
7201 }
7202 }
7203
7204 template <>
Transpose2D(const RuntimeShape & input_shape,const int32_t * input_data,const RuntimeShape & output_shape,int32_t * output_data)7205 inline void Transpose2D(const RuntimeShape& input_shape,
7206 const int32_t* input_data,
7207 const RuntimeShape& output_shape,
7208 int32_t* output_data) {
7209 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7210 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
7211
7212 const int d0 = input_shape.DimsData()[0];
7213 const int d1 = input_shape.DimsData()[1];
7214 #ifdef USE_NEON
7215 const int kLines = 4;
7216 const int kSkipSize = (kLines - 1) * d1;
7217 #endif
7218
7219 const int32_t* input = input_data;
7220
7221 int i = 0;
7222 #ifdef USE_NEON
7223 for (; i <= d0 - kLines; i += kLines) {
7224 int32_t* output = output_data + i;
7225
7226 const int32_t* input_ptr = input;
7227 optimized_ops_preload_l1_keep(input_ptr);
7228 input_ptr += d1;
7229 optimized_ops_preload_l1_keep(input_ptr);
7230 input_ptr += d1;
7231 optimized_ops_preload_l1_keep(input_ptr);
7232 input_ptr += d1;
7233 optimized_ops_preload_l1_keep(input_ptr);
7234
7235 int j = 0;
7236 for (; j <= d1 - kLines; j += kLines) {
7237 input_ptr = input;
7238 int32x4_t a0 = vld1q_s32(input);
7239 input_ptr += d1;
7240 int32x4_t a1 = vld1q_s32(input_ptr);
7241 input_ptr += d1;
7242 int32x4_t a2 = vld1q_s32(input_ptr);
7243 input_ptr += d1;
7244 int32x4_t a3 = vld1q_s32(input_ptr);
7245
7246 int32x4x2_t tmp1 = vuzpq_s32(a0, a2);
7247 int32x4x2_t tmp2 = vuzpq_s32(a1, a3);
7248 int32x4x2_t tmp3 = vtrnq_s32(tmp1.val[0], tmp2.val[0]);
7249 int32x4x2_t tmp4 = vtrnq_s32(tmp1.val[1], tmp2.val[1]);
7250
7251 vst1q_s32(output, tmp3.val[0]);
7252 output += d0;
7253 vst1q_s32(output, tmp4.val[0]);
7254 output += d0;
7255 vst1q_s32(output, tmp3.val[1]);
7256 output += d0;
7257 vst1q_s32(output, tmp4.val[1]);
7258 output += d0;
7259 input += kLines;
7260 }
7261 if (j == d1) {
7262 input += kSkipSize;
7263 } else {
7264 for (int p = 0; p < kLines; ++p) {
7265 for (int q = 0; q < d1 - j; ++q) {
7266 *(output + q * d0 + p) = *(input + p * d1 + q);
7267 }
7268 }
7269 input += (d1 - j) + kSkipSize;
7270 }
7271 }
7272 #endif
7273 for (; i < d0; ++i) {
7274 int32_t* output = output_data + i;
7275 for (int j = 0; j < d1; ++j) {
7276 *output = *input;
7277 output += d0;
7278 ++input;
7279 }
7280 }
7281 }
7282
7283 // TODO(b/173718660): see if we can reduce the number
7284 // of lines of code in branching without affecting latency.
7285 template <typename T>
Transpose3D(const TransposeParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7286 inline void Transpose3D(const TransposeParams& params,
7287 const RuntimeShape& input_shape, const T* input_data,
7288 const RuntimeShape& output_shape, T* output_data) {
7289 int s1, s2, s3;
7290 s1 = input_shape.Dims(0);
7291 s2 = input_shape.Dims(1);
7292 s3 = input_shape.Dims(2);
7293
7294 int p1, p2, p3;
7295 if (params.perm[0] == 2) {
7296 p1 = 1;
7297 } else if (params.perm[1] == 2) {
7298 p2 = 1;
7299 } else {
7300 p3 = 1;
7301 }
7302
7303 if (params.perm[0] == 1) {
7304 p1 = s3;
7305 } else if (params.perm[1] == 1) {
7306 p2 = s3;
7307 } else {
7308 p3 = s3;
7309 }
7310
7311 if (params.perm[0] == 0) {
7312 p1 = s2 * s3;
7313 } else if (params.perm[1] == 0) {
7314 p2 = s2 * s3;
7315 } else {
7316 p3 = s2 * s3;
7317 }
7318
7319 int o_s[3];
7320 o_s[0] = input_shape.Dims(params.perm[0]);
7321 o_s[1] = input_shape.Dims(params.perm[1]);
7322 o_s[2] = input_shape.Dims(params.perm[2]);
7323
7324 for (int i1 = 0; i1 < o_s[0]; ++i1) {
7325 for (int i2 = 0; i2 < o_s[1]; ++i2) {
7326 for (int i3 = 0; i3 < o_s[2]; ++i3) {
7327 const int i = i1 * p1 + i2 * p2 + i3 * p3;
7328 const int o = i1 * o_s[1] * o_s[2] + i2 * o_s[2] + i3;
7329 output_data[o] = input_data[i];
7330 }
7331 }
7332 }
7333 }
7334
7335 template <typename T, int N>
TransposeImpl(const TransposeParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7336 void TransposeImpl(const TransposeParams& params,
7337 const RuntimeShape& input_shape, const T* input_data,
7338 const RuntimeShape& output_shape, T* output_data) {
7339 const int dims_cnt = input_shape.DimensionsCount();
7340
7341 int dim0, dim1;
7342 if (transpose_utils::IsTranspose2DApplicable(params, input_shape, &dim0,
7343 &dim1)) {
7344 Transpose2D(RuntimeShape({dim0, dim1}), input_data,
7345 RuntimeShape({dim1, dim0}), output_data);
7346 return;
7347 }
7348
7349 // TODO(b/141217325): notably Eigen is better suited for
7350 // larger inputs whereas Transpose3D is generally
7351 // better for smaller ones.
7352 //
7353 // E.g. on Nexus 5, Eigen is better for size 96^3 and up
7354 // and Transpose3D is better for 72^3 and down.
7355 //
7356 // 96^3 is not mobile-friendly for certain usecases
7357 // (e.g. model used in beam search for seq2seq) but is in others.
7358 // Consider tradeoffs.
7359 if (dims_cnt == 3) {
7360 Transpose3D(params, input_shape, input_data, output_shape, output_data);
7361 return;
7362 }
7363
7364 // Reroute to the reference version if an optimized method for the given data
7365 // is not available.
7366 reference_ops::Transpose<T, N>(params, input_shape, input_data, output_shape,
7367 output_data);
7368 }
7369
7370 template <typename T, int N = 5>
Transpose(const TransposeParams & unshrinked_params,const RuntimeShape & unshrinked_input_shape,const T * input_data,const RuntimeShape & unshrinked_output_shape,T * output_data)7371 void Transpose(const TransposeParams& unshrinked_params,
7372 const RuntimeShape& unshrinked_input_shape, const T* input_data,
7373 const RuntimeShape& unshrinked_output_shape, T* output_data) {
7374 ruy::profiler::ScopeLabel label("Transpose");
7375
7376 const int output_size = unshrinked_output_shape.DimensionsCount();
7377 TFLITE_DCHECK_LE(unshrinked_input_shape.DimensionsCount(), N);
7378 TFLITE_DCHECK_LE(output_size, N);
7379 TFLITE_DCHECK_EQ(output_size, unshrinked_params.perm_count);
7380
7381 RuntimeShape shrinked_input_shape = RuntimeShape(unshrinked_input_shape);
7382 RuntimeShape shrinked_output_shape = RuntimeShape(unshrinked_output_shape);
7383 TransposeParams shrinked_params = unshrinked_params;
7384
7385 // Reduce any dimensions that have one size. Lower transpose op usually
7386 // performs better since memory access patterns will be improved.
7387 transpose_utils::RemoveOneSizeDimensions(
7388 &shrinked_input_shape, &shrinked_output_shape, &shrinked_params);
7389
7390 // Handle identity cases.
7391 // TODO(b/140779653): Add an optimization pass in the conversion process to
7392 // remove transpose op nodes where they do nothing like the below one.
7393 bool identical = true;
7394 for (int i = 0; i < shrinked_params.perm_count; ++i) {
7395 if (shrinked_params.perm[i] != i) {
7396 identical = false;
7397 break;
7398 }
7399 }
7400 if (identical) {
7401 memcpy(output_data, input_data,
7402 unshrinked_input_shape.FlatSize() * sizeof(T));
7403 return;
7404 }
7405
7406 // Reduce dimensions by flattening.
7407 if (shrinked_params.perm[0] == 0 && output_size >= 3) {
7408 RuntimeShape non_flatten_input_shape;
7409 RuntimeShape non_flatten_output_shape;
7410 TransposeParams non_flatten_params;
7411 const int total_size = shrinked_input_shape.FlatSize();
7412 const int non_flatten_size = transpose_utils::Flatten(
7413 shrinked_input_shape, shrinked_output_shape, shrinked_params,
7414 &non_flatten_input_shape, &non_flatten_output_shape,
7415 &non_flatten_params);
7416 TFLITE_DCHECK_NE(non_flatten_params.perm[0], 0);
7417
7418 for (int i = 0; i < total_size; i += non_flatten_size) {
7419 TransposeImpl<T, N>(non_flatten_params, non_flatten_input_shape,
7420 input_data + i, non_flatten_output_shape,
7421 output_data + i);
7422 }
7423 return;
7424 }
7425
7426 // Call non-flattened case.
7427 TransposeImpl<T, N>(shrinked_params, shrinked_input_shape, input_data,
7428 shrinked_output_shape, output_data);
7429 }
7430
7431 // Assume input1 & input2 have the same scale & zero point.
MaximumElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)7432 inline void MaximumElementwise(int size, const ArithmeticParams& params,
7433 const int8* input1_data, const int8* input2_data,
7434 int8* output_data) {
7435 ruy::profiler::ScopeLabel label("MaximumElementwiseInt8/8bit");
7436 int i = 0;
7437 #ifdef USE_NEON
7438 for (; i <= size - 16; i += 16) {
7439 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
7440 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7441 const int8x16_t max_data =
7442 vmaxq_s8(input1_val_original, input2_val_original);
7443 vst1q_s8(output_data + i, max_data);
7444 }
7445 #endif // USE_NEON
7446 for (; i < size; ++i) {
7447 const int8 input1_val = input1_data[i];
7448 const int8 input2_val = input2_data[i];
7449 output_data[i] = std::max(input1_val, input2_val);
7450 }
7451 }
7452
MaximumScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)7453 inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params,
7454 int8 input1_data, const int8* input2_data,
7455 int8* output_data) {
7456 ruy::profiler::ScopeLabel label("MaximumScalarBroadcastInt8/8bit");
7457 int i = 0;
7458
7459 #ifdef USE_NEON
7460 const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
7461 for (; i <= size - 16; i += 16) {
7462 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7463 const int8x16_t max_data =
7464 vmaxq_s8(input1_val_original, input2_val_original);
7465 vst1q_s8(output_data + i, max_data);
7466 }
7467 #endif // USE_NEON
7468 for (; i < size; ++i) {
7469 const int8 input2_val = input2_data[i];
7470 output_data[i] = std::max(input1_data, input2_val);
7471 }
7472 }
7473
7474 // Assume input1 & input2 have the same scale & zero point.
MinimumElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)7475 inline void MinimumElementwise(int size, const ArithmeticParams& params,
7476 const int8* input1_data, const int8* input2_data,
7477 int8* output_data) {
7478 ruy::profiler::ScopeLabel label("MinimumElementwiseInt8/8bit");
7479 int i = 0;
7480 #ifdef USE_NEON
7481 for (; i <= size - 16; i += 16) {
7482 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
7483 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7484 const int8x16_t min_data =
7485 vminq_s8(input1_val_original, input2_val_original);
7486 vst1q_s8(output_data + i, min_data);
7487 }
7488 #endif // USE_NEON
7489 for (; i < size; ++i) {
7490 const int8 input1_val = input1_data[i];
7491 const int8 input2_val = input2_data[i];
7492 output_data[i] = std::min(input1_val, input2_val);
7493 }
7494 }
7495
MinimumScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)7496 inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params,
7497 int8 input1_data, const int8* input2_data,
7498 int8* output_data) {
7499 ruy::profiler::ScopeLabel label("MinimumScalarBroadcastInt8/8bit");
7500 int i = 0;
7501
7502 #ifdef USE_NEON
7503 const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
7504 for (; i <= size - 16; i += 16) {
7505 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7506 const int8x16_t min_data =
7507 vminq_s8(input1_val_original, input2_val_original);
7508 vst1q_s8(output_data + i, min_data);
7509 }
7510 #endif // USE_NEON
7511 for (; i < size; ++i) {
7512 const int8 input2_val = input2_data[i];
7513 output_data[i] = std::min(input1_data, input2_val);
7514 }
7515 }
7516
7517 template <typename Op>
BroadcastMaximumDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data,Op op)7518 inline void BroadcastMaximumDispatch(const ArithmeticParams& params,
7519 const RuntimeShape& input1_shape,
7520 const int8* input1_data,
7521 const RuntimeShape& input2_shape,
7522 const int8* input2_data,
7523 const RuntimeShape& output_shape,
7524 int8* output_data, Op op) {
7525 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
7526 return reference_ops::MaximumMinimumBroadcastSlow(
7527 input1_shape, input1_data, input2_shape, input2_data, output_shape,
7528 output_data, op);
7529 }
7530
7531 BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
7532 input2_data, output_shape, output_data,
7533 MaximumElementwise, MaximumScalarBroadcast);
7534 }
7535
7536 template <typename Op>
BroadcastMinimumDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data,Op op)7537 inline void BroadcastMinimumDispatch(const ArithmeticParams& params,
7538 const RuntimeShape& input1_shape,
7539 const int8* input1_data,
7540 const RuntimeShape& input2_shape,
7541 const int8* input2_data,
7542 const RuntimeShape& output_shape,
7543 int8* output_data, Op op) {
7544 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
7545 return reference_ops::MaximumMinimumBroadcastSlow(
7546 input1_shape, input1_data, input2_shape, input2_data, output_shape,
7547 output_data, op);
7548 }
7549
7550 BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
7551 input2_data, output_shape, output_data,
7552 MinimumElementwise, MinimumScalarBroadcast);
7553 }
7554
7555 template <typename T>
CumsumImpl(const T * input_data,const RuntimeShape & shape,int axis,bool exclusive,bool reverse,T * output_data)7556 void CumsumImpl(const T* input_data, const RuntimeShape& shape, int axis,
7557 bool exclusive, bool reverse, T* output_data) {
7558 Eigen::array<Eigen::DenseIndex, 3> dims = {1, 1, 1};
7559
7560 for (int i = 0; i < axis; ++i) {
7561 dims[0] *= shape.Dims(i);
7562 }
7563 dims[1] = shape.Dims(axis);
7564 for (int i = axis + 1; i < shape.DimensionsCount(); ++i) {
7565 dims[2] *= shape.Dims(i);
7566 }
7567
7568 typedef Eigen::TensorMap<
7569 Eigen::Tensor<const T, 3, Eigen::RowMajor, Eigen::DenseIndex>,
7570 Eigen::Aligned>
7571 ConstTensor;
7572 typedef Eigen::TensorMap<
7573 Eigen::Tensor<T, 3, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>
7574 Tensor;
7575 ConstTensor input(input_data, dims);
7576 Tensor output(output_data, dims);
7577
7578 if (reverse) {
7579 Eigen::array<bool, 3> reverse_idx = {false, true, false};
7580 output =
7581 input.reverse(reverse_idx).cumsum(1, exclusive).reverse(reverse_idx);
7582 } else {
7583 output = input.cumsum(1, exclusive);
7584 }
7585 }
7586
7587 template <typename T>
CumSum(const T * input_data,const RuntimeShape & shape,int axis,bool exclusive,bool reverse,T * output_data)7588 void CumSum(const T* input_data, const RuntimeShape& shape, int axis,
7589 bool exclusive, bool reverse, T* output_data) {
7590 const int dim = shape.DimensionsCount();
7591 TFLITE_DCHECK_GE(dim, 1);
7592 CumsumImpl<T>(input_data, shape, axis, exclusive, reverse, output_data);
7593 }
7594
PReluScalarBroadcast(int size,const ArithmeticParams & params,float alpha,const float * input_data,float * output_data)7595 inline void PReluScalarBroadcast(int size, const ArithmeticParams& params,
7596 float alpha, const float* input_data,
7597 float* output_data) {
7598 ruy::profiler::ScopeLabel label("PreluScalarBroadcast/float");
7599 int i = 0;
7600
7601 #ifdef USE_NEON
7602 const float32x4_t zero_dup = vdupq_n_f32(0.0f);
7603 const float32x4_t alpha_dup = vdupq_n_f32(alpha);
7604 for (; i <= size - 16; i += 16) {
7605 const float32x4_t input1 = vld1q_f32(input_data + i);
7606 const float32x4_t input2 = vld1q_f32(input_data + i + 4);
7607 const float32x4_t input3 = vld1q_f32(input_data + i + 8);
7608 const float32x4_t input4 = vld1q_f32(input_data + i + 12);
7609
7610 const float32x4_t temp1 = vmulq_f32(input1, alpha_dup);
7611 const float32x4_t temp2 = vmulq_f32(input2, alpha_dup);
7612 const float32x4_t temp3 = vmulq_f32(input3, alpha_dup);
7613 const float32x4_t temp4 = vmulq_f32(input4, alpha_dup);
7614
7615 const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
7616 const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
7617 const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
7618 const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
7619
7620 const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
7621 vst1q_f32(output_data + i, result1);
7622 const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
7623 vst1q_f32(output_data + i + 4, result2);
7624 const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
7625 vst1q_f32(output_data + i + 8, result3);
7626 const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
7627 vst1q_f32(output_data + i + 12, result4);
7628 }
7629
7630 for (; i <= size - 4; i += 4) {
7631 const float32x4_t input = vld1q_f32(input_data + i);
7632 const float32x4_t temp = vmulq_f32(input, alpha_dup);
7633 const uint32x4_t mask = vcgeq_f32(input, zero_dup);
7634 const float32x4_t result = vbslq_f32(mask, input, temp);
7635 vst1q_f32(output_data + i, result);
7636 }
7637 #endif // USE_NEON
7638 for (; i < size; ++i) {
7639 const float input = input_data[i];
7640 output_data[i] = input >= 0.f ? input : input * alpha;
7641 }
7642 }
7643
PReluElementWise(int flat_size,const ArithmeticParams & params,const float * alpha_data,const float * input_data,float * output_data)7644 inline void PReluElementWise(int flat_size, const ArithmeticParams& params,
7645 const float* alpha_data, const float* input_data,
7646 float* output_data) {
7647 ruy::profiler::ScopeLabel label("PreluElementWise/float");
7648
7649 int i = 0;
7650 #ifdef USE_NEON
7651 const float32x4_t zero_dup = vdupq_n_f32(0.0f);
7652 for (; i <= flat_size - 16; i += 16) {
7653 const float32x4_t input1 = vld1q_f32(input_data + i);
7654 const float32x4_t alpha1 = vld1q_f32(alpha_data + i);
7655 const float32x4_t input2 = vld1q_f32(input_data + i + 4);
7656 const float32x4_t alpha2 = vld1q_f32(alpha_data + i + 4);
7657 const float32x4_t input3 = vld1q_f32(input_data + i + 8);
7658 const float32x4_t alpha3 = vld1q_f32(alpha_data + i + 8);
7659 const float32x4_t input4 = vld1q_f32(input_data + i + 12);
7660 const float32x4_t alpha4 = vld1q_f32(alpha_data + i + 12);
7661
7662 const float32x4_t temp1 = vmulq_f32(input1, alpha1);
7663 const float32x4_t temp2 = vmulq_f32(input2, alpha2);
7664 const float32x4_t temp3 = vmulq_f32(input3, alpha3);
7665 const float32x4_t temp4 = vmulq_f32(input4, alpha4);
7666
7667 const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
7668 const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
7669 const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
7670 const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
7671
7672 const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
7673 vst1q_f32(output_data + i, result1);
7674 const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
7675 vst1q_f32(output_data + i + 4, result2);
7676 const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
7677 vst1q_f32(output_data + i + 8, result3);
7678 const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
7679 vst1q_f32(output_data + i + 12, result4);
7680 }
7681
7682 for (; i <= flat_size - 4; i += 4) {
7683 const float32x4_t input = vld1q_f32(input_data + i);
7684 const float32x4_t alpha = vld1q_f32(alpha_data + i);
7685
7686 const float32x4_t temp = vmulq_f32(input, alpha);
7687 const uint32x4_t mask = vcgeq_f32(input, zero_dup);
7688 const float32x4_t result = vbslq_f32(mask, input, temp);
7689 vst1q_f32(output_data + i, result);
7690 }
7691 #endif // USE_NEON
7692 for (; i < flat_size; ++i) {
7693 const float input = input_data[i];
7694 const float alpha = alpha_data[i];
7695 output_data[i] = input >= 0.f ? input : input * alpha;
7696 }
7697 }
7698
BroadcastPReluDispatch(const ArithmeticParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & alpha_shape,const float * alpha_data,const RuntimeShape & output_shape,float * output_data,float (* func)(float,float))7699 inline void BroadcastPReluDispatch(
7700 const ArithmeticParams& params, const RuntimeShape& input_shape,
7701 const float* input_data, const RuntimeShape& alpha_shape,
7702 const float* alpha_data, const RuntimeShape& output_shape,
7703 float* output_data, float (*func)(float, float)) {
7704 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
7705 return reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
7706 input_shape, input_data, alpha_shape, alpha_data, output_shape,
7707 output_data, func);
7708 }
7709
7710 BinaryBroadcastFiveFold(params, input_shape, input_data, alpha_shape,
7711 alpha_data, output_shape, output_data,
7712 PReluElementWise, PReluScalarBroadcast);
7713 }
7714
7715 // Returns the index with minimum value within `input_data`.
7716 // If there is a tie, returns the smaller index.
7717 template <typename T>
ArgMinVector(const T * input_data,int size)7718 inline int ArgMinVector(const T* input_data, int size) {
7719 T min_value = input_data[0];
7720 int min_index = 0;
7721 for (int i = 1; i < size; ++i) {
7722 const T curr_value = input_data[i];
7723 if (curr_value < min_value) {
7724 min_value = curr_value;
7725 min_index = i;
7726 }
7727 }
7728 return min_index;
7729 }
7730
7731 // Returns the index with maximum value within `input_data`.
7732 // If there is a tie, returns the smaller index.
7733 template <typename T>
ArgMaxVector(const T * input_data,int size)7734 inline int ArgMaxVector(const T* input_data, int size) {
7735 T max_value = input_data[0];
7736 int max_index = 0;
7737 for (int i = 1; i < size; ++i) {
7738 const T curr_value = input_data[i];
7739 if (curr_value > max_value) {
7740 max_value = curr_value;
7741 max_index = i;
7742 }
7743 }
7744 return max_index;
7745 }
7746
7747 template <>
ArgMinVector(const float * input_data,int size)7748 inline int ArgMinVector(const float* input_data, int size) {
7749 int32_t min_index = 0;
7750 float min_value = input_data[0];
7751 int32_t i = 1;
7752 #ifdef USE_NEON
7753 if (size >= 4) {
7754 float32x4_t min_value_f32x4 = vld1q_f32(input_data);
7755 const int32_t index_init[4] = {0, 1, 2, 3};
7756 int32x4_t min_index_s32x4 = vld1q_s32(index_init);
7757 int32x4_t index_s32x4 = min_index_s32x4;
7758 int32x4_t inc = vdupq_n_s32(4);
7759 for (i = 4; i <= size - 4; i += 4) {
7760 // Increase indices by 4.
7761 index_s32x4 = vaddq_s32(index_s32x4, inc);
7762 float32x4_t v = vld1q_f32(&input_data[i]);
7763 uint32x4_t mask = vcltq_f32(v, min_value_f32x4);
7764 min_value_f32x4 = vminq_f32(min_value_f32x4, v);
7765 min_index_s32x4 = vbslq_s32(mask, index_s32x4, min_index_s32x4);
7766 }
7767 // Find min element within float32x4_t.
7768 #ifdef __aarch64__
7769 min_value = vminvq_f32(min_value_f32x4);
7770 #else
7771 float32x2_t min_value_f32x2 = vpmin_f32(vget_low_f32(min_value_f32x4),
7772 vget_high_f32(min_value_f32x4));
7773 min_value_f32x2 = vpmin_f32(min_value_f32x2, min_value_f32x2);
7774 min_value = vget_lane_f32(min_value_f32x2, 0);
7775 #endif // __aarch64__
7776 // Mask indices of non-min values with max int32_t.
7777 float32x4_t fill_min_value_f32x4 = vdupq_n_f32(min_value);
7778 uint32x4_t mask = vceqq_f32(min_value_f32x4, fill_min_value_f32x4);
7779 int32x4_t all_set = vdupq_n_s32(std::numeric_limits<int>::max());
7780 min_index_s32x4 = vbslq_s32(mask, min_index_s32x4, all_set);
7781 // Find min index of min values.
7782 #ifdef __aarch64__
7783 min_index = vminvq_s32(min_index_s32x4);
7784 #else
7785 int32x2_t min_index_s32x2 = vpmin_s32(vget_low_s32(min_index_s32x4),
7786 vget_high_s32(min_index_s32x4));
7787 min_index_s32x2 = vpmin_s32(min_index_s32x2, min_index_s32x2);
7788 min_index = vget_lane_s32(min_index_s32x2, 0);
7789 #endif // __aarch64__
7790 }
7791 #endif // USE_NEON
7792 // Leftover loop.
7793 for (; i < size; ++i) {
7794 const float curr_value = input_data[i];
7795 if (curr_value < min_value) {
7796 min_value = curr_value;
7797 min_index = i;
7798 }
7799 }
7800 return min_index;
7801 }
7802
7803 template <>
ArgMaxVector(const float * input_data,int size)7804 inline int ArgMaxVector(const float* input_data, int size) {
7805 int32_t max_index = 0;
7806 float max_value = input_data[0];
7807 int32_t i = 1;
7808 #ifdef USE_NEON
7809 if (size >= 4) {
7810 float32x4_t max_value_f32x4 = vld1q_f32(input_data);
7811 const int32_t index_init[4] = {0, 1, 2, 3};
7812 int32x4_t max_index_s32x4 = vld1q_s32(index_init);
7813 int32x4_t index_s32x4 = max_index_s32x4;
7814 int32x4_t inc = vdupq_n_s32(4);
7815 for (i = 4; i <= size - 4; i += 4) {
7816 // Increase indices by 4.
7817 index_s32x4 = vaddq_s32(index_s32x4, inc);
7818 float32x4_t v = vld1q_f32(&input_data[i]);
7819 uint32x4_t mask = vcgtq_f32(v, max_value_f32x4);
7820 max_value_f32x4 = vmaxq_f32(max_value_f32x4, v);
7821 max_index_s32x4 = vbslq_s32(mask, index_s32x4, max_index_s32x4);
7822 }
7823 // Find max element within float32x4_t.
7824 #ifdef __aarch64__
7825 max_value = vmaxvq_f32(max_value_f32x4);
7826 #else
7827 float32x2_t max_value_f32x2 = vpmax_f32(vget_low_f32(max_value_f32x4),
7828 vget_high_f32(max_value_f32x4));
7829 max_value_f32x2 = vpmax_f32(max_value_f32x2, max_value_f32x2);
7830 max_value = vget_lane_f32(max_value_f32x2, 0);
7831 #endif // __aarch64__
7832 // Mask indices of non-max values with max int32_t.
7833 float32x4_t fill_max_value_f32x4 = vdupq_n_f32(max_value);
7834 uint32x4_t mask = vceqq_f32(max_value_f32x4, fill_max_value_f32x4);
7835 int32x4_t all_set = vdupq_n_s32(std::numeric_limits<int>::max());
7836 max_index_s32x4 = vbslq_s32(mask, max_index_s32x4, all_set);
7837 // Find min index of max values.
7838 #ifdef __aarch64__
7839 max_index = vminvq_s32(max_index_s32x4);
7840 #else
7841 int32x2_t max_index_s32x2 = vpmin_s32(vget_low_s32(max_index_s32x4),
7842 vget_high_s32(max_index_s32x4));
7843 max_index_s32x2 = vpmin_s32(max_index_s32x2, max_index_s32x2);
7844 max_index = vget_lane_s32(max_index_s32x2, 0);
7845 #endif // __aarch64__
7846 }
7847 #endif // USE_NEON
7848 // Leftover loop.
7849 for (; i < size; ++i) {
7850 const float curr_value = input_data[i];
7851 if (curr_value > max_value) {
7852 max_value = curr_value;
7853 max_index = i;
7854 }
7855 }
7856 return max_index;
7857 }
7858
7859 template <>
ArgMaxVector(const int8_t * input_data,int size)7860 inline int ArgMaxVector(const int8_t* input_data, int size) {
7861 int32_t max_index = 0;
7862 int8_t max_value = input_data[0];
7863 int32_t i = 0;
7864 #ifdef USE_NEON
7865 constexpr int VECTOR_SIZE = 16;
7866 if (size >= VECTOR_SIZE) {
7867 int8x16_t max_value_s8x16;
7868 for (; i <= size - VECTOR_SIZE; i += VECTOR_SIZE) {
7869 max_value_s8x16 = vld1q_s8(input_data + i);
7870 int8_t max_from_vec;
7871 #ifdef __aarch64__
7872 max_from_vec = vmaxvq_s8(max_value_s8x16);
7873 #else // 32 bit
7874 int8x8_t max_val_s8x8 =
7875 vpmax_s8(vget_low_s8(max_value_s8x16), vget_high_s8(max_value_s8x16));
7876 max_val_s8x8 = vpmax_s8(max_val_s8x8, max_val_s8x8);
7877 max_val_s8x8 = vpmax_s8(max_val_s8x8, max_val_s8x8);
7878 max_val_s8x8 = vpmax_s8(max_val_s8x8, max_val_s8x8);
7879 max_from_vec = vget_lane_s8(max_val_s8x8, 0);
7880 #endif // __aarch64__
7881 if (max_from_vec > max_value) {
7882 max_value = max_from_vec;
7883 max_index = i;
7884 }
7885 }
7886 }
7887 for (int start_idx = max_index; start_idx < max_index + VECTOR_SIZE;
7888 start_idx++) {
7889 if (input_data[start_idx] == max_value) {
7890 max_index = start_idx;
7891 break;
7892 }
7893 }
7894
7895 #endif // USE_NEON
7896 // Leftover loop.
7897 for (; i < size; ++i) {
7898 const int8_t curr_value = input_data[i];
7899 if (curr_value > max_value) {
7900 max_value = curr_value;
7901 max_index = i;
7902 }
7903 }
7904
7905 return max_index;
7906 }
7907
7908 template <>
ArgMaxVector(const uint8_t * input_data,int size)7909 inline int ArgMaxVector(const uint8_t* input_data, int size) {
7910 int32_t max_index = 0;
7911 uint8_t max_value = input_data[0];
7912 int32_t i = 0;
7913 #ifdef USE_NEON
7914 constexpr int VECTOR_SIZE = 16;
7915 if (size >= VECTOR_SIZE) {
7916 uint8x16_t max_value_u8x16;
7917 for (; i <= size - VECTOR_SIZE; i += VECTOR_SIZE) {
7918 max_value_u8x16 = vld1q_u8(input_data + i);
7919 uint8_t max_from_vec;
7920 #ifdef __aarch64__
7921 max_from_vec = vmaxvq_u8(max_value_u8x16);
7922 #else // 32 bit
7923 uint8x8_t max_val_u8x8 =
7924 vpmax_u8(vget_low_u8(max_value_u8x16), vget_high_u8(max_value_u8x16));
7925 max_val_u8x8 = vpmax_u8(max_val_u8x8, max_val_u8x8);
7926 max_val_u8x8 = vpmax_u8(max_val_u8x8, max_val_u8x8);
7927 max_val_u8x8 = vpmax_u8(max_val_u8x8, max_val_u8x8);
7928 max_from_vec = vget_lane_u8(max_val_u8x8, 0);
7929 #endif // __aarch64__
7930 if (max_from_vec > max_value) {
7931 max_value = max_from_vec;
7932 max_index = i;
7933 }
7934 }
7935 }
7936 for (int start_idx = max_index; start_idx < max_index + VECTOR_SIZE;
7937 start_idx++) {
7938 if (input_data[start_idx] == max_value) {
7939 max_index = start_idx;
7940 break;
7941 }
7942 }
7943
7944 #endif // USE_NEON
7945 // Leftover loop.
7946 for (; i < size; ++i) {
7947 const uint8_t curr_value = input_data[i];
7948 if (curr_value > max_value) {
7949 max_value = curr_value;
7950 max_index = i;
7951 }
7952 }
7953
7954 return max_index;
7955 }
7956
7957 // Specializes ArgMinMax function with axis=dims-1.
7958 // In this case, ArgMinMax reduction is applied on contiguous memory.
7959 template <typename T1, typename T2, bool is_arg_max>
ArgMinMaxLastAxis(const RuntimeShape & input_shape,const T1 * input_data,const RuntimeShape & output_shape,T2 * output_data)7960 inline void ArgMinMaxLastAxis(const RuntimeShape& input_shape,
7961 const T1* input_data,
7962 const RuntimeShape& output_shape,
7963 T2* output_data) {
7964 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7965 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 1);
7966 TFLITE_DCHECK_EQ(input_shape.Dims(0), output_shape.Dims(0));
7967
7968 int outer_size = input_shape.Dims(0);
7969 int axis_size = input_shape.Dims(1);
7970 for (int outer = 0; outer < outer_size; ++outer) {
7971 if (is_arg_max) {
7972 output_data[outer] = static_cast<T2>(
7973 ArgMaxVector<T1>(input_data + outer * axis_size, axis_size));
7974 } else {
7975 output_data[outer] = static_cast<T2>(
7976 ArgMinVector<T1>(input_data + outer * axis_size, axis_size));
7977 }
7978 }
7979 }
7980
7981 template <typename T1, typename T2, typename T3>
ArgMinMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data,const bool is_arg_max)7982 inline void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
7983 const T3* input2_data, const RuntimeShape& output_shape,
7984 T2* output_data, const bool is_arg_max) {
7985 ruy::profiler::ScopeLabel label("ArgMinMax");
7986
7987 TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0);
7988 TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1,
7989 output_shape.DimensionsCount());
7990 int axis = input2_data[0];
7991 if (axis < 0) {
7992 axis += input1_shape.DimensionsCount();
7993 }
7994 const int axis_size = input1_shape.Dims(axis);
7995
7996 int outer_size = 1;
7997 for (int i = 0; i < axis; ++i) {
7998 TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i));
7999 outer_size *= input1_shape.Dims(i);
8000 }
8001
8002 int inner_size = 1;
8003 const int dims_count = input1_shape.DimensionsCount();
8004 for (int i = axis + 1; i < dims_count; ++i) {
8005 TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1));
8006 inner_size *= input1_shape.Dims(i);
8007 }
8008
8009 // Call specialized function when axis=dims-1. So far, only float32 is
8010 // optimized so reroute to specialized function only when T1 is float32.
8011 if (inner_size == 1 &&
8012 (std::is_same<T1, float>::value || std::is_same<T1, int8_t>::value ||
8013 std::is_same<T1, uint8_t>::value)) {
8014 if (is_arg_max) {
8015 ArgMinMaxLastAxis<T1, T2, /*is_arg_max=*/true>(
8016 {outer_size, axis_size}, input1_data, {outer_size}, output_data);
8017 } else {
8018 ArgMinMaxLastAxis<T1, T2, /*is_arg_max=*/false>(
8019 {outer_size, axis_size}, input1_data, {outer_size}, output_data);
8020 }
8021 return;
8022 }
8023
8024 reference_ops::ArgMinMax(input1_shape, input1_data, input2_data, output_shape,
8025 output_data, is_arg_max);
8026 }
8027
8028 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)8029 void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
8030 const T3* input2_data, const RuntimeShape& output_shape,
8031 T2* output_data) {
8032 ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
8033 /*is_arg_max=*/true);
8034 }
8035
8036 // Convenience version that allows, for example, generated-code calls to be
8037 // the same as other binary ops.
8038 // For backward compatibility, reference_ops has ArgMax function.
8039 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const RuntimeShape & input2_shape,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)8040 inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
8041 const RuntimeShape& input2_shape, const T3* input2_data,
8042 const RuntimeShape& output_shape, T2* output_data) {
8043 // Drop shape of second input: not needed.
8044 ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
8045 }
8046
Conv3D(const Conv3DParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data,const RuntimeShape & transposed_filter_shape,float * transposed_filter_data,CpuBackendContext * cpu_backend_context)8047 inline void Conv3D(const Conv3DParams& params, const RuntimeShape& input_shape,
8048 const float* input_data, const RuntimeShape& filter_shape,
8049 const float* filter_data, const RuntimeShape& bias_shape,
8050 const float* bias_data, const RuntimeShape& output_shape,
8051 float* output_data, const RuntimeShape& im2col_shape,
8052 float* im2col_data,
8053 const RuntimeShape& transposed_filter_shape,
8054 float* transposed_filter_data,
8055 CpuBackendContext* cpu_backend_context) {
8056 const int stride_depth = params.stride_depth;
8057 const int stride_height = params.stride_height;
8058 const int stride_width = params.stride_width;
8059 const int dilation_depth_factor = params.dilation_depth;
8060 const int dilation_height_factor = params.dilation_height;
8061 const int dilation_width_factor = params.dilation_width;
8062 const float output_activation_min = params.float_activation_min;
8063 const float output_activation_max = params.float_activation_max;
8064 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
8065 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 5);
8066 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 5);
8067
8068 ruy::profiler::ScopeLabel label("Conv3D");
8069
8070 // NB: the float 0.0f value is represented by all zero bytes.
8071 const uint8 float_zero_byte = 0x00;
8072 const float* gemm_input_data = nullptr;
8073 const RuntimeShape* gemm_input_shape = nullptr;
8074 const int filter_width = filter_shape.Dims(2);
8075 const int filter_height = filter_shape.Dims(1);
8076 const int filter_depth = filter_shape.Dims(0);
8077 const bool need_dilated_im2col = dilation_width_factor != 1 ||
8078 dilation_height_factor != 1 ||
8079 dilation_depth_factor != 1;
8080 const bool need_im2col = stride_depth != 1 || stride_height != 1 ||
8081 stride_width != 1 || filter_depth != 1 ||
8082 filter_height != 1 || filter_width != 1;
8083
8084 if (need_dilated_im2col) {
8085 DilatedIm2col3D(params, filter_depth, filter_height, filter_width,
8086 float_zero_byte, input_shape, input_data, im2col_shape,
8087 im2col_data);
8088 gemm_input_data = im2col_data;
8089 gemm_input_shape = &im2col_shape;
8090 } else if (need_im2col) {
8091 TFLITE_DCHECK(im2col_data);
8092 Im2col3D(params, filter_depth, filter_height, filter_width, float_zero_byte,
8093 input_shape, input_data, im2col_shape, im2col_data);
8094 gemm_input_data = im2col_data;
8095 gemm_input_shape = &im2col_shape;
8096 } else {
8097 TFLITE_DCHECK(!im2col_data);
8098 gemm_input_data = input_data;
8099 gemm_input_shape = &input_shape;
8100 }
8101
8102 // Transpose the filter tensor.
8103 TransposeParams transpose_params;
8104 transpose_params.perm_count = 5;
8105 transpose_params.perm[0] = 4;
8106 transpose_params.perm[1] = 0;
8107 transpose_params.perm[2] = 1;
8108 transpose_params.perm[3] = 2;
8109 transpose_params.perm[4] = 3;
8110 Transpose<float, 5>(transpose_params, filter_shape, filter_data,
8111 transposed_filter_shape, transposed_filter_data);
8112
8113 const int gemm_input_dims = gemm_input_shape->DimensionsCount();
8114 int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
8115 int n = output_shape.Dims(4);
8116 int k = gemm_input_shape->Dims(gemm_input_dims - 1);
8117
8118 cpu_backend_gemm::MatrixParams<float> lhs_params;
8119 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
8120 lhs_params.rows = n;
8121 lhs_params.cols = k;
8122 cpu_backend_gemm::MatrixParams<float> rhs_params;
8123 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
8124 rhs_params.rows = k;
8125 rhs_params.cols = m;
8126 cpu_backend_gemm::MatrixParams<float> dst_params;
8127 dst_params.order = cpu_backend_gemm::Order::kColMajor;
8128 dst_params.rows = n;
8129 dst_params.cols = m;
8130 cpu_backend_gemm::GemmParams<float, float> gemm_params;
8131 gemm_params.bias = bias_data;
8132 gemm_params.clamp_min = output_activation_min;
8133 gemm_params.clamp_max = output_activation_max;
8134 cpu_backend_gemm::Gemm(lhs_params, transposed_filter_data, rhs_params,
8135 gemm_input_data, dst_params, output_data, gemm_params,
8136 cpu_backend_context);
8137 }
8138
8139 // Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
8140 // order (planes, height, width, channel), constructed from patches in
8141 // 'col_data', which is required to be in storage order (out_planes * out_height
8142 // * out_width, filter_planes, filter_height, filter_width, in_channel).
8143 //
8144 // This function is copied from tensorflow/core/kernels/conv_grad_ops_3d.cc
8145 // authored by Eugene Zhulenev(ezhulenev).
8146 template <typename T>
Col2im(const T * col_data,const int channel,const int planes,const int height,const int width,const int filter_p,const int filter_h,const int filter_w,const int pad_pt,const int pad_t,const int pad_l,const int pad_pb,const int pad_b,const int pad_r,const int stride_p,const int stride_h,const int stride_w,T * im_data)8147 void Col2im(const T* col_data, const int channel, const int planes,
8148 const int height, const int width, const int filter_p,
8149 const int filter_h, const int filter_w, const int pad_pt,
8150 const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
8151 const int pad_r, const int stride_p, const int stride_h,
8152 const int stride_w, T* im_data) {
8153 const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
8154 const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
8155 const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
8156 int p_pad = -pad_pt;
8157 for (int p = 0; p < planes_col; ++p) {
8158 int h_pad = -pad_t;
8159 for (int h = 0; h < height_col; ++h) {
8160 int w_pad = -pad_l;
8161 for (int w = 0; w < width_col; ++w) {
8162 T* im_patch_data =
8163 im_data +
8164 (p_pad * height * width + h_pad * width + w_pad) * channel;
8165 for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
8166 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
8167 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
8168 if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
8169 iw < width) {
8170 for (int i = 0; i < channel; ++i) {
8171 im_patch_data[i] += col_data[i];
8172 }
8173 }
8174 im_patch_data += channel;
8175 col_data += channel;
8176 }
8177 // Jump over remaining number of channel.
8178 im_patch_data += channel * (width - filter_w);
8179 }
8180 // Jump over remaining number of (channel * width).
8181 im_patch_data += (channel * width) * (height - filter_h);
8182 }
8183 w_pad += stride_w;
8184 }
8185 h_pad += stride_h;
8186 }
8187 p_pad += stride_p;
8188 }
8189 }
8190
8191 template <typename T>
BiasAdd3D(T * im_data,const T * bias_data,const RuntimeShape & input_shape,float float_activation_min,float float_activation_max)8192 void BiasAdd3D(T* im_data, const T* bias_data, const RuntimeShape& input_shape,
8193 float float_activation_min, float float_activation_max) {
8194 if (bias_data) {
8195 const int outer_size = input_shape.Dims(0) * input_shape.Dims(1) *
8196 input_shape.Dims(2) * input_shape.Dims(3);
8197 const int num_channels = input_shape.Dims(4);
8198 for (int n = 0; n < outer_size; ++n) {
8199 for (int c = 0; c < num_channels; ++c) {
8200 im_data[c] = ActivationFunctionWithMinMax(im_data[c] + bias_data[c],
8201 float_activation_min,
8202 float_activation_max);
8203 }
8204 im_data += num_channels;
8205 }
8206 } else {
8207 const int flat_size = input_shape.FlatSize();
8208 for (int i = 0; i < flat_size; ++i) {
8209 im_data[i] = ActivationFunctionWithMinMax(
8210 im_data[i], float_activation_min, float_activation_max);
8211 }
8212 }
8213 }
8214
Conv3DTranspose(const Conv3DTransposeParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * const output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)8215 inline void Conv3DTranspose(
8216 const Conv3DTransposeParams& params, const RuntimeShape& input_shape,
8217 const float* input_data, const RuntimeShape& filter_shape,
8218 const float* filter_data, const RuntimeShape& bias_shape,
8219 const float* bias_data, const RuntimeShape& output_shape,
8220 float* const output_data, const RuntimeShape& col2im_shape,
8221 float* col2im_data, CpuBackendContext* cpu_backend_context) {
8222 ruy::profiler::ScopeLabel label("Conv3DTranspose/float");
8223 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
8224 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 5);
8225 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 5);
8226 TFLITE_DCHECK(col2im_data);
8227
8228 const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
8229 const int input_channel = MatchingDim(input_shape, 4, filter_shape, 4);
8230 const int output_channel = MatchingDim(output_shape, 4, filter_shape, 3);
8231 const int input_spatial_size =
8232 input_shape.Dims(1) * input_shape.Dims(2) * input_shape.Dims(3);
8233 const int output_spatial_size =
8234 output_shape.Dims(1) * output_shape.Dims(2) * output_shape.Dims(3);
8235
8236 const int output_spatial_dim_1 = output_shape.Dims(1);
8237 const int output_spatial_dim_2 = output_shape.Dims(2);
8238 const int output_spatial_dim_3 = output_shape.Dims(3);
8239 const int input_offset = input_spatial_size * input_channel;
8240 const int output_offset = output_spatial_size * output_channel;
8241
8242 const int filter_spatial_dim_1 = filter_shape.Dims(0);
8243 const int filter_spatial_dim_2 = filter_shape.Dims(1);
8244 const int filter_spatial_dim_3 = filter_shape.Dims(2);
8245
8246 const int spatial_dim_1_padding_before = params.padding_values.depth;
8247 const int spatial_dim_1_padding_after =
8248 params.padding_values.height + params.padding_values.depth_offset;
8249 const int spatial_dim_2_padding_before = params.padding_values.height;
8250 const int spatial_dim_2_padding_after =
8251 params.padding_values.height + params.padding_values.height_offset;
8252 const int spatial_dim_3_padding_before = params.padding_values.width;
8253 const int spatial_dim_3_padding_after =
8254 params.padding_values.width + params.padding_values.width_offset;
8255 const int spatial_dim_1_stride = params.stride_depth;
8256 const int spatial_dim_2_stride = params.stride_height;
8257 const int spatial_dim_3_stride = params.stride_width;
8258 const int filter_total_size = filter_spatial_dim_1 * filter_spatial_dim_2 *
8259 filter_spatial_dim_3 * output_channel;
8260
8261 cpu_backend_gemm::MatrixParams<float> lhs_params;
8262 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
8263 lhs_params.rows = filter_total_size;
8264 lhs_params.cols = input_channel;
8265 float* output_data_p = output_data;
8266 std::fill_n(output_data, output_offset * batch_size, 0.0f);
8267 for (int i = 0; i < batch_size; ++i) {
8268 cpu_backend_gemm::MatrixParams<float> rhs_params;
8269 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
8270 rhs_params.rows = input_channel;
8271 rhs_params.cols = input_spatial_size;
8272 cpu_backend_gemm::MatrixParams<float> dst_params;
8273 dst_params.order = cpu_backend_gemm::Order::kColMajor;
8274 dst_params.rows = filter_total_size;
8275 dst_params.cols = input_spatial_size;
8276 cpu_backend_gemm::GemmParams<float, float> gemm_params;
8277 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params,
8278 input_data + input_offset * i, dst_params,
8279 col2im_data, gemm_params, cpu_backend_context);
8280
8281 Col2im(col2im_data, output_channel, output_spatial_dim_1,
8282 output_spatial_dim_2, output_spatial_dim_3, filter_spatial_dim_1,
8283 filter_spatial_dim_2, filter_spatial_dim_3,
8284 spatial_dim_1_padding_before, spatial_dim_2_padding_before,
8285 spatial_dim_3_padding_before, spatial_dim_1_padding_after,
8286 spatial_dim_2_padding_after, spatial_dim_3_padding_after,
8287 spatial_dim_1_stride, spatial_dim_2_stride, spatial_dim_3_stride,
8288 output_data_p);
8289 output_data_p += output_offset;
8290 }
8291 output_data_p = output_data;
8292 BiasAdd3D(output_data_p, bias_data, output_shape, params.float_activation_min,
8293 params.float_activation_max);
8294 }
8295
8296 // Worker for summing up within a single interval. Interval is identified by
8297 // index from [start, end).
8298 template <typename T>
8299 struct AddNWorkerTask : cpu_backend_threadpool::Task {
AddNWorkerTaskAddNWorkerTask8300 AddNWorkerTask(const T* const* input_data, T* scratch_buffer, int start,
8301 int end, int num_elems, int split)
8302 : input_data(input_data),
8303 scratch_buffer(scratch_buffer),
8304 start(start),
8305 end(end),
8306 num_elems(num_elems),
8307 split(split) {}
RunAddNWorkerTask8308 void Run() override {
8309 RuntimeShape shape(1);
8310 shape.SetDim(0, num_elems);
8311 ArithmeticParams params;
8312 T output_activation_min = std::numeric_limits<T>::lowest(),
8313 output_activation_max = std::numeric_limits<T>::max();
8314 SetActivationParams(output_activation_min, output_activation_max, ¶ms);
8315 T* start_p = scratch_buffer + split * num_elems;
8316 memcpy(start_p, input_data[start], sizeof(T) * num_elems);
8317 for (int i = start + 1; i < end; i++) {
8318 Add(params, shape, start_p, shape, input_data[i], shape, start_p);
8319 }
8320 }
8321
8322 const T* const* input_data;
8323 T* scratch_buffer;
8324 int start;
8325 int end;
8326 int num_elems;
8327 int split;
8328 };
8329
8330 // T is expected to be either float or int.
8331 template <typename T>
AddN(const RuntimeShape & input_shape,const size_t num_inputs,const T * const * input_data,T * output_data,T * scratch_buffer,CpuBackendContext * cpu_backend_context)8332 inline void AddN(const RuntimeShape& input_shape, const size_t num_inputs,
8333 const T* const* input_data, T* output_data, T* scratch_buffer,
8334 CpuBackendContext* cpu_backend_context) {
8335 // All inputs and output should have the same shape, this is checked during
8336 // Prepare stage.
8337 const size_t num_elems = input_shape.FlatSize();
8338 const int thread_count =
8339 std::min(std::max(1, static_cast<int>(num_inputs) / 2),
8340 cpu_backend_context->max_num_threads());
8341 memset(scratch_buffer, 0, sizeof(T) * num_elems * thread_count);
8342
8343 std::vector<AddNWorkerTask<T>> tasks;
8344 tasks.reserve(thread_count);
8345 int start = 0;
8346 for (int i = 0; i < thread_count; ++i) {
8347 int end = start + (num_inputs - start) / (thread_count - i);
8348 tasks.emplace_back(AddNWorkerTask<T>(input_data, scratch_buffer, start, end,
8349 num_elems, i));
8350 start = end;
8351 }
8352 // Run all tasks on the thread pool.
8353 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
8354 cpu_backend_context);
8355 RuntimeShape shape(1);
8356 shape.SetDim(0, num_elems);
8357 ArithmeticParams params;
8358 T output_activation_min = std::numeric_limits<T>::lowest(),
8359 output_activation_max = std::numeric_limits<T>::max();
8360 SetActivationParams(output_activation_min, output_activation_max, ¶ms);
8361 memcpy(output_data, scratch_buffer, sizeof(T) * num_elems);
8362 for (int i = 1; i < tasks.size(); i++) {
8363 Add(params, shape, output_data, shape, scratch_buffer + i * num_elems,
8364 shape, output_data);
8365 }
8366 }
8367
8368 } // namespace optimized_ops
8369 } // namespace tflite
8370
8371 #if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
8372 #undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
8373 #pragma GCC diagnostic pop
8374 #endif
8375
8376 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
8377