1 /*
2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/transient/transient_suppressor.h"
12
13 #include <vector>
14
15 #include "absl/types/optional.h"
16 #include "modules/audio_processing/transient/common.h"
17 #include "modules/audio_processing/transient/transient_suppressor_impl.h"
18 #include "test/gtest.h"
19
20 namespace webrtc {
21 namespace {
22 constexpr int kMono = 1;
23
24 // Returns the index of the first non-zero sample in `samples` or an unspecified
25 // value if no value is zero.
FindFirstNonZeroSample(const std::vector<float> & samples)26 absl::optional<int> FindFirstNonZeroSample(const std::vector<float>& samples) {
27 for (size_t i = 0; i < samples.size(); ++i) {
28 if (samples[i] != 0.0f) {
29 return i;
30 }
31 }
32 return absl::nullopt;
33 }
34
35 } // namespace
36
37 class TransientSuppressorVadModeParametrization
38 : public ::testing::TestWithParam<TransientSuppressor::VadMode> {};
39
TEST_P(TransientSuppressorVadModeParametrization,TypingDetectionLogicWorksAsExpectedForMono)40 TEST_P(TransientSuppressorVadModeParametrization,
41 TypingDetectionLogicWorksAsExpectedForMono) {
42 TransientSuppressorImpl ts(GetParam(), ts::kSampleRate16kHz,
43 ts::kSampleRate16kHz, kMono);
44
45 // Each key-press enables detection.
46 EXPECT_FALSE(ts.detection_enabled_);
47 ts.UpdateKeypress(true);
48 EXPECT_TRUE(ts.detection_enabled_);
49
50 // It takes four seconds without any key-press to disable the detection
51 for (int time_ms = 0; time_ms < 3990; time_ms += ts::kChunkSizeMs) {
52 ts.UpdateKeypress(false);
53 EXPECT_TRUE(ts.detection_enabled_);
54 }
55 ts.UpdateKeypress(false);
56 EXPECT_FALSE(ts.detection_enabled_);
57
58 // Key-presses that are more than a second apart from each other don't enable
59 // suppression.
60 for (int i = 0; i < 100; ++i) {
61 EXPECT_FALSE(ts.suppression_enabled_);
62 ts.UpdateKeypress(true);
63 EXPECT_TRUE(ts.detection_enabled_);
64 EXPECT_FALSE(ts.suppression_enabled_);
65 for (int time_ms = 0; time_ms < 990; time_ms += ts::kChunkSizeMs) {
66 ts.UpdateKeypress(false);
67 EXPECT_TRUE(ts.detection_enabled_);
68 EXPECT_FALSE(ts.suppression_enabled_);
69 }
70 ts.UpdateKeypress(false);
71 }
72
73 // Two consecutive key-presses is enough to enable the suppression.
74 ts.UpdateKeypress(true);
75 EXPECT_FALSE(ts.suppression_enabled_);
76 ts.UpdateKeypress(true);
77 EXPECT_TRUE(ts.suppression_enabled_);
78
79 // Key-presses that are less than a second apart from each other don't disable
80 // detection nor suppression.
81 for (int i = 0; i < 100; ++i) {
82 for (int time_ms = 0; time_ms < 1000; time_ms += ts::kChunkSizeMs) {
83 ts.UpdateKeypress(false);
84 EXPECT_TRUE(ts.detection_enabled_);
85 EXPECT_TRUE(ts.suppression_enabled_);
86 }
87 ts.UpdateKeypress(true);
88 EXPECT_TRUE(ts.detection_enabled_);
89 EXPECT_TRUE(ts.suppression_enabled_);
90 }
91
92 // It takes four seconds without any key-press to disable the detection and
93 // suppression.
94 for (int time_ms = 0; time_ms < 3990; time_ms += ts::kChunkSizeMs) {
95 ts.UpdateKeypress(false);
96 EXPECT_TRUE(ts.detection_enabled_);
97 EXPECT_TRUE(ts.suppression_enabled_);
98 }
99 for (int time_ms = 0; time_ms < 1000; time_ms += ts::kChunkSizeMs) {
100 ts.UpdateKeypress(false);
101 EXPECT_FALSE(ts.detection_enabled_);
102 EXPECT_FALSE(ts.suppression_enabled_);
103 }
104 }
105
106 INSTANTIATE_TEST_SUITE_P(
107 TransientSuppressorImplTest,
108 TransientSuppressorVadModeParametrization,
109 ::testing::Values(TransientSuppressor::VadMode::kDefault,
110 TransientSuppressor::VadMode::kRnnVad,
111 TransientSuppressor::VadMode::kNoVad));
112
113 class TransientSuppressorSampleRateParametrization
114 : public ::testing::TestWithParam<int> {};
115
116 // Checks that voice probability and processed audio data are temporally aligned
117 // after `Suppress()` is called.
TEST_P(TransientSuppressorSampleRateParametrization,CheckAudioAndVoiceProbabilityTemporallyAligned)118 TEST_P(TransientSuppressorSampleRateParametrization,
119 CheckAudioAndVoiceProbabilityTemporallyAligned) {
120 const int sample_rate_hz = GetParam();
121 TransientSuppressorImpl ts(TransientSuppressor::VadMode::kDefault,
122 sample_rate_hz,
123 /*detection_rate_hz=*/sample_rate_hz, kMono);
124
125 const int frame_size = sample_rate_hz * ts::kChunkSizeMs / 1000;
126 std::vector<float> frame(frame_size);
127
128 constexpr int kMaxAttempts = 3;
129 for (int i = 0; i < kMaxAttempts; ++i) {
130 SCOPED_TRACE(i);
131
132 // Call `Suppress()` on frames of non-zero audio samples.
133 std::fill(frame.begin(), frame.end(), 1000.0f);
134 float delayed_voice_probability = ts.Suppress(
135 frame.data(), frame.size(), kMono, /*detection_data=*/nullptr,
136 /*detection_length=*/frame_size, /*reference_data=*/nullptr,
137 /*reference_length=*/frame_size, /*voice_probability=*/1.0f,
138 /*key_pressed=*/false);
139
140 // Detect the algorithmic delay of `TransientSuppressorImpl`.
141 absl::optional<int> frame_delay = FindFirstNonZeroSample(frame);
142
143 // Check that the delayed voice probability is delayed according to the
144 // measured delay.
145 if (frame_delay.has_value()) {
146 if (*frame_delay == 0) {
147 // When the delay is a multiple integer of the frame duration,
148 // `Suppress()` returns a copy of a previously observed voice
149 // probability value.
150 EXPECT_EQ(delayed_voice_probability, 1.0f);
151 } else {
152 // Instead, when the delay is fractional, `Suppress()` returns an
153 // interpolated value. Since the exact value depends on the
154 // interpolation method, we only check that the delayed voice
155 // probability is not zero as it must converge towards the previoulsy
156 // observed value.
157 EXPECT_GT(delayed_voice_probability, 0.0f);
158 }
159 break;
160 } else {
161 // The algorithmic delay is longer than the duration of a single frame.
162 // Until the delay is detected, the delayed voice probability is zero.
163 EXPECT_EQ(delayed_voice_probability, 0.0f);
164 }
165 }
166 }
167
168 INSTANTIATE_TEST_SUITE_P(TransientSuppressorImplTest,
169 TransientSuppressorSampleRateParametrization,
170 ::testing::Values(ts::kSampleRate8kHz,
171 ts::kSampleRate16kHz,
172 ts::kSampleRate32kHz,
173 ts::kSampleRate48kHz));
174
175 } // namespace webrtc
176