xref: /aosp_15_r20/external/webrtc/modules/audio_processing/aec3/subband_erle_estimator.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/subband_erle_estimator.h"
12 
13 #include <algorithm>
14 #include <functional>
15 
16 #include "rtc_base/checks.h"
17 #include "rtc_base/numerics/safe_minmax.h"
18 #include "system_wrappers/include/field_trial.h"
19 
20 namespace webrtc {
21 
22 namespace {
23 
24 constexpr float kX2BandEnergyThreshold = 44015068.0f;
25 constexpr int kBlocksToHoldErle = 100;
26 constexpr int kBlocksForOnsetDetection = kBlocksToHoldErle + 150;
27 constexpr int kPointsToAccumulate = 6;
28 
SetMaxErleBands(float max_erle_l,float max_erle_h)29 std::array<float, kFftLengthBy2Plus1> SetMaxErleBands(float max_erle_l,
30                                                       float max_erle_h) {
31   std::array<float, kFftLengthBy2Plus1> max_erle;
32   std::fill(max_erle.begin(), max_erle.begin() + kFftLengthBy2 / 2, max_erle_l);
33   std::fill(max_erle.begin() + kFftLengthBy2 / 2, max_erle.end(), max_erle_h);
34   return max_erle;
35 }
36 
EnableMinErleDuringOnsets()37 bool EnableMinErleDuringOnsets() {
38   return !field_trial::IsEnabled("WebRTC-Aec3MinErleDuringOnsetsKillSwitch");
39 }
40 
41 }  // namespace
42 
SubbandErleEstimator(const EchoCanceller3Config & config,size_t num_capture_channels)43 SubbandErleEstimator::SubbandErleEstimator(const EchoCanceller3Config& config,
44                                            size_t num_capture_channels)
45     : use_onset_detection_(config.erle.onset_detection),
46       min_erle_(config.erle.min),
47       max_erle_(SetMaxErleBands(config.erle.max_l, config.erle.max_h)),
48       use_min_erle_during_onsets_(EnableMinErleDuringOnsets()),
49       accum_spectra_(num_capture_channels),
50       erle_(num_capture_channels),
51       erle_onset_compensated_(num_capture_channels),
52       erle_unbounded_(num_capture_channels),
53       erle_during_onsets_(num_capture_channels),
54       coming_onset_(num_capture_channels),
55       hold_counters_(num_capture_channels) {
56   Reset();
57 }
58 
59 SubbandErleEstimator::~SubbandErleEstimator() = default;
60 
Reset()61 void SubbandErleEstimator::Reset() {
62   const size_t num_capture_channels = erle_.size();
63   for (size_t ch = 0; ch < num_capture_channels; ++ch) {
64     erle_[ch].fill(min_erle_);
65     erle_onset_compensated_[ch].fill(min_erle_);
66     erle_unbounded_[ch].fill(min_erle_);
67     erle_during_onsets_[ch].fill(min_erle_);
68     coming_onset_[ch].fill(true);
69     hold_counters_[ch].fill(0);
70   }
71   ResetAccumulatedSpectra();
72 }
73 
Update(rtc::ArrayView<const float,kFftLengthBy2Plus1> X2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> Y2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> E2,const std::vector<bool> & converged_filters)74 void SubbandErleEstimator::Update(
75     rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
76     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
77     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
78     const std::vector<bool>& converged_filters) {
79   UpdateAccumulatedSpectra(X2, Y2, E2, converged_filters);
80   UpdateBands(converged_filters);
81 
82   if (use_onset_detection_) {
83     DecreaseErlePerBandForLowRenderSignals();
84   }
85 
86   const size_t num_capture_channels = erle_.size();
87   for (size_t ch = 0; ch < num_capture_channels; ++ch) {
88     auto& erle = erle_[ch];
89     erle[0] = erle[1];
90     erle[kFftLengthBy2] = erle[kFftLengthBy2 - 1];
91 
92     auto& erle_oc = erle_onset_compensated_[ch];
93     erle_oc[0] = erle_oc[1];
94     erle_oc[kFftLengthBy2] = erle_oc[kFftLengthBy2 - 1];
95 
96     auto& erle_u = erle_unbounded_[ch];
97     erle_u[0] = erle_u[1];
98     erle_u[kFftLengthBy2] = erle_u[kFftLengthBy2 - 1];
99   }
100 }
101 
Dump(const std::unique_ptr<ApmDataDumper> & data_dumper) const102 void SubbandErleEstimator::Dump(
103     const std::unique_ptr<ApmDataDumper>& data_dumper) const {
104   data_dumper->DumpRaw("aec3_erle_onset", ErleDuringOnsets()[0]);
105 }
106 
UpdateBands(const std::vector<bool> & converged_filters)107 void SubbandErleEstimator::UpdateBands(
108     const std::vector<bool>& converged_filters) {
109   const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size());
110   for (int ch = 0; ch < num_capture_channels; ++ch) {
111     // Note that the use of the converged_filter flag already imposed
112     // a minimum of the erle that can be estimated as that flag would
113     // be false if the filter is performing poorly.
114     if (!converged_filters[ch]) {
115       continue;
116     }
117 
118     if (accum_spectra_.num_points[ch] != kPointsToAccumulate) {
119       continue;
120     }
121 
122     std::array<float, kFftLengthBy2> new_erle;
123     std::array<bool, kFftLengthBy2> is_erle_updated;
124     is_erle_updated.fill(false);
125 
126     for (size_t k = 1; k < kFftLengthBy2; ++k) {
127       if (accum_spectra_.E2[ch][k] > 0.f) {
128         new_erle[k] = accum_spectra_.Y2[ch][k] / accum_spectra_.E2[ch][k];
129         is_erle_updated[k] = true;
130       }
131     }
132 
133     if (use_onset_detection_) {
134       for (size_t k = 1; k < kFftLengthBy2; ++k) {
135         if (is_erle_updated[k] && !accum_spectra_.low_render_energy[ch][k]) {
136           if (coming_onset_[ch][k]) {
137             coming_onset_[ch][k] = false;
138             if (!use_min_erle_during_onsets_) {
139               float alpha =
140                   new_erle[k] < erle_during_onsets_[ch][k] ? 0.3f : 0.15f;
141               erle_during_onsets_[ch][k] = rtc::SafeClamp(
142                   erle_during_onsets_[ch][k] +
143                       alpha * (new_erle[k] - erle_during_onsets_[ch][k]),
144                   min_erle_, max_erle_[k]);
145             }
146           }
147           hold_counters_[ch][k] = kBlocksForOnsetDetection;
148         }
149       }
150     }
151 
152     auto update_erle_band = [](float& erle, float new_erle,
153                                bool low_render_energy, float min_erle,
154                                float max_erle) {
155       float alpha = 0.05f;
156       if (new_erle < erle) {
157         alpha = low_render_energy ? 0.f : 0.1f;
158       }
159       erle =
160           rtc::SafeClamp(erle + alpha * (new_erle - erle), min_erle, max_erle);
161     };
162 
163     for (size_t k = 1; k < kFftLengthBy2; ++k) {
164       if (is_erle_updated[k]) {
165         const bool low_render_energy = accum_spectra_.low_render_energy[ch][k];
166         update_erle_band(erle_[ch][k], new_erle[k], low_render_energy,
167                          min_erle_, max_erle_[k]);
168         if (use_onset_detection_) {
169           update_erle_band(erle_onset_compensated_[ch][k], new_erle[k],
170                            low_render_energy, min_erle_, max_erle_[k]);
171         }
172 
173         // Virtually unbounded ERLE.
174         constexpr float kUnboundedErleMax = 100000.0f;
175         update_erle_band(erle_unbounded_[ch][k], new_erle[k], low_render_energy,
176                          min_erle_, kUnboundedErleMax);
177       }
178     }
179   }
180 }
181 
DecreaseErlePerBandForLowRenderSignals()182 void SubbandErleEstimator::DecreaseErlePerBandForLowRenderSignals() {
183   const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size());
184   for (int ch = 0; ch < num_capture_channels; ++ch) {
185     for (size_t k = 1; k < kFftLengthBy2; ++k) {
186       --hold_counters_[ch][k];
187       if (hold_counters_[ch][k] <=
188           (kBlocksForOnsetDetection - kBlocksToHoldErle)) {
189         if (erle_onset_compensated_[ch][k] > erle_during_onsets_[ch][k]) {
190           erle_onset_compensated_[ch][k] =
191               std::max(erle_during_onsets_[ch][k],
192                        0.97f * erle_onset_compensated_[ch][k]);
193           RTC_DCHECK_LE(min_erle_, erle_onset_compensated_[ch][k]);
194         }
195         if (hold_counters_[ch][k] <= 0) {
196           coming_onset_[ch][k] = true;
197           hold_counters_[ch][k] = 0;
198         }
199       }
200     }
201   }
202 }
203 
ResetAccumulatedSpectra()204 void SubbandErleEstimator::ResetAccumulatedSpectra() {
205   for (size_t ch = 0; ch < erle_during_onsets_.size(); ++ch) {
206     accum_spectra_.Y2[ch].fill(0.f);
207     accum_spectra_.E2[ch].fill(0.f);
208     accum_spectra_.num_points[ch] = 0;
209     accum_spectra_.low_render_energy[ch].fill(false);
210   }
211 }
212 
UpdateAccumulatedSpectra(rtc::ArrayView<const float,kFftLengthBy2Plus1> X2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> Y2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> E2,const std::vector<bool> & converged_filters)213 void SubbandErleEstimator::UpdateAccumulatedSpectra(
214     rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
215     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
216     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
217     const std::vector<bool>& converged_filters) {
218   auto& st = accum_spectra_;
219   RTC_DCHECK_EQ(st.E2.size(), E2.size());
220   RTC_DCHECK_EQ(st.E2.size(), E2.size());
221   const int num_capture_channels = static_cast<int>(Y2.size());
222   for (int ch = 0; ch < num_capture_channels; ++ch) {
223     // Note that the use of the converged_filter flag already imposed
224     // a minimum of the erle that can be estimated as that flag would
225     // be false if the filter is performing poorly.
226     if (!converged_filters[ch]) {
227       continue;
228     }
229 
230     if (st.num_points[ch] == kPointsToAccumulate) {
231       st.num_points[ch] = 0;
232       st.Y2[ch].fill(0.f);
233       st.E2[ch].fill(0.f);
234       st.low_render_energy[ch].fill(false);
235     }
236 
237     std::transform(Y2[ch].begin(), Y2[ch].end(), st.Y2[ch].begin(),
238                    st.Y2[ch].begin(), std::plus<float>());
239     std::transform(E2[ch].begin(), E2[ch].end(), st.E2[ch].begin(),
240                    st.E2[ch].begin(), std::plus<float>());
241 
242     for (size_t k = 0; k < X2.size(); ++k) {
243       st.low_render_energy[ch][k] =
244           st.low_render_energy[ch][k] || X2[k] < kX2BandEnergyThreshold;
245     }
246 
247     ++st.num_points[ch];
248   }
249 }
250 
251 }  // namespace webrtc
252