xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/optimized/sse_tensor_utils_impl.h"
16 
17 #ifdef __SSSE3__
18 
19 #include <emmintrin.h>  // SSE2
20 #include <tmmintrin.h>  // SSSE3
21 #ifdef __SSE4_1__
22 #include <smmintrin.h>  // SSE4.1
23 #endif
24 #ifdef __AVX2__
25 #include <immintrin.h>
26 #endif
27 
28 #include <cstdint>
29 
30 #include "ruy/profiler/instrumentation.h"  // from @ruy
31 #include "tensorflow/lite/kernels/cpu_backend_context.h"
32 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
33 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
34 #include "tensorflow/lite/kernels/internal/compatibility.h"
35 
36 namespace tflite {
37 namespace tensor_utils {
38 namespace {
39 
40 #if defined(__SSE2__)
41 // Note: this part is copied from XNNPACK/src/xnnpack/intrinsics-polyfill.h
42 // w.r.t the defition of '_mm_loadu_si32' intrinsic.
43 // GCC any, Clang pre-8, Android NDK Clang pre-8.0.7, Apple Clang pre-11, and
44 // ICC pre-16
45 #if (defined(__GNUC__) && !defined(__clang__) &&                             \
46      !defined(__INTEL_COMPILER)) ||                                          \
47     (defined(__clang__) && !defined(__apple_build_version__) &&              \
48      (__clang_major__ < 8)) ||                                               \
49     (defined(__clang__) && defined(__ANDROID__) && (__clang_major__ == 8) && \
50      (__clang_minor__ == 0) && (__clang_patchlevel__ < 7)) ||                \
51     (defined(__clang__) && defined(__apple_build_version__) &&               \
52      (__apple_build_version__ < 11000000)) ||                                \
53     (defined(__INTEL_COMPILER) && (__INTEL_COMPILER < 1600))
54 
_mm_loadu_si32(const void * address)55 static inline __m128i _mm_loadu_si32(const void* address) {
56   return _mm_cvtsi32_si128(*((const int*)address));
57 }
58 #endif  // GCC any, Clang pre-8, Android NDK Clang pre-8.0.7, Apple Clang pre-11
59         // and ICC pre-16
60 #endif  // __SSE2__
61 
62 // Dot product of four int8 vectors of 4 elements packed into a XMM register.
63 // Result is four int32 scalars packed into a XMM register.
64 // int8x4x4 · int8x4x4 => int32x4
DotProdInt8x4x4(__m128i a_8x16,__m128i b_8x16)65 static inline __m128i DotProdInt8x4x4(__m128i a_8x16, __m128i b_8x16) {
66   // Transfer sign from 'a' to 'b', as _mm_maddubs_epi16 treats 'a' unsigned.
67   b_8x16 = _mm_sign_epi8(b_8x16, a_8x16);
68   a_8x16 = _mm_abs_epi8(a_8x16);
69   // sumprod[i] = a[2*i]*b[2*i] + a[2*i+1]*b[2*i+1] (i = 0..7)
70   __m128i sumprod_16x8 = _mm_maddubs_epi16(a_8x16, b_8x16);
71   // sumprod[i] = sumprod[2*i]*1 + sumprod[2*i+1]*1 (i = 0..3)
72   return _mm_madd_epi16(sumprod_16x8, _mm_set1_epi16(1));
73 }
74 
75 // Horizontally add 4 int32 values stored in a single XMM register to int32_t.
ReduceInt32x4(__m128i acc)76 static inline int32_t ReduceInt32x4(__m128i acc) {
77   // Shuffle to contain high half of acc (both in high and low halfs).
78   __m128i shuffle = _mm_unpackhi_epi64(acc, acc);
79   // Add shuffle and acc; low half is sums of twos (high half is ignored).
80   acc = _mm_add_epi32(acc, shuffle);
81   // Shuffle the two elements in low half (ignore high half).
82   shuffle = _mm_shuffle_epi32(acc, _MM_SHUFFLE(2, 3, 0, 1));
83   // Add shuffle and acc; lowest element is sum of all 4 input.
84   acc = _mm_add_epi32(acc, shuffle);
85   // Return lowest element as int32_t.
86   return _mm_cvtsi128_si32(acc);
87 }
88 
89 #ifdef __AVX2__
90 // Horizontally add 4 float values stored in a single XMM register to float.
ReduceFloat32x4(__m128 acc)91 static inline float ReduceFloat32x4(__m128 acc) {
92   __m128 shuffle = _mm_movehdup_ps(acc);
93   acc = _mm_add_ps(acc, shuffle);
94   shuffle = _mm_movehl_ps(shuffle, acc);
95   acc = _mm_add_ss(acc, shuffle);
96   return _mm_cvtss_f32(acc);
97 }
98 
99 // Horizontally add 8 float values stored in a single XMM register to float.
ReduceFloat32x8(__m256 acc)100 static inline float ReduceFloat32x8(__m256 acc) {
101   __m128 low = _mm256_extractf128_ps(acc, 0);
102   __m128 high = _mm256_extractf128_ps(acc, 1);
103   return ReduceFloat32x4(_mm_add_ps(low, high));
104 }
105 
106 // Dot product of four int8 vectors of 4 elements packed into a YMM register.
107 // Result is eight int32 scalars packed into a YMM register.
108 // int8x4x8 · int8x4x8 => int32x8
DotProdInt8x4x8(__m256i a_16x16,__m256i b_16x16)109 static inline __m256i DotProdInt8x4x8(__m256i a_16x16, __m256i b_16x16) {
110   // Transfer sign from 'a' to 'b', as _mm256_maddubs_epi16 treats 'a' unsigned.
111   b_16x16 = _mm256_sign_epi8(b_16x16, a_16x16);
112   a_16x16 = _mm256_abs_epi8(a_16x16);
113   // sumprod[i] = a[2*i]*b[2*i] + a[2*i+1]*b[2*i+1] (i = 0..15)
114   __m256i sumprod_16x16 = _mm256_maddubs_epi16(a_16x16, b_16x16);
115   // sumprod[i] = sumprod[2*i]*1 + sumprod[2*i+1]*1 (i = 0..7)
116   return _mm256_madd_epi16(sumprod_16x16, _mm256_set1_epi16(1));
117 }
118 #endif  // __AVX2__
119 
120 // Horizontally add each of 4 XMM registers with 4 int32 values, pack result
121 // into a single XMM register. Similar to ReduceInt32x4, but with 4x inputs.
ReduceInt32x4x4(__m128i a,__m128i b,__m128i c,__m128i d)122 static inline __m128i ReduceInt32x4x4(__m128i a, __m128i b, __m128i c,
123                                       __m128i d) {
124   // Assuming x = [x0, x1, x2, x3]
125   const __m128i a_b_lo_half = _mm_unpacklo_epi32(a, b);  // [a0, b0, a1, b1]
126   const __m128i a_b_hi_half = _mm_unpackhi_epi32(a, b);  // [a2, b2, a3, b3]
127   const __m128i a_plus_b =
128       _mm_add_epi32(a_b_lo_half, a_b_hi_half);  // [a0+a2, b0+b2, a1+a3, b1+b3]
129   const __m128i c_d_lo_half = _mm_unpacklo_epi32(c, d);  // [c0, d0, c1, d1]
130   const __m128i c_d_hi_half = _mm_unpackhi_epi32(c, d);  // [c2, d2, c3, d3]
131   const __m128i c_plus_d =
132       _mm_add_epi32(c_d_lo_half, c_d_hi_half);  // [c0+c2, d0+d2, c1+c3, d1+d3]
133   const __m128i all_evns =
134       _mm_unpacklo_epi64(a_plus_b, c_plus_d);  // [a02, b02, c02, d02]
135   const __m128i all_odds =
136       _mm_unpackhi_epi64(a_plus_b, c_plus_d);  // [a13, b13, c13, d13]
137   return _mm_add_epi32(all_evns, all_odds);    // [a0123, b0123, c0123, d0123]
138 }
139 
140 // Returns the ith element of a XMM register holding float numbers.
141 template <int i>
GetFloatVectorElement(__m128 v)142 float GetFloatVectorElement(__m128 v) {
143   static_assert(i >= 0 && i < 4, "The index must be 0 <= i < 4.");
144   // Note, _mm_extract_ps returns int, so we can't use it here.
145   // These lines will be optimized to extractps anyway.
146   v = _mm_shuffle_ps(v, v, _MM_SHUFFLE(i, i, i, i));
147   return _mm_cvtss_f32(v);
148 }
149 
150 }  // namespace
151 
152 #ifdef __AVX2__
153 constexpr int kFloatValuesPerAvx2Vector = 8;
154 template <int PerVectorSize>
RoundDownVectors(int size)155 inline int RoundDownVectors(int size) {
156   return size & ~(PerVectorSize - 1);
157 }
158 
Avx2MatrixBatchVectorMultiplyAccumulateImpl(const float * __restrict__ matrix,int m_rows,int m_cols,const float * __restrict__ vector,int n_batch,float * __restrict__ result)159 void Avx2MatrixBatchVectorMultiplyAccumulateImpl(
160     const float* __restrict__ matrix, int m_rows, int m_cols,
161     const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
162   // If v_size is not divisible by the vector size, then we need to process the
163   // final few elements sequentially. postamble_start shows the start index
164   // where this should happen.
165   const int postamble_start =
166       RoundDownVectors<kFloatValuesPerAvx2Vector>(m_cols);
167 
168   for (int b = 0; b < n_batch; ++b) {
169     float* result_in_batch = result + b * m_rows;
170     const float* vector_in_batch = vector + b * m_cols;
171     const float* matrix_row = matrix;
172 
173     // Main matrix by vector multiplication loop
174     for (int r = 0; r < m_rows; ++r) {
175       __m256 acc_32x8 = _mm256_setzero_ps();
176       int c = 0;
177       for (; c < postamble_start; c += kFloatValuesPerAvx2Vector) {
178         // Load 8 float values from vector and matrix row.
179         __m256 vector_f32x8 = _mm256_loadu_ps(vector_in_batch + c);
180         __m256 matrix_f32x8 = _mm256_loadu_ps(matrix_row + c);
181 
182         // Multiply the vector and matrix row and add to accumulator.
183         __m256 res = _mm256_mul_ps(vector_f32x8, matrix_f32x8);
184         acc_32x8 = _mm256_add_ps(acc_32x8, res);
185       }
186       // Add the 8 intermediate sum values to get the final dot-prod value for
187       // this column.
188       float sum = ReduceFloat32x8(acc_32x8);
189       for (; (c < m_cols); c++) {
190         sum += matrix_row[c] * vector_in_batch[c];
191       }
192       *result_in_batch += sum;
193       ++result_in_batch;
194       matrix_row += m_cols;
195     }
196   }
197 }
198 
Avx2MatrixBatchVectorMultiplyAccumulateImpl(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,const int32_t * row_sums)199 void Avx2MatrixBatchVectorMultiplyAccumulateImpl(
200     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
201     const int8_t* __restrict__ vectors,
202     const float* __restrict__ scaling_factors, int n_batch,
203     float* __restrict__ result, const float* per_channel_scale,
204     const int32_t* input_offset, const int32_t* row_sums) {
205   for (std::intptr_t batch = 0; batch < n_batch; ++batch) {
206     const float batch_scaling_factor = scaling_factors[batch];
207     const int32_t batch_offset = input_offset ? input_offset[batch] : 0;
208     // Compute dot-product for every column.
209     for (std::intptr_t row = 0; row < m_rows; ++row) {
210       // Get the address of the first element of the row.
211       const int8_t* __restrict__ row_ptr = matrix + row * m_cols;
212       const float row_scale =
213           per_channel_scale ? per_channel_scale[row] * batch_scaling_factor
214                             : batch_scaling_factor;
215       const int32_t row_offset =
216           row_sums && batch_offset ? batch_offset * row_sums[row] : 0;
217       // Initialize the dot product sum for the row to 0.
218       __m256i dotprod_32x8 = _mm256_setzero_si256();
219       std::intptr_t col = 0;
220       // For every block of 32x 8-bit inputs.
221       while (col < (m_cols & ~31)) {
222         const __m256i vec_16x16 =
223             _mm256_loadu_si256(reinterpret_cast<const __m256i*>(vectors + col));
224         const __m256i row_16x16 =
225             _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_ptr + col));
226         // dotprod += vec · row
227         dotprod_32x8 = _mm256_add_epi32(dotprod_32x8,
228                                         DotProdInt8x4x8(vec_16x16, row_16x16));
229         col += 32;
230       }
231       // Sum lower and upper halves of 32x8 vector into 32x4 vector
232       __m128i low = _mm256_extracti128_si256(dotprod_32x8, 0);
233       __m128i high = _mm256_extracti128_si256(dotprod_32x8, 1);
234       __m128i dotprod_32x4 = _mm_add_epi32(low, high);
235       // Postamble for 16x 8-bit inputs.
236       if (col < (m_cols & ~15)) {
237         const __m128i vec_16x8 =
238             _mm_loadu_si128(reinterpret_cast<const __m128i*>(vectors + col));
239         const __m128i row_16x8 =
240             _mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
241         // dotprod += vec · row
242         dotprod_32x4 =
243             _mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_16x8, row_16x8));
244         col += 16;
245       }
246       // Postamble for 8x 8-bit inputs.
247       if (col < (m_cols & ~7)) {
248         const __m128i vec_16x8 = _mm_cvtepi8_epi16(
249             _mm_loadl_epi64(reinterpret_cast<const __m128i*>(vectors + col)));
250         const __m128i row_16x8 = _mm_cvtepi8_epi16(
251             _mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + col)));
252         // dotprod += vec · row
253         dotprod_32x4 =
254             _mm_add_epi32(dotprod_32x4, _mm_madd_epi16(vec_16x8, row_16x8));
255         col += 8;
256       }
257       // Postamble for 4x 8-bit inputs.
258       if (col < (m_cols & ~3)) {
259         const __m128i vec_32x4 = _mm_cvtepi8_epi32(
260             _mm_loadu_si32(reinterpret_cast<const __m128i*>(vectors + col)));
261         const __m128i row_32x4 = _mm_cvtepi8_epi32(
262             _mm_loadu_si32(reinterpret_cast<const __m128i*>(row_ptr + col)));
263         // dotprod += vec · row
264         dotprod_32x4 =
265             _mm_add_epi32(dotprod_32x4, _mm_mullo_epi32(vec_32x4, row_32x4));
266         col += 4;
267       }
268 
269       // Horizontally add the 4 intermediate sum values to get the final
270       // dot-prod value for this row.
271       int32_t sum = ReduceInt32x4(dotprod_32x4);
272 
273 #pragma clang loop unroll(disable) vectorize(disable)
274       // Postamble loop for <4x remaining 8-bit inputs.
275       for (; col < m_cols; ++col) {
276         sum += row_ptr[col] * vectors[col];
277       }  // for col
278       if (row_offset) {
279         sum -= row_offset;
280       }
281       *result += sum * row_scale;
282       ++result;
283     }  // for row
284 
285     vectors += m_cols;
286   }  // for batch
287 }
288 
289 #endif  // __AVX2__
290 
SseMatrixBatchVectorMultiplyAccumulateImpl(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,const int32_t * row_sums)291 void SseMatrixBatchVectorMultiplyAccumulateImpl(
292     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
293     const int8_t* __restrict__ vectors,
294     const float* __restrict__ scaling_factors, int n_batch,
295     float* __restrict__ result, const float* per_channel_scale,
296     const int32_t* input_offset, const int32_t* row_sums) {
297 #ifdef __AVX2__
298   Avx2MatrixBatchVectorMultiplyAccumulateImpl(
299       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
300       per_channel_scale, input_offset, row_sums);
301   return;
302 #else
303   for (std::intptr_t batch = 0; batch < n_batch; ++batch) {
304     const float batch_scaling_factor = scaling_factors[batch];
305     const int32_t batch_offset = input_offset ? input_offset[batch] : 0;
306     // Compute dot-product for every column.
307     for (std::intptr_t row = 0; row < m_rows; ++row) {
308       // Get the address of the first element of the row.
309       const int8_t* __restrict__ row_ptr = matrix + row * m_cols;
310       const float row_scale =
311           per_channel_scale ? per_channel_scale[row] * batch_scaling_factor
312                             : batch_scaling_factor;
313       const int32_t row_offset =
314           row_sums && batch_offset ? batch_offset * row_sums[row] : 0;
315       // Initialize the dot product sum for the row to 0.
316       __m128i dotprod_32x4 = _mm_setzero_si128();
317       std::intptr_t col = 0;
318       // For every block of 16x 8-bit inputs.
319       while (col < (m_cols & ~15)) {
320         const __m128i vec_8x16 =
321             _mm_loadu_si128(reinterpret_cast<const __m128i*>(vectors + col));
322         const __m128i row_8x16 =
323             _mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
324         // dotprod += vec · row
325         dotprod_32x4 =
326             _mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16));
327         col += 16;
328       }
329 #ifdef __SSE4_1__
330       // Postamble for 8x 8-bit inputs.
331       if (col < (m_cols & ~7)) {
332         const __m128i vec_16x8 = _mm_cvtepi8_epi16(
333             _mm_loadl_epi64(reinterpret_cast<const __m128i*>(vectors + col)));
334         const __m128i row_16x8 = _mm_cvtepi8_epi16(
335             _mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + col)));
336         // dotprod += vec · row
337         dotprod_32x4 =
338             _mm_add_epi32(dotprod_32x4, _mm_madd_epi16(vec_16x8, row_16x8));
339         col += 8;
340       }
341       // Postamble for 4x 8-bit inputs.
342       if (col < (m_cols & ~3)) {
343         const __m128i vec_32x4 = _mm_cvtepi8_epi32(
344             _mm_loadu_si32(reinterpret_cast<const __m128i*>(vectors + col)));
345         const __m128i row_32x4 = _mm_cvtepi8_epi32(
346             _mm_loadu_si32(reinterpret_cast<const __m128i*>(row_ptr + col)));
347         // dotprod += vec · row
348         dotprod_32x4 =
349             _mm_add_epi32(dotprod_32x4, _mm_mullo_epi32(vec_32x4, row_32x4));
350         col += 4;
351       }
352 #endif
353 
354       // Horizontally add the 4 intermediate sum values to get the final
355       // dot-prod value for this row.
356       int32_t sum = ReduceInt32x4(dotprod_32x4);
357 
358 #if defined(__SSE4_1__) && defined(__clang__)
359       // SSE 4.1: Don't try to unroll and vectorize this, already done above.
360 #pragma clang loop unroll(disable) vectorize(disable)
361 #endif
362       // Postamble loop for <4x (<16x without SSE 4.1) remaining 8-bit inputs.
363       for (; col < m_cols; ++col) {
364         sum += row_ptr[col] * vectors[col];
365       }  // for col
366       if (row_offset) {
367         sum -= row_offset;
368       }
369       *result += sum * row_scale;
370       ++result;
371     }  // for row
372 
373     vectors += m_cols;
374   }  // for batch
375 #endif  // ifdef __AVX2__
376 }
377 
SseCpuBackendGemm(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch,CpuBackendContext * context)378 void SseCpuBackendGemm(const int8_t* input, const int32_t* bias,
379                        const int8_t* input_to_gate_weights, int32_t n_batch,
380                        int32_t n_input, int32_t n_output, int32_t output_zp,
381                        int32_t* scratch, CpuBackendContext* context) {
382   using ::tflite::cpu_backend_gemm::Gemm;
383   using ::tflite::cpu_backend_gemm::GemmParams;
384   using ::tflite::cpu_backend_gemm::MatrixParams;
385 
386   MatrixParams<int8_t> lhs_params;
387   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
388   lhs_params.rows = n_output;
389   lhs_params.cols = n_input;
390   lhs_params.cache_policy = cpu_backend_gemm::CachePolicy::kCacheIfLargeSpeedup;
391 
392   MatrixParams<int8_t> rhs_params;
393   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
394   rhs_params.rows = n_input;
395   rhs_params.cols = n_batch;
396 
397   MatrixParams<int32_t> dst_params;
398   dst_params.order = cpu_backend_gemm::Order::kColMajor;
399   dst_params.rows = n_output;
400   dst_params.cols = n_batch;
401 
402   GemmParams<int32, int32> gemm_params;
403   if (bias) {
404     gemm_params.bias = bias;
405   }
406   cpu_backend_gemm::Gemm(lhs_params, input_to_gate_weights, rhs_params, input,
407                          dst_params, scratch, gemm_params, context);
408 }
409 
SseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ result)410 void SseMatrixBatchVectorMultiplyAccumulate(
411     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
412     const int8_t* __restrict__ vectors,
413     const float* __restrict__ scaling_factors, int n_batch,
414     float* __restrict__ result) {
415   SseMatrixBatchVectorMultiplyAccumulateImpl(
416       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
417       /*per_channel_scale=*/nullptr, /*input_offset=*/nullptr,
418       /*row_sums=*/nullptr);
419 }
420 
SseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,int32_t * scratch,float * __restrict__ result,CpuBackendContext * context)421 void SseMatrixBatchVectorMultiplyAccumulate(
422     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
423     const int8_t* __restrict__ vectors,
424     const float* __restrict__ scaling_factors, int n_batch, int32_t* scratch,
425     float* __restrict__ result, CpuBackendContext* context) {
426   // TODO(b/183178387): Use a proper query to detect AVX/optimized paths.
427   if (m_rows % 4 == 0 && !context->PreferGemmlowpOnX86()) {
428     const int32_t* bias = static_cast<const int32_t*>(nullptr);
429     SseCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows,
430                       /*output_zp=*/0, scratch, context);
431 
432     {
433       ruy::profiler::ScopeLabel label("HybridMultiplyScalingFactor");
434       // Multiply by float scaling factors and write to result
435       const int total_size = n_batch * m_rows;
436       int i = 0;
437       for (; i <= total_size - 8; i += 8, result += 8) {
438         const float batch_scaling_factor0 = scaling_factors[i / m_rows];
439         const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
440         const __m128 scaling_factor0 = _mm_set1_ps(batch_scaling_factor0);
441         const __m128 scaling_factor1 = _mm_set1_ps(batch_scaling_factor1);
442         const __m128i scratch_val0 =
443             _mm_loadu_si128(reinterpret_cast<const __m128i*>(scratch + i));
444         const __m128i scratch_val1 =
445             _mm_loadu_si128(reinterpret_cast<const __m128i*>(scratch + i + 4));
446         const __m128 float_val0 = _mm_cvtepi32_ps(scratch_val0);
447         const __m128 float_val1 = _mm_cvtepi32_ps(scratch_val1);
448         const __m128 prod0 = _mm_mul_ps(float_val0, scaling_factor0);
449         const __m128 result0 = _mm_add_ps(_mm_load1_ps(result), prod0);
450         const __m128 prod1 = _mm_mul_ps(float_val1, scaling_factor1);
451         const __m128 result1 = _mm_add_ps(_mm_load1_ps(result + 4), prod1);
452         _mm_store_ps(result, result0);
453         _mm_store_ps(result + 4, result1);
454       }
455       scratch += i;
456       for (; i < total_size; i++) {
457         const float batch_scaling_factor = scaling_factors[i / m_rows];
458         int32_t x = *(scratch++);
459         *result += x * batch_scaling_factor;
460         ++result;
461       }
462     }
463     return;
464   }
465 
466   SseMatrixBatchVectorMultiplyAccumulateImpl(
467       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
468       /*per_channel_scale=*/nullptr, /*input_offset=*/nullptr,
469       /*row_sums=*/nullptr);
470 }
471 
SseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,int32_t * scratch,int32_t * row_sums,bool * compute_row_sums,CpuBackendContext * context)472 void SseMatrixBatchVectorMultiplyAccumulate(
473     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
474     const int8_t* __restrict__ vectors,
475     const float* __restrict__ scaling_factors, int n_batch,
476     float* __restrict__ result, const float* per_channel_scale,
477     const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
478     bool* compute_row_sums, CpuBackendContext* context) {
479   if ((input_offset != nullptr) && (!compute_row_sums || *compute_row_sums)) {
480     SseReductionSumVector(matrix, row_sums, m_rows, m_cols);
481     if (compute_row_sums) {
482       *compute_row_sums = false;
483     }
484   }
485   SseMatrixBatchVectorMultiplyAccumulateImpl(
486       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
487       per_channel_scale, input_offset, row_sums);
488 }
489 
490 namespace {
491 
492 // Implements sparse-matrix - vector multiply-accumulate.
SseSparseMatrixVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * __restrict__ ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ vector,const float scaling_factor,float * __restrict__ result)493 inline void SseSparseMatrixVectorMultiplyAccumulate(
494     const int8_t* __restrict__ matrix, const uint8_t* __restrict__ ledger,
495     const int m_rows, const int m_cols, const int8_t* __restrict__ vector,
496     const float scaling_factor, float* __restrict__ result) {
497   static const std::intptr_t kBlockSize = 16;
498   TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
499   const uint8_t* __restrict__ ledger_ptr = ledger;
500   for (std::intptr_t row = 0; row < m_rows; ++row) {
501     // Initialize the dot product sum for the row to 0.
502     __m128i dotprod_32x4 = _mm_setzero_si128();
503     std::intptr_t num_nonzero_blocks = *ledger_ptr++;
504     for (std::intptr_t i = 0; i < num_nonzero_blocks; i++) {
505       const std::intptr_t col_index = *ledger_ptr++ * kBlockSize;
506       const __m128i vec_8x16 =
507           _mm_loadu_si128(reinterpret_cast<const __m128i*>(vector + col_index));
508       const __m128i row_8x16 =
509           _mm_loadu_si128(reinterpret_cast<const __m128i*>(matrix));
510       // dotprod += vec · row
511       dotprod_32x4 =
512           _mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16));
513       matrix += kBlockSize;
514     }  // for col
515     // Horizontally add the 4 intermediate sum values to get the final
516     // dot-prod value for this row.
517     int32_t dotprod = ReduceInt32x4(dotprod_32x4);
518 
519     result[row] += dotprod * scaling_factor;
520   }  // for row
521 }
522 
523 // Implements sparse-matrix - batch-of-4-vectors multiply-accumulate.
524 // The stride between vectors and results must be equal to m_cols.
525 // Parameter 'batch' is the index of the first batch, must be a multiple of 4.
SseSparseMatrix4VectorsMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * __restrict__ ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ const vectors,const __m128 scaling_factors_fx4,float * __restrict__ const results)526 inline void SseSparseMatrix4VectorsMultiplyAccumulate(
527     const int8_t* __restrict__ matrix, const uint8_t* __restrict__ ledger,
528     const int m_rows, const int m_cols,
529     const int8_t* __restrict__ const vectors, const __m128 scaling_factors_fx4,
530     float* __restrict__ const results) {
531   static const std::intptr_t kBlockSize = 16;
532   TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
533 
534   const int8_t* __restrict__ vector0 = vectors + 0 * m_cols;
535   const int8_t* __restrict__ vector1 = vectors + 1 * m_cols;
536   const int8_t* __restrict__ vector2 = vectors + 2 * m_cols;
537   const int8_t* __restrict__ vector3 = vectors + 3 * m_cols;
538   float* __restrict__ result0 = results + 0 * m_rows;
539   float* __restrict__ result1 = results + 1 * m_rows;
540   float* __restrict__ result2 = results + 2 * m_rows;
541   float* __restrict__ result3 = results + 3 * m_rows;
542 
543   for (std::intptr_t row = 0; row < m_rows; ++row) {
544     // Initialize the dot product sum for the row to 0.
545     __m128i dp0_32x4 = _mm_setzero_si128();
546     __m128i dp1_32x4 = _mm_setzero_si128();
547     __m128i dp2_32x4 = _mm_setzero_si128();
548     __m128i dp3_32x4 = _mm_setzero_si128();
549 
550     std::intptr_t num_nonzero_blocks = *ledger++;
551     for (std::intptr_t i = 0; i < num_nonzero_blocks; i++) {
552       const std::intptr_t col_index = *ledger++ * kBlockSize;
553       // vecN are for different batches
554       const __m128i vec0_8x16 = _mm_loadu_si128(
555           reinterpret_cast<const __m128i*>(vector0 + col_index));
556       const __m128i vec1_8x16 = _mm_loadu_si128(
557           reinterpret_cast<const __m128i*>(vector1 + col_index));
558       const __m128i vec2_8x16 = _mm_loadu_si128(
559           reinterpret_cast<const __m128i*>(vector2 + col_index));
560       const __m128i vec3_8x16 = _mm_loadu_si128(
561           reinterpret_cast<const __m128i*>(vector3 + col_index));
562       const __m128i row_8x16 =
563           _mm_loadu_si128(reinterpret_cast<const __m128i*>(matrix));
564       // dp += vec · row
565       // dpN are for different batches
566       dp0_32x4 = _mm_add_epi32(dp0_32x4, DotProdInt8x4x4(vec0_8x16, row_8x16));
567       dp1_32x4 = _mm_add_epi32(dp1_32x4, DotProdInt8x4x4(vec1_8x16, row_8x16));
568       dp2_32x4 = _mm_add_epi32(dp2_32x4, DotProdInt8x4x4(vec2_8x16, row_8x16));
569       dp3_32x4 = _mm_add_epi32(dp3_32x4, DotProdInt8x4x4(vec3_8x16, row_8x16));
570       matrix += kBlockSize;
571     }  // for col
572 
573     // Horizontally add the 4 intermediate values.
574     const __m128i dp_32x4 =
575         ReduceInt32x4x4(dp0_32x4, dp1_32x4, dp2_32x4, dp3_32x4);
576     // Convert to float
577     const __m128 dp_fx4 = _mm_cvtepi32_ps(dp_32x4);
578     // Load the results (This is an Accumulate function..)
579     __m128 result_fx4 =
580         _mm_set_ps(result3[row], result2[row], result1[row], result0[row]);
581     // result += dp .* scaling
582     result_fx4 =
583         _mm_add_ps(result_fx4, _mm_mul_ps(dp_fx4, scaling_factors_fx4));
584     // Save the results
585     result0[row] = GetFloatVectorElement<0>(result_fx4);
586     result1[row] = GetFloatVectorElement<1>(result_fx4);
587     result2[row] = GetFloatVectorElement<2>(result_fx4);
588     result3[row] = GetFloatVectorElement<3>(result_fx4);
589   }  // for row
590 }
591 
592 }  // namespace
593 
SseSparseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * __restrict__ ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ results)594 void SseSparseMatrixBatchVectorMultiplyAccumulate(
595     const int8_t* __restrict__ matrix, const uint8_t* __restrict__ ledger,
596     const int m_rows, const int m_cols, const int8_t* __restrict__ vectors,
597     const float* __restrict__ scaling_factors, int n_batch,
598     float* __restrict__ results) {
599   int batch = 0;
600   const int kBatchSize4 = 4;
601   const int n_batch_rounddown_to_batchsize_4 = n_batch & ~(kBatchSize4 - 1);
602   while (batch < n_batch_rounddown_to_batchsize_4) {
603     const __m128 scaling_factors_fx4 = _mm_loadu_ps(scaling_factors + batch);
604     SseSparseMatrix4VectorsMultiplyAccumulate(
605         matrix, ledger, m_rows, m_cols, vectors, scaling_factors_fx4, results);
606     batch += kBatchSize4;
607     vectors += kBatchSize4 * m_cols;
608     results += kBatchSize4 * m_rows;
609   }  // for batch
610   while (batch < n_batch) {
611     SseSparseMatrixVectorMultiplyAccumulate(matrix, ledger, m_rows, m_cols,
612                                             vectors, scaling_factors[batch],
613                                             results);
614     ++batch;
615     vectors += m_cols;
616     results += m_rows;
617   }  // for batch
618 }
619 
SseReductionSumVector(const int8_t * input_vector,int32_t * output_vector,const int output_size,const int reduction_size)620 void SseReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
621                            const int output_size, const int reduction_size) {
622   static constexpr std::intptr_t kBlockSize = 16;
623   for (std::intptr_t row = 0; row < output_size; ++row) {
624     const int8_t* __restrict__ row_ptr = input_vector + row * reduction_size;
625     __m128i row_sum_16x8 = _mm_setzero_si128();
626     std::intptr_t col = 0;
627     for (; col < (reduction_size & ~(kBlockSize - 1)); col += kBlockSize) {
628       const __m128i row_8x16 =
629           _mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
630       const __m128i row_16x8 = _mm_maddubs_epi16(_mm_set1_epi8(1), row_8x16);
631       row_sum_16x8 = _mm_add_epi16(row_sum_16x8, row_16x8);
632     }  // for col
633 #ifdef __SSE4_1__
634     // Postamble for 8x 8-bit inputs.
635     if (col < (reduction_size & ~7)) {
636       // _mm_loadu_si64 not supported in gcc versions < 9, breaks kokoro build.
637       const __m128i row_16x8 = _mm_cvtepi8_epi16(
638           _mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + col)));
639       // dotprod += vec · row
640       row_sum_16x8 = _mm_add_epi16(row_sum_16x8, row_16x8);
641       col += 8;
642     }
643 #endif
644     const __m128i row_sum_32x4 =
645         _mm_madd_epi16(row_sum_16x8, _mm_set1_epi16(1));
646     int32_t row_sum = ReduceInt32x4(row_sum_32x4);
647 #if defined(__SSE4_1__) && defined(__clang__)
648     // SSE 4.1: Don't try to unroll and vectorize this, already done above.
649 #pragma clang loop unroll(disable) vectorize(disable)
650 #endif
651     for (; col < reduction_size; col++) {
652       row_sum += row_ptr[col];
653     }
654     output_vector[row] = row_sum;
655   }
656 }
657 
658 }  // namespace tensor_utils
659 }  // namespace tflite
660 
661 #endif  // __SSSE3__
662