1 /*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/aec3/suppression_gain.h"
12
13 #include "modules/audio_processing/aec3/aec_state.h"
14 #include "modules/audio_processing/aec3/render_delay_buffer.h"
15 #include "modules/audio_processing/aec3/subtractor.h"
16 #include "modules/audio_processing/aec3/subtractor_output.h"
17 #include "modules/audio_processing/logging/apm_data_dumper.h"
18 #include "rtc_base/checks.h"
19 #include "system_wrappers/include/cpu_features_wrapper.h"
20 #include "test/gtest.h"
21
22 namespace webrtc {
23 namespace aec3 {
24
25 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
26
27 // Verifies that the check for non-null output gains works.
TEST(SuppressionGainDeathTest,NullOutputGains)28 TEST(SuppressionGainDeathTest, NullOutputGains) {
29 std::vector<std::array<float, kFftLengthBy2Plus1>> E2(1, {0.0f});
30 std::vector<std::array<float, kFftLengthBy2Plus1>> R2(1, {0.0f});
31 std::vector<std::array<float, kFftLengthBy2Plus1>> R2_unbounded(1, {0.0f});
32 std::vector<std::array<float, kFftLengthBy2Plus1>> S2(1);
33 std::vector<std::array<float, kFftLengthBy2Plus1>> N2(1, {0.0f});
34 for (auto& S2_k : S2) {
35 S2_k.fill(0.1f);
36 }
37 FftData E;
38 FftData Y;
39 E.re.fill(0.0f);
40 E.im.fill(0.0f);
41 Y.re.fill(0.0f);
42 Y.im.fill(0.0f);
43
44 float high_bands_gain;
45 AecState aec_state(EchoCanceller3Config{}, 1);
46 EXPECT_DEATH(
47 SuppressionGain(EchoCanceller3Config{}, DetectOptimization(), 16000, 1)
48 .GetGain(E2, S2, R2, R2_unbounded, N2,
49 RenderSignalAnalyzer((EchoCanceller3Config{})), aec_state,
50 Block(3, 1), false, &high_bands_gain, nullptr),
51 "");
52 }
53
54 #endif
55
56 // Does a sanity check that the gains are correctly computed.
TEST(SuppressionGain,BasicGainComputation)57 TEST(SuppressionGain, BasicGainComputation) {
58 constexpr size_t kNumRenderChannels = 1;
59 constexpr size_t kNumCaptureChannels = 2;
60 constexpr int kSampleRateHz = 16000;
61 constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
62 SuppressionGain suppression_gain(EchoCanceller3Config(), DetectOptimization(),
63 kSampleRateHz, kNumCaptureChannels);
64 RenderSignalAnalyzer analyzer(EchoCanceller3Config{});
65 float high_bands_gain;
66 std::vector<std::array<float, kFftLengthBy2Plus1>> E2(kNumCaptureChannels);
67 std::vector<std::array<float, kFftLengthBy2Plus1>> S2(kNumCaptureChannels,
68 {0.0f});
69 std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(kNumCaptureChannels);
70 std::vector<std::array<float, kFftLengthBy2Plus1>> R2(kNumCaptureChannels);
71 std::vector<std::array<float, kFftLengthBy2Plus1>> R2_unbounded(
72 kNumCaptureChannels);
73 std::vector<std::array<float, kFftLengthBy2Plus1>> N2(kNumCaptureChannels);
74 std::array<float, kFftLengthBy2Plus1> g;
75 std::vector<SubtractorOutput> output(kNumCaptureChannels);
76 Block x(kNumBands, kNumRenderChannels);
77 EchoCanceller3Config config;
78 AecState aec_state(config, kNumCaptureChannels);
79 ApmDataDumper data_dumper(42);
80 Subtractor subtractor(config, kNumRenderChannels, kNumCaptureChannels,
81 &data_dumper, DetectOptimization());
82 std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
83 RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels));
84 absl::optional<DelayEstimate> delay_estimate;
85
86 // Ensure that a strong noise is detected to mask any echoes.
87 for (size_t ch = 0; ch < kNumCaptureChannels; ++ch) {
88 E2[ch].fill(10.f);
89 Y2[ch].fill(10.f);
90 R2[ch].fill(0.1f);
91 R2_unbounded[ch].fill(0.1f);
92 N2[ch].fill(100.0f);
93 }
94 for (auto& subtractor_output : output) {
95 subtractor_output.Reset();
96 }
97
98 // Ensure that the gain is no longer forced to zero.
99 for (int k = 0; k <= kNumBlocksPerSecond / 5 + 1; ++k) {
100 aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
101 subtractor.FilterImpulseResponses(),
102 *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
103 }
104
105 for (int k = 0; k < 100; ++k) {
106 aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
107 subtractor.FilterImpulseResponses(),
108 *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
109 suppression_gain.GetGain(E2, S2, R2, R2_unbounded, N2, analyzer, aec_state,
110 x, false, &high_bands_gain, &g);
111 }
112 std::for_each(g.begin(), g.end(),
113 [](float a) { EXPECT_NEAR(1.0f, a, 0.001f); });
114
115 // Ensure that a strong nearend is detected to mask any echoes.
116 for (size_t ch = 0; ch < kNumCaptureChannels; ++ch) {
117 E2[ch].fill(100.f);
118 Y2[ch].fill(100.f);
119 R2[ch].fill(0.1f);
120 R2_unbounded[ch].fill(0.1f);
121 S2[ch].fill(0.1f);
122 N2[ch].fill(0.f);
123 }
124
125 for (int k = 0; k < 100; ++k) {
126 aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(),
127 subtractor.FilterImpulseResponses(),
128 *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
129 suppression_gain.GetGain(E2, S2, R2, R2_unbounded, N2, analyzer, aec_state,
130 x, false, &high_bands_gain, &g);
131 }
132 std::for_each(g.begin(), g.end(),
133 [](float a) { EXPECT_NEAR(1.0f, a, 0.001f); });
134
135 // Add a strong echo to one of the channels and ensure that it is suppressed.
136 E2[1].fill(1000000000.0f);
137 R2[1].fill(10000000000000.0f);
138 R2_unbounded[1].fill(10000000000000.0f);
139
140 for (int k = 0; k < 10; ++k) {
141 suppression_gain.GetGain(E2, S2, R2, R2_unbounded, N2, analyzer, aec_state,
142 x, false, &high_bands_gain, &g);
143 }
144 std::for_each(g.begin(), g.end(),
145 [](float a) { EXPECT_NEAR(0.0f, a, 0.001f); });
146 }
147
148 } // namespace aec3
149 } // namespace webrtc
150