1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/signal_dependent_erle_estimator.h"
12 
13 #include <algorithm>
14 #include <iostream>
15 #include <string>
16 
17 #include "api/audio/echo_canceller3_config.h"
18 #include "modules/audio_processing/aec3/render_buffer.h"
19 #include "modules/audio_processing/aec3/render_delay_buffer.h"
20 #include "rtc_base/strings/string_builder.h"
21 #include "test/gtest.h"
22 
23 namespace webrtc {
24 
25 namespace {
26 
GetActiveFrame(Block * x)27 void GetActiveFrame(Block* x) {
28   const std::array<float, kBlockSize> frame = {
29       7459.88, 17209.6, 17383,   20768.9, 16816.7, 18386.3, 4492.83, 9675.85,
30       6665.52, 14808.6, 9342.3,  7483.28, 19261.7, 4145.98, 1622.18, 13475.2,
31       7166.32, 6856.61, 21937,   7263.14, 9569.07, 14919,   8413.32, 7551.89,
32       7848.65, 6011.27, 13080.6, 15865.2, 12656,   17459.6, 4263.93, 4503.03,
33       9311.79, 21095.8, 12657.9, 13906.6, 19267.2, 11338.1, 16828.9, 11501.6,
34       11405,   15031.4, 14541.6, 19765.5, 18346.3, 19350.2, 3157.47, 18095.8,
35       1743.68, 21328.2, 19727.5, 7295.16, 10332.4, 11055.5, 20107.4, 14708.4,
36       12416.2, 16434,   2454.69, 9840.8,  6867.23, 1615.75, 6059.9,  8394.19};
37   for (int band = 0; band < x->NumBands(); ++band) {
38     for (int channel = 0; channel < x->NumChannels(); ++channel) {
39       RTC_DCHECK_GE(kBlockSize, frame.size());
40       std::copy(frame.begin(), frame.end(), x->begin(band, channel));
41     }
42   }
43 }
44 
45 class TestInputs {
46  public:
47   TestInputs(const EchoCanceller3Config& cfg,
48              size_t num_render_channels,
49              size_t num_capture_channels);
50   ~TestInputs();
GetRenderBuffer()51   const RenderBuffer& GetRenderBuffer() { return *render_buffer_; }
GetX2()52   rtc::ArrayView<const float, kFftLengthBy2Plus1> GetX2() { return X2_; }
GetY2() const53   rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> GetY2() const {
54     return Y2_;
55   }
GetE2() const56   rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> GetE2() const {
57     return E2_;
58   }
59   rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
GetH2() const60   GetH2() const {
61     return H2_;
62   }
GetConvergedFilters() const63   const std::vector<bool>& GetConvergedFilters() const {
64     return converged_filters_;
65   }
66   void Update();
67 
68  private:
69   void UpdateCurrentPowerSpectra();
70   int n_ = 0;
71   std::unique_ptr<RenderDelayBuffer> render_delay_buffer_;
72   RenderBuffer* render_buffer_;
73   std::array<float, kFftLengthBy2Plus1> X2_;
74   std::vector<std::array<float, kFftLengthBy2Plus1>> Y2_;
75   std::vector<std::array<float, kFftLengthBy2Plus1>> E2_;
76   std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2_;
77   Block x_;
78   std::vector<bool> converged_filters_;
79 };
80 
TestInputs(const EchoCanceller3Config & cfg,size_t num_render_channels,size_t num_capture_channels)81 TestInputs::TestInputs(const EchoCanceller3Config& cfg,
82                        size_t num_render_channels,
83                        size_t num_capture_channels)
84     : render_delay_buffer_(
85           RenderDelayBuffer::Create(cfg, 16000, num_render_channels)),
86       Y2_(num_capture_channels),
87       E2_(num_capture_channels),
88       H2_(num_capture_channels,
89           std::vector<std::array<float, kFftLengthBy2Plus1>>(
90               cfg.filter.refined.length_blocks)),
91       x_(1, num_render_channels),
92       converged_filters_(num_capture_channels, true) {
93   render_delay_buffer_->AlignFromDelay(4);
94   render_buffer_ = render_delay_buffer_->GetRenderBuffer();
95   for (auto& H2_ch : H2_) {
96     for (auto& H2_p : H2_ch) {
97       H2_p.fill(0.f);
98     }
99   }
100   for (auto& H2_p : H2_[0]) {
101     H2_p.fill(1.f);
102   }
103 }
104 
105 TestInputs::~TestInputs() = default;
106 
Update()107 void TestInputs::Update() {
108   if (n_ % 2 == 0) {
109     std::fill(x_.begin(/*band=*/0, /*channel=*/0),
110               x_.end(/*band=*/0, /*channel=*/0), 0.f);
111   } else {
112     GetActiveFrame(&x_);
113   }
114 
115   render_delay_buffer_->Insert(x_);
116   render_delay_buffer_->PrepareCaptureProcessing();
117   UpdateCurrentPowerSpectra();
118   ++n_;
119 }
120 
UpdateCurrentPowerSpectra()121 void TestInputs::UpdateCurrentPowerSpectra() {
122   const SpectrumBuffer& spectrum_render_buffer =
123       render_buffer_->GetSpectrumBuffer();
124   size_t idx = render_buffer_->Position();
125   size_t prev_idx = spectrum_render_buffer.OffsetIndex(idx, 1);
126   auto& X2 = spectrum_render_buffer.buffer[idx][/*channel=*/0];
127   auto& X2_prev = spectrum_render_buffer.buffer[prev_idx][/*channel=*/0];
128   std::copy(X2.begin(), X2.end(), X2_.begin());
129   for (size_t ch = 0; ch < Y2_.size(); ++ch) {
130     RTC_DCHECK_EQ(X2.size(), Y2_[ch].size());
131     for (size_t k = 0; k < X2.size(); ++k) {
132       E2_[ch][k] = 0.01f * X2_prev[k];
133       Y2_[ch][k] = X2[k] + E2_[ch][k];
134     }
135   }
136 }
137 
138 }  // namespace
139 
140 class SignalDependentErleEstimatorMultiChannel
141     : public ::testing::Test,
142       public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
143 
144 INSTANTIATE_TEST_SUITE_P(MultiChannel,
145                          SignalDependentErleEstimatorMultiChannel,
146                          ::testing::Combine(::testing::Values(1, 2, 4),
147                                             ::testing::Values(1, 2, 4)));
148 
TEST_P(SignalDependentErleEstimatorMultiChannel,SweepSettings)149 TEST_P(SignalDependentErleEstimatorMultiChannel, SweepSettings) {
150   const size_t num_render_channels = std::get<0>(GetParam());
151   const size_t num_capture_channels = std::get<1>(GetParam());
152   EchoCanceller3Config cfg;
153   size_t max_length_blocks = 50;
154   for (size_t blocks = 1; blocks < max_length_blocks; blocks = blocks + 10) {
155     for (size_t delay_headroom = 0; delay_headroom < 5; ++delay_headroom) {
156       for (size_t num_sections = 2; num_sections < max_length_blocks;
157            ++num_sections) {
158         cfg.filter.refined.length_blocks = blocks;
159         cfg.filter.refined_initial.length_blocks =
160             std::min(cfg.filter.refined_initial.length_blocks, blocks);
161         cfg.delay.delay_headroom_samples = delay_headroom * kBlockSize;
162         cfg.erle.num_sections = num_sections;
163         if (EchoCanceller3Config::Validate(&cfg)) {
164           SignalDependentErleEstimator s(cfg, num_capture_channels);
165           std::vector<std::array<float, kFftLengthBy2Plus1>> average_erle(
166               num_capture_channels);
167           for (auto& e : average_erle) {
168             e.fill(cfg.erle.max_l);
169           }
170           TestInputs inputs(cfg, num_render_channels, num_capture_channels);
171           for (size_t n = 0; n < 10; ++n) {
172             inputs.Update();
173             s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(),
174                      inputs.GetY2(), inputs.GetE2(), average_erle, average_erle,
175                      inputs.GetConvergedFilters());
176           }
177         }
178       }
179     }
180   }
181 }
182 
TEST_P(SignalDependentErleEstimatorMultiChannel,LongerRun)183 TEST_P(SignalDependentErleEstimatorMultiChannel, LongerRun) {
184   const size_t num_render_channels = std::get<0>(GetParam());
185   const size_t num_capture_channels = std::get<1>(GetParam());
186   EchoCanceller3Config cfg;
187   cfg.filter.refined.length_blocks = 2;
188   cfg.filter.refined_initial.length_blocks = 1;
189   cfg.delay.delay_headroom_samples = 0;
190   cfg.delay.hysteresis_limit_blocks = 0;
191   cfg.erle.num_sections = 2;
192   EXPECT_EQ(EchoCanceller3Config::Validate(&cfg), true);
193   std::vector<std::array<float, kFftLengthBy2Plus1>> average_erle(
194       num_capture_channels);
195   for (auto& e : average_erle) {
196     e.fill(cfg.erle.max_l);
197   }
198   SignalDependentErleEstimator s(cfg, num_capture_channels);
199   TestInputs inputs(cfg, num_render_channels, num_capture_channels);
200   for (size_t n = 0; n < 200; ++n) {
201     inputs.Update();
202     s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(),
203              inputs.GetY2(), inputs.GetE2(), average_erle, average_erle,
204              inputs.GetConvergedFilters());
205   }
206 }
207 
208 }  // namespace webrtc
209