xref: /aosp_15_r20/external/webrtc/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/signal_dependent_erle_estimator.h"
12 
13 #include <algorithm>
14 #include <functional>
15 #include <numeric>
16 
17 #include "modules/audio_processing/aec3/spectrum_buffer.h"
18 #include "rtc_base/numerics/safe_minmax.h"
19 
20 namespace webrtc {
21 
22 namespace {
23 
24 constexpr std::array<size_t, SignalDependentErleEstimator::kSubbands + 1>
25     kBandBoundaries = {1, 8, 16, 24, 32, 48, kFftLengthBy2Plus1};
26 
FormSubbandMap()27 std::array<size_t, kFftLengthBy2Plus1> FormSubbandMap() {
28   std::array<size_t, kFftLengthBy2Plus1> map_band_to_subband;
29   size_t subband = 1;
30   for (size_t k = 0; k < map_band_to_subband.size(); ++k) {
31     RTC_DCHECK_LT(subband, kBandBoundaries.size());
32     if (k >= kBandBoundaries[subband]) {
33       subband++;
34       RTC_DCHECK_LT(k, kBandBoundaries[subband]);
35     }
36     map_band_to_subband[k] = subband - 1;
37   }
38   return map_band_to_subband;
39 }
40 
41 // Defines the size in blocks of the sections that are used for dividing the
42 // linear filter. The sections are split in a non-linear manner so that lower
43 // sections that typically represent the direct path have a larger resolution
44 // than the higher sections which typically represent more reverberant acoustic
45 // paths.
DefineFilterSectionSizes(size_t delay_headroom_blocks,size_t num_blocks,size_t num_sections)46 std::vector<size_t> DefineFilterSectionSizes(size_t delay_headroom_blocks,
47                                              size_t num_blocks,
48                                              size_t num_sections) {
49   size_t filter_length_blocks = num_blocks - delay_headroom_blocks;
50   std::vector<size_t> section_sizes(num_sections);
51   size_t remaining_blocks = filter_length_blocks;
52   size_t remaining_sections = num_sections;
53   size_t estimator_size = 2;
54   size_t idx = 0;
55   while (remaining_sections > 1 &&
56          remaining_blocks > estimator_size * remaining_sections) {
57     RTC_DCHECK_LT(idx, section_sizes.size());
58     section_sizes[idx] = estimator_size;
59     remaining_blocks -= estimator_size;
60     remaining_sections--;
61     estimator_size *= 2;
62     idx++;
63   }
64 
65   size_t last_groups_size = remaining_blocks / remaining_sections;
66   for (; idx < num_sections; idx++) {
67     section_sizes[idx] = last_groups_size;
68   }
69   section_sizes[num_sections - 1] +=
70       remaining_blocks - last_groups_size * remaining_sections;
71   return section_sizes;
72 }
73 
74 // Forms the limits in blocks for each filter section. Those sections
75 // are used for analyzing the echo estimates and investigating which
76 // linear filter sections contribute most to the echo estimate energy.
SetSectionsBoundaries(size_t delay_headroom_blocks,size_t num_blocks,size_t num_sections)77 std::vector<size_t> SetSectionsBoundaries(size_t delay_headroom_blocks,
78                                           size_t num_blocks,
79                                           size_t num_sections) {
80   std::vector<size_t> estimator_boundaries_blocks(num_sections + 1);
81   if (estimator_boundaries_blocks.size() == 2) {
82     estimator_boundaries_blocks[0] = 0;
83     estimator_boundaries_blocks[1] = num_blocks;
84     return estimator_boundaries_blocks;
85   }
86   RTC_DCHECK_GT(estimator_boundaries_blocks.size(), 2);
87   const std::vector<size_t> section_sizes =
88       DefineFilterSectionSizes(delay_headroom_blocks, num_blocks,
89                                estimator_boundaries_blocks.size() - 1);
90 
91   size_t idx = 0;
92   size_t current_size_block = 0;
93   RTC_DCHECK_EQ(section_sizes.size() + 1, estimator_boundaries_blocks.size());
94   estimator_boundaries_blocks[0] = delay_headroom_blocks;
95   for (size_t k = delay_headroom_blocks; k < num_blocks; ++k) {
96     current_size_block++;
97     if (current_size_block >= section_sizes[idx]) {
98       idx = idx + 1;
99       if (idx == section_sizes.size()) {
100         break;
101       }
102       estimator_boundaries_blocks[idx] = k + 1;
103       current_size_block = 0;
104     }
105   }
106   estimator_boundaries_blocks[section_sizes.size()] = num_blocks;
107   return estimator_boundaries_blocks;
108 }
109 
110 std::array<float, SignalDependentErleEstimator::kSubbands>
SetMaxErleSubbands(float max_erle_l,float max_erle_h,size_t limit_subband_l)111 SetMaxErleSubbands(float max_erle_l, float max_erle_h, size_t limit_subband_l) {
112   std::array<float, SignalDependentErleEstimator::kSubbands> max_erle;
113   std::fill(max_erle.begin(), max_erle.begin() + limit_subband_l, max_erle_l);
114   std::fill(max_erle.begin() + limit_subband_l, max_erle.end(), max_erle_h);
115   return max_erle;
116 }
117 
118 }  // namespace
119 
SignalDependentErleEstimator(const EchoCanceller3Config & config,size_t num_capture_channels)120 SignalDependentErleEstimator::SignalDependentErleEstimator(
121     const EchoCanceller3Config& config,
122     size_t num_capture_channels)
123     : min_erle_(config.erle.min),
124       num_sections_(config.erle.num_sections),
125       num_blocks_(config.filter.refined.length_blocks),
126       delay_headroom_blocks_(config.delay.delay_headroom_samples / kBlockSize),
127       band_to_subband_(FormSubbandMap()),
128       max_erle_(SetMaxErleSubbands(config.erle.max_l,
129                                    config.erle.max_h,
130                                    band_to_subband_[kFftLengthBy2 / 2])),
131       section_boundaries_blocks_(SetSectionsBoundaries(delay_headroom_blocks_,
132                                                        num_blocks_,
133                                                        num_sections_)),
134       use_onset_detection_(config.erle.onset_detection),
135       erle_(num_capture_channels),
136       erle_onset_compensated_(num_capture_channels),
137       S2_section_accum_(
138           num_capture_channels,
139           std::vector<std::array<float, kFftLengthBy2Plus1>>(num_sections_)),
140       erle_estimators_(
141           num_capture_channels,
142           std::vector<std::array<float, kSubbands>>(num_sections_)),
143       erle_ref_(num_capture_channels),
144       correction_factors_(
145           num_capture_channels,
146           std::vector<std::array<float, kSubbands>>(num_sections_)),
147       num_updates_(num_capture_channels),
148       n_active_sections_(num_capture_channels) {
149   RTC_DCHECK_LE(num_sections_, num_blocks_);
150   RTC_DCHECK_GE(num_sections_, 1);
151   Reset();
152 }
153 
154 SignalDependentErleEstimator::~SignalDependentErleEstimator() = default;
155 
Reset()156 void SignalDependentErleEstimator::Reset() {
157   for (size_t ch = 0; ch < erle_.size(); ++ch) {
158     erle_[ch].fill(min_erle_);
159     erle_onset_compensated_[ch].fill(min_erle_);
160     for (auto& erle_estimator : erle_estimators_[ch]) {
161       erle_estimator.fill(min_erle_);
162     }
163     erle_ref_[ch].fill(min_erle_);
164     for (auto& factor : correction_factors_[ch]) {
165       factor.fill(1.0f);
166     }
167     num_updates_[ch].fill(0);
168     n_active_sections_[ch].fill(0);
169   }
170 }
171 
172 // Updates the Erle estimate by analyzing the current input signals. It takes
173 // the render buffer and the filter frequency response in order to do an
174 // estimation of the number of sections of the linear filter that are needed
175 // for getting the majority of the energy in the echo estimate. Based on that
176 // number of sections, it updates the erle estimation by introducing a
177 // correction factor to the erle that is given as an input to this method.
Update(const RenderBuffer & render_buffer,rtc::ArrayView<const std::vector<std::array<float,kFftLengthBy2Plus1>>> filter_frequency_responses,rtc::ArrayView<const float,kFftLengthBy2Plus1> X2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> Y2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> E2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> average_erle,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> average_erle_onset_compensated,const std::vector<bool> & converged_filters)178 void SignalDependentErleEstimator::Update(
179     const RenderBuffer& render_buffer,
180     rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
181         filter_frequency_responses,
182     rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
183     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
184     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
185     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> average_erle,
186     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
187         average_erle_onset_compensated,
188     const std::vector<bool>& converged_filters) {
189   RTC_DCHECK_GT(num_sections_, 1);
190 
191   // Gets the number of filter sections that are needed for achieving 90 %
192   // of the power spectrum energy of the echo estimate.
193   ComputeNumberOfActiveFilterSections(render_buffer,
194                                       filter_frequency_responses);
195 
196   // Updates the correction factors that is used for correcting the erle and
197   // adapt it to the particular characteristics of the input signal.
198   UpdateCorrectionFactors(X2, Y2, E2, converged_filters);
199 
200   // Applies the correction factor to the input erle for getting a more refined
201   // erle estimation for the current input signal.
202   for (size_t ch = 0; ch < erle_.size(); ++ch) {
203     for (size_t k = 0; k < kFftLengthBy2; ++k) {
204       RTC_DCHECK_GT(correction_factors_[ch].size(), n_active_sections_[ch][k]);
205       float correction_factor =
206           correction_factors_[ch][n_active_sections_[ch][k]]
207                              [band_to_subband_[k]];
208       erle_[ch][k] = rtc::SafeClamp(average_erle[ch][k] * correction_factor,
209                                     min_erle_, max_erle_[band_to_subband_[k]]);
210       if (use_onset_detection_) {
211         erle_onset_compensated_[ch][k] = rtc::SafeClamp(
212             average_erle_onset_compensated[ch][k] * correction_factor,
213             min_erle_, max_erle_[band_to_subband_[k]]);
214       }
215     }
216   }
217 }
218 
Dump(const std::unique_ptr<ApmDataDumper> & data_dumper) const219 void SignalDependentErleEstimator::Dump(
220     const std::unique_ptr<ApmDataDumper>& data_dumper) const {
221   for (auto& erle : erle_estimators_[0]) {
222     data_dumper->DumpRaw("aec3_all_erle", erle);
223   }
224   data_dumper->DumpRaw("aec3_ref_erle", erle_ref_[0]);
225   for (auto& factor : correction_factors_[0]) {
226     data_dumper->DumpRaw("aec3_erle_correction_factor", factor);
227   }
228 }
229 
230 // Estimates for each band the smallest number of sections in the filter that
231 // together constitute 90% of the estimated echo energy.
ComputeNumberOfActiveFilterSections(const RenderBuffer & render_buffer,rtc::ArrayView<const std::vector<std::array<float,kFftLengthBy2Plus1>>> filter_frequency_responses)232 void SignalDependentErleEstimator::ComputeNumberOfActiveFilterSections(
233     const RenderBuffer& render_buffer,
234     rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
235         filter_frequency_responses) {
236   RTC_DCHECK_GT(num_sections_, 1);
237   // Computes an approximation of the power spectrum if the filter would have
238   // been limited to a certain number of filter sections.
239   ComputeEchoEstimatePerFilterSection(render_buffer,
240                                       filter_frequency_responses);
241   // For each band, computes the number of filter sections that are needed for
242   // achieving the 90 % energy in the echo estimate.
243   ComputeActiveFilterSections();
244 }
245 
UpdateCorrectionFactors(rtc::ArrayView<const float,kFftLengthBy2Plus1> X2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> Y2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> E2,const std::vector<bool> & converged_filters)246 void SignalDependentErleEstimator::UpdateCorrectionFactors(
247     rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
248     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
249     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
250     const std::vector<bool>& converged_filters) {
251   for (size_t ch = 0; ch < converged_filters.size(); ++ch) {
252     if (converged_filters[ch]) {
253       constexpr float kX2BandEnergyThreshold = 44015068.0f;
254       constexpr float kSmthConstantDecreases = 0.1f;
255       constexpr float kSmthConstantIncreases = kSmthConstantDecreases / 2.f;
256       auto subband_powers = [](rtc::ArrayView<const float> power_spectrum,
257                                rtc::ArrayView<float> power_spectrum_subbands) {
258         for (size_t subband = 0; subband < kSubbands; ++subband) {
259           RTC_DCHECK_LE(kBandBoundaries[subband + 1], power_spectrum.size());
260           power_spectrum_subbands[subband] = std::accumulate(
261               power_spectrum.begin() + kBandBoundaries[subband],
262               power_spectrum.begin() + kBandBoundaries[subband + 1], 0.f);
263         }
264       };
265 
266       std::array<float, kSubbands> X2_subbands, E2_subbands, Y2_subbands;
267       subband_powers(X2, X2_subbands);
268       subband_powers(E2[ch], E2_subbands);
269       subband_powers(Y2[ch], Y2_subbands);
270       std::array<size_t, kSubbands> idx_subbands;
271       for (size_t subband = 0; subband < kSubbands; ++subband) {
272         // When aggregating the number of active sections in the filter for
273         // different bands we choose to take the minimum of all of them. As an
274         // example, if for one of the bands it is the direct path its refined
275         // contributor to the final echo estimate, we consider the direct path
276         // is as well the refined contributor for the subband that contains that
277         // particular band. That aggregate number of sections will be later used
278         // as the identifier of the erle estimator that needs to be updated.
279         RTC_DCHECK_LE(kBandBoundaries[subband + 1],
280                       n_active_sections_[ch].size());
281         idx_subbands[subband] = *std::min_element(
282             n_active_sections_[ch].begin() + kBandBoundaries[subband],
283             n_active_sections_[ch].begin() + kBandBoundaries[subband + 1]);
284       }
285 
286       std::array<float, kSubbands> new_erle;
287       std::array<bool, kSubbands> is_erle_updated;
288       is_erle_updated.fill(false);
289       new_erle.fill(0.f);
290       for (size_t subband = 0; subband < kSubbands; ++subband) {
291         if (X2_subbands[subband] > kX2BandEnergyThreshold &&
292             E2_subbands[subband] > 0) {
293           new_erle[subband] = Y2_subbands[subband] / E2_subbands[subband];
294           RTC_DCHECK_GT(new_erle[subband], 0);
295           is_erle_updated[subband] = true;
296           ++num_updates_[ch][subband];
297         }
298       }
299 
300       for (size_t subband = 0; subband < kSubbands; ++subband) {
301         const size_t idx = idx_subbands[subband];
302         RTC_DCHECK_LT(idx, erle_estimators_[ch].size());
303         float alpha = new_erle[subband] > erle_estimators_[ch][idx][subband]
304                           ? kSmthConstantIncreases
305                           : kSmthConstantDecreases;
306         alpha = static_cast<float>(is_erle_updated[subband]) * alpha;
307         erle_estimators_[ch][idx][subband] +=
308             alpha * (new_erle[subband] - erle_estimators_[ch][idx][subband]);
309         erle_estimators_[ch][idx][subband] = rtc::SafeClamp(
310             erle_estimators_[ch][idx][subband], min_erle_, max_erle_[subband]);
311       }
312 
313       for (size_t subband = 0; subband < kSubbands; ++subband) {
314         float alpha = new_erle[subband] > erle_ref_[ch][subband]
315                           ? kSmthConstantIncreases
316                           : kSmthConstantDecreases;
317         alpha = static_cast<float>(is_erle_updated[subband]) * alpha;
318         erle_ref_[ch][subband] +=
319             alpha * (new_erle[subband] - erle_ref_[ch][subband]);
320         erle_ref_[ch][subband] = rtc::SafeClamp(erle_ref_[ch][subband],
321                                                 min_erle_, max_erle_[subband]);
322       }
323 
324       for (size_t subband = 0; subband < kSubbands; ++subband) {
325         constexpr int kNumUpdateThr = 50;
326         if (is_erle_updated[subband] &&
327             num_updates_[ch][subband] > kNumUpdateThr) {
328           const size_t idx = idx_subbands[subband];
329           RTC_DCHECK_GT(erle_ref_[ch][subband], 0.f);
330           // Computes the ratio between the erle that is updated using all the
331           // points and the erle that is updated only on signals that share the
332           // same number of active filter sections.
333           float new_correction_factor =
334               erle_estimators_[ch][idx][subband] / erle_ref_[ch][subband];
335 
336           correction_factors_[ch][idx][subband] +=
337               0.1f *
338               (new_correction_factor - correction_factors_[ch][idx][subband]);
339         }
340       }
341     }
342   }
343 }
344 
ComputeEchoEstimatePerFilterSection(const RenderBuffer & render_buffer,rtc::ArrayView<const std::vector<std::array<float,kFftLengthBy2Plus1>>> filter_frequency_responses)345 void SignalDependentErleEstimator::ComputeEchoEstimatePerFilterSection(
346     const RenderBuffer& render_buffer,
347     rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
348         filter_frequency_responses) {
349   const SpectrumBuffer& spectrum_render_buffer =
350       render_buffer.GetSpectrumBuffer();
351   const size_t num_render_channels = spectrum_render_buffer.buffer[0].size();
352   const size_t num_capture_channels = S2_section_accum_.size();
353   const float one_by_num_render_channels = 1.f / num_render_channels;
354 
355   RTC_DCHECK_EQ(S2_section_accum_.size(), filter_frequency_responses.size());
356 
357   for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) {
358     RTC_DCHECK_EQ(S2_section_accum_[capture_ch].size() + 1,
359                   section_boundaries_blocks_.size());
360     size_t idx_render = render_buffer.Position();
361     idx_render = spectrum_render_buffer.OffsetIndex(
362         idx_render, section_boundaries_blocks_[0]);
363 
364     for (size_t section = 0; section < num_sections_; ++section) {
365       std::array<float, kFftLengthBy2Plus1> X2_section;
366       std::array<float, kFftLengthBy2Plus1> H2_section;
367       X2_section.fill(0.f);
368       H2_section.fill(0.f);
369       const size_t block_limit =
370           std::min(section_boundaries_blocks_[section + 1],
371                    filter_frequency_responses[capture_ch].size());
372       for (size_t block = section_boundaries_blocks_[section];
373            block < block_limit; ++block) {
374         for (size_t render_ch = 0;
375              render_ch < spectrum_render_buffer.buffer[idx_render].size();
376              ++render_ch) {
377           for (size_t k = 0; k < X2_section.size(); ++k) {
378             X2_section[k] +=
379                 spectrum_render_buffer.buffer[idx_render][render_ch][k] *
380                 one_by_num_render_channels;
381           }
382         }
383         std::transform(H2_section.begin(), H2_section.end(),
384                        filter_frequency_responses[capture_ch][block].begin(),
385                        H2_section.begin(), std::plus<float>());
386         idx_render = spectrum_render_buffer.IncIndex(idx_render);
387       }
388 
389       std::transform(X2_section.begin(), X2_section.end(), H2_section.begin(),
390                      S2_section_accum_[capture_ch][section].begin(),
391                      std::multiplies<float>());
392     }
393 
394     for (size_t section = 1; section < num_sections_; ++section) {
395       std::transform(S2_section_accum_[capture_ch][section - 1].begin(),
396                      S2_section_accum_[capture_ch][section - 1].end(),
397                      S2_section_accum_[capture_ch][section].begin(),
398                      S2_section_accum_[capture_ch][section].begin(),
399                      std::plus<float>());
400     }
401   }
402 }
403 
ComputeActiveFilterSections()404 void SignalDependentErleEstimator::ComputeActiveFilterSections() {
405   for (size_t ch = 0; ch < n_active_sections_.size(); ++ch) {
406     std::fill(n_active_sections_[ch].begin(), n_active_sections_[ch].end(), 0);
407     for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
408       size_t section = num_sections_;
409       float target = 0.9f * S2_section_accum_[ch][num_sections_ - 1][k];
410       while (section > 0 && S2_section_accum_[ch][section - 1][k] >= target) {
411         n_active_sections_[ch][k] = --section;
412       }
413     }
414   }
415 }
416 }  // namespace webrtc
417