1 /* 2 * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_ 12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_ 13 14 // Defines WEBRTC_ARCH_X86_FAMILY, used below. 15 #include "rtc_base/system/arch.h" 16 17 #if defined(WEBRTC_HAS_NEON) 18 #include <arm_neon.h> 19 #endif 20 #if defined(WEBRTC_ARCH_X86_FAMILY) 21 #include <emmintrin.h> 22 #endif 23 24 #include <numeric> 25 26 #include "api/array_view.h" 27 #include "modules/audio_processing/agc2/cpu_features.h" 28 #include "rtc_base/checks.h" 29 #include "rtc_base/numerics/safe_conversions.h" 30 #include "rtc_base/system/arch.h" 31 32 namespace webrtc { 33 namespace rnn_vad { 34 35 // Provides optimizations for mathematical operations having vectors as 36 // operand(s). 37 class VectorMath { 38 public: VectorMath(AvailableCpuFeatures cpu_features)39 explicit VectorMath(AvailableCpuFeatures cpu_features) 40 : cpu_features_(cpu_features) {} 41 42 // Computes the dot product between two equally sized vectors. DotProduct(rtc::ArrayView<const float> x,rtc::ArrayView<const float> y)43 float DotProduct(rtc::ArrayView<const float> x, 44 rtc::ArrayView<const float> y) const { 45 RTC_DCHECK_EQ(x.size(), y.size()); 46 #if defined(WEBRTC_ARCH_X86_FAMILY) 47 if (cpu_features_.avx2) { 48 return DotProductAvx2(x, y); 49 } else if (cpu_features_.sse2) { 50 __m128 accumulator = _mm_setzero_ps(); 51 constexpr int kBlockSizeLog2 = 2; 52 constexpr int kBlockSize = 1 << kBlockSizeLog2; 53 const int incomplete_block_index = (x.size() >> kBlockSizeLog2) 54 << kBlockSizeLog2; 55 for (int i = 0; i < incomplete_block_index; i += kBlockSize) { 56 RTC_DCHECK_LE(i + kBlockSize, x.size()); 57 const __m128 x_i = _mm_loadu_ps(&x[i]); 58 const __m128 y_i = _mm_loadu_ps(&y[i]); 59 // Multiply-add. 60 const __m128 z_j = _mm_mul_ps(x_i, y_i); 61 accumulator = _mm_add_ps(accumulator, z_j); 62 } 63 // Reduce `accumulator` by addition. 64 __m128 high = _mm_movehl_ps(accumulator, accumulator); 65 accumulator = _mm_add_ps(accumulator, high); 66 high = _mm_shuffle_ps(accumulator, accumulator, 1); 67 accumulator = _mm_add_ps(accumulator, high); 68 float dot_product = _mm_cvtss_f32(accumulator); 69 // Add the result for the last block if incomplete. 70 for (int i = incomplete_block_index; 71 i < rtc::dchecked_cast<int>(x.size()); ++i) { 72 dot_product += x[i] * y[i]; 73 } 74 return dot_product; 75 } 76 #elif defined(WEBRTC_HAS_NEON) && defined(WEBRTC_ARCH_ARM64) 77 if (cpu_features_.neon) { 78 float32x4_t accumulator = vdupq_n_f32(0.f); 79 constexpr int kBlockSizeLog2 = 2; 80 constexpr int kBlockSize = 1 << kBlockSizeLog2; 81 const int incomplete_block_index = (x.size() >> kBlockSizeLog2) 82 << kBlockSizeLog2; 83 for (int i = 0; i < incomplete_block_index; i += kBlockSize) { 84 RTC_DCHECK_LE(i + kBlockSize, x.size()); 85 const float32x4_t x_i = vld1q_f32(&x[i]); 86 const float32x4_t y_i = vld1q_f32(&y[i]); 87 accumulator = vfmaq_f32(accumulator, x_i, y_i); 88 } 89 // Reduce `accumulator` by addition. 90 const float32x2_t tmp = 91 vpadd_f32(vget_low_f32(accumulator), vget_high_f32(accumulator)); 92 float dot_product = vget_lane_f32(vpadd_f32(tmp, vrev64_f32(tmp)), 0); 93 // Add the result for the last block if incomplete. 94 for (int i = incomplete_block_index; 95 i < rtc::dchecked_cast<int>(x.size()); ++i) { 96 dot_product += x[i] * y[i]; 97 } 98 return dot_product; 99 } 100 #endif 101 return std::inner_product(x.begin(), x.end(), y.begin(), 0.f); 102 } 103 104 private: 105 float DotProductAvx2(rtc::ArrayView<const float> x, 106 rtc::ArrayView<const float> y) const; 107 108 const AvailableCpuFeatures cpu_features_; 109 }; 110 111 } // namespace rnn_vad 112 } // namespace webrtc 113 114 #endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_ 115