xref: /aosp_15_r20/external/webrtc/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/residual_echo_estimator.h"
12 
13 #include <numeric>
14 
15 #include "api/audio/echo_canceller3_config.h"
16 #include "modules/audio_processing/aec3/aec3_fft.h"
17 #include "modules/audio_processing/aec3/aec_state.h"
18 #include "modules/audio_processing/aec3/render_delay_buffer.h"
19 #include "modules/audio_processing/test/echo_canceller_test_tools.h"
20 #include "rtc_base/random.h"
21 #include "rtc_base/strings/string_builder.h"
22 #include "test/gtest.h"
23 
24 namespace webrtc {
25 
26 namespace {
27 constexpr int kSampleRateHz = 48000;
28 constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
29 constexpr float kEpsilon = 1e-4f;
30 }  // namespace
31 
32 class ResidualEchoEstimatorTest {
33  public:
ResidualEchoEstimatorTest(size_t num_render_channels,size_t num_capture_channels,const EchoCanceller3Config & config)34   ResidualEchoEstimatorTest(size_t num_render_channels,
35                             size_t num_capture_channels,
36                             const EchoCanceller3Config& config)
37       : num_render_channels_(num_render_channels),
38         num_capture_channels_(num_capture_channels),
39         config_(config),
40         estimator_(config_, num_render_channels_),
41         aec_state_(config_, num_capture_channels_),
42         render_delay_buffer_(RenderDelayBuffer::Create(config_,
43                                                        kSampleRateHz,
44                                                        num_render_channels_)),
45         E2_refined_(num_capture_channels_),
46         S2_linear_(num_capture_channels_),
47         Y2_(num_capture_channels_),
48         R2_(num_capture_channels_),
49         R2_unbounded_(num_capture_channels_),
50         x_(kNumBands, num_render_channels_),
51         H2_(num_capture_channels_,
52             std::vector<std::array<float, kFftLengthBy2Plus1>>(10)),
53         h_(num_capture_channels_,
54            std::vector<float>(
55                GetTimeDomainLength(config_.filter.refined.length_blocks),
56                0.0f)),
57         random_generator_(42U),
58         output_(num_capture_channels_) {
59     for (auto& H2_ch : H2_) {
60       for (auto& H2_k : H2_ch) {
61         H2_k.fill(0.01f);
62       }
63       H2_ch[2].fill(10.f);
64       H2_ch[2][0] = 0.1f;
65     }
66 
67     for (auto& subtractor_output : output_) {
68       subtractor_output.Reset();
69       subtractor_output.s_refined.fill(100.f);
70     }
71     y_.fill(0.f);
72 
73     constexpr float kLevel = 10.f;
74     for (auto& E2_refined_ch : E2_refined_) {
75       E2_refined_ch.fill(kLevel);
76     }
77     S2_linear_[0].fill(kLevel);
78     for (auto& Y2_ch : Y2_) {
79       Y2_ch.fill(kLevel);
80     }
81   }
82 
RunOneFrame(bool dominant_nearend)83   void RunOneFrame(bool dominant_nearend) {
84     RandomizeSampleVector(&random_generator_,
85                           x_.View(/*band=*/0, /*channel=*/0));
86     render_delay_buffer_->Insert(x_);
87     if (first_frame_) {
88       render_delay_buffer_->Reset();
89       first_frame_ = false;
90     }
91     render_delay_buffer_->PrepareCaptureProcessing();
92 
93     aec_state_.Update(delay_estimate_, H2_, h_,
94                       *render_delay_buffer_->GetRenderBuffer(), E2_refined_,
95                       Y2_, output_);
96 
97     estimator_.Estimate(aec_state_, *render_delay_buffer_->GetRenderBuffer(),
98                         S2_linear_, Y2_, dominant_nearend, R2_, R2_unbounded_);
99   }
100 
R2() const101   rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> R2() const {
102     return R2_;
103   }
104 
105  private:
106   const size_t num_render_channels_;
107   const size_t num_capture_channels_;
108   const EchoCanceller3Config& config_;
109   ResidualEchoEstimator estimator_;
110   AecState aec_state_;
111   std::unique_ptr<RenderDelayBuffer> render_delay_buffer_;
112   std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined_;
113   std::vector<std::array<float, kFftLengthBy2Plus1>> S2_linear_;
114   std::vector<std::array<float, kFftLengthBy2Plus1>> Y2_;
115   std::vector<std::array<float, kFftLengthBy2Plus1>> R2_;
116   std::vector<std::array<float, kFftLengthBy2Plus1>> R2_unbounded_;
117   Block x_;
118   std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2_;
119   std::vector<std::vector<float>> h_;
120   Random random_generator_;
121   std::vector<SubtractorOutput> output_;
122   std::array<float, kBlockSize> y_;
123   absl::optional<DelayEstimate> delay_estimate_;
124   bool first_frame_ = true;
125 };
126 
127 class ResidualEchoEstimatorMultiChannel
128     : public ::testing::Test,
129       public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
130 
131 INSTANTIATE_TEST_SUITE_P(MultiChannel,
132                          ResidualEchoEstimatorMultiChannel,
133                          ::testing::Combine(::testing::Values(1, 2, 4),
134                                             ::testing::Values(1, 2, 4)));
135 
TEST_P(ResidualEchoEstimatorMultiChannel,BasicTest)136 TEST_P(ResidualEchoEstimatorMultiChannel, BasicTest) {
137   const size_t num_render_channels = std::get<0>(GetParam());
138   const size_t num_capture_channels = std::get<1>(GetParam());
139 
140   EchoCanceller3Config config;
141   ResidualEchoEstimatorTest residual_echo_estimator_test(
142       num_render_channels, num_capture_channels, config);
143   for (int k = 0; k < 1993; ++k) {
144     residual_echo_estimator_test.RunOneFrame(/*dominant_nearend=*/false);
145   }
146 }
147 
TEST(ResidualEchoEstimatorMultiChannel,ReverbTest)148 TEST(ResidualEchoEstimatorMultiChannel, ReverbTest) {
149   const size_t num_render_channels = 1;
150   const size_t num_capture_channels = 1;
151   const size_t nFrames = 100;
152 
153   EchoCanceller3Config reference_config;
154   reference_config.ep_strength.default_len = 0.95f;
155   reference_config.ep_strength.nearend_len = 0.95f;
156   EchoCanceller3Config config_use_nearend_len = reference_config;
157   config_use_nearend_len.ep_strength.default_len = 0.95f;
158   config_use_nearend_len.ep_strength.nearend_len = 0.83f;
159 
160   ResidualEchoEstimatorTest reference_residual_echo_estimator_test(
161       num_render_channels, num_capture_channels, reference_config);
162   ResidualEchoEstimatorTest use_nearend_len_residual_echo_estimator_test(
163       num_render_channels, num_capture_channels, config_use_nearend_len);
164 
165   std::vector<float> acum_energy_reference_R2(num_capture_channels, 0.0f);
166   std::vector<float> acum_energy_R2(num_capture_channels, 0.0f);
167   for (size_t frame = 0; frame < nFrames; ++frame) {
168     bool dominant_nearend = frame <= nFrames / 2 ? false : true;
169     reference_residual_echo_estimator_test.RunOneFrame(dominant_nearend);
170     use_nearend_len_residual_echo_estimator_test.RunOneFrame(dominant_nearend);
171     const auto& reference_R2 = reference_residual_echo_estimator_test.R2();
172     const auto& R2 = use_nearend_len_residual_echo_estimator_test.R2();
173     ASSERT_EQ(reference_R2.size(), R2.size());
174     for (size_t ch = 0; ch < reference_R2.size(); ++ch) {
175       float energy_reference_R2 = std::accumulate(
176           reference_R2[ch].cbegin(), reference_R2[ch].cend(), 0.0f);
177       float energy_R2 = std::accumulate(R2[ch].cbegin(), R2[ch].cend(), 0.0f);
178       if (dominant_nearend) {
179         EXPECT_GE(energy_reference_R2, energy_R2);
180       } else {
181         EXPECT_NEAR(energy_reference_R2, energy_R2, kEpsilon);
182       }
183       acum_energy_reference_R2[ch] += energy_reference_R2;
184       acum_energy_R2[ch] += energy_R2;
185     }
186     if (frame == nFrames / 2 || frame == nFrames - 1) {
187       for (size_t ch = 0; ch < acum_energy_reference_R2.size(); ch++) {
188         if (dominant_nearend) {
189           EXPECT_GT(acum_energy_reference_R2[ch], acum_energy_R2[ch]);
190         } else {
191           EXPECT_NEAR(acum_energy_reference_R2[ch], acum_energy_R2[ch],
192                       kEpsilon);
193         }
194       }
195     }
196   }
197 }
198 
199 }  // namespace webrtc
200