xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
17 
18 #include "tensorflow/lite/kernels/cpu_backend_context.h"
19 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
20 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
21 #include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h"
22 #include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h"
23 
24 namespace tflite {
25 namespace tensor_utils {
26 
MatrixBatchVectorMultiplyAccumulate(const float * matrix,int m_rows,int m_cols,const float * vector,int n_batch,float * result)27 void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
28                                          int m_cols, const float* vector,
29                                          int n_batch, float* result) {
30   NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
31                    vector, n_batch, result);
32 }
33 
MatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result)34 void MatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
35                                          const int m_rows, const int m_cols,
36                                          const int8_t* __restrict__ vectors,
37                                          const float* scaling_factors,
38                                          int n_batch,
39                                          float* __restrict__ result) {
40   NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
41                    vectors, scaling_factors, n_batch, result);
42 }
43 
MatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,int32_t * scratch,float * __restrict__ result,CpuBackendContext * context)44 void MatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
45                                          const int m_rows, const int m_cols,
46                                          const int8_t* __restrict__ vectors,
47                                          const float* scaling_factors,
48                                          int n_batch, int32_t* scratch,
49                                          float* __restrict__ result,
50                                          CpuBackendContext* context) {
51   NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
52                    vectors, scaling_factors, n_batch, scratch, result, context);
53 }
54 
MatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,int32_t * scratch,int32_t * row_sums,bool * compute_row_sums,CpuBackendContext * context)55 void MatrixBatchVectorMultiplyAccumulate(
56     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
57     const int8_t* __restrict__ vectors, const float* scaling_factors,
58     int n_batch, float* __restrict__ result, const float* per_channel_scale,
59     const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
60     bool* compute_row_sums, CpuBackendContext* context) {
61   NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
62                    vectors, scaling_factors, n_batch, result, per_channel_scale,
63                    input_offset, scratch, row_sums, compute_row_sums, context);
64 }
65 
SparseMatrixBatchVectorMultiplyAccumulate1x4(const float * __restrict__ matrix,const int32_t * __restrict__ segments,const int32_t * __restrict__ indices,int m_rows,int m_cols,const float * __restrict__ vector,int n_batch,float * __restrict__ result)66 void SparseMatrixBatchVectorMultiplyAccumulate1x4(
67     const float* __restrict__ matrix, const int32_t* __restrict__ segments,
68     const int32_t* __restrict__ indices, int m_rows, int m_cols,
69     const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
70   NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate1x4, matrix,
71                    segments, indices, m_rows, m_cols, vector, n_batch, result);
72 }
73 
SparseMatrixBatchVectorMultiplyAccumulate(const float * __restrict__ matrix,const uint8_t * __restrict__ ledger,int m_rows,int m_cols,const float * __restrict__ vector,int n_batch,float * __restrict__ result)74 void SparseMatrixBatchVectorMultiplyAccumulate(
75     const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
76     int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
77     float* __restrict__ result) {
78   NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate, matrix, ledger,
79                    m_rows, m_cols, vector, n_batch, result);
80 }
81 
SparseMatrixBatchVectorMultiplyAccumulate1x16(const int8_t * __restrict__ matrix,const int32_t * __restrict__ segments,const int32_t * __restrict__ indices,int m_rows,int m_cols,const int8_t * __restrict__ vector,const int32_t * __restrict__ bias_vector,int n_batch,const int32_t input_offset,const int32_t output_multiplier,const int32_t output_shift,const int32_t output_offset,const int32_t output_activation_min,const int32_t output_activation_max,int8_t * __restrict__ result)82 void SparseMatrixBatchVectorMultiplyAccumulate1x16(
83     const int8_t* __restrict__ matrix, const int32_t* __restrict__ segments,
84     const int32_t* __restrict__ indices, int m_rows, int m_cols,
85     const int8_t* __restrict__ vector, const int32_t* __restrict__ bias_vector,
86     int n_batch, const int32_t input_offset, const int32_t output_multiplier,
87     const int32_t output_shift, const int32_t output_offset,
88     const int32_t output_activation_min, const int32_t output_activation_max,
89     int8_t* __restrict__ result) {
90   NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate1x16, matrix,
91                    segments, indices, m_rows, m_cols, vector, bias_vector,
92                    n_batch, input_offset, output_multiplier, output_shift,
93                    output_offset, output_activation_min, output_activation_max,
94                    result);
95 }
96 
SparseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result)97 void SparseMatrixBatchVectorMultiplyAccumulate(
98     const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
99     const int m_cols, const int8_t* __restrict__ vectors,
100     const float* scaling_factors, int n_batch, float* __restrict__ result) {
101   NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate, matrix, ledger,
102                    m_rows, m_cols, vectors, scaling_factors, n_batch, result);
103 }
104 
MatrixBatchVectorMultiplyAccumulate(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch,int16_t * output,CpuBackendContext * context)105 void MatrixBatchVectorMultiplyAccumulate(
106     const int8_t* input, const int32_t* bias,
107     const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
108     int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
109     int32_t* scratch, int16_t* output, CpuBackendContext* context) {
110   NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, input, bias,
111                    input_to_gate_weights, multiplier, shift, n_batch, n_input,
112                    n_output, output_zp, scratch, output, context);
113 }
114 
MatrixBatchVectorMultiplyAccumulate(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch,int8_t * output,CpuBackendContext * context)115 void MatrixBatchVectorMultiplyAccumulate(
116     const int8_t* input, const int32_t* bias,
117     const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
118     int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
119     int32_t* scratch, int8_t* output, CpuBackendContext* context) {
120   NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, input, bias,
121                    input_to_gate_weights, multiplier, shift, n_batch, n_input,
122                    n_output, output_zp, scratch, output, context);
123 }
124 
MatrixBatchVectorMultiply(const int8_t * input,int32_t input_zeropoint,const int8_t * input_to_gate_weights,int32_t input_to_gate_effective_scale_a,int32_t input_to_gate_effective_scale_b,int32_t n_batch,int32_t n_input,int32_t n_cell,int8_t * gate_output,int8_t gate_output_zp)125 void MatrixBatchVectorMultiply(const int8_t* input, int32_t input_zeropoint,
126                                const int8_t* input_to_gate_weights,
127                                int32_t input_to_gate_effective_scale_a,
128                                int32_t input_to_gate_effective_scale_b,
129                                int32_t n_batch, int32_t n_input, int32_t n_cell,
130                                int8_t* gate_output, int8_t gate_output_zp) {
131   PortableMatrixBatchVectorMultiply(
132       input, input_zeropoint, input_to_gate_weights,
133       input_to_gate_effective_scale_a, input_to_gate_effective_scale_b, n_batch,
134       n_input, n_cell, gate_output, gate_output_zp);
135 }
136 
MatrixBatchVectorMultiply(const int16_t * hidden,const int8_t * hidden_to_output_weights,int32_t proj_effective_scale_a,int32_t proj_effective_scale_b,const int32_t * gate_bias,int32_t n_batch,int32_t n_hidden,int32_t n_output,int32_t output_zp,int8_t * proj_output)137 void MatrixBatchVectorMultiply(const int16_t* hidden,
138                                const int8_t* hidden_to_output_weights,
139                                int32_t proj_effective_scale_a,
140                                int32_t proj_effective_scale_b,
141                                const int32_t* gate_bias, int32_t n_batch,
142                                int32_t n_hidden, int32_t n_output,
143                                int32_t output_zp, int8_t* proj_output) {
144   PortableMatrixBatchVectorMultiply(hidden, hidden_to_output_weights,
145                                     proj_effective_scale_a,
146                                     proj_effective_scale_b, gate_bias, n_batch,
147                                     n_hidden, n_output, output_zp, proj_output);
148 }
149 
MatrixScalarMultiplyAccumulate(const int8_t * matrix,int32_t scalar,int32_t n_row,int32_t n_col,int32_t * output)150 void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
151                                     int32_t n_row, int32_t n_col,
152                                     int32_t* output) {
153   NEON_OR_PORTABLE(MatrixScalarMultiplyAccumulate, matrix, scalar, n_row, n_col,
154                    output);
155 }
156 
ApplyLayerNorm(const int16_t * input,const int16_t * layer_norm_weights,const int32_t * bias,int32_t layer_norm_scale_a,int32_t layer_norm_scale_b,int32_t variance_limit,int n_batch,int n_input,int16_t * output)157 void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
158                     const int32_t* bias, int32_t layer_norm_scale_a,
159                     int32_t layer_norm_scale_b, int32_t variance_limit,
160                     int n_batch, int n_input, int16_t* output) {
161   NEON_OR_PORTABLE(ApplyLayerNorm, input, layer_norm_weights, bias,
162                    layer_norm_scale_a, layer_norm_scale_b, variance_limit,
163                    n_batch, n_input, output);
164 }
165 
ApplyLayerNormFloat(const int16_t * input,const int16_t * layer_norm_weights,int32_t layer_norm_scale_a,int32_t layer_norm_scale_b,const int32_t * bias,int n_batch,int n_input,int16_t * output)166 void ApplyLayerNormFloat(const int16_t* input,
167                          const int16_t* layer_norm_weights,
168                          int32_t layer_norm_scale_a, int32_t layer_norm_scale_b,
169                          const int32_t* bias, int n_batch, int n_input,
170                          int16_t* output) {
171   PortableApplyLayerNormFloat(input, layer_norm_weights, layer_norm_scale_a,
172                               layer_norm_scale_b, bias, n_batch, n_input,
173                               output);
174 }
175 
ApplySigmoid(const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)176 void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
177                   int16_t* output) {
178   NEON_OR_PORTABLE(ApplySigmoid, input, n_batch, n_input, output);
179 }
180 
ApplySigmoidFloat(const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)181 void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
182                        int16_t* output) {
183   PortableApplySigmoidFloat(input, n_batch, n_input, output);
184 }
185 
ApplyTanh(int32_t integer_bits,const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)186 void ApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch,
187                int32_t n_input, int16_t* output) {
188   NEON_OR_PORTABLE(ApplyTanh, integer_bits, input, n_batch, n_input, output);
189 }
190 
ApplyTanhFloat(const int16_t * input,int32_t n_batch,int32_t n_input,int32_t integer_bits,int16_t * output)191 void ApplyTanhFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
192                     int32_t integer_bits, int16_t* output) {
193   PortableApplyTanhFloat(input, n_batch, n_input, integer_bits, output);
194 }
195 
CwiseMul(const int16_t * input_1,const int16_t * input_2,int n_batch,int n_input,int shift,int16_t * output)196 void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
197               int n_input, int shift, int16_t* output) {
198   NEON_OR_PORTABLE(CwiseMul, input_1, input_2, n_batch, n_input, shift, output);
199 }
200 
CwiseMul(const int16_t * input_1,const int16_t * input_2,int32_t multiplier,int shift,int n_batch,int n_input,int32_t output_zp,int8_t * output)201 void CwiseMul(const int16_t* input_1, const int16_t* input_2,
202               int32_t multiplier, int shift, int n_batch, int n_input,
203               int32_t output_zp, int8_t* output) {
204   NEON_OR_PORTABLE(CwiseMul, input_1, input_2, multiplier, shift, n_batch,
205                    n_input, output_zp, output);
206 }
207 
CwiseAdd(const int16_t * input_1,const int16_t * input_2,int n_batch,int n_input,int16_t * output)208 void CwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
209               int n_input, int16_t* output) {
210   NEON_OR_PORTABLE(CwiseAdd, input_1, input_2, n_batch, n_input, output);
211 }
212 
CwiseClipping(float * vector,const int v_size,const float clipping_value)213 void CwiseClipping(float* vector, const int v_size,
214                    const float clipping_value) {
215   NEON_OR_PORTABLE(CwiseClipping, vector, v_size, clipping_value);
216 }
CwiseClipping(int16_t * vector,const int v_size,const int16_t clipping_value)217 void CwiseClipping(int16_t* vector, const int v_size,
218                    const int16_t clipping_value) {
219   NEON_OR_PORTABLE(CwiseClipping, vector, v_size, clipping_value);
220 }
CwiseClipping(int8_t * vector,const int v_size,const int8_t clipping_value)221 void CwiseClipping(int8_t* vector, const int v_size,
222                    const int8_t clipping_value) {
223   NEON_OR_PORTABLE(CwiseClipping, vector, v_size, clipping_value);
224 }
225 
BatchVectorBatchVectorDotProduct(const int16_t * vector1,const int16_t * vector2,int v_size,int n_batch,int32_t * result)226 void BatchVectorBatchVectorDotProduct(const int16_t* vector1,
227                                       const int16_t* vector2, int v_size,
228                                       int n_batch, int32_t* result) {
229   PortableBatchVectorBatchVectorDotProduct(vector1, vector2, v_size, n_batch,
230                                            result);
231 }
232 
VectorBatchVectorCwiseProductAccumulate(const int16_t * vector,int v_size,const int16_t * batch_vector,int n_batch,int32_t multiplier,int shift,int16_t * result)233 void VectorBatchVectorCwiseProductAccumulate(const int16_t* vector, int v_size,
234                                              const int16_t* batch_vector,
235                                              int n_batch, int32_t multiplier,
236                                              int shift, int16_t* result) {
237   NEON_OR_PORTABLE(VectorBatchVectorCwiseProductAccumulate, vector, v_size,
238                    batch_vector, n_batch, multiplier, shift, result);
239 }
240 
VectorVectorDotProduct(const float * vector1,const float * vector2,int v_size)241 float VectorVectorDotProduct(const float* vector1, const float* vector2,
242                              int v_size) {
243   return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
244 }
245 
Sub1Vector(const float * vector,int v_size,float * result)246 void Sub1Vector(const float* vector, int v_size, float* result) {
247   NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result);
248 }
249 
Sub1Vector(const int16_t * vector,int v_size,int16_t * result)250 void Sub1Vector(const int16_t* vector, int v_size, int16_t* result) {
251   NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result);
252 }
253 
254 // Check if all entries of a vector are zero for float.
IsZeroVector(const float * vector,int v_size)255 bool IsZeroVector(const float* vector, int v_size) {
256   return NEON_OR_PORTABLE(IsZeroVector, vector, v_size);
257 }
258 
259 // Check if all entries of a vector are zero for int8.
IsZeroVector(const int8_t * vector,int v_size)260 bool IsZeroVector(const int8_t* vector, int v_size) {
261   return NEON_OR_PORTABLE(IsZeroVector, vector, v_size);
262 }
263 
VectorScalarMultiply(const int8_t * vector,int v_size,float scale,float * result)264 void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
265                           float* result) {
266   NEON_OR_PORTABLE(VectorScalarMultiply, vector, v_size, scale, result);
267 }
268 
SymmetricQuantizeFloats(const float * values,const int size,int8_t * quantized_values,float * min_value,float * max_value,float * scaling_factor)269 void SymmetricQuantizeFloats(const float* values, const int size,
270                              int8_t* quantized_values, float* min_value,
271                              float* max_value, float* scaling_factor) {
272   NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values,
273                    min_value, max_value, scaling_factor);
274 }
275 
SymmetricQuantizeFloats(const float * values,const int size,int8_t * quantized_values,float min_value,float max_value,float * scaling_factor)276 void SymmetricQuantizeFloats(const float* values, const int size,
277                              int8_t* quantized_values, float min_value,
278                              float max_value, float* scaling_factor) {
279   NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values,
280                    min_value, max_value, scaling_factor);
281 }
282 
AsymmetricQuantizeFloats(const float * values,const int size,int8_t * quantized_values,float * scaling_factor,int32_t * offset)283 void AsymmetricQuantizeFloats(const float* values, const int size,
284                               int8_t* quantized_values, float* scaling_factor,
285                               int32_t* offset) {
286   NEON_OR_PORTABLE(AsymmetricQuantizeFloats, values, size, quantized_values,
287                    scaling_factor, offset);
288 }
289 
ReductionSumVector(const float * input_vector,float * output_vector,int output_size,int reduction_size)290 void ReductionSumVector(const float* input_vector, float* output_vector,
291                         int output_size, int reduction_size) {
292   NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size,
293                    reduction_size);
294 }
295 
ReductionSumVector(const int32_t * input_vector,int32_t * output_vector,int output_size,int reduction_size)296 void ReductionSumVector(const int32_t* input_vector, int32_t* output_vector,
297                         int output_size, int reduction_size) {
298   PortableReductionSumVector(input_vector, output_vector, output_size,
299                              reduction_size);
300 }
301 
ReductionSumVector(const int8_t * input_vector,int32_t * output_vector,int output_size,int reduction_size)302 void ReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
303                         int output_size, int reduction_size) {
304   NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size,
305                    reduction_size);
306 }
307 
MeanStddevNormalization(const float * __restrict__ input_vector,float * __restrict__ output_vector,int v_size,int n_batch)308 void MeanStddevNormalization(const float* __restrict__ input_vector,
309                              float* __restrict__ output_vector, int v_size,
310                              int n_batch) {
311   NEON_OR_PORTABLE(MeanStddevNormalization, input_vector, output_vector, v_size,
312                    n_batch);
313 }
314 
TwoGateSaturatingAdd(const int8_t * input,int8_t input_zp,const int8_t * recurrent,int8_t recurrent_zp,int32_t input_effective_scale_a,int32_t input_effective_scale_b,int32_t recurrent_effective_scale_a,int32_t recurrent_effective_scale_b,int32_t n_batch,int32_t n_cell,int16_t * output)315 void TwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
316                           const int8_t* recurrent, int8_t recurrent_zp,
317                           int32_t input_effective_scale_a,
318                           int32_t input_effective_scale_b,
319                           int32_t recurrent_effective_scale_a,
320                           int32_t recurrent_effective_scale_b, int32_t n_batch,
321                           int32_t n_cell, int16_t* output) {
322   PortableTwoGateSaturatingAdd(
323       input, input_zp, recurrent, recurrent_zp, input_effective_scale_a,
324       input_effective_scale_b, recurrent_effective_scale_a,
325       recurrent_effective_scale_b, n_batch, n_cell, output);
326 }
327 
328 }  // namespace tensor_utils
329 }  // namespace tflite
330 
331 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
332