1 /*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <array>
12 #include <memory>
13 #include <string>
14 #include <vector>
15
16 #include "common_audio/resampler/push_sinc_resampler.h"
17 #include "modules/audio_processing/agc2/cpu_features.h"
18 #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
19 #include "modules/audio_processing/agc2/rnn_vad/rnn.h"
20 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
21 #include "modules/audio_processing/test/performance_timer.h"
22 #include "rtc_base/checks.h"
23 #include "rtc_base/logging.h"
24 #include "test/gtest.h"
25 #include "third_party/rnnoise/src/rnn_activations.h"
26 #include "third_party/rnnoise/src/rnn_vad_weights.h"
27
28 namespace webrtc {
29 namespace rnn_vad {
30 namespace {
31
32 constexpr int kFrameSize10ms48kHz = 480;
33
DumpPerfStats(int num_samples,int sample_rate,double average_us,double standard_deviation)34 void DumpPerfStats(int num_samples,
35 int sample_rate,
36 double average_us,
37 double standard_deviation) {
38 float audio_track_length_ms =
39 1e3f * static_cast<float>(num_samples) / static_cast<float>(sample_rate);
40 float average_ms = static_cast<float>(average_us) / 1e3f;
41 float speed = audio_track_length_ms / average_ms;
42 RTC_LOG(LS_INFO) << "track duration (ms): " << audio_track_length_ms;
43 RTC_LOG(LS_INFO) << "average processing time (ms): " << average_ms << " +/- "
44 << (standard_deviation / 1e3);
45 RTC_LOG(LS_INFO) << "speed: " << speed << "x";
46 }
47
48 // When the RNN VAD model is updated and the expected output changes, set the
49 // constant below to true in order to write new expected output binary files.
50 constexpr bool kWriteComputedOutputToFile = false;
51
52 // Avoids that one forgets to set `kWriteComputedOutputToFile` back to false
53 // when the expected output files are re-exported.
TEST(RnnVadTest,CheckWriteComputedOutputIsFalse)54 TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) {
55 ASSERT_FALSE(kWriteComputedOutputToFile)
56 << "Cannot land if kWriteComputedOutput is true.";
57 }
58
59 class RnnVadProbabilityParametrization
60 : public ::testing::TestWithParam<AvailableCpuFeatures> {};
61
62 // Checks that the computed VAD probability for a test input sequence sampled at
63 // 48 kHz is within tolerance.
TEST_P(RnnVadProbabilityParametrization,RnnVadProbabilityWithinTolerance)64 TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) {
65 // Init resampler, feature extractor and RNN.
66 PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
67 const AvailableCpuFeatures cpu_features = GetParam();
68 FeaturesExtractor features_extractor(cpu_features);
69 RnnVad rnn_vad(cpu_features);
70
71 // Init input samples and expected output readers.
72 std::unique_ptr<FileReader> samples_reader = CreatePcmSamplesReader();
73 std::unique_ptr<FileReader> expected_vad_prob_reader = CreateVadProbsReader();
74
75 // Input length. The last incomplete frame is ignored.
76 const int num_frames = samples_reader->size() / kFrameSize10ms48kHz;
77
78 // Init buffers.
79 std::vector<float> samples_48k(kFrameSize10ms48kHz);
80 std::vector<float> samples_24k(kFrameSize10ms24kHz);
81 std::vector<float> feature_vector(kFeatureVectorSize);
82 std::vector<float> computed_vad_prob(num_frames);
83 std::vector<float> expected_vad_prob(num_frames);
84
85 // Read expected output.
86 ASSERT_TRUE(expected_vad_prob_reader->ReadChunk(expected_vad_prob));
87
88 // Compute VAD probabilities on the downsampled input.
89 float cumulative_error = 0.f;
90 for (int i = 0; i < num_frames; ++i) {
91 ASSERT_TRUE(samples_reader->ReadChunk(samples_48k));
92 decimator.Resample(samples_48k.data(), samples_48k.size(),
93 samples_24k.data(), samples_24k.size());
94 bool is_silence = features_extractor.CheckSilenceComputeFeatures(
95 {samples_24k.data(), kFrameSize10ms24kHz},
96 {feature_vector.data(), kFeatureVectorSize});
97 computed_vad_prob[i] = rnn_vad.ComputeVadProbability(
98 {feature_vector.data(), kFeatureVectorSize}, is_silence);
99 EXPECT_NEAR(computed_vad_prob[i], expected_vad_prob[i], 1e-3f);
100 cumulative_error += std::abs(computed_vad_prob[i] - expected_vad_prob[i]);
101 }
102 // Check average error.
103 EXPECT_LT(cumulative_error / num_frames, 1e-4f);
104
105 if (kWriteComputedOutputToFile) {
106 FileWriter vad_prob_writer("new_vad_prob.dat");
107 vad_prob_writer.WriteChunk(computed_vad_prob);
108 }
109 }
110
111 // Performance test for the RNN VAD (pre-fetching and downsampling are
112 // excluded). Keep disabled and only enable locally to measure performance as
113 // follows:
114 // - on desktop: run the this unit test adding "--logs";
115 // - on android: run the this unit test adding "--logcat-output-file".
TEST_P(RnnVadProbabilityParametrization,DISABLED_RnnVadPerformance)116 TEST_P(RnnVadProbabilityParametrization, DISABLED_RnnVadPerformance) {
117 // PCM samples reader and buffers.
118 std::unique_ptr<FileReader> samples_reader = CreatePcmSamplesReader();
119 // The last incomplete frame is ignored.
120 const int num_frames = samples_reader->size() / kFrameSize10ms48kHz;
121 std::array<float, kFrameSize10ms48kHz> samples;
122 // Pre-fetch and decimate samples.
123 PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
124 std::vector<float> prefetched_decimated_samples;
125 prefetched_decimated_samples.resize(num_frames * kFrameSize10ms24kHz);
126 for (int i = 0; i < num_frames; ++i) {
127 ASSERT_TRUE(samples_reader->ReadChunk(samples));
128 decimator.Resample(samples.data(), samples.size(),
129 &prefetched_decimated_samples[i * kFrameSize10ms24kHz],
130 kFrameSize10ms24kHz);
131 }
132 // Initialize.
133 const AvailableCpuFeatures cpu_features = GetParam();
134 FeaturesExtractor features_extractor(cpu_features);
135 std::array<float, kFeatureVectorSize> feature_vector;
136 RnnVad rnn_vad(cpu_features);
137 constexpr int number_of_tests = 100;
138 ::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
139 for (int k = 0; k < number_of_tests; ++k) {
140 features_extractor.Reset();
141 rnn_vad.Reset();
142 // Process frames.
143 perf_timer.StartTimer();
144 for (int i = 0; i < num_frames; ++i) {
145 bool is_silence = features_extractor.CheckSilenceComputeFeatures(
146 {&prefetched_decimated_samples[i * kFrameSize10ms24kHz],
147 kFrameSize10ms24kHz},
148 feature_vector);
149 rnn_vad.ComputeVadProbability(feature_vector, is_silence);
150 }
151 perf_timer.StopTimer();
152 }
153 DumpPerfStats(num_frames * kFrameSize10ms24kHz, kSampleRate24kHz,
154 perf_timer.GetDurationAverage(),
155 perf_timer.GetDurationStandardDeviation());
156 }
157
158 // Finds the relevant CPU features combinations to test.
GetCpuFeaturesToTest()159 std::vector<AvailableCpuFeatures> GetCpuFeaturesToTest() {
160 std::vector<AvailableCpuFeatures> v;
161 v.push_back(NoAvailableCpuFeatures());
162 AvailableCpuFeatures available = GetAvailableCpuFeatures();
163 if (available.avx2 && available.sse2) {
164 v.push_back({/*sse2=*/true, /*avx2=*/true, /*neon=*/false});
165 }
166 if (available.sse2) {
167 v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
168 }
169 if (available.neon) {
170 v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/true});
171 }
172 return v;
173 }
174
175 INSTANTIATE_TEST_SUITE_P(
176 RnnVadTest,
177 RnnVadProbabilityParametrization,
178 ::testing::ValuesIn(GetCpuFeaturesToTest()),
__anon7c6585200202(const ::testing::TestParamInfo<AvailableCpuFeatures>& info) 179 [](const ::testing::TestParamInfo<AvailableCpuFeatures>& info) {
180 return info.param.ToString();
181 });
182
183 } // namespace
184 } // namespace rnn_vad
185 } // namespace webrtc
186