xref: /aosp_15_r20/external/webrtc/modules/audio_processing/aec3/aec_state.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/aec_state.h"
12 
13 #include <math.h>
14 
15 #include <algorithm>
16 #include <numeric>
17 #include <vector>
18 
19 #include "absl/types/optional.h"
20 #include "api/array_view.h"
21 #include "modules/audio_processing/aec3/aec3_common.h"
22 #include "modules/audio_processing/logging/apm_data_dumper.h"
23 #include "rtc_base/checks.h"
24 #include "system_wrappers/include/field_trial.h"
25 
26 namespace webrtc {
27 namespace {
28 
DeactivateInitialStateResetAtEchoPathChange()29 bool DeactivateInitialStateResetAtEchoPathChange() {
30   return field_trial::IsEnabled(
31       "WebRTC-Aec3DeactivateInitialStateResetKillSwitch");
32 }
33 
FullResetAtEchoPathChange()34 bool FullResetAtEchoPathChange() {
35   return !field_trial::IsEnabled("WebRTC-Aec3AecStateFullResetKillSwitch");
36 }
37 
SubtractorAnalyzerResetAtEchoPathChange()38 bool SubtractorAnalyzerResetAtEchoPathChange() {
39   return !field_trial::IsEnabled(
40       "WebRTC-Aec3AecStateSubtractorAnalyzerResetKillSwitch");
41 }
42 
ComputeAvgRenderReverb(const SpectrumBuffer & spectrum_buffer,int delay_blocks,float reverb_decay,ReverbModel * reverb_model,rtc::ArrayView<float,kFftLengthBy2Plus1> reverb_power_spectrum)43 void ComputeAvgRenderReverb(
44     const SpectrumBuffer& spectrum_buffer,
45     int delay_blocks,
46     float reverb_decay,
47     ReverbModel* reverb_model,
48     rtc::ArrayView<float, kFftLengthBy2Plus1> reverb_power_spectrum) {
49   RTC_DCHECK(reverb_model);
50   const size_t num_render_channels = spectrum_buffer.buffer[0].size();
51   int idx_at_delay =
52       spectrum_buffer.OffsetIndex(spectrum_buffer.read, delay_blocks);
53   int idx_past = spectrum_buffer.IncIndex(idx_at_delay);
54 
55   std::array<float, kFftLengthBy2Plus1> X2_data;
56   rtc::ArrayView<const float> X2;
57   if (num_render_channels > 1) {
58     auto average_channels =
59         [](size_t num_render_channels,
60            rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
61                spectrum_band_0,
62            rtc::ArrayView<float, kFftLengthBy2Plus1> render_power) {
63           std::fill(render_power.begin(), render_power.end(), 0.f);
64           for (size_t ch = 0; ch < num_render_channels; ++ch) {
65             for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
66               render_power[k] += spectrum_band_0[ch][k];
67             }
68           }
69           const float normalizer = 1.f / num_render_channels;
70           for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
71             render_power[k] *= normalizer;
72           }
73         };
74     average_channels(num_render_channels, spectrum_buffer.buffer[idx_past],
75                      X2_data);
76     reverb_model->UpdateReverbNoFreqShaping(
77         X2_data, /*power_spectrum_scaling=*/1.0f, reverb_decay);
78 
79     average_channels(num_render_channels, spectrum_buffer.buffer[idx_at_delay],
80                      X2_data);
81     X2 = X2_data;
82   } else {
83     reverb_model->UpdateReverbNoFreqShaping(
84         spectrum_buffer.buffer[idx_past][/*channel=*/0],
85         /*power_spectrum_scaling=*/1.0f, reverb_decay);
86 
87     X2 = spectrum_buffer.buffer[idx_at_delay][/*channel=*/0];
88   }
89 
90   rtc::ArrayView<const float, kFftLengthBy2Plus1> reverb_power =
91       reverb_model->reverb();
92   for (size_t k = 0; k < X2.size(); ++k) {
93     reverb_power_spectrum[k] = X2[k] + reverb_power[k];
94   }
95 }
96 
97 }  // namespace
98 
99 std::atomic<int> AecState::instance_count_(0);
100 
GetResidualEchoScaling(rtc::ArrayView<float> residual_scaling) const101 void AecState::GetResidualEchoScaling(
102     rtc::ArrayView<float> residual_scaling) const {
103   bool filter_has_had_time_to_converge;
104   if (config_.filter.conservative_initial_phase) {
105     filter_has_had_time_to_converge =
106         strong_not_saturated_render_blocks_ >= 1.5f * kNumBlocksPerSecond;
107   } else {
108     filter_has_had_time_to_converge =
109         strong_not_saturated_render_blocks_ >= 0.8f * kNumBlocksPerSecond;
110   }
111   echo_audibility_.GetResidualEchoScaling(filter_has_had_time_to_converge,
112                                           residual_scaling);
113 }
114 
AecState(const EchoCanceller3Config & config,size_t num_capture_channels)115 AecState::AecState(const EchoCanceller3Config& config,
116                    size_t num_capture_channels)
117     : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)),
118       config_(config),
119       num_capture_channels_(num_capture_channels),
120       deactivate_initial_state_reset_at_echo_path_change_(
121           DeactivateInitialStateResetAtEchoPathChange()),
122       full_reset_at_echo_path_change_(FullResetAtEchoPathChange()),
123       subtractor_analyzer_reset_at_echo_path_change_(
124           SubtractorAnalyzerResetAtEchoPathChange()),
125       initial_state_(config_),
126       delay_state_(config_, num_capture_channels_),
127       transparent_state_(TransparentMode::Create(config_)),
128       filter_quality_state_(config_, num_capture_channels_),
129       erl_estimator_(2 * kNumBlocksPerSecond),
130       erle_estimator_(2 * kNumBlocksPerSecond, config_, num_capture_channels_),
131       filter_analyzer_(config_, num_capture_channels_),
132       echo_audibility_(
133           config_.echo_audibility.use_stationarity_properties_at_init),
134       reverb_model_estimator_(config_, num_capture_channels_),
135       subtractor_output_analyzer_(num_capture_channels_) {}
136 
137 AecState::~AecState() = default;
138 
HandleEchoPathChange(const EchoPathVariability & echo_path_variability)139 void AecState::HandleEchoPathChange(
140     const EchoPathVariability& echo_path_variability) {
141   const auto full_reset = [&]() {
142     filter_analyzer_.Reset();
143     capture_signal_saturation_ = false;
144     strong_not_saturated_render_blocks_ = 0;
145     blocks_with_active_render_ = 0;
146     if (!deactivate_initial_state_reset_at_echo_path_change_) {
147       initial_state_.Reset();
148     }
149     if (transparent_state_) {
150       transparent_state_->Reset();
151     }
152     erle_estimator_.Reset(true);
153     erl_estimator_.Reset();
154     filter_quality_state_.Reset();
155   };
156 
157   // TODO(peah): Refine the reset scheme according to the type of gain and
158   // delay adjustment.
159 
160   if (full_reset_at_echo_path_change_ &&
161       echo_path_variability.delay_change !=
162           EchoPathVariability::DelayAdjustment::kNone) {
163     full_reset();
164   } else if (echo_path_variability.gain_change) {
165     erle_estimator_.Reset(false);
166   }
167   if (subtractor_analyzer_reset_at_echo_path_change_) {
168     subtractor_output_analyzer_.HandleEchoPathChange();
169   }
170 }
171 
Update(const absl::optional<DelayEstimate> & external_delay,rtc::ArrayView<const std::vector<std::array<float,kFftLengthBy2Plus1>>> adaptive_filter_frequency_responses,rtc::ArrayView<const std::vector<float>> adaptive_filter_impulse_responses,const RenderBuffer & render_buffer,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> E2_refined,rtc::ArrayView<const std::array<float,kFftLengthBy2Plus1>> Y2,rtc::ArrayView<const SubtractorOutput> subtractor_output)172 void AecState::Update(
173     const absl::optional<DelayEstimate>& external_delay,
174     rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>>
175         adaptive_filter_frequency_responses,
176     rtc::ArrayView<const std::vector<float>> adaptive_filter_impulse_responses,
177     const RenderBuffer& render_buffer,
178     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2_refined,
179     rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
180     rtc::ArrayView<const SubtractorOutput> subtractor_output) {
181   RTC_DCHECK_EQ(num_capture_channels_, Y2.size());
182   RTC_DCHECK_EQ(num_capture_channels_, subtractor_output.size());
183   RTC_DCHECK_EQ(num_capture_channels_,
184                 adaptive_filter_frequency_responses.size());
185   RTC_DCHECK_EQ(num_capture_channels_,
186                 adaptive_filter_impulse_responses.size());
187 
188   // Analyze the filter outputs and filters.
189   bool any_filter_converged;
190   bool any_coarse_filter_converged;
191   bool all_filters_diverged;
192   subtractor_output_analyzer_.Update(subtractor_output, &any_filter_converged,
193                                      &any_coarse_filter_converged,
194                                      &all_filters_diverged);
195 
196   bool any_filter_consistent;
197   float max_echo_path_gain;
198   filter_analyzer_.Update(adaptive_filter_impulse_responses, render_buffer,
199                           &any_filter_consistent, &max_echo_path_gain);
200 
201   // Estimate the direct path delay of the filter.
202   if (config_.filter.use_linear_filter) {
203     delay_state_.Update(filter_analyzer_.FilterDelaysBlocks(), external_delay,
204                         strong_not_saturated_render_blocks_);
205   }
206 
207   const Block& aligned_render_block =
208       render_buffer.GetBlock(-delay_state_.MinDirectPathFilterDelay());
209 
210   // Update render counters.
211   bool active_render = false;
212   for (int ch = 0; ch < aligned_render_block.NumChannels(); ++ch) {
213     const float render_energy =
214         std::inner_product(aligned_render_block.begin(/*block=*/0, ch),
215                            aligned_render_block.end(/*block=*/0, ch),
216                            aligned_render_block.begin(/*block=*/0, ch), 0.f);
217     if (render_energy > (config_.render_levels.active_render_limit *
218                          config_.render_levels.active_render_limit) *
219                             kFftLengthBy2) {
220       active_render = true;
221       break;
222     }
223   }
224   blocks_with_active_render_ += active_render ? 1 : 0;
225   strong_not_saturated_render_blocks_ +=
226       active_render && !SaturatedCapture() ? 1 : 0;
227 
228   std::array<float, kFftLengthBy2Plus1> avg_render_spectrum_with_reverb;
229 
230   ComputeAvgRenderReverb(render_buffer.GetSpectrumBuffer(),
231                          delay_state_.MinDirectPathFilterDelay(),
232                          ReverbDecay(/*mild=*/false), &avg_render_reverb_,
233                          avg_render_spectrum_with_reverb);
234 
235   if (config_.echo_audibility.use_stationarity_properties) {
236     // Update the echo audibility evaluator.
237     echo_audibility_.Update(render_buffer, avg_render_reverb_.reverb(),
238                             delay_state_.MinDirectPathFilterDelay(),
239                             delay_state_.ExternalDelayReported());
240   }
241 
242   // Update the ERL and ERLE measures.
243   if (initial_state_.TransitionTriggered()) {
244     erle_estimator_.Reset(false);
245   }
246 
247   erle_estimator_.Update(render_buffer, adaptive_filter_frequency_responses,
248                          avg_render_spectrum_with_reverb, Y2, E2_refined,
249                          subtractor_output_analyzer_.ConvergedFilters());
250 
251   erl_estimator_.Update(
252       subtractor_output_analyzer_.ConvergedFilters(),
253       render_buffer.Spectrum(delay_state_.MinDirectPathFilterDelay()), Y2);
254 
255   // Detect and flag echo saturation.
256   if (config_.ep_strength.echo_can_saturate) {
257     saturation_detector_.Update(aligned_render_block, SaturatedCapture(),
258                                 UsableLinearEstimate(), subtractor_output,
259                                 max_echo_path_gain);
260   } else {
261     RTC_DCHECK(!saturation_detector_.SaturatedEcho());
262   }
263 
264   // Update the decision on whether to use the initial state parameter set.
265   initial_state_.Update(active_render, SaturatedCapture());
266 
267   // Detect whether the transparent mode should be activated.
268   if (transparent_state_) {
269     transparent_state_->Update(
270         delay_state_.MinDirectPathFilterDelay(), any_filter_consistent,
271         any_filter_converged, any_coarse_filter_converged, all_filters_diverged,
272         active_render, SaturatedCapture());
273   }
274 
275   // Analyze the quality of the filter.
276   filter_quality_state_.Update(active_render, TransparentModeActive(),
277                                SaturatedCapture(), external_delay,
278                                any_filter_converged);
279 
280   // Update the reverb estimate.
281   const bool stationary_block =
282       config_.echo_audibility.use_stationarity_properties &&
283       echo_audibility_.IsBlockStationary();
284 
285   reverb_model_estimator_.Update(
286       filter_analyzer_.GetAdjustedFilters(),
287       adaptive_filter_frequency_responses,
288       erle_estimator_.GetInstLinearQualityEstimates(),
289       delay_state_.DirectPathFilterDelays(),
290       filter_quality_state_.UsableLinearFilterOutputs(), stationary_block);
291 
292   erle_estimator_.Dump(data_dumper_);
293   reverb_model_estimator_.Dump(data_dumper_.get());
294   data_dumper_->DumpRaw("aec3_active_render", active_render);
295   data_dumper_->DumpRaw("aec3_erl", Erl());
296   data_dumper_->DumpRaw("aec3_erl_time_domain", ErlTimeDomain());
297   data_dumper_->DumpRaw("aec3_erle", Erle(/*onset_compensated=*/false)[0]);
298   data_dumper_->DumpRaw("aec3_erle_onset_compensated",
299                         Erle(/*onset_compensated=*/true)[0]);
300   data_dumper_->DumpRaw("aec3_usable_linear_estimate", UsableLinearEstimate());
301   data_dumper_->DumpRaw("aec3_transparent_mode", TransparentModeActive());
302   data_dumper_->DumpRaw("aec3_filter_delay",
303                         filter_analyzer_.MinFilterDelayBlocks());
304 
305   data_dumper_->DumpRaw("aec3_any_filter_consistent", any_filter_consistent);
306   data_dumper_->DumpRaw("aec3_initial_state",
307                         initial_state_.InitialStateActive());
308   data_dumper_->DumpRaw("aec3_capture_saturation", SaturatedCapture());
309   data_dumper_->DumpRaw("aec3_echo_saturation", SaturatedEcho());
310   data_dumper_->DumpRaw("aec3_any_filter_converged", any_filter_converged);
311   data_dumper_->DumpRaw("aec3_any_coarse_filter_converged",
312                         any_coarse_filter_converged);
313   data_dumper_->DumpRaw("aec3_all_filters_diverged", all_filters_diverged);
314 
315   data_dumper_->DumpRaw("aec3_external_delay_avaliable",
316                         external_delay ? 1 : 0);
317   data_dumper_->DumpRaw("aec3_filter_tail_freq_resp_est",
318                         GetReverbFrequencyResponse());
319   data_dumper_->DumpRaw("aec3_subtractor_y2", subtractor_output[0].y2);
320   data_dumper_->DumpRaw("aec3_subtractor_e2_coarse",
321                         subtractor_output[0].e2_coarse);
322   data_dumper_->DumpRaw("aec3_subtractor_e2_refined",
323                         subtractor_output[0].e2_refined);
324 }
325 
InitialState(const EchoCanceller3Config & config)326 AecState::InitialState::InitialState(const EchoCanceller3Config& config)
327     : conservative_initial_phase_(config.filter.conservative_initial_phase),
328       initial_state_seconds_(config.filter.initial_state_seconds) {
329   Reset();
330 }
Reset()331 void AecState::InitialState::InitialState::Reset() {
332   initial_state_ = true;
333   strong_not_saturated_render_blocks_ = 0;
334 }
Update(bool active_render,bool saturated_capture)335 void AecState::InitialState::InitialState::Update(bool active_render,
336                                                   bool saturated_capture) {
337   strong_not_saturated_render_blocks_ +=
338       active_render && !saturated_capture ? 1 : 0;
339 
340   // Flag whether the initial state is still active.
341   bool prev_initial_state = initial_state_;
342   if (conservative_initial_phase_) {
343     initial_state_ =
344         strong_not_saturated_render_blocks_ < 5 * kNumBlocksPerSecond;
345   } else {
346     initial_state_ = strong_not_saturated_render_blocks_ <
347                      initial_state_seconds_ * kNumBlocksPerSecond;
348   }
349 
350   // Flag whether the transition from the initial state has started.
351   transition_triggered_ = !initial_state_ && prev_initial_state;
352 }
353 
FilterDelay(const EchoCanceller3Config & config,size_t num_capture_channels)354 AecState::FilterDelay::FilterDelay(const EchoCanceller3Config& config,
355                                    size_t num_capture_channels)
356     : delay_headroom_blocks_(config.delay.delay_headroom_samples / kBlockSize),
357       filter_delays_blocks_(num_capture_channels, delay_headroom_blocks_),
358       min_filter_delay_(delay_headroom_blocks_) {}
359 
Update(rtc::ArrayView<const int> analyzer_filter_delay_estimates_blocks,const absl::optional<DelayEstimate> & external_delay,size_t blocks_with_proper_filter_adaptation)360 void AecState::FilterDelay::Update(
361     rtc::ArrayView<const int> analyzer_filter_delay_estimates_blocks,
362     const absl::optional<DelayEstimate>& external_delay,
363     size_t blocks_with_proper_filter_adaptation) {
364   // Update the delay based on the external delay.
365   if (external_delay &&
366       (!external_delay_ || external_delay_->delay != external_delay->delay)) {
367     external_delay_ = external_delay;
368     external_delay_reported_ = true;
369   }
370 
371   // Override the estimated delay if it is not certain that the filter has had
372   // time to converge.
373   const bool delay_estimator_may_not_have_converged =
374       blocks_with_proper_filter_adaptation < 2 * kNumBlocksPerSecond;
375   if (delay_estimator_may_not_have_converged && external_delay_) {
376     const int delay_guess = delay_headroom_blocks_;
377     std::fill(filter_delays_blocks_.begin(), filter_delays_blocks_.end(),
378               delay_guess);
379   } else {
380     RTC_DCHECK_EQ(filter_delays_blocks_.size(),
381                   analyzer_filter_delay_estimates_blocks.size());
382     std::copy(analyzer_filter_delay_estimates_blocks.begin(),
383               analyzer_filter_delay_estimates_blocks.end(),
384               filter_delays_blocks_.begin());
385   }
386 
387   min_filter_delay_ = *std::min_element(filter_delays_blocks_.begin(),
388                                         filter_delays_blocks_.end());
389 }
390 
FilteringQualityAnalyzer(const EchoCanceller3Config & config,size_t num_capture_channels)391 AecState::FilteringQualityAnalyzer::FilteringQualityAnalyzer(
392     const EchoCanceller3Config& config,
393     size_t num_capture_channels)
394     : use_linear_filter_(config.filter.use_linear_filter),
395       usable_linear_filter_estimates_(num_capture_channels, false) {}
396 
Reset()397 void AecState::FilteringQualityAnalyzer::Reset() {
398   std::fill(usable_linear_filter_estimates_.begin(),
399             usable_linear_filter_estimates_.end(), false);
400   overall_usable_linear_estimates_ = false;
401   filter_update_blocks_since_reset_ = 0;
402 }
403 
Update(bool active_render,bool transparent_mode,bool saturated_capture,const absl::optional<DelayEstimate> & external_delay,bool any_filter_converged)404 void AecState::FilteringQualityAnalyzer::Update(
405     bool active_render,
406     bool transparent_mode,
407     bool saturated_capture,
408     const absl::optional<DelayEstimate>& external_delay,
409     bool any_filter_converged) {
410   // Update blocks counter.
411   const bool filter_update = active_render && !saturated_capture;
412   filter_update_blocks_since_reset_ += filter_update ? 1 : 0;
413   filter_update_blocks_since_start_ += filter_update ? 1 : 0;
414 
415   // Store convergence flag when observed.
416   convergence_seen_ = convergence_seen_ || any_filter_converged;
417 
418   // Verify requirements for achieving a decent filter. The requirements for
419   // filter adaptation at call startup are more restrictive than after an
420   // in-call reset.
421   const bool sufficient_data_to_converge_at_startup =
422       filter_update_blocks_since_start_ > kNumBlocksPerSecond * 0.4f;
423   const bool sufficient_data_to_converge_at_reset =
424       sufficient_data_to_converge_at_startup &&
425       filter_update_blocks_since_reset_ > kNumBlocksPerSecond * 0.2f;
426 
427   // The linear filter can only be used if it has had time to converge.
428   overall_usable_linear_estimates_ = sufficient_data_to_converge_at_startup &&
429                                      sufficient_data_to_converge_at_reset;
430 
431   // The linear filter can only be used if an external delay or convergence have
432   // been identified
433   overall_usable_linear_estimates_ =
434       overall_usable_linear_estimates_ && (external_delay || convergence_seen_);
435 
436   // If transparent mode is on, deactivate usign the linear filter.
437   overall_usable_linear_estimates_ =
438       overall_usable_linear_estimates_ && !transparent_mode;
439 
440   if (use_linear_filter_) {
441     std::fill(usable_linear_filter_estimates_.begin(),
442               usable_linear_filter_estimates_.end(),
443               overall_usable_linear_estimates_);
444   }
445 }
446 
Update(const Block & x,bool saturated_capture,bool usable_linear_estimate,rtc::ArrayView<const SubtractorOutput> subtractor_output,float echo_path_gain)447 void AecState::SaturationDetector::Update(
448     const Block& x,
449     bool saturated_capture,
450     bool usable_linear_estimate,
451     rtc::ArrayView<const SubtractorOutput> subtractor_output,
452     float echo_path_gain) {
453   saturated_echo_ = false;
454   if (!saturated_capture) {
455     return;
456   }
457 
458   if (usable_linear_estimate) {
459     constexpr float kSaturationThreshold = 20000.f;
460     for (size_t ch = 0; ch < subtractor_output.size(); ++ch) {
461       saturated_echo_ =
462           saturated_echo_ ||
463           (subtractor_output[ch].s_refined_max_abs > kSaturationThreshold ||
464            subtractor_output[ch].s_coarse_max_abs > kSaturationThreshold);
465     }
466   } else {
467     float max_sample = 0.f;
468     for (int ch = 0; ch < x.NumChannels(); ++ch) {
469       rtc::ArrayView<const float, kBlockSize> x_ch = x.View(/*band=*/0, ch);
470       for (float sample : x_ch) {
471         max_sample = std::max(max_sample, fabsf(sample));
472       }
473     }
474 
475     const float kMargin = 10.f;
476     float peak_echo_amplitude = max_sample * echo_path_gain * kMargin;
477     saturated_echo_ = saturated_echo_ || peak_echo_amplitude > 32000;
478   }
479 }
480 
481 }  // namespace webrtc
482