1 /*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
12
13 #include <cmath>
14 #include <vector>
15
16 #include "modules/audio_processing/agc2/cpu_features.h"
17 #include "rtc_base/numerics/safe_compare.h"
18 #include "rtc_base/numerics/safe_conversions.h"
19 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
20 // #include "test/fpe_observer.h"
21 #include "test/gtest.h"
22
23 namespace webrtc {
24 namespace rnn_vad {
25 namespace {
26
ceil(int n,int m)27 constexpr int ceil(int n, int m) {
28 return (n + m - 1) / m;
29 }
30
31 // Number of 10 ms frames required to fill a pitch buffer having size
32 // `kBufSize24kHz`.
33 constexpr int kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz);
34 // Number of samples for the test data.
35 constexpr int kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz;
36
37 // Verifies that the pitch in Hz is in the detectable range.
PitchIsValid(float pitch_hz)38 bool PitchIsValid(float pitch_hz) {
39 const int pitch_period = static_cast<float>(kSampleRate24kHz) / pitch_hz;
40 return kInitialMinPitch24kHz <= pitch_period &&
41 pitch_period <= kMaxPitch24kHz;
42 }
43
CreatePureTone(float amplitude,float freq_hz,rtc::ArrayView<float> dst)44 void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView<float> dst) {
45 for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) {
46 dst[i] = amplitude * std::sin(2.f * kPi * freq_hz * i / kSampleRate24kHz);
47 }
48 }
49
50 // Feeds `features_extractor` with `samples` splitting it in 10 ms frames.
51 // For every frame, the output is written into `feature_vector`. Returns true
52 // if silence is detected in the last frame.
FeedTestData(FeaturesExtractor & features_extractor,rtc::ArrayView<const float> samples,rtc::ArrayView<float,kFeatureVectorSize> feature_vector)53 bool FeedTestData(FeaturesExtractor& features_extractor,
54 rtc::ArrayView<const float> samples,
55 rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
56 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
57 // FloatingPointExceptionObserver fpe_observer;
58 bool is_silence = true;
59 const int num_frames = samples.size() / kFrameSize10ms24kHz;
60 for (int i = 0; i < num_frames; ++i) {
61 is_silence = features_extractor.CheckSilenceComputeFeatures(
62 {samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz},
63 feature_vector);
64 }
65 return is_silence;
66 }
67
68 // Extracts the features for two pure tones and verifies that the pitch field
69 // values reflect the known tone frequencies.
TEST(RnnVadTest,FeatureExtractionLowHighPitch)70 TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
71 constexpr float amplitude = 1000.f;
72 constexpr float low_pitch_hz = 150.f;
73 constexpr float high_pitch_hz = 250.f;
74 ASSERT_TRUE(PitchIsValid(low_pitch_hz));
75 ASSERT_TRUE(PitchIsValid(high_pitch_hz));
76
77 const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
78 FeaturesExtractor features_extractor(cpu_features);
79 std::vector<float> samples(kNumTestDataSize);
80 std::vector<float> feature_vector(kFeatureVectorSize);
81 ASSERT_EQ(kFeatureVectorSize, rtc::dchecked_cast<int>(feature_vector.size()));
82 rtc::ArrayView<float, kFeatureVectorSize> feature_vector_view(
83 feature_vector.data(), kFeatureVectorSize);
84
85 // Extract the normalized scalar feature that is proportional to the estimated
86 // pitch period.
87 constexpr int pitch_feature_index = kFeatureVectorSize - 2;
88 // Low frequency tone - i.e., high period.
89 CreatePureTone(amplitude, low_pitch_hz, samples);
90 ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
91 float high_pitch_period = feature_vector_view[pitch_feature_index];
92 // High frequency tone - i.e., low period.
93 features_extractor.Reset();
94 CreatePureTone(amplitude, high_pitch_hz, samples);
95 ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
96 float low_pitch_period = feature_vector_view[pitch_feature_index];
97 // Check.
98 EXPECT_LT(low_pitch_period, high_pitch_period);
99 }
100
101 } // namespace
102 } // namespace rnn_vad
103 } // namespace webrtc
104