xref: /aosp_15_r20/external/webrtc/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
12 
13 #include <cmath>
14 #include <vector>
15 
16 #include "modules/audio_processing/agc2/cpu_features.h"
17 #include "rtc_base/numerics/safe_compare.h"
18 #include "rtc_base/numerics/safe_conversions.h"
19 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
20 // #include "test/fpe_observer.h"
21 #include "test/gtest.h"
22 
23 namespace webrtc {
24 namespace rnn_vad {
25 namespace {
26 
ceil(int n,int m)27 constexpr int ceil(int n, int m) {
28   return (n + m - 1) / m;
29 }
30 
31 // Number of 10 ms frames required to fill a pitch buffer having size
32 // `kBufSize24kHz`.
33 constexpr int kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz);
34 // Number of samples for the test data.
35 constexpr int kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz;
36 
37 // Verifies that the pitch in Hz is in the detectable range.
PitchIsValid(float pitch_hz)38 bool PitchIsValid(float pitch_hz) {
39   const int pitch_period = static_cast<float>(kSampleRate24kHz) / pitch_hz;
40   return kInitialMinPitch24kHz <= pitch_period &&
41          pitch_period <= kMaxPitch24kHz;
42 }
43 
CreatePureTone(float amplitude,float freq_hz,rtc::ArrayView<float> dst)44 void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView<float> dst) {
45   for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) {
46     dst[i] = amplitude * std::sin(2.f * kPi * freq_hz * i / kSampleRate24kHz);
47   }
48 }
49 
50 // Feeds `features_extractor` with `samples` splitting it in 10 ms frames.
51 // For every frame, the output is written into `feature_vector`. Returns true
52 // if silence is detected in the last frame.
FeedTestData(FeaturesExtractor & features_extractor,rtc::ArrayView<const float> samples,rtc::ArrayView<float,kFeatureVectorSize> feature_vector)53 bool FeedTestData(FeaturesExtractor& features_extractor,
54                   rtc::ArrayView<const float> samples,
55                   rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
56   // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
57   // FloatingPointExceptionObserver fpe_observer;
58   bool is_silence = true;
59   const int num_frames = samples.size() / kFrameSize10ms24kHz;
60   for (int i = 0; i < num_frames; ++i) {
61     is_silence = features_extractor.CheckSilenceComputeFeatures(
62         {samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz},
63         feature_vector);
64   }
65   return is_silence;
66 }
67 
68 // Extracts the features for two pure tones and verifies that the pitch field
69 // values reflect the known tone frequencies.
TEST(RnnVadTest,FeatureExtractionLowHighPitch)70 TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
71   constexpr float amplitude = 1000.f;
72   constexpr float low_pitch_hz = 150.f;
73   constexpr float high_pitch_hz = 250.f;
74   ASSERT_TRUE(PitchIsValid(low_pitch_hz));
75   ASSERT_TRUE(PitchIsValid(high_pitch_hz));
76 
77   const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
78   FeaturesExtractor features_extractor(cpu_features);
79   std::vector<float> samples(kNumTestDataSize);
80   std::vector<float> feature_vector(kFeatureVectorSize);
81   ASSERT_EQ(kFeatureVectorSize, rtc::dchecked_cast<int>(feature_vector.size()));
82   rtc::ArrayView<float, kFeatureVectorSize> feature_vector_view(
83       feature_vector.data(), kFeatureVectorSize);
84 
85   // Extract the normalized scalar feature that is proportional to the estimated
86   // pitch period.
87   constexpr int pitch_feature_index = kFeatureVectorSize - 2;
88   // Low frequency tone - i.e., high period.
89   CreatePureTone(amplitude, low_pitch_hz, samples);
90   ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
91   float high_pitch_period = feature_vector_view[pitch_feature_index];
92   // High frequency tone - i.e., low period.
93   features_extractor.Reset();
94   CreatePureTone(amplitude, high_pitch_hz, samples);
95   ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
96   float low_pitch_period = feature_vector_view[pitch_feature_index];
97   // Check.
98   EXPECT_LT(low_pitch_period, high_pitch_period);
99 }
100 
101 }  // namespace
102 }  // namespace rnn_vad
103 }  // namespace webrtc
104