xref: /aosp_15_r20/external/webrtc/modules/audio_processing/aec3/subtractor.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/subtractor.h"
12 
13 #include <algorithm>
14 #include <utility>
15 
16 #include "api/array_view.h"
17 #include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h"
18 #include "modules/audio_processing/aec3/fft_data.h"
19 #include "modules/audio_processing/logging/apm_data_dumper.h"
20 #include "rtc_base/checks.h"
21 #include "rtc_base/numerics/safe_minmax.h"
22 #include "system_wrappers/include/field_trial.h"
23 
24 namespace webrtc {
25 
26 namespace {
27 
UseCoarseFilterResetHangover()28 bool UseCoarseFilterResetHangover() {
29   return !field_trial::IsEnabled(
30       "WebRTC-Aec3CoarseFilterResetHangoverKillSwitch");
31 }
32 
PredictionError(const Aec3Fft & fft,const FftData & S,rtc::ArrayView<const float> y,std::array<float,kBlockSize> * e,std::array<float,kBlockSize> * s)33 void PredictionError(const Aec3Fft& fft,
34                      const FftData& S,
35                      rtc::ArrayView<const float> y,
36                      std::array<float, kBlockSize>* e,
37                      std::array<float, kBlockSize>* s) {
38   std::array<float, kFftLength> tmp;
39   fft.Ifft(S, &tmp);
40   constexpr float kScale = 1.0f / kFftLengthBy2;
41   std::transform(y.begin(), y.end(), tmp.begin() + kFftLengthBy2, e->begin(),
42                  [&](float a, float b) { return a - b * kScale; });
43 
44   if (s) {
45     for (size_t k = 0; k < s->size(); ++k) {
46       (*s)[k] = kScale * tmp[k + kFftLengthBy2];
47     }
48   }
49 }
50 
ScaleFilterOutput(rtc::ArrayView<const float> y,float factor,rtc::ArrayView<float> e,rtc::ArrayView<float> s)51 void ScaleFilterOutput(rtc::ArrayView<const float> y,
52                        float factor,
53                        rtc::ArrayView<float> e,
54                        rtc::ArrayView<float> s) {
55   RTC_DCHECK_EQ(y.size(), e.size());
56   RTC_DCHECK_EQ(y.size(), s.size());
57   for (size_t k = 0; k < y.size(); ++k) {
58     s[k] *= factor;
59     e[k] = y[k] - s[k];
60   }
61 }
62 
63 }  // namespace
64 
Subtractor(const EchoCanceller3Config & config,size_t num_render_channels,size_t num_capture_channels,ApmDataDumper * data_dumper,Aec3Optimization optimization)65 Subtractor::Subtractor(const EchoCanceller3Config& config,
66                        size_t num_render_channels,
67                        size_t num_capture_channels,
68                        ApmDataDumper* data_dumper,
69                        Aec3Optimization optimization)
70     : fft_(),
71       data_dumper_(data_dumper),
72       optimization_(optimization),
73       config_(config),
74       num_capture_channels_(num_capture_channels),
75       use_coarse_filter_reset_hangover_(UseCoarseFilterResetHangover()),
76       refined_filters_(num_capture_channels_),
77       coarse_filter_(num_capture_channels_),
78       refined_gains_(num_capture_channels_),
79       coarse_gains_(num_capture_channels_),
80       filter_misadjustment_estimators_(num_capture_channels_),
81       poor_coarse_filter_counters_(num_capture_channels_, 0),
82       coarse_filter_reset_hangover_(num_capture_channels_, 0),
83       refined_frequency_responses_(
84           num_capture_channels_,
85           std::vector<std::array<float, kFftLengthBy2Plus1>>(
86               std::max(config_.filter.refined_initial.length_blocks,
87                        config_.filter.refined.length_blocks),
88               std::array<float, kFftLengthBy2Plus1>())),
89       refined_impulse_responses_(
90           num_capture_channels_,
91           std::vector<float>(GetTimeDomainLength(std::max(
92                                  config_.filter.refined_initial.length_blocks,
93                                  config_.filter.refined.length_blocks)),
94                              0.f)),
95       coarse_impulse_responses_(0) {
96   // Set up the storing of coarse impulse responses if data dumping is
97   // available.
98   if (ApmDataDumper::IsAvailable()) {
99     coarse_impulse_responses_.resize(num_capture_channels_);
100     const size_t filter_size = GetTimeDomainLength(
101         std::max(config_.filter.coarse_initial.length_blocks,
102                  config_.filter.coarse.length_blocks));
103     for (std::vector<float>& impulse_response : coarse_impulse_responses_) {
104       impulse_response.resize(filter_size, 0.f);
105     }
106   }
107 
108   for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
109     refined_filters_[ch] = std::make_unique<AdaptiveFirFilter>(
110         config_.filter.refined.length_blocks,
111         config_.filter.refined_initial.length_blocks,
112         config.filter.config_change_duration_blocks, num_render_channels,
113         optimization, data_dumper_);
114 
115     coarse_filter_[ch] = std::make_unique<AdaptiveFirFilter>(
116         config_.filter.coarse.length_blocks,
117         config_.filter.coarse_initial.length_blocks,
118         config.filter.config_change_duration_blocks, num_render_channels,
119         optimization, data_dumper_);
120     refined_gains_[ch] = std::make_unique<RefinedFilterUpdateGain>(
121         config_.filter.refined_initial,
122         config_.filter.config_change_duration_blocks);
123     coarse_gains_[ch] = std::make_unique<CoarseFilterUpdateGain>(
124         config_.filter.coarse_initial,
125         config.filter.config_change_duration_blocks);
126   }
127 
128   RTC_DCHECK(data_dumper_);
129   for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
130     for (auto& H2_k : refined_frequency_responses_[ch]) {
131       H2_k.fill(0.f);
132     }
133   }
134 }
135 
136 Subtractor::~Subtractor() = default;
137 
HandleEchoPathChange(const EchoPathVariability & echo_path_variability)138 void Subtractor::HandleEchoPathChange(
139     const EchoPathVariability& echo_path_variability) {
140   const auto full_reset = [&]() {
141     for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
142       refined_filters_[ch]->HandleEchoPathChange();
143       coarse_filter_[ch]->HandleEchoPathChange();
144       refined_gains_[ch]->HandleEchoPathChange(echo_path_variability);
145       coarse_gains_[ch]->HandleEchoPathChange();
146       refined_gains_[ch]->SetConfig(config_.filter.refined_initial, true);
147       coarse_gains_[ch]->SetConfig(config_.filter.coarse_initial, true);
148       refined_filters_[ch]->SetSizePartitions(
149           config_.filter.refined_initial.length_blocks, true);
150       coarse_filter_[ch]->SetSizePartitions(
151           config_.filter.coarse_initial.length_blocks, true);
152     }
153   };
154 
155   if (echo_path_variability.delay_change !=
156       EchoPathVariability::DelayAdjustment::kNone) {
157     full_reset();
158   }
159 
160   if (echo_path_variability.gain_change) {
161     for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
162       refined_gains_[ch]->HandleEchoPathChange(echo_path_variability);
163     }
164   }
165 }
166 
ExitInitialState()167 void Subtractor::ExitInitialState() {
168   for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
169     refined_gains_[ch]->SetConfig(config_.filter.refined, false);
170     coarse_gains_[ch]->SetConfig(config_.filter.coarse, false);
171     refined_filters_[ch]->SetSizePartitions(
172         config_.filter.refined.length_blocks, false);
173     coarse_filter_[ch]->SetSizePartitions(config_.filter.coarse.length_blocks,
174                                           false);
175   }
176 }
177 
Process(const RenderBuffer & render_buffer,const Block & capture,const RenderSignalAnalyzer & render_signal_analyzer,const AecState & aec_state,rtc::ArrayView<SubtractorOutput> outputs)178 void Subtractor::Process(const RenderBuffer& render_buffer,
179                          const Block& capture,
180                          const RenderSignalAnalyzer& render_signal_analyzer,
181                          const AecState& aec_state,
182                          rtc::ArrayView<SubtractorOutput> outputs) {
183   RTC_DCHECK_EQ(num_capture_channels_, capture.NumChannels());
184 
185   // Compute the render powers.
186   const bool same_filter_sizes = refined_filters_[0]->SizePartitions() ==
187                                  coarse_filter_[0]->SizePartitions();
188   std::array<float, kFftLengthBy2Plus1> X2_refined;
189   std::array<float, kFftLengthBy2Plus1> X2_coarse_data;
190   auto& X2_coarse = same_filter_sizes ? X2_refined : X2_coarse_data;
191   if (same_filter_sizes) {
192     render_buffer.SpectralSum(refined_filters_[0]->SizePartitions(),
193                               &X2_refined);
194   } else if (refined_filters_[0]->SizePartitions() >
195              coarse_filter_[0]->SizePartitions()) {
196     render_buffer.SpectralSums(coarse_filter_[0]->SizePartitions(),
197                                refined_filters_[0]->SizePartitions(),
198                                &X2_coarse, &X2_refined);
199   } else {
200     render_buffer.SpectralSums(refined_filters_[0]->SizePartitions(),
201                                coarse_filter_[0]->SizePartitions(), &X2_refined,
202                                &X2_coarse);
203   }
204 
205   // Process all capture channels
206   for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
207     SubtractorOutput& output = outputs[ch];
208     rtc::ArrayView<const float> y = capture.View(/*band=*/0, ch);
209     FftData& E_refined = output.E_refined;
210     FftData E_coarse;
211     std::array<float, kBlockSize>& e_refined = output.e_refined;
212     std::array<float, kBlockSize>& e_coarse = output.e_coarse;
213 
214     FftData S;
215     FftData& G = S;
216 
217     // Form the outputs of the refined and coarse filters.
218     refined_filters_[ch]->Filter(render_buffer, &S);
219     PredictionError(fft_, S, y, &e_refined, &output.s_refined);
220 
221     coarse_filter_[ch]->Filter(render_buffer, &S);
222     PredictionError(fft_, S, y, &e_coarse, &output.s_coarse);
223 
224     // Compute the signal powers in the subtractor output.
225     output.ComputeMetrics(y);
226 
227     // Adjust the filter if needed.
228     bool refined_filters_adjusted = false;
229     filter_misadjustment_estimators_[ch].Update(output);
230     if (filter_misadjustment_estimators_[ch].IsAdjustmentNeeded()) {
231       float scale = filter_misadjustment_estimators_[ch].GetMisadjustment();
232       refined_filters_[ch]->ScaleFilter(scale);
233       for (auto& h_k : refined_impulse_responses_[ch]) {
234         h_k *= scale;
235       }
236       ScaleFilterOutput(y, scale, e_refined, output.s_refined);
237       filter_misadjustment_estimators_[ch].Reset();
238       refined_filters_adjusted = true;
239     }
240 
241     // Compute the FFts of the refined and coarse filter outputs.
242     fft_.ZeroPaddedFft(e_refined, Aec3Fft::Window::kHanning, &E_refined);
243     fft_.ZeroPaddedFft(e_coarse, Aec3Fft::Window::kHanning, &E_coarse);
244 
245     // Compute spectra for future use.
246     E_coarse.Spectrum(optimization_, output.E2_coarse);
247     E_refined.Spectrum(optimization_, output.E2_refined);
248 
249     // Update the refined filter.
250     if (!refined_filters_adjusted) {
251       // Do not allow the performance of the coarse filter to affect the
252       // adaptation speed of the refined filter just after the coarse filter has
253       // been reset.
254       const bool disallow_leakage_diverged =
255           coarse_filter_reset_hangover_[ch] > 0 &&
256           use_coarse_filter_reset_hangover_;
257 
258       std::array<float, kFftLengthBy2Plus1> erl;
259       ComputeErl(optimization_, refined_frequency_responses_[ch], erl);
260       refined_gains_[ch]->Compute(X2_refined, render_signal_analyzer, output,
261                                   erl, refined_filters_[ch]->SizePartitions(),
262                                   aec_state.SaturatedCapture(),
263                                   disallow_leakage_diverged, &G);
264     } else {
265       G.re.fill(0.f);
266       G.im.fill(0.f);
267     }
268     refined_filters_[ch]->Adapt(render_buffer, G,
269                                 &refined_impulse_responses_[ch]);
270     refined_filters_[ch]->ComputeFrequencyResponse(
271         &refined_frequency_responses_[ch]);
272 
273     if (ch == 0) {
274       data_dumper_->DumpRaw("aec3_subtractor_G_refined", G.re);
275       data_dumper_->DumpRaw("aec3_subtractor_G_refined", G.im);
276     }
277 
278     // Update the coarse filter.
279     poor_coarse_filter_counters_[ch] =
280         output.e2_refined < output.e2_coarse
281             ? poor_coarse_filter_counters_[ch] + 1
282             : 0;
283     if (poor_coarse_filter_counters_[ch] < 5) {
284       coarse_gains_[ch]->Compute(X2_coarse, render_signal_analyzer, E_coarse,
285                                  coarse_filter_[ch]->SizePartitions(),
286                                  aec_state.SaturatedCapture(), &G);
287       coarse_filter_reset_hangover_[ch] =
288           std::max(coarse_filter_reset_hangover_[ch] - 1, 0);
289     } else {
290       poor_coarse_filter_counters_[ch] = 0;
291       coarse_filter_[ch]->SetFilter(refined_filters_[ch]->SizePartitions(),
292                                     refined_filters_[ch]->GetFilter());
293       coarse_gains_[ch]->Compute(X2_coarse, render_signal_analyzer, E_refined,
294                                  coarse_filter_[ch]->SizePartitions(),
295                                  aec_state.SaturatedCapture(), &G);
296       coarse_filter_reset_hangover_[ch] =
297           config_.filter.coarse_reset_hangover_blocks;
298     }
299 
300     if (ApmDataDumper::IsAvailable()) {
301       RTC_DCHECK_LT(ch, coarse_impulse_responses_.size());
302       coarse_filter_[ch]->Adapt(render_buffer, G,
303                                 &coarse_impulse_responses_[ch]);
304     } else {
305       coarse_filter_[ch]->Adapt(render_buffer, G);
306     }
307 
308     if (ch == 0) {
309       data_dumper_->DumpRaw("aec3_subtractor_G_coarse", G.re);
310       data_dumper_->DumpRaw("aec3_subtractor_G_coarse", G.im);
311       filter_misadjustment_estimators_[ch].Dump(data_dumper_);
312       DumpFilters();
313     }
314 
315     std::for_each(e_refined.begin(), e_refined.end(),
316                   [](float& a) { a = rtc::SafeClamp(a, -32768.f, 32767.f); });
317 
318     if (ch == 0) {
319       data_dumper_->DumpWav("aec3_refined_filters_output", kBlockSize,
320                             &e_refined[0], 16000, 1);
321       data_dumper_->DumpWav("aec3_coarse_filter_output", kBlockSize,
322                             &e_coarse[0], 16000, 1);
323     }
324   }
325 }
326 
Update(const SubtractorOutput & output)327 void Subtractor::FilterMisadjustmentEstimator::Update(
328     const SubtractorOutput& output) {
329   e2_acum_ += output.e2_refined;
330   y2_acum_ += output.y2;
331   if (++n_blocks_acum_ == n_blocks_) {
332     if (y2_acum_ > n_blocks_ * 200.f * 200.f * kBlockSize) {
333       float update = (e2_acum_ / y2_acum_);
334       if (e2_acum_ > n_blocks_ * 7500.f * 7500.f * kBlockSize) {
335         // Duration equal to blockSizeMs * n_blocks_ * 4.
336         overhang_ = 4;
337       } else {
338         overhang_ = std::max(overhang_ - 1, 0);
339       }
340 
341       if ((update < inv_misadjustment_) || (overhang_ > 0)) {
342         inv_misadjustment_ += 0.1f * (update - inv_misadjustment_);
343       }
344     }
345     e2_acum_ = 0.f;
346     y2_acum_ = 0.f;
347     n_blocks_acum_ = 0;
348   }
349 }
350 
Reset()351 void Subtractor::FilterMisadjustmentEstimator::Reset() {
352   e2_acum_ = 0.f;
353   y2_acum_ = 0.f;
354   n_blocks_acum_ = 0;
355   inv_misadjustment_ = 0.f;
356   overhang_ = 0.f;
357 }
358 
Dump(ApmDataDumper * data_dumper) const359 void Subtractor::FilterMisadjustmentEstimator::Dump(
360     ApmDataDumper* data_dumper) const {
361   data_dumper->DumpRaw("aec3_inv_misadjustment_factor", inv_misadjustment_);
362 }
363 
364 }  // namespace webrtc
365