xref: /aosp_15_r20/external/webrtc/modules/audio_processing/aec3/suppression_gain_unittest.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/suppression_gain.h"
12 
13 #include "modules/audio_processing/aec3/aec_state.h"
14 #include "modules/audio_processing/aec3/render_delay_buffer.h"
15 #include "modules/audio_processing/aec3/subtractor.h"
16 #include "modules/audio_processing/aec3/subtractor_output.h"
17 #include "modules/audio_processing/logging/apm_data_dumper.h"
18 #include "rtc_base/checks.h"
19 #include "system_wrappers/include/cpu_features_wrapper.h"
20 #include "test/gtest.h"
21 
22 namespace webrtc {
23 namespace aec3 {
24 
25 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
26 
27 // Verifies that the check for non-null output gains works.
TEST(SuppressionGainDeathTest,NullOutputGains)28 TEST(SuppressionGainDeathTest, NullOutputGains) {
29   std::vector<std::array<float, kFftLengthBy2Plus1>> E2(1, {0.0f});
30   std::vector<std::array<float, kFftLengthBy2Plus1>> R2(1, {0.0f});
31   std::vector<std::array<float, kFftLengthBy2Plus1>> R2_unbounded(1, {0.0f});
32   std::vector<std::array<float, kFftLengthBy2Plus1>> S2(1);
33   std::vector<std::array<float, kFftLengthBy2Plus1>> N2(1, {0.0f});
34   for (auto& S2_k : S2) {
35     S2_k.fill(0.1f);
36   }
37   FftData E;
38   FftData Y;
39   E.re.fill(0.0f);
40   E.im.fill(0.0f);
41   Y.re.fill(0.0f);
42   Y.im.fill(0.0f);
43 
44   float high_bands_gain;
45   AecState aec_state(EchoCanceller3Config{}, 1);
46   EXPECT_DEATH(
47       SuppressionGain(EchoCanceller3Config{}, DetectOptimization(), 16000, 1)
48           .GetGain(E2, S2, R2, R2_unbounded, N2,
49                    RenderSignalAnalyzer((EchoCanceller3Config{})), aec_state,
50                    Block(3, 1), false, &high_bands_gain, nullptr),
51       "");
52 }
53 
54 #endif
55 
56 // Does a sanity check that the gains are correctly computed.
TEST(SuppressionGain,BasicGainComputation)57 TEST(SuppressionGain, BasicGainComputation) {
58   constexpr size_t kNumRenderChannels = 1;
59   constexpr size_t kNumCaptureChannels = 2;
60   constexpr int kSampleRateHz = 16000;
61   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
62   SuppressionGain suppression_gain(EchoCanceller3Config(), DetectOptimization(),
63                                    kSampleRateHz, kNumCaptureChannels);
64   RenderSignalAnalyzer analyzer(EchoCanceller3Config{});
65   float high_bands_gain;
66   std::vector<std::array<float, kFftLengthBy2Plus1>> E2(kNumCaptureChannels);
67   std::vector<std::array<float, kFftLengthBy2Plus1>> S2(kNumCaptureChannels,
68                                                         {0.0f});
69   std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(kNumCaptureChannels);
70   std::vector<std::array<float, kFftLengthBy2Plus1>> R2(kNumCaptureChannels);
71   std::vector<std::array<float, kFftLengthBy2Plus1>> R2_unbounded(
72       kNumCaptureChannels);
73   std::vector<std::array<float, kFftLengthBy2Plus1>> N2(kNumCaptureChannels);
74   std::array<float, kFftLengthBy2Plus1> g;
75   std::vector<SubtractorOutput> output(kNumCaptureChannels);
76   Block x(kNumBands, kNumRenderChannels);
77   EchoCanceller3Config config;
78   AecState aec_state(config, kNumCaptureChannels);
79   ApmDataDumper data_dumper(42);
80   Subtractor subtractor(config, kNumRenderChannels, kNumCaptureChannels,
81                         &data_dumper, DetectOptimization());
82   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
83       RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels));
84   absl::optional<DelayEstimate> delay_estimate;
85 
86   // Ensure that a strong noise is detected to mask any echoes.
87   for (size_t ch = 0; ch < kNumCaptureChannels; ++ch) {
88     E2[ch].fill(10.f);
89     Y2[ch].fill(10.f);
90     R2[ch].fill(0.1f);
91     R2_unbounded[ch].fill(0.1f);
92     N2[ch].fill(100.0f);
93   }
94   for (auto& subtractor_output : output) {
95     subtractor_output.Reset();
96   }
97 
98   // Ensure that the gain is no longer forced to zero.
99   for (int k = 0; k <= kNumBlocksPerSecond / 5 + 1; ++k) {
100     aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
101                      subtractor.FilterImpulseResponses(),
102                      *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
103   }
104 
105   for (int k = 0; k < 100; ++k) {
106     aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
107                      subtractor.FilterImpulseResponses(),
108                      *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
109     suppression_gain.GetGain(E2, S2, R2, R2_unbounded, N2, analyzer, aec_state,
110                              x, false, &high_bands_gain, &g);
111   }
112   std::for_each(g.begin(), g.end(),
113                 [](float a) { EXPECT_NEAR(1.0f, a, 0.001f); });
114 
115   // Ensure that a strong nearend is detected to mask any echoes.
116   for (size_t ch = 0; ch < kNumCaptureChannels; ++ch) {
117     E2[ch].fill(100.f);
118     Y2[ch].fill(100.f);
119     R2[ch].fill(0.1f);
120     R2_unbounded[ch].fill(0.1f);
121     S2[ch].fill(0.1f);
122     N2[ch].fill(0.f);
123   }
124 
125   for (int k = 0; k < 100; ++k) {
126     aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
127                      subtractor.FilterImpulseResponses(),
128                      *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
129     suppression_gain.GetGain(E2, S2, R2, R2_unbounded, N2, analyzer, aec_state,
130                              x, false, &high_bands_gain, &g);
131   }
132   std::for_each(g.begin(), g.end(),
133                 [](float a) { EXPECT_NEAR(1.0f, a, 0.001f); });
134 
135   // Add a strong echo to one of the channels and ensure that it is suppressed.
136   E2[1].fill(1000000000.0f);
137   R2[1].fill(10000000000000.0f);
138   R2_unbounded[1].fill(10000000000000.0f);
139 
140   for (int k = 0; k < 10; ++k) {
141     suppression_gain.GetGain(E2, S2, R2, R2_unbounded, N2, analyzer, aec_state,
142                              x, false, &high_bands_gain, &g);
143   }
144   std::for_each(g.begin(), g.end(),
145                 [](float a) { EXPECT_NEAR(0.0f, a, 0.001f); });
146 }
147 
148 }  // namespace aec3
149 }  // namespace webrtc
150