1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/optimized/sse_tensor_utils_impl.h"
16
17 #ifdef __SSSE3__
18
19 #include <emmintrin.h> // SSE2
20 #include <tmmintrin.h> // SSSE3
21 #ifdef __SSE4_1__
22 #include <smmintrin.h> // SSE4.1
23 #endif
24 #ifdef __AVX2__
25 #include <immintrin.h>
26 #endif
27
28 #include <cstdint>
29
30 #include "ruy/profiler/instrumentation.h" // from @ruy
31 #include "tensorflow/lite/kernels/cpu_backend_context.h"
32 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
33 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
34 #include "tensorflow/lite/kernels/internal/compatibility.h"
35
36 namespace tflite {
37 namespace tensor_utils {
38 namespace {
39
40 #if defined(__SSE2__)
41 // Note: this part is copied from XNNPACK/src/xnnpack/intrinsics-polyfill.h
42 // w.r.t the defition of '_mm_loadu_si32' intrinsic.
43 // GCC any, Clang pre-8, Android NDK Clang pre-8.0.7, Apple Clang pre-11, and
44 // ICC pre-16
45 #if (defined(__GNUC__) && !defined(__clang__) && \
46 !defined(__INTEL_COMPILER)) || \
47 (defined(__clang__) && !defined(__apple_build_version__) && \
48 (__clang_major__ < 8)) || \
49 (defined(__clang__) && defined(__ANDROID__) && (__clang_major__ == 8) && \
50 (__clang_minor__ == 0) && (__clang_patchlevel__ < 7)) || \
51 (defined(__clang__) && defined(__apple_build_version__) && \
52 (__apple_build_version__ < 11000000)) || \
53 (defined(__INTEL_COMPILER) && (__INTEL_COMPILER < 1600))
54
_mm_loadu_si32(const void * address)55 static inline __m128i _mm_loadu_si32(const void* address) {
56 return _mm_cvtsi32_si128(*((const int*)address));
57 }
58 #endif // GCC any, Clang pre-8, Android NDK Clang pre-8.0.7, Apple Clang pre-11
59 // and ICC pre-16
60 #endif // __SSE2__
61
62 // Dot product of four int8 vectors of 4 elements packed into a XMM register.
63 // Result is four int32 scalars packed into a XMM register.
64 // int8x4x4 · int8x4x4 => int32x4
DotProdInt8x4x4(__m128i a_8x16,__m128i b_8x16)65 static inline __m128i DotProdInt8x4x4(__m128i a_8x16, __m128i b_8x16) {
66 // Transfer sign from 'a' to 'b', as _mm_maddubs_epi16 treats 'a' unsigned.
67 b_8x16 = _mm_sign_epi8(b_8x16, a_8x16);
68 a_8x16 = _mm_abs_epi8(a_8x16);
69 // sumprod[i] = a[2*i]*b[2*i] + a[2*i+1]*b[2*i+1] (i = 0..7)
70 __m128i sumprod_16x8 = _mm_maddubs_epi16(a_8x16, b_8x16);
71 // sumprod[i] = sumprod[2*i]*1 + sumprod[2*i+1]*1 (i = 0..3)
72 return _mm_madd_epi16(sumprod_16x8, _mm_set1_epi16(1));
73 }
74
75 // Horizontally add 4 int32 values stored in a single XMM register to int32_t.
ReduceInt32x4(__m128i acc)76 static inline int32_t ReduceInt32x4(__m128i acc) {
77 // Shuffle to contain high half of acc (both in high and low halfs).
78 __m128i shuffle = _mm_unpackhi_epi64(acc, acc);
79 // Add shuffle and acc; low half is sums of twos (high half is ignored).
80 acc = _mm_add_epi32(acc, shuffle);
81 // Shuffle the two elements in low half (ignore high half).
82 shuffle = _mm_shuffle_epi32(acc, _MM_SHUFFLE(2, 3, 0, 1));
83 // Add shuffle and acc; lowest element is sum of all 4 input.
84 acc = _mm_add_epi32(acc, shuffle);
85 // Return lowest element as int32_t.
86 return _mm_cvtsi128_si32(acc);
87 }
88
89 #ifdef __AVX2__
90 // Horizontally add 4 float values stored in a single XMM register to float.
ReduceFloat32x4(__m128 acc)91 static inline float ReduceFloat32x4(__m128 acc) {
92 __m128 shuffle = _mm_movehdup_ps(acc);
93 acc = _mm_add_ps(acc, shuffle);
94 shuffle = _mm_movehl_ps(shuffle, acc);
95 acc = _mm_add_ss(acc, shuffle);
96 return _mm_cvtss_f32(acc);
97 }
98
99 // Horizontally add 8 float values stored in a single XMM register to float.
ReduceFloat32x8(__m256 acc)100 static inline float ReduceFloat32x8(__m256 acc) {
101 __m128 low = _mm256_extractf128_ps(acc, 0);
102 __m128 high = _mm256_extractf128_ps(acc, 1);
103 return ReduceFloat32x4(_mm_add_ps(low, high));
104 }
105
106 // Dot product of four int8 vectors of 4 elements packed into a YMM register.
107 // Result is eight int32 scalars packed into a YMM register.
108 // int8x4x8 · int8x4x8 => int32x8
DotProdInt8x4x8(__m256i a_16x16,__m256i b_16x16)109 static inline __m256i DotProdInt8x4x8(__m256i a_16x16, __m256i b_16x16) {
110 // Transfer sign from 'a' to 'b', as _mm256_maddubs_epi16 treats 'a' unsigned.
111 b_16x16 = _mm256_sign_epi8(b_16x16, a_16x16);
112 a_16x16 = _mm256_abs_epi8(a_16x16);
113 // sumprod[i] = a[2*i]*b[2*i] + a[2*i+1]*b[2*i+1] (i = 0..15)
114 __m256i sumprod_16x16 = _mm256_maddubs_epi16(a_16x16, b_16x16);
115 // sumprod[i] = sumprod[2*i]*1 + sumprod[2*i+1]*1 (i = 0..7)
116 return _mm256_madd_epi16(sumprod_16x16, _mm256_set1_epi16(1));
117 }
118 #endif // __AVX2__
119
120 // Horizontally add each of 4 XMM registers with 4 int32 values, pack result
121 // into a single XMM register. Similar to ReduceInt32x4, but with 4x inputs.
ReduceInt32x4x4(__m128i a,__m128i b,__m128i c,__m128i d)122 static inline __m128i ReduceInt32x4x4(__m128i a, __m128i b, __m128i c,
123 __m128i d) {
124 // Assuming x = [x0, x1, x2, x3]
125 const __m128i a_b_lo_half = _mm_unpacklo_epi32(a, b); // [a0, b0, a1, b1]
126 const __m128i a_b_hi_half = _mm_unpackhi_epi32(a, b); // [a2, b2, a3, b3]
127 const __m128i a_plus_b =
128 _mm_add_epi32(a_b_lo_half, a_b_hi_half); // [a0+a2, b0+b2, a1+a3, b1+b3]
129 const __m128i c_d_lo_half = _mm_unpacklo_epi32(c, d); // [c0, d0, c1, d1]
130 const __m128i c_d_hi_half = _mm_unpackhi_epi32(c, d); // [c2, d2, c3, d3]
131 const __m128i c_plus_d =
132 _mm_add_epi32(c_d_lo_half, c_d_hi_half); // [c0+c2, d0+d2, c1+c3, d1+d3]
133 const __m128i all_evns =
134 _mm_unpacklo_epi64(a_plus_b, c_plus_d); // [a02, b02, c02, d02]
135 const __m128i all_odds =
136 _mm_unpackhi_epi64(a_plus_b, c_plus_d); // [a13, b13, c13, d13]
137 return _mm_add_epi32(all_evns, all_odds); // [a0123, b0123, c0123, d0123]
138 }
139
140 // Returns the ith element of a XMM register holding float numbers.
141 template <int i>
GetFloatVectorElement(__m128 v)142 float GetFloatVectorElement(__m128 v) {
143 static_assert(i >= 0 && i < 4, "The index must be 0 <= i < 4.");
144 // Note, _mm_extract_ps returns int, so we can't use it here.
145 // These lines will be optimized to extractps anyway.
146 v = _mm_shuffle_ps(v, v, _MM_SHUFFLE(i, i, i, i));
147 return _mm_cvtss_f32(v);
148 }
149
150 } // namespace
151
152 #ifdef __AVX2__
153 constexpr int kFloatValuesPerAvx2Vector = 8;
154 template <int PerVectorSize>
RoundDownVectors(int size)155 inline int RoundDownVectors(int size) {
156 return size & ~(PerVectorSize - 1);
157 }
158
Avx2MatrixBatchVectorMultiplyAccumulateImpl(const float * __restrict__ matrix,int m_rows,int m_cols,const float * __restrict__ vector,int n_batch,float * __restrict__ result)159 void Avx2MatrixBatchVectorMultiplyAccumulateImpl(
160 const float* __restrict__ matrix, int m_rows, int m_cols,
161 const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
162 // If v_size is not divisible by the vector size, then we need to process the
163 // final few elements sequentially. postamble_start shows the start index
164 // where this should happen.
165 const int postamble_start =
166 RoundDownVectors<kFloatValuesPerAvx2Vector>(m_cols);
167
168 for (int b = 0; b < n_batch; ++b) {
169 float* result_in_batch = result + b * m_rows;
170 const float* vector_in_batch = vector + b * m_cols;
171 const float* matrix_row = matrix;
172
173 // Main matrix by vector multiplication loop
174 for (int r = 0; r < m_rows; ++r) {
175 __m256 acc_32x8 = _mm256_setzero_ps();
176 int c = 0;
177 for (; c < postamble_start; c += kFloatValuesPerAvx2Vector) {
178 // Load 8 float values from vector and matrix row.
179 __m256 vector_f32x8 = _mm256_loadu_ps(vector_in_batch + c);
180 __m256 matrix_f32x8 = _mm256_loadu_ps(matrix_row + c);
181
182 // Multiply the vector and matrix row and add to accumulator.
183 __m256 res = _mm256_mul_ps(vector_f32x8, matrix_f32x8);
184 acc_32x8 = _mm256_add_ps(acc_32x8, res);
185 }
186 // Add the 8 intermediate sum values to get the final dot-prod value for
187 // this column.
188 float sum = ReduceFloat32x8(acc_32x8);
189 for (; (c < m_cols); c++) {
190 sum += matrix_row[c] * vector_in_batch[c];
191 }
192 *result_in_batch += sum;
193 ++result_in_batch;
194 matrix_row += m_cols;
195 }
196 }
197 }
198
Avx2MatrixBatchVectorMultiplyAccumulateImpl(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,const int32_t * row_sums)199 void Avx2MatrixBatchVectorMultiplyAccumulateImpl(
200 const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
201 const int8_t* __restrict__ vectors,
202 const float* __restrict__ scaling_factors, int n_batch,
203 float* __restrict__ result, const float* per_channel_scale,
204 const int32_t* input_offset, const int32_t* row_sums) {
205 for (std::intptr_t batch = 0; batch < n_batch; ++batch) {
206 const float batch_scaling_factor = scaling_factors[batch];
207 const int32_t batch_offset = input_offset ? input_offset[batch] : 0;
208 // Compute dot-product for every column.
209 for (std::intptr_t row = 0; row < m_rows; ++row) {
210 // Get the address of the first element of the row.
211 const int8_t* __restrict__ row_ptr = matrix + row * m_cols;
212 const float row_scale =
213 per_channel_scale ? per_channel_scale[row] * batch_scaling_factor
214 : batch_scaling_factor;
215 const int32_t row_offset =
216 row_sums && batch_offset ? batch_offset * row_sums[row] : 0;
217 // Initialize the dot product sum for the row to 0.
218 __m256i dotprod_32x8 = _mm256_setzero_si256();
219 std::intptr_t col = 0;
220 // For every block of 32x 8-bit inputs.
221 while (col < (m_cols & ~31)) {
222 const __m256i vec_16x16 =
223 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(vectors + col));
224 const __m256i row_16x16 =
225 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_ptr + col));
226 // dotprod += vec · row
227 dotprod_32x8 = _mm256_add_epi32(dotprod_32x8,
228 DotProdInt8x4x8(vec_16x16, row_16x16));
229 col += 32;
230 }
231 // Sum lower and upper halves of 32x8 vector into 32x4 vector
232 __m128i low = _mm256_extracti128_si256(dotprod_32x8, 0);
233 __m128i high = _mm256_extracti128_si256(dotprod_32x8, 1);
234 __m128i dotprod_32x4 = _mm_add_epi32(low, high);
235 // Postamble for 16x 8-bit inputs.
236 if (col < (m_cols & ~15)) {
237 const __m128i vec_16x8 =
238 _mm_loadu_si128(reinterpret_cast<const __m128i*>(vectors + col));
239 const __m128i row_16x8 =
240 _mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
241 // dotprod += vec · row
242 dotprod_32x4 =
243 _mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_16x8, row_16x8));
244 col += 16;
245 }
246 // Postamble for 8x 8-bit inputs.
247 if (col < (m_cols & ~7)) {
248 const __m128i vec_16x8 = _mm_cvtepi8_epi16(
249 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(vectors + col)));
250 const __m128i row_16x8 = _mm_cvtepi8_epi16(
251 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + col)));
252 // dotprod += vec · row
253 dotprod_32x4 =
254 _mm_add_epi32(dotprod_32x4, _mm_madd_epi16(vec_16x8, row_16x8));
255 col += 8;
256 }
257 // Postamble for 4x 8-bit inputs.
258 if (col < (m_cols & ~3)) {
259 const __m128i vec_32x4 = _mm_cvtepi8_epi32(
260 _mm_loadu_si32(reinterpret_cast<const __m128i*>(vectors + col)));
261 const __m128i row_32x4 = _mm_cvtepi8_epi32(
262 _mm_loadu_si32(reinterpret_cast<const __m128i*>(row_ptr + col)));
263 // dotprod += vec · row
264 dotprod_32x4 =
265 _mm_add_epi32(dotprod_32x4, _mm_mullo_epi32(vec_32x4, row_32x4));
266 col += 4;
267 }
268
269 // Horizontally add the 4 intermediate sum values to get the final
270 // dot-prod value for this row.
271 int32_t sum = ReduceInt32x4(dotprod_32x4);
272
273 #pragma clang loop unroll(disable) vectorize(disable)
274 // Postamble loop for <4x remaining 8-bit inputs.
275 for (; col < m_cols; ++col) {
276 sum += row_ptr[col] * vectors[col];
277 } // for col
278 if (row_offset) {
279 sum -= row_offset;
280 }
281 *result += sum * row_scale;
282 ++result;
283 } // for row
284
285 vectors += m_cols;
286 } // for batch
287 }
288
289 #endif // __AVX2__
290
SseMatrixBatchVectorMultiplyAccumulateImpl(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,const int32_t * row_sums)291 void SseMatrixBatchVectorMultiplyAccumulateImpl(
292 const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
293 const int8_t* __restrict__ vectors,
294 const float* __restrict__ scaling_factors, int n_batch,
295 float* __restrict__ result, const float* per_channel_scale,
296 const int32_t* input_offset, const int32_t* row_sums) {
297 #ifdef __AVX2__
298 Avx2MatrixBatchVectorMultiplyAccumulateImpl(
299 matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
300 per_channel_scale, input_offset, row_sums);
301 return;
302 #else
303 for (std::intptr_t batch = 0; batch < n_batch; ++batch) {
304 const float batch_scaling_factor = scaling_factors[batch];
305 const int32_t batch_offset = input_offset ? input_offset[batch] : 0;
306 // Compute dot-product for every column.
307 for (std::intptr_t row = 0; row < m_rows; ++row) {
308 // Get the address of the first element of the row.
309 const int8_t* __restrict__ row_ptr = matrix + row * m_cols;
310 const float row_scale =
311 per_channel_scale ? per_channel_scale[row] * batch_scaling_factor
312 : batch_scaling_factor;
313 const int32_t row_offset =
314 row_sums && batch_offset ? batch_offset * row_sums[row] : 0;
315 // Initialize the dot product sum for the row to 0.
316 __m128i dotprod_32x4 = _mm_setzero_si128();
317 std::intptr_t col = 0;
318 // For every block of 16x 8-bit inputs.
319 while (col < (m_cols & ~15)) {
320 const __m128i vec_8x16 =
321 _mm_loadu_si128(reinterpret_cast<const __m128i*>(vectors + col));
322 const __m128i row_8x16 =
323 _mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
324 // dotprod += vec · row
325 dotprod_32x4 =
326 _mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16));
327 col += 16;
328 }
329 #ifdef __SSE4_1__
330 // Postamble for 8x 8-bit inputs.
331 if (col < (m_cols & ~7)) {
332 const __m128i vec_16x8 = _mm_cvtepi8_epi16(
333 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(vectors + col)));
334 const __m128i row_16x8 = _mm_cvtepi8_epi16(
335 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + col)));
336 // dotprod += vec · row
337 dotprod_32x4 =
338 _mm_add_epi32(dotprod_32x4, _mm_madd_epi16(vec_16x8, row_16x8));
339 col += 8;
340 }
341 // Postamble for 4x 8-bit inputs.
342 if (col < (m_cols & ~3)) {
343 const __m128i vec_32x4 = _mm_cvtepi8_epi32(
344 _mm_loadu_si32(reinterpret_cast<const __m128i*>(vectors + col)));
345 const __m128i row_32x4 = _mm_cvtepi8_epi32(
346 _mm_loadu_si32(reinterpret_cast<const __m128i*>(row_ptr + col)));
347 // dotprod += vec · row
348 dotprod_32x4 =
349 _mm_add_epi32(dotprod_32x4, _mm_mullo_epi32(vec_32x4, row_32x4));
350 col += 4;
351 }
352 #endif
353
354 // Horizontally add the 4 intermediate sum values to get the final
355 // dot-prod value for this row.
356 int32_t sum = ReduceInt32x4(dotprod_32x4);
357
358 #if defined(__SSE4_1__) && defined(__clang__)
359 // SSE 4.1: Don't try to unroll and vectorize this, already done above.
360 #pragma clang loop unroll(disable) vectorize(disable)
361 #endif
362 // Postamble loop for <4x (<16x without SSE 4.1) remaining 8-bit inputs.
363 for (; col < m_cols; ++col) {
364 sum += row_ptr[col] * vectors[col];
365 } // for col
366 if (row_offset) {
367 sum -= row_offset;
368 }
369 *result += sum * row_scale;
370 ++result;
371 } // for row
372
373 vectors += m_cols;
374 } // for batch
375 #endif // ifdef __AVX2__
376 }
377
SseCpuBackendGemm(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch,CpuBackendContext * context)378 void SseCpuBackendGemm(const int8_t* input, const int32_t* bias,
379 const int8_t* input_to_gate_weights, int32_t n_batch,
380 int32_t n_input, int32_t n_output, int32_t output_zp,
381 int32_t* scratch, CpuBackendContext* context) {
382 using ::tflite::cpu_backend_gemm::Gemm;
383 using ::tflite::cpu_backend_gemm::GemmParams;
384 using ::tflite::cpu_backend_gemm::MatrixParams;
385
386 MatrixParams<int8_t> lhs_params;
387 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
388 lhs_params.rows = n_output;
389 lhs_params.cols = n_input;
390 lhs_params.cache_policy = cpu_backend_gemm::CachePolicy::kCacheIfLargeSpeedup;
391
392 MatrixParams<int8_t> rhs_params;
393 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
394 rhs_params.rows = n_input;
395 rhs_params.cols = n_batch;
396
397 MatrixParams<int32_t> dst_params;
398 dst_params.order = cpu_backend_gemm::Order::kColMajor;
399 dst_params.rows = n_output;
400 dst_params.cols = n_batch;
401
402 GemmParams<int32, int32> gemm_params;
403 if (bias) {
404 gemm_params.bias = bias;
405 }
406 cpu_backend_gemm::Gemm(lhs_params, input_to_gate_weights, rhs_params, input,
407 dst_params, scratch, gemm_params, context);
408 }
409
SseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ result)410 void SseMatrixBatchVectorMultiplyAccumulate(
411 const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
412 const int8_t* __restrict__ vectors,
413 const float* __restrict__ scaling_factors, int n_batch,
414 float* __restrict__ result) {
415 SseMatrixBatchVectorMultiplyAccumulateImpl(
416 matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
417 /*per_channel_scale=*/nullptr, /*input_offset=*/nullptr,
418 /*row_sums=*/nullptr);
419 }
420
SseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,int32_t * scratch,float * __restrict__ result,CpuBackendContext * context)421 void SseMatrixBatchVectorMultiplyAccumulate(
422 const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
423 const int8_t* __restrict__ vectors,
424 const float* __restrict__ scaling_factors, int n_batch, int32_t* scratch,
425 float* __restrict__ result, CpuBackendContext* context) {
426 // TODO(b/183178387): Use a proper query to detect AVX/optimized paths.
427 if (m_rows % 4 == 0 && !context->PreferGemmlowpOnX86()) {
428 const int32_t* bias = static_cast<const int32_t*>(nullptr);
429 SseCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows,
430 /*output_zp=*/0, scratch, context);
431
432 {
433 ruy::profiler::ScopeLabel label("HybridMultiplyScalingFactor");
434 // Multiply by float scaling factors and write to result
435 const int total_size = n_batch * m_rows;
436 int i = 0;
437 for (; i <= total_size - 8; i += 8, result += 8) {
438 const float batch_scaling_factor0 = scaling_factors[i / m_rows];
439 const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
440 const __m128 scaling_factor0 = _mm_set1_ps(batch_scaling_factor0);
441 const __m128 scaling_factor1 = _mm_set1_ps(batch_scaling_factor1);
442 const __m128i scratch_val0 =
443 _mm_loadu_si128(reinterpret_cast<const __m128i*>(scratch + i));
444 const __m128i scratch_val1 =
445 _mm_loadu_si128(reinterpret_cast<const __m128i*>(scratch + i + 4));
446 const __m128 float_val0 = _mm_cvtepi32_ps(scratch_val0);
447 const __m128 float_val1 = _mm_cvtepi32_ps(scratch_val1);
448 const __m128 prod0 = _mm_mul_ps(float_val0, scaling_factor0);
449 const __m128 result0 = _mm_add_ps(_mm_load1_ps(result), prod0);
450 const __m128 prod1 = _mm_mul_ps(float_val1, scaling_factor1);
451 const __m128 result1 = _mm_add_ps(_mm_load1_ps(result + 4), prod1);
452 _mm_store_ps(result, result0);
453 _mm_store_ps(result + 4, result1);
454 }
455 scratch += i;
456 for (; i < total_size; i++) {
457 const float batch_scaling_factor = scaling_factors[i / m_rows];
458 int32_t x = *(scratch++);
459 *result += x * batch_scaling_factor;
460 ++result;
461 }
462 }
463 return;
464 }
465
466 SseMatrixBatchVectorMultiplyAccumulateImpl(
467 matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
468 /*per_channel_scale=*/nullptr, /*input_offset=*/nullptr,
469 /*row_sums=*/nullptr);
470 }
471
SseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,int32_t * scratch,int32_t * row_sums,bool * compute_row_sums,CpuBackendContext * context)472 void SseMatrixBatchVectorMultiplyAccumulate(
473 const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
474 const int8_t* __restrict__ vectors,
475 const float* __restrict__ scaling_factors, int n_batch,
476 float* __restrict__ result, const float* per_channel_scale,
477 const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
478 bool* compute_row_sums, CpuBackendContext* context) {
479 if ((input_offset != nullptr) && (!compute_row_sums || *compute_row_sums)) {
480 SseReductionSumVector(matrix, row_sums, m_rows, m_cols);
481 if (compute_row_sums) {
482 *compute_row_sums = false;
483 }
484 }
485 SseMatrixBatchVectorMultiplyAccumulateImpl(
486 matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
487 per_channel_scale, input_offset, row_sums);
488 }
489
490 namespace {
491
492 // Implements sparse-matrix - vector multiply-accumulate.
SseSparseMatrixVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * __restrict__ ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ vector,const float scaling_factor,float * __restrict__ result)493 inline void SseSparseMatrixVectorMultiplyAccumulate(
494 const int8_t* __restrict__ matrix, const uint8_t* __restrict__ ledger,
495 const int m_rows, const int m_cols, const int8_t* __restrict__ vector,
496 const float scaling_factor, float* __restrict__ result) {
497 static const std::intptr_t kBlockSize = 16;
498 TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
499 const uint8_t* __restrict__ ledger_ptr = ledger;
500 for (std::intptr_t row = 0; row < m_rows; ++row) {
501 // Initialize the dot product sum for the row to 0.
502 __m128i dotprod_32x4 = _mm_setzero_si128();
503 std::intptr_t num_nonzero_blocks = *ledger_ptr++;
504 for (std::intptr_t i = 0; i < num_nonzero_blocks; i++) {
505 const std::intptr_t col_index = *ledger_ptr++ * kBlockSize;
506 const __m128i vec_8x16 =
507 _mm_loadu_si128(reinterpret_cast<const __m128i*>(vector + col_index));
508 const __m128i row_8x16 =
509 _mm_loadu_si128(reinterpret_cast<const __m128i*>(matrix));
510 // dotprod += vec · row
511 dotprod_32x4 =
512 _mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16));
513 matrix += kBlockSize;
514 } // for col
515 // Horizontally add the 4 intermediate sum values to get the final
516 // dot-prod value for this row.
517 int32_t dotprod = ReduceInt32x4(dotprod_32x4);
518
519 result[row] += dotprod * scaling_factor;
520 } // for row
521 }
522
523 // Implements sparse-matrix - batch-of-4-vectors multiply-accumulate.
524 // The stride between vectors and results must be equal to m_cols.
525 // Parameter 'batch' is the index of the first batch, must be a multiple of 4.
SseSparseMatrix4VectorsMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * __restrict__ ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ const vectors,const __m128 scaling_factors_fx4,float * __restrict__ const results)526 inline void SseSparseMatrix4VectorsMultiplyAccumulate(
527 const int8_t* __restrict__ matrix, const uint8_t* __restrict__ ledger,
528 const int m_rows, const int m_cols,
529 const int8_t* __restrict__ const vectors, const __m128 scaling_factors_fx4,
530 float* __restrict__ const results) {
531 static const std::intptr_t kBlockSize = 16;
532 TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
533
534 const int8_t* __restrict__ vector0 = vectors + 0 * m_cols;
535 const int8_t* __restrict__ vector1 = vectors + 1 * m_cols;
536 const int8_t* __restrict__ vector2 = vectors + 2 * m_cols;
537 const int8_t* __restrict__ vector3 = vectors + 3 * m_cols;
538 float* __restrict__ result0 = results + 0 * m_rows;
539 float* __restrict__ result1 = results + 1 * m_rows;
540 float* __restrict__ result2 = results + 2 * m_rows;
541 float* __restrict__ result3 = results + 3 * m_rows;
542
543 for (std::intptr_t row = 0; row < m_rows; ++row) {
544 // Initialize the dot product sum for the row to 0.
545 __m128i dp0_32x4 = _mm_setzero_si128();
546 __m128i dp1_32x4 = _mm_setzero_si128();
547 __m128i dp2_32x4 = _mm_setzero_si128();
548 __m128i dp3_32x4 = _mm_setzero_si128();
549
550 std::intptr_t num_nonzero_blocks = *ledger++;
551 for (std::intptr_t i = 0; i < num_nonzero_blocks; i++) {
552 const std::intptr_t col_index = *ledger++ * kBlockSize;
553 // vecN are for different batches
554 const __m128i vec0_8x16 = _mm_loadu_si128(
555 reinterpret_cast<const __m128i*>(vector0 + col_index));
556 const __m128i vec1_8x16 = _mm_loadu_si128(
557 reinterpret_cast<const __m128i*>(vector1 + col_index));
558 const __m128i vec2_8x16 = _mm_loadu_si128(
559 reinterpret_cast<const __m128i*>(vector2 + col_index));
560 const __m128i vec3_8x16 = _mm_loadu_si128(
561 reinterpret_cast<const __m128i*>(vector3 + col_index));
562 const __m128i row_8x16 =
563 _mm_loadu_si128(reinterpret_cast<const __m128i*>(matrix));
564 // dp += vec · row
565 // dpN are for different batches
566 dp0_32x4 = _mm_add_epi32(dp0_32x4, DotProdInt8x4x4(vec0_8x16, row_8x16));
567 dp1_32x4 = _mm_add_epi32(dp1_32x4, DotProdInt8x4x4(vec1_8x16, row_8x16));
568 dp2_32x4 = _mm_add_epi32(dp2_32x4, DotProdInt8x4x4(vec2_8x16, row_8x16));
569 dp3_32x4 = _mm_add_epi32(dp3_32x4, DotProdInt8x4x4(vec3_8x16, row_8x16));
570 matrix += kBlockSize;
571 } // for col
572
573 // Horizontally add the 4 intermediate values.
574 const __m128i dp_32x4 =
575 ReduceInt32x4x4(dp0_32x4, dp1_32x4, dp2_32x4, dp3_32x4);
576 // Convert to float
577 const __m128 dp_fx4 = _mm_cvtepi32_ps(dp_32x4);
578 // Load the results (This is an Accumulate function..)
579 __m128 result_fx4 =
580 _mm_set_ps(result3[row], result2[row], result1[row], result0[row]);
581 // result += dp .* scaling
582 result_fx4 =
583 _mm_add_ps(result_fx4, _mm_mul_ps(dp_fx4, scaling_factors_fx4));
584 // Save the results
585 result0[row] = GetFloatVectorElement<0>(result_fx4);
586 result1[row] = GetFloatVectorElement<1>(result_fx4);
587 result2[row] = GetFloatVectorElement<2>(result_fx4);
588 result3[row] = GetFloatVectorElement<3>(result_fx4);
589 } // for row
590 }
591
592 } // namespace
593
SseSparseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * __restrict__ ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ results)594 void SseSparseMatrixBatchVectorMultiplyAccumulate(
595 const int8_t* __restrict__ matrix, const uint8_t* __restrict__ ledger,
596 const int m_rows, const int m_cols, const int8_t* __restrict__ vectors,
597 const float* __restrict__ scaling_factors, int n_batch,
598 float* __restrict__ results) {
599 int batch = 0;
600 const int kBatchSize4 = 4;
601 const int n_batch_rounddown_to_batchsize_4 = n_batch & ~(kBatchSize4 - 1);
602 while (batch < n_batch_rounddown_to_batchsize_4) {
603 const __m128 scaling_factors_fx4 = _mm_loadu_ps(scaling_factors + batch);
604 SseSparseMatrix4VectorsMultiplyAccumulate(
605 matrix, ledger, m_rows, m_cols, vectors, scaling_factors_fx4, results);
606 batch += kBatchSize4;
607 vectors += kBatchSize4 * m_cols;
608 results += kBatchSize4 * m_rows;
609 } // for batch
610 while (batch < n_batch) {
611 SseSparseMatrixVectorMultiplyAccumulate(matrix, ledger, m_rows, m_cols,
612 vectors, scaling_factors[batch],
613 results);
614 ++batch;
615 vectors += m_cols;
616 results += m_rows;
617 } // for batch
618 }
619
SseReductionSumVector(const int8_t * input_vector,int32_t * output_vector,const int output_size,const int reduction_size)620 void SseReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
621 const int output_size, const int reduction_size) {
622 static constexpr std::intptr_t kBlockSize = 16;
623 for (std::intptr_t row = 0; row < output_size; ++row) {
624 const int8_t* __restrict__ row_ptr = input_vector + row * reduction_size;
625 __m128i row_sum_16x8 = _mm_setzero_si128();
626 std::intptr_t col = 0;
627 for (; col < (reduction_size & ~(kBlockSize - 1)); col += kBlockSize) {
628 const __m128i row_8x16 =
629 _mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
630 const __m128i row_16x8 = _mm_maddubs_epi16(_mm_set1_epi8(1), row_8x16);
631 row_sum_16x8 = _mm_add_epi16(row_sum_16x8, row_16x8);
632 } // for col
633 #ifdef __SSE4_1__
634 // Postamble for 8x 8-bit inputs.
635 if (col < (reduction_size & ~7)) {
636 // _mm_loadu_si64 not supported in gcc versions < 9, breaks kokoro build.
637 const __m128i row_16x8 = _mm_cvtepi8_epi16(
638 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + col)));
639 // dotprod += vec · row
640 row_sum_16x8 = _mm_add_epi16(row_sum_16x8, row_16x8);
641 col += 8;
642 }
643 #endif
644 const __m128i row_sum_32x4 =
645 _mm_madd_epi16(row_sum_16x8, _mm_set1_epi16(1));
646 int32_t row_sum = ReduceInt32x4(row_sum_32x4);
647 #if defined(__SSE4_1__) && defined(__clang__)
648 // SSE 4.1: Don't try to unroll and vectorize this, already done above.
649 #pragma clang loop unroll(disable) vectorize(disable)
650 #endif
651 for (; col < reduction_size; col++) {
652 row_sum += row_ptr[col];
653 }
654 output_vector[row] = row_sum;
655 }
656 }
657
658 } // namespace tensor_utils
659 } // namespace tflite
660
661 #endif // __SSSE3__
662