xref: /aosp_15_r20/external/webrtc/modules/audio_processing/transient/transient_suppressor_unittest.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/transient/transient_suppressor.h"
12 
13 #include <vector>
14 
15 #include "absl/types/optional.h"
16 #include "modules/audio_processing/transient/common.h"
17 #include "modules/audio_processing/transient/transient_suppressor_impl.h"
18 #include "test/gtest.h"
19 
20 namespace webrtc {
21 namespace {
22 constexpr int kMono = 1;
23 
24 // Returns the index of the first non-zero sample in `samples` or an unspecified
25 // value if no value is zero.
FindFirstNonZeroSample(const std::vector<float> & samples)26 absl::optional<int> FindFirstNonZeroSample(const std::vector<float>& samples) {
27   for (size_t i = 0; i < samples.size(); ++i) {
28     if (samples[i] != 0.0f) {
29       return i;
30     }
31   }
32   return absl::nullopt;
33 }
34 
35 }  // namespace
36 
37 class TransientSuppressorVadModeParametrization
38     : public ::testing::TestWithParam<TransientSuppressor::VadMode> {};
39 
TEST_P(TransientSuppressorVadModeParametrization,TypingDetectionLogicWorksAsExpectedForMono)40 TEST_P(TransientSuppressorVadModeParametrization,
41        TypingDetectionLogicWorksAsExpectedForMono) {
42   TransientSuppressorImpl ts(GetParam(), ts::kSampleRate16kHz,
43                              ts::kSampleRate16kHz, kMono);
44 
45   // Each key-press enables detection.
46   EXPECT_FALSE(ts.detection_enabled_);
47   ts.UpdateKeypress(true);
48   EXPECT_TRUE(ts.detection_enabled_);
49 
50   // It takes four seconds without any key-press to disable the detection
51   for (int time_ms = 0; time_ms < 3990; time_ms += ts::kChunkSizeMs) {
52     ts.UpdateKeypress(false);
53     EXPECT_TRUE(ts.detection_enabled_);
54   }
55   ts.UpdateKeypress(false);
56   EXPECT_FALSE(ts.detection_enabled_);
57 
58   // Key-presses that are more than a second apart from each other don't enable
59   // suppression.
60   for (int i = 0; i < 100; ++i) {
61     EXPECT_FALSE(ts.suppression_enabled_);
62     ts.UpdateKeypress(true);
63     EXPECT_TRUE(ts.detection_enabled_);
64     EXPECT_FALSE(ts.suppression_enabled_);
65     for (int time_ms = 0; time_ms < 990; time_ms += ts::kChunkSizeMs) {
66       ts.UpdateKeypress(false);
67       EXPECT_TRUE(ts.detection_enabled_);
68       EXPECT_FALSE(ts.suppression_enabled_);
69     }
70     ts.UpdateKeypress(false);
71   }
72 
73   // Two consecutive key-presses is enough to enable the suppression.
74   ts.UpdateKeypress(true);
75   EXPECT_FALSE(ts.suppression_enabled_);
76   ts.UpdateKeypress(true);
77   EXPECT_TRUE(ts.suppression_enabled_);
78 
79   // Key-presses that are less than a second apart from each other don't disable
80   // detection nor suppression.
81   for (int i = 0; i < 100; ++i) {
82     for (int time_ms = 0; time_ms < 1000; time_ms += ts::kChunkSizeMs) {
83       ts.UpdateKeypress(false);
84       EXPECT_TRUE(ts.detection_enabled_);
85       EXPECT_TRUE(ts.suppression_enabled_);
86     }
87     ts.UpdateKeypress(true);
88     EXPECT_TRUE(ts.detection_enabled_);
89     EXPECT_TRUE(ts.suppression_enabled_);
90   }
91 
92   // It takes four seconds without any key-press to disable the detection and
93   // suppression.
94   for (int time_ms = 0; time_ms < 3990; time_ms += ts::kChunkSizeMs) {
95     ts.UpdateKeypress(false);
96     EXPECT_TRUE(ts.detection_enabled_);
97     EXPECT_TRUE(ts.suppression_enabled_);
98   }
99   for (int time_ms = 0; time_ms < 1000; time_ms += ts::kChunkSizeMs) {
100     ts.UpdateKeypress(false);
101     EXPECT_FALSE(ts.detection_enabled_);
102     EXPECT_FALSE(ts.suppression_enabled_);
103   }
104 }
105 
106 INSTANTIATE_TEST_SUITE_P(
107     TransientSuppressorImplTest,
108     TransientSuppressorVadModeParametrization,
109     ::testing::Values(TransientSuppressor::VadMode::kDefault,
110                       TransientSuppressor::VadMode::kRnnVad,
111                       TransientSuppressor::VadMode::kNoVad));
112 
113 class TransientSuppressorSampleRateParametrization
114     : public ::testing::TestWithParam<int> {};
115 
116 // Checks that voice probability and processed audio data are temporally aligned
117 // after `Suppress()` is called.
TEST_P(TransientSuppressorSampleRateParametrization,CheckAudioAndVoiceProbabilityTemporallyAligned)118 TEST_P(TransientSuppressorSampleRateParametrization,
119        CheckAudioAndVoiceProbabilityTemporallyAligned) {
120   const int sample_rate_hz = GetParam();
121   TransientSuppressorImpl ts(TransientSuppressor::VadMode::kDefault,
122                              sample_rate_hz,
123                              /*detection_rate_hz=*/sample_rate_hz, kMono);
124 
125   const int frame_size = sample_rate_hz * ts::kChunkSizeMs / 1000;
126   std::vector<float> frame(frame_size);
127 
128   constexpr int kMaxAttempts = 3;
129   for (int i = 0; i < kMaxAttempts; ++i) {
130     SCOPED_TRACE(i);
131 
132     // Call `Suppress()` on frames of non-zero audio samples.
133     std::fill(frame.begin(), frame.end(), 1000.0f);
134     float delayed_voice_probability = ts.Suppress(
135         frame.data(), frame.size(), kMono, /*detection_data=*/nullptr,
136         /*detection_length=*/frame_size, /*reference_data=*/nullptr,
137         /*reference_length=*/frame_size, /*voice_probability=*/1.0f,
138         /*key_pressed=*/false);
139 
140     // Detect the algorithmic delay of `TransientSuppressorImpl`.
141     absl::optional<int> frame_delay = FindFirstNonZeroSample(frame);
142 
143     // Check that the delayed voice probability is delayed according to the
144     // measured delay.
145     if (frame_delay.has_value()) {
146       if (*frame_delay == 0) {
147         // When the delay is a multiple integer of the frame duration,
148         // `Suppress()` returns a copy of a previously observed voice
149         // probability value.
150         EXPECT_EQ(delayed_voice_probability, 1.0f);
151       } else {
152         // Instead, when the delay is fractional, `Suppress()` returns an
153         // interpolated value. Since the exact value depends on the
154         // interpolation method, we only check that the delayed voice
155         // probability is not zero as it must converge towards the previoulsy
156         // observed value.
157         EXPECT_GT(delayed_voice_probability, 0.0f);
158       }
159       break;
160     } else {
161       // The algorithmic delay is longer than the duration of a single frame.
162       // Until the delay is detected, the delayed voice probability is zero.
163       EXPECT_EQ(delayed_voice_probability, 0.0f);
164     }
165   }
166 }
167 
168 INSTANTIATE_TEST_SUITE_P(TransientSuppressorImplTest,
169                          TransientSuppressorSampleRateParametrization,
170                          ::testing::Values(ts::kSampleRate8kHz,
171                                            ts::kSampleRate16kHz,
172                                            ts::kSampleRate32kHz,
173                                            ts::kSampleRate48kHz));
174 
175 }  // namespace webrtc
176