1 /*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/aec3/subband_erle_estimator.h"
12
13 #include <algorithm>
14 #include <functional>
15
16 #include "rtc_base/checks.h"
17 #include "rtc_base/numerics/safe_minmax.h"
18 #include "system_wrappers/include/field_trial.h"
19
20 namespace webrtc {
21
22 namespace {
23
24 constexpr float kX2BandEnergyThreshold = 44015068.0f;
25 constexpr int kBlocksToHoldErle = 100;
26 constexpr int kBlocksForOnsetDetection = kBlocksToHoldErle + 150;
27 constexpr int kPointsToAccumulate = 6;
28
SetMaxErleBands(float max_erle_l,float max_erle_h)29 std::array<float, kFftLengthBy2Plus1> SetMaxErleBands(float max_erle_l,
30 float max_erle_h) {
31 std::array<float, kFftLengthBy2Plus1> max_erle;
32 std::fill(max_erle.begin(), max_erle.begin() + kFftLengthBy2 / 2, max_erle_l);
33 std::fill(max_erle.begin() + kFftLengthBy2 / 2, max_erle.end(), max_erle_h);
34 return max_erle;
35 }
36
EnableMinErleDuringOnsets()37 bool EnableMinErleDuringOnsets() {
38 return !field_trial::IsEnabled("WebRTC-Aec3MinErleDuringOnsetsKillSwitch");
39 }
40
41 } // namespace
42
SubbandErleEstimator(const EchoCanceller3Config & config,size_t num_capture_channels)43 SubbandErleEstimator::SubbandErleEstimator(const EchoCanceller3Config& config,
44 size_t num_capture_channels)
45 : use_onset_detection_(config.erle.onset_detection),
46 min_erle_(config.erle.min),
47 max_erle_(SetMaxErleBands(config.erle.max_l, config.erle.max_h)),
48 use_min_erle_during_onsets_(EnableMinErleDuringOnsets()),
49 accum_spectra_(num_capture_channels),
50 erle_(num_capture_channels),
51 erle_onset_compensated_(num_capture_channels),
52 erle_unbounded_(num_capture_channels),
53 erle_during_onsets_(num_capture_channels),
54 coming_onset_(num_capture_channels),
55 hold_counters_(num_capture_channels) {
56 Reset();
57 }
58
59 SubbandErleEstimator::~SubbandErleEstimator() = default;
60
Reset()61 void SubbandErleEstimator::Reset() {
62 const size_t num_capture_channels = erle_.size();
63 for (size_t ch = 0; ch < num_capture_channels; ++ch) {
64 erle_[ch].fill(min_erle_);
65 erle_onset_compensated_[ch].fill(min_erle_);
66 erle_unbounded_[ch].fill(min_erle_);
67 erle_during_onsets_[ch].fill(min_erle_);
68 coming_onset_[ch].fill(true);
69 hold_counters_[ch].fill(0);
70 }
71 ResetAccumulatedSpectra();
72 }
73
Update(rtc::ArrayView<const float,kFftLengthBy2Plus1> X2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> Y2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> E2,const std::vector<bool> & converged_filters)74 void SubbandErleEstimator::Update(
75 rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
76 rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
77 rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
78 const std::vector<bool>& converged_filters) {
79 UpdateAccumulatedSpectra(X2, Y2, E2, converged_filters);
80 UpdateBands(converged_filters);
81
82 if (use_onset_detection_) {
83 DecreaseErlePerBandForLowRenderSignals();
84 }
85
86 const size_t num_capture_channels = erle_.size();
87 for (size_t ch = 0; ch < num_capture_channels; ++ch) {
88 auto& erle = erle_[ch];
89 erle[0] = erle[1];
90 erle[kFftLengthBy2] = erle[kFftLengthBy2 - 1];
91
92 auto& erle_oc = erle_onset_compensated_[ch];
93 erle_oc[0] = erle_oc[1];
94 erle_oc[kFftLengthBy2] = erle_oc[kFftLengthBy2 - 1];
95
96 auto& erle_u = erle_unbounded_[ch];
97 erle_u[0] = erle_u[1];
98 erle_u[kFftLengthBy2] = erle_u[kFftLengthBy2 - 1];
99 }
100 }
101
Dump(const std::unique_ptr<ApmDataDumper> & data_dumper) const102 void SubbandErleEstimator::Dump(
103 const std::unique_ptr<ApmDataDumper>& data_dumper) const {
104 data_dumper->DumpRaw("aec3_erle_onset", ErleDuringOnsets()[0]);
105 }
106
UpdateBands(const std::vector<bool> & converged_filters)107 void SubbandErleEstimator::UpdateBands(
108 const std::vector<bool>& converged_filters) {
109 const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size());
110 for (int ch = 0; ch < num_capture_channels; ++ch) {
111 // Note that the use of the converged_filter flag already imposed
112 // a minimum of the erle that can be estimated as that flag would
113 // be false if the filter is performing poorly.
114 if (!converged_filters[ch]) {
115 continue;
116 }
117
118 if (accum_spectra_.num_points[ch] != kPointsToAccumulate) {
119 continue;
120 }
121
122 std::array<float, kFftLengthBy2> new_erle;
123 std::array<bool, kFftLengthBy2> is_erle_updated;
124 is_erle_updated.fill(false);
125
126 for (size_t k = 1; k < kFftLengthBy2; ++k) {
127 if (accum_spectra_.E2[ch][k] > 0.f) {
128 new_erle[k] = accum_spectra_.Y2[ch][k] / accum_spectra_.E2[ch][k];
129 is_erle_updated[k] = true;
130 }
131 }
132
133 if (use_onset_detection_) {
134 for (size_t k = 1; k < kFftLengthBy2; ++k) {
135 if (is_erle_updated[k] && !accum_spectra_.low_render_energy[ch][k]) {
136 if (coming_onset_[ch][k]) {
137 coming_onset_[ch][k] = false;
138 if (!use_min_erle_during_onsets_) {
139 float alpha =
140 new_erle[k] < erle_during_onsets_[ch][k] ? 0.3f : 0.15f;
141 erle_during_onsets_[ch][k] = rtc::SafeClamp(
142 erle_during_onsets_[ch][k] +
143 alpha * (new_erle[k] - erle_during_onsets_[ch][k]),
144 min_erle_, max_erle_[k]);
145 }
146 }
147 hold_counters_[ch][k] = kBlocksForOnsetDetection;
148 }
149 }
150 }
151
152 auto update_erle_band = [](float& erle, float new_erle,
153 bool low_render_energy, float min_erle,
154 float max_erle) {
155 float alpha = 0.05f;
156 if (new_erle < erle) {
157 alpha = low_render_energy ? 0.f : 0.1f;
158 }
159 erle =
160 rtc::SafeClamp(erle + alpha * (new_erle - erle), min_erle, max_erle);
161 };
162
163 for (size_t k = 1; k < kFftLengthBy2; ++k) {
164 if (is_erle_updated[k]) {
165 const bool low_render_energy = accum_spectra_.low_render_energy[ch][k];
166 update_erle_band(erle_[ch][k], new_erle[k], low_render_energy,
167 min_erle_, max_erle_[k]);
168 if (use_onset_detection_) {
169 update_erle_band(erle_onset_compensated_[ch][k], new_erle[k],
170 low_render_energy, min_erle_, max_erle_[k]);
171 }
172
173 // Virtually unbounded ERLE.
174 constexpr float kUnboundedErleMax = 100000.0f;
175 update_erle_band(erle_unbounded_[ch][k], new_erle[k], low_render_energy,
176 min_erle_, kUnboundedErleMax);
177 }
178 }
179 }
180 }
181
DecreaseErlePerBandForLowRenderSignals()182 void SubbandErleEstimator::DecreaseErlePerBandForLowRenderSignals() {
183 const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size());
184 for (int ch = 0; ch < num_capture_channels; ++ch) {
185 for (size_t k = 1; k < kFftLengthBy2; ++k) {
186 --hold_counters_[ch][k];
187 if (hold_counters_[ch][k] <=
188 (kBlocksForOnsetDetection - kBlocksToHoldErle)) {
189 if (erle_onset_compensated_[ch][k] > erle_during_onsets_[ch][k]) {
190 erle_onset_compensated_[ch][k] =
191 std::max(erle_during_onsets_[ch][k],
192 0.97f * erle_onset_compensated_[ch][k]);
193 RTC_DCHECK_LE(min_erle_, erle_onset_compensated_[ch][k]);
194 }
195 if (hold_counters_[ch][k] <= 0) {
196 coming_onset_[ch][k] = true;
197 hold_counters_[ch][k] = 0;
198 }
199 }
200 }
201 }
202 }
203
ResetAccumulatedSpectra()204 void SubbandErleEstimator::ResetAccumulatedSpectra() {
205 for (size_t ch = 0; ch < erle_during_onsets_.size(); ++ch) {
206 accum_spectra_.Y2[ch].fill(0.f);
207 accum_spectra_.E2[ch].fill(0.f);
208 accum_spectra_.num_points[ch] = 0;
209 accum_spectra_.low_render_energy[ch].fill(false);
210 }
211 }
212
UpdateAccumulatedSpectra(rtc::ArrayView<const float,kFftLengthBy2Plus1> X2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> Y2,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> E2,const std::vector<bool> & converged_filters)213 void SubbandErleEstimator::UpdateAccumulatedSpectra(
214 rtc::ArrayView<const float, kFftLengthBy2Plus1> X2,
215 rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
216 rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
217 const std::vector<bool>& converged_filters) {
218 auto& st = accum_spectra_;
219 RTC_DCHECK_EQ(st.E2.size(), E2.size());
220 RTC_DCHECK_EQ(st.E2.size(), E2.size());
221 const int num_capture_channels = static_cast<int>(Y2.size());
222 for (int ch = 0; ch < num_capture_channels; ++ch) {
223 // Note that the use of the converged_filter flag already imposed
224 // a minimum of the erle that can be estimated as that flag would
225 // be false if the filter is performing poorly.
226 if (!converged_filters[ch]) {
227 continue;
228 }
229
230 if (st.num_points[ch] == kPointsToAccumulate) {
231 st.num_points[ch] = 0;
232 st.Y2[ch].fill(0.f);
233 st.E2[ch].fill(0.f);
234 st.low_render_energy[ch].fill(false);
235 }
236
237 std::transform(Y2[ch].begin(), Y2[ch].end(), st.Y2[ch].begin(),
238 st.Y2[ch].begin(), std::plus<float>());
239 std::transform(E2[ch].begin(), E2[ch].end(), st.E2[ch].begin(),
240 st.E2[ch].begin(), std::plus<float>());
241
242 for (size_t k = 0; k < X2.size(); ++k) {
243 st.low_render_energy[ch][k] =
244 st.low_render_energy[ch][k] || X2[k] < kX2BandEnergyThreshold;
245 }
246
247 ++st.num_points[ch];
248 }
249 }
250
251 } // namespace webrtc
252