xref: /aosp_15_r20/external/webrtc/modules/audio_processing/agc2/rnn_vad/vector_math.h (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
13 
14 // Defines WEBRTC_ARCH_X86_FAMILY, used below.
15 #include "rtc_base/system/arch.h"
16 
17 #if defined(WEBRTC_HAS_NEON)
18 #include <arm_neon.h>
19 #endif
20 #if defined(WEBRTC_ARCH_X86_FAMILY)
21 #include <emmintrin.h>
22 #endif
23 
24 #include <numeric>
25 
26 #include "api/array_view.h"
27 #include "modules/audio_processing/agc2/cpu_features.h"
28 #include "rtc_base/checks.h"
29 #include "rtc_base/numerics/safe_conversions.h"
30 #include "rtc_base/system/arch.h"
31 
32 namespace webrtc {
33 namespace rnn_vad {
34 
35 // Provides optimizations for mathematical operations having vectors as
36 // operand(s).
37 class VectorMath {
38  public:
VectorMath(AvailableCpuFeatures cpu_features)39   explicit VectorMath(AvailableCpuFeatures cpu_features)
40       : cpu_features_(cpu_features) {}
41 
42   // Computes the dot product between two equally sized vectors.
DotProduct(rtc::ArrayView<const float> x,rtc::ArrayView<const float> y)43   float DotProduct(rtc::ArrayView<const float> x,
44                    rtc::ArrayView<const float> y) const {
45     RTC_DCHECK_EQ(x.size(), y.size());
46 #if defined(WEBRTC_ARCH_X86_FAMILY)
47     if (cpu_features_.avx2) {
48       return DotProductAvx2(x, y);
49     } else if (cpu_features_.sse2) {
50       __m128 accumulator = _mm_setzero_ps();
51       constexpr int kBlockSizeLog2 = 2;
52       constexpr int kBlockSize = 1 << kBlockSizeLog2;
53       const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
54                                          << kBlockSizeLog2;
55       for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
56         RTC_DCHECK_LE(i + kBlockSize, x.size());
57         const __m128 x_i = _mm_loadu_ps(&x[i]);
58         const __m128 y_i = _mm_loadu_ps(&y[i]);
59         // Multiply-add.
60         const __m128 z_j = _mm_mul_ps(x_i, y_i);
61         accumulator = _mm_add_ps(accumulator, z_j);
62       }
63       // Reduce `accumulator` by addition.
64       __m128 high = _mm_movehl_ps(accumulator, accumulator);
65       accumulator = _mm_add_ps(accumulator, high);
66       high = _mm_shuffle_ps(accumulator, accumulator, 1);
67       accumulator = _mm_add_ps(accumulator, high);
68       float dot_product = _mm_cvtss_f32(accumulator);
69       // Add the result for the last block if incomplete.
70       for (int i = incomplete_block_index;
71            i < rtc::dchecked_cast<int>(x.size()); ++i) {
72         dot_product += x[i] * y[i];
73       }
74       return dot_product;
75     }
76 #elif defined(WEBRTC_HAS_NEON) && defined(WEBRTC_ARCH_ARM64)
77     if (cpu_features_.neon) {
78       float32x4_t accumulator = vdupq_n_f32(0.f);
79       constexpr int kBlockSizeLog2 = 2;
80       constexpr int kBlockSize = 1 << kBlockSizeLog2;
81       const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
82                                          << kBlockSizeLog2;
83       for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
84         RTC_DCHECK_LE(i + kBlockSize, x.size());
85         const float32x4_t x_i = vld1q_f32(&x[i]);
86         const float32x4_t y_i = vld1q_f32(&y[i]);
87         accumulator = vfmaq_f32(accumulator, x_i, y_i);
88       }
89       // Reduce `accumulator` by addition.
90       const float32x2_t tmp =
91           vpadd_f32(vget_low_f32(accumulator), vget_high_f32(accumulator));
92       float dot_product = vget_lane_f32(vpadd_f32(tmp, vrev64_f32(tmp)), 0);
93       // Add the result for the last block if incomplete.
94       for (int i = incomplete_block_index;
95            i < rtc::dchecked_cast<int>(x.size()); ++i) {
96         dot_product += x[i] * y[i];
97       }
98       return dot_product;
99     }
100 #endif
101     return std::inner_product(x.begin(), x.end(), y.begin(), 0.f);
102   }
103 
104  private:
105   float DotProductAvx2(rtc::ArrayView<const float> x,
106                        rtc::ArrayView<const float> y) const;
107 
108   const AvailableCpuFeatures cpu_features_;
109 };
110 
111 }  // namespace rnn_vad
112 }  // namespace webrtc
113 
114 #endif  // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
115