xref: /aosp_15_r20/external/webrtc/modules/audio_processing/aec3/filter_analyzer.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/aec3/filter_analyzer.h"
12 
13 #include <math.h>
14 
15 #include <algorithm>
16 #include <array>
17 #include <numeric>
18 
19 #include "modules/audio_processing/aec3/aec3_common.h"
20 #include "modules/audio_processing/aec3/render_buffer.h"
21 #include "modules/audio_processing/logging/apm_data_dumper.h"
22 #include "rtc_base/checks.h"
23 
24 namespace webrtc {
25 namespace {
26 
FindPeakIndex(rtc::ArrayView<const float> filter_time_domain,size_t peak_index_in,size_t start_sample,size_t end_sample)27 size_t FindPeakIndex(rtc::ArrayView<const float> filter_time_domain,
28                      size_t peak_index_in,
29                      size_t start_sample,
30                      size_t end_sample) {
31   size_t peak_index_out = peak_index_in;
32   float max_h2 =
33       filter_time_domain[peak_index_out] * filter_time_domain[peak_index_out];
34   for (size_t k = start_sample; k <= end_sample; ++k) {
35     float tmp = filter_time_domain[k] * filter_time_domain[k];
36     if (tmp > max_h2) {
37       peak_index_out = k;
38       max_h2 = tmp;
39     }
40   }
41 
42   return peak_index_out;
43 }
44 
45 }  // namespace
46 
47 std::atomic<int> FilterAnalyzer::instance_count_(0);
48 
FilterAnalyzer(const EchoCanceller3Config & config,size_t num_capture_channels)49 FilterAnalyzer::FilterAnalyzer(const EchoCanceller3Config& config,
50                                size_t num_capture_channels)
51     : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)),
52       bounded_erl_(config.ep_strength.bounded_erl),
53       default_gain_(config.ep_strength.default_gain),
54       h_highpass_(num_capture_channels,
55                   std::vector<float>(
56                       GetTimeDomainLength(config.filter.refined.length_blocks),
57                       0.f)),
58       filter_analysis_states_(num_capture_channels,
59                               FilterAnalysisState(config)),
60       filter_delays_blocks_(num_capture_channels, 0) {
61   Reset();
62 }
63 
64 FilterAnalyzer::~FilterAnalyzer() = default;
65 
Reset()66 void FilterAnalyzer::Reset() {
67   blocks_since_reset_ = 0;
68   ResetRegion();
69   for (auto& state : filter_analysis_states_) {
70     state.Reset(default_gain_);
71   }
72   std::fill(filter_delays_blocks_.begin(), filter_delays_blocks_.end(), 0);
73 }
74 
Update(rtc::ArrayView<const std::vector<float>> filters_time_domain,const RenderBuffer & render_buffer,bool * any_filter_consistent,float * max_echo_path_gain)75 void FilterAnalyzer::Update(
76     rtc::ArrayView<const std::vector<float>> filters_time_domain,
77     const RenderBuffer& render_buffer,
78     bool* any_filter_consistent,
79     float* max_echo_path_gain) {
80   RTC_DCHECK(any_filter_consistent);
81   RTC_DCHECK(max_echo_path_gain);
82   RTC_DCHECK_EQ(filters_time_domain.size(), filter_analysis_states_.size());
83   RTC_DCHECK_EQ(filters_time_domain.size(), h_highpass_.size());
84 
85   ++blocks_since_reset_;
86   SetRegionToAnalyze(filters_time_domain[0].size());
87   AnalyzeRegion(filters_time_domain, render_buffer);
88 
89   // Aggregate the results for all capture channels.
90   auto& st_ch0 = filter_analysis_states_[0];
91   *any_filter_consistent = st_ch0.consistent_estimate;
92   *max_echo_path_gain = st_ch0.gain;
93   min_filter_delay_blocks_ = filter_delays_blocks_[0];
94   for (size_t ch = 1; ch < filters_time_domain.size(); ++ch) {
95     auto& st_ch = filter_analysis_states_[ch];
96     *any_filter_consistent =
97         *any_filter_consistent || st_ch.consistent_estimate;
98     *max_echo_path_gain = std::max(*max_echo_path_gain, st_ch.gain);
99     min_filter_delay_blocks_ =
100         std::min(min_filter_delay_blocks_, filter_delays_blocks_[ch]);
101   }
102 }
103 
AnalyzeRegion(rtc::ArrayView<const std::vector<float>> filters_time_domain,const RenderBuffer & render_buffer)104 void FilterAnalyzer::AnalyzeRegion(
105     rtc::ArrayView<const std::vector<float>> filters_time_domain,
106     const RenderBuffer& render_buffer) {
107   // Preprocess the filter to avoid issues with low-frequency components in the
108   // filter.
109   PreProcessFilters(filters_time_domain);
110   data_dumper_->DumpRaw("aec3_linear_filter_processed_td", h_highpass_[0]);
111 
112   constexpr float kOneByBlockSize = 1.f / kBlockSize;
113   for (size_t ch = 0; ch < filters_time_domain.size(); ++ch) {
114     RTC_DCHECK_LT(region_.start_sample_, filters_time_domain[ch].size());
115     RTC_DCHECK_LT(region_.end_sample_, filters_time_domain[ch].size());
116 
117     auto& st_ch = filter_analysis_states_[ch];
118     RTC_DCHECK_EQ(h_highpass_[ch].size(), filters_time_domain[ch].size());
119     RTC_DCHECK_GT(h_highpass_[ch].size(), 0);
120     st_ch.peak_index = std::min(st_ch.peak_index, h_highpass_[ch].size() - 1);
121 
122     st_ch.peak_index =
123         FindPeakIndex(h_highpass_[ch], st_ch.peak_index, region_.start_sample_,
124                       region_.end_sample_);
125     filter_delays_blocks_[ch] = st_ch.peak_index >> kBlockSizeLog2;
126     UpdateFilterGain(h_highpass_[ch], &st_ch);
127     st_ch.filter_length_blocks =
128         filters_time_domain[ch].size() * kOneByBlockSize;
129 
130     st_ch.consistent_estimate = st_ch.consistent_filter_detector.Detect(
131         h_highpass_[ch], region_,
132         render_buffer.GetBlock(-filter_delays_blocks_[ch]), st_ch.peak_index,
133         filter_delays_blocks_[ch]);
134   }
135 }
136 
UpdateFilterGain(rtc::ArrayView<const float> filter_time_domain,FilterAnalysisState * st)137 void FilterAnalyzer::UpdateFilterGain(
138     rtc::ArrayView<const float> filter_time_domain,
139     FilterAnalysisState* st) {
140   bool sufficient_time_to_converge =
141       blocks_since_reset_ > 5 * kNumBlocksPerSecond;
142 
143   if (sufficient_time_to_converge && st->consistent_estimate) {
144     st->gain = fabsf(filter_time_domain[st->peak_index]);
145   } else {
146     // TODO(peah): Verify whether this check against a float is ok.
147     if (st->gain) {
148       st->gain = std::max(st->gain, fabsf(filter_time_domain[st->peak_index]));
149     }
150   }
151 
152   if (bounded_erl_ && st->gain) {
153     st->gain = std::max(st->gain, 0.01f);
154   }
155 }
156 
PreProcessFilters(rtc::ArrayView<const std::vector<float>> filters_time_domain)157 void FilterAnalyzer::PreProcessFilters(
158     rtc::ArrayView<const std::vector<float>> filters_time_domain) {
159   for (size_t ch = 0; ch < filters_time_domain.size(); ++ch) {
160     RTC_DCHECK_LT(region_.start_sample_, filters_time_domain[ch].size());
161     RTC_DCHECK_LT(region_.end_sample_, filters_time_domain[ch].size());
162 
163     RTC_DCHECK_GE(h_highpass_[ch].capacity(), filters_time_domain[ch].size());
164     h_highpass_[ch].resize(filters_time_domain[ch].size());
165     // Minimum phase high-pass filter with cutoff frequency at about 600 Hz.
166     constexpr std::array<float, 3> h = {
167         {0.7929742f, -0.36072128f, -0.47047766f}};
168 
169     std::fill(h_highpass_[ch].begin() + region_.start_sample_,
170               h_highpass_[ch].begin() + region_.end_sample_ + 1, 0.f);
171     float* h_highpass_ch = h_highpass_[ch].data();
172     const float* filters_time_domain_ch = filters_time_domain[ch].data();
173     const size_t region_end = region_.end_sample_;
174     for (size_t k = std::max(h.size() - 1, region_.start_sample_);
175          k <= region_end; ++k) {
176       float tmp = h_highpass_ch[k];
177       for (size_t j = 0; j < h.size(); ++j) {
178         tmp += filters_time_domain_ch[k - j] * h[j];
179       }
180       h_highpass_ch[k] = tmp;
181     }
182   }
183 }
184 
ResetRegion()185 void FilterAnalyzer::ResetRegion() {
186   region_.start_sample_ = 0;
187   region_.end_sample_ = 0;
188 }
189 
SetRegionToAnalyze(size_t filter_size)190 void FilterAnalyzer::SetRegionToAnalyze(size_t filter_size) {
191   constexpr size_t kNumberBlocksToUpdate = 1;
192   auto& r = region_;
193   r.start_sample_ = r.end_sample_ >= filter_size - 1 ? 0 : r.end_sample_ + 1;
194   r.end_sample_ =
195       std::min(r.start_sample_ + kNumberBlocksToUpdate * kBlockSize - 1,
196                filter_size - 1);
197 
198   // Check range.
199   RTC_DCHECK_LT(r.start_sample_, filter_size);
200   RTC_DCHECK_LT(r.end_sample_, filter_size);
201   RTC_DCHECK_LE(r.start_sample_, r.end_sample_);
202 }
203 
ConsistentFilterDetector(const EchoCanceller3Config & config)204 FilterAnalyzer::ConsistentFilterDetector::ConsistentFilterDetector(
205     const EchoCanceller3Config& config)
206     : active_render_threshold_(config.render_levels.active_render_limit *
207                                config.render_levels.active_render_limit *
208                                kFftLengthBy2) {
209   Reset();
210 }
211 
Reset()212 void FilterAnalyzer::ConsistentFilterDetector::Reset() {
213   significant_peak_ = false;
214   filter_floor_accum_ = 0.f;
215   filter_secondary_peak_ = 0.f;
216   filter_floor_low_limit_ = 0;
217   filter_floor_high_limit_ = 0;
218   consistent_estimate_counter_ = 0;
219   consistent_delay_reference_ = -10;
220 }
221 
Detect(rtc::ArrayView<const float> filter_to_analyze,const FilterRegion & region,const Block & x_block,size_t peak_index,int delay_blocks)222 bool FilterAnalyzer::ConsistentFilterDetector::Detect(
223     rtc::ArrayView<const float> filter_to_analyze,
224     const FilterRegion& region,
225     const Block& x_block,
226     size_t peak_index,
227     int delay_blocks) {
228   if (region.start_sample_ == 0) {
229     filter_floor_accum_ = 0.f;
230     filter_secondary_peak_ = 0.f;
231     filter_floor_low_limit_ = peak_index < 64 ? 0 : peak_index - 64;
232     filter_floor_high_limit_ =
233         peak_index > filter_to_analyze.size() - 129 ? 0 : peak_index + 128;
234   }
235 
236   float filter_floor_accum = filter_floor_accum_;
237   float filter_secondary_peak = filter_secondary_peak_;
238   for (size_t k = region.start_sample_;
239        k < std::min(region.end_sample_ + 1, filter_floor_low_limit_); ++k) {
240     float abs_h = fabsf(filter_to_analyze[k]);
241     filter_floor_accum += abs_h;
242     filter_secondary_peak = std::max(filter_secondary_peak, abs_h);
243   }
244 
245   for (size_t k = std::max(filter_floor_high_limit_, region.start_sample_);
246        k <= region.end_sample_; ++k) {
247     float abs_h = fabsf(filter_to_analyze[k]);
248     filter_floor_accum += abs_h;
249     filter_secondary_peak = std::max(filter_secondary_peak, abs_h);
250   }
251   filter_floor_accum_ = filter_floor_accum;
252   filter_secondary_peak_ = filter_secondary_peak;
253 
254   if (region.end_sample_ == filter_to_analyze.size() - 1) {
255     float filter_floor = filter_floor_accum_ /
256                          (filter_floor_low_limit_ + filter_to_analyze.size() -
257                           filter_floor_high_limit_);
258 
259     float abs_peak = fabsf(filter_to_analyze[peak_index]);
260     significant_peak_ = abs_peak > 10.f * filter_floor &&
261                         abs_peak > 2.f * filter_secondary_peak_;
262   }
263 
264   if (significant_peak_) {
265     bool active_render_block = false;
266     for (int ch = 0; ch < x_block.NumChannels(); ++ch) {
267       rtc::ArrayView<const float, kBlockSize> x_channel =
268           x_block.View(/*band=*/0, ch);
269       const float x_energy = std::inner_product(
270           x_channel.begin(), x_channel.end(), x_channel.begin(), 0.f);
271       if (x_energy > active_render_threshold_) {
272         active_render_block = true;
273         break;
274       }
275     }
276 
277     if (consistent_delay_reference_ == delay_blocks) {
278       if (active_render_block) {
279         ++consistent_estimate_counter_;
280       }
281     } else {
282       consistent_estimate_counter_ = 0;
283       consistent_delay_reference_ = delay_blocks;
284     }
285   }
286   return consistent_estimate_counter_ > 1.5f * kNumBlocksPerSecond;
287 }
288 
289 }  // namespace webrtc
290