xref: /aosp_15_r20/external/webrtc/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <array>
12 #include <memory>
13 #include <string>
14 #include <vector>
15 
16 #include "common_audio/resampler/push_sinc_resampler.h"
17 #include "modules/audio_processing/agc2/cpu_features.h"
18 #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
19 #include "modules/audio_processing/agc2/rnn_vad/rnn.h"
20 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
21 #include "modules/audio_processing/test/performance_timer.h"
22 #include "rtc_base/checks.h"
23 #include "rtc_base/logging.h"
24 #include "test/gtest.h"
25 #include "third_party/rnnoise/src/rnn_activations.h"
26 #include "third_party/rnnoise/src/rnn_vad_weights.h"
27 
28 namespace webrtc {
29 namespace rnn_vad {
30 namespace {
31 
32 constexpr int kFrameSize10ms48kHz = 480;
33 
DumpPerfStats(int num_samples,int sample_rate,double average_us,double standard_deviation)34 void DumpPerfStats(int num_samples,
35                    int sample_rate,
36                    double average_us,
37                    double standard_deviation) {
38   float audio_track_length_ms =
39       1e3f * static_cast<float>(num_samples) / static_cast<float>(sample_rate);
40   float average_ms = static_cast<float>(average_us) / 1e3f;
41   float speed = audio_track_length_ms / average_ms;
42   RTC_LOG(LS_INFO) << "track duration (ms): " << audio_track_length_ms;
43   RTC_LOG(LS_INFO) << "average processing time (ms): " << average_ms << " +/- "
44                    << (standard_deviation / 1e3);
45   RTC_LOG(LS_INFO) << "speed: " << speed << "x";
46 }
47 
48 // When the RNN VAD model is updated and the expected output changes, set the
49 // constant below to true in order to write new expected output binary files.
50 constexpr bool kWriteComputedOutputToFile = false;
51 
52 // Avoids that one forgets to set `kWriteComputedOutputToFile` back to false
53 // when the expected output files are re-exported.
TEST(RnnVadTest,CheckWriteComputedOutputIsFalse)54 TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) {
55   ASSERT_FALSE(kWriteComputedOutputToFile)
56       << "Cannot land if kWriteComputedOutput is true.";
57 }
58 
59 class RnnVadProbabilityParametrization
60     : public ::testing::TestWithParam<AvailableCpuFeatures> {};
61 
62 // Checks that the computed VAD probability for a test input sequence sampled at
63 // 48 kHz is within tolerance.
TEST_P(RnnVadProbabilityParametrization,RnnVadProbabilityWithinTolerance)64 TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) {
65   // Init resampler, feature extractor and RNN.
66   PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
67   const AvailableCpuFeatures cpu_features = GetParam();
68   FeaturesExtractor features_extractor(cpu_features);
69   RnnVad rnn_vad(cpu_features);
70 
71   // Init input samples and expected output readers.
72   std::unique_ptr<FileReader> samples_reader = CreatePcmSamplesReader();
73   std::unique_ptr<FileReader> expected_vad_prob_reader = CreateVadProbsReader();
74 
75   // Input length. The last incomplete frame is ignored.
76   const int num_frames = samples_reader->size() / kFrameSize10ms48kHz;
77 
78   // Init buffers.
79   std::vector<float> samples_48k(kFrameSize10ms48kHz);
80   std::vector<float> samples_24k(kFrameSize10ms24kHz);
81   std::vector<float> feature_vector(kFeatureVectorSize);
82   std::vector<float> computed_vad_prob(num_frames);
83   std::vector<float> expected_vad_prob(num_frames);
84 
85   // Read expected output.
86   ASSERT_TRUE(expected_vad_prob_reader->ReadChunk(expected_vad_prob));
87 
88   // Compute VAD probabilities on the downsampled input.
89   float cumulative_error = 0.f;
90   for (int i = 0; i < num_frames; ++i) {
91     ASSERT_TRUE(samples_reader->ReadChunk(samples_48k));
92     decimator.Resample(samples_48k.data(), samples_48k.size(),
93                        samples_24k.data(), samples_24k.size());
94     bool is_silence = features_extractor.CheckSilenceComputeFeatures(
95         {samples_24k.data(), kFrameSize10ms24kHz},
96         {feature_vector.data(), kFeatureVectorSize});
97     computed_vad_prob[i] = rnn_vad.ComputeVadProbability(
98         {feature_vector.data(), kFeatureVectorSize}, is_silence);
99     EXPECT_NEAR(computed_vad_prob[i], expected_vad_prob[i], 1e-3f);
100     cumulative_error += std::abs(computed_vad_prob[i] - expected_vad_prob[i]);
101   }
102   // Check average error.
103   EXPECT_LT(cumulative_error / num_frames, 1e-4f);
104 
105   if (kWriteComputedOutputToFile) {
106     FileWriter vad_prob_writer("new_vad_prob.dat");
107     vad_prob_writer.WriteChunk(computed_vad_prob);
108   }
109 }
110 
111 // Performance test for the RNN VAD (pre-fetching and downsampling are
112 // excluded). Keep disabled and only enable locally to measure performance as
113 // follows:
114 // - on desktop: run the this unit test adding "--logs";
115 // - on android: run the this unit test adding "--logcat-output-file".
TEST_P(RnnVadProbabilityParametrization,DISABLED_RnnVadPerformance)116 TEST_P(RnnVadProbabilityParametrization, DISABLED_RnnVadPerformance) {
117   // PCM samples reader and buffers.
118   std::unique_ptr<FileReader> samples_reader = CreatePcmSamplesReader();
119   // The last incomplete frame is ignored.
120   const int num_frames = samples_reader->size() / kFrameSize10ms48kHz;
121   std::array<float, kFrameSize10ms48kHz> samples;
122   // Pre-fetch and decimate samples.
123   PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
124   std::vector<float> prefetched_decimated_samples;
125   prefetched_decimated_samples.resize(num_frames * kFrameSize10ms24kHz);
126   for (int i = 0; i < num_frames; ++i) {
127     ASSERT_TRUE(samples_reader->ReadChunk(samples));
128     decimator.Resample(samples.data(), samples.size(),
129                        &prefetched_decimated_samples[i * kFrameSize10ms24kHz],
130                        kFrameSize10ms24kHz);
131   }
132   // Initialize.
133   const AvailableCpuFeatures cpu_features = GetParam();
134   FeaturesExtractor features_extractor(cpu_features);
135   std::array<float, kFeatureVectorSize> feature_vector;
136   RnnVad rnn_vad(cpu_features);
137   constexpr int number_of_tests = 100;
138   ::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
139   for (int k = 0; k < number_of_tests; ++k) {
140     features_extractor.Reset();
141     rnn_vad.Reset();
142     // Process frames.
143     perf_timer.StartTimer();
144     for (int i = 0; i < num_frames; ++i) {
145       bool is_silence = features_extractor.CheckSilenceComputeFeatures(
146           {&prefetched_decimated_samples[i * kFrameSize10ms24kHz],
147            kFrameSize10ms24kHz},
148           feature_vector);
149       rnn_vad.ComputeVadProbability(feature_vector, is_silence);
150     }
151     perf_timer.StopTimer();
152   }
153   DumpPerfStats(num_frames * kFrameSize10ms24kHz, kSampleRate24kHz,
154                 perf_timer.GetDurationAverage(),
155                 perf_timer.GetDurationStandardDeviation());
156 }
157 
158 // Finds the relevant CPU features combinations to test.
GetCpuFeaturesToTest()159 std::vector<AvailableCpuFeatures> GetCpuFeaturesToTest() {
160   std::vector<AvailableCpuFeatures> v;
161   v.push_back(NoAvailableCpuFeatures());
162   AvailableCpuFeatures available = GetAvailableCpuFeatures();
163   if (available.avx2 && available.sse2) {
164     v.push_back({/*sse2=*/true, /*avx2=*/true, /*neon=*/false});
165   }
166   if (available.sse2) {
167     v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
168   }
169   if (available.neon) {
170     v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/true});
171   }
172   return v;
173 }
174 
175 INSTANTIATE_TEST_SUITE_P(
176     RnnVadTest,
177     RnnVadProbabilityParametrization,
178     ::testing::ValuesIn(GetCpuFeaturesToTest()),
__anon7c6585200202(const ::testing::TestParamInfo<AvailableCpuFeatures>& info) 179     [](const ::testing::TestParamInfo<AvailableCpuFeatures>& info) {
180       return info.param.ToString();
181     });
182 
183 }  // namespace
184 }  // namespace rnn_vad
185 }  // namespace webrtc
186