xref: /aosp_15_r20/external/webrtc/modules/audio_processing/transient/transient_suppressor_impl.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/transient/transient_suppressor_impl.h"
12 
13 #include <string.h>
14 
15 #include <algorithm>
16 #include <cmath>
17 #include <complex>
18 #include <deque>
19 #include <limits>
20 #include <set>
21 #include <string>
22 
23 #include "common_audio/include/audio_util.h"
24 #include "common_audio/signal_processing/include/signal_processing_library.h"
25 #include "common_audio/third_party/ooura/fft_size_256/fft4g.h"
26 #include "modules/audio_processing/transient/common.h"
27 #include "modules/audio_processing/transient/transient_detector.h"
28 #include "modules/audio_processing/transient/transient_suppressor.h"
29 #include "modules/audio_processing/transient/windows_private.h"
30 #include "rtc_base/checks.h"
31 #include "rtc_base/logging.h"
32 
33 namespace webrtc {
34 
35 static const float kMeanIIRCoefficient = 0.5f;
36 
37 // TODO(aluebs): Check if these values work also for 48kHz.
38 static const size_t kMinVoiceBin = 3;
39 static const size_t kMaxVoiceBin = 60;
40 
41 namespace {
42 
ComplexMagnitude(float a,float b)43 float ComplexMagnitude(float a, float b) {
44   return std::abs(a) + std::abs(b);
45 }
46 
GetVadModeLabel(TransientSuppressor::VadMode vad_mode)47 std::string GetVadModeLabel(TransientSuppressor::VadMode vad_mode) {
48   switch (vad_mode) {
49     case TransientSuppressor::VadMode::kDefault:
50       return "default";
51     case TransientSuppressor::VadMode::kRnnVad:
52       return "RNN VAD";
53     case TransientSuppressor::VadMode::kNoVad:
54       return "no VAD";
55   }
56 }
57 
58 }  // namespace
59 
TransientSuppressorImpl(VadMode vad_mode,int sample_rate_hz,int detector_rate_hz,int num_channels)60 TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode,
61                                                  int sample_rate_hz,
62                                                  int detector_rate_hz,
63                                                  int num_channels)
64     : vad_mode_(vad_mode),
65       voice_probability_delay_unit_(/*delay_num_samples=*/0, sample_rate_hz),
66       analyzed_audio_is_silent_(false),
67       data_length_(0),
68       detection_length_(0),
69       analysis_length_(0),
70       buffer_delay_(0),
71       complex_analysis_length_(0),
72       num_channels_(0),
73       window_(NULL),
74       detector_smoothed_(0.f),
75       keypress_counter_(0),
76       chunks_since_keypress_(0),
77       detection_enabled_(false),
78       suppression_enabled_(false),
79       use_hard_restoration_(false),
80       chunks_since_voice_change_(0),
81       seed_(182),
82       using_reference_(false) {
83   RTC_LOG(LS_INFO) << "VAD mode: " << GetVadModeLabel(vad_mode_);
84   Initialize(sample_rate_hz, detector_rate_hz, num_channels);
85 }
86 
~TransientSuppressorImpl()87 TransientSuppressorImpl::~TransientSuppressorImpl() {}
88 
Initialize(int sample_rate_hz,int detection_rate_hz,int num_channels)89 void TransientSuppressorImpl::Initialize(int sample_rate_hz,
90                                          int detection_rate_hz,
91                                          int num_channels) {
92   RTC_DCHECK(sample_rate_hz == ts::kSampleRate8kHz ||
93              sample_rate_hz == ts::kSampleRate16kHz ||
94              sample_rate_hz == ts::kSampleRate32kHz ||
95              sample_rate_hz == ts::kSampleRate48kHz);
96   RTC_DCHECK(detection_rate_hz == ts::kSampleRate8kHz ||
97              detection_rate_hz == ts::kSampleRate16kHz ||
98              detection_rate_hz == ts::kSampleRate32kHz ||
99              detection_rate_hz == ts::kSampleRate48kHz);
100   RTC_DCHECK_GT(num_channels, 0);
101 
102   switch (sample_rate_hz) {
103     case ts::kSampleRate8kHz:
104       analysis_length_ = 128u;
105       window_ = kBlocks80w128;
106       break;
107     case ts::kSampleRate16kHz:
108       analysis_length_ = 256u;
109       window_ = kBlocks160w256;
110       break;
111     case ts::kSampleRate32kHz:
112       analysis_length_ = 512u;
113       window_ = kBlocks320w512;
114       break;
115     case ts::kSampleRate48kHz:
116       analysis_length_ = 1024u;
117       window_ = kBlocks480w1024;
118       break;
119     default:
120       RTC_DCHECK_NOTREACHED();
121       return;
122   }
123 
124   detector_.reset(new TransientDetector(detection_rate_hz));
125   data_length_ = sample_rate_hz * ts::kChunkSizeMs / 1000;
126   RTC_DCHECK_LE(data_length_, analysis_length_);
127   buffer_delay_ = analysis_length_ - data_length_;
128 
129   voice_probability_delay_unit_.Initialize(/*delay_num_samples=*/buffer_delay_,
130                                            sample_rate_hz);
131 
132   complex_analysis_length_ = analysis_length_ / 2 + 1;
133   RTC_DCHECK_GE(complex_analysis_length_, kMaxVoiceBin);
134   num_channels_ = num_channels;
135   in_buffer_.reset(new float[analysis_length_ * num_channels_]);
136   memset(in_buffer_.get(), 0,
137          analysis_length_ * num_channels_ * sizeof(in_buffer_[0]));
138   detection_length_ = detection_rate_hz * ts::kChunkSizeMs / 1000;
139   detection_buffer_.reset(new float[detection_length_]);
140   memset(detection_buffer_.get(), 0,
141          detection_length_ * sizeof(detection_buffer_[0]));
142   out_buffer_.reset(new float[analysis_length_ * num_channels_]);
143   memset(out_buffer_.get(), 0,
144          analysis_length_ * num_channels_ * sizeof(out_buffer_[0]));
145   // ip[0] must be zero to trigger initialization using rdft().
146   size_t ip_length = 2 + sqrtf(analysis_length_);
147   ip_.reset(new size_t[ip_length]());
148   memset(ip_.get(), 0, ip_length * sizeof(ip_[0]));
149   wfft_.reset(new float[complex_analysis_length_ - 1]);
150   memset(wfft_.get(), 0, (complex_analysis_length_ - 1) * sizeof(wfft_[0]));
151   spectral_mean_.reset(new float[complex_analysis_length_ * num_channels_]);
152   memset(spectral_mean_.get(), 0,
153          complex_analysis_length_ * num_channels_ * sizeof(spectral_mean_[0]));
154   fft_buffer_.reset(new float[analysis_length_ + 2]);
155   memset(fft_buffer_.get(), 0, (analysis_length_ + 2) * sizeof(fft_buffer_[0]));
156   magnitudes_.reset(new float[complex_analysis_length_]);
157   memset(magnitudes_.get(), 0,
158          complex_analysis_length_ * sizeof(magnitudes_[0]));
159   mean_factor_.reset(new float[complex_analysis_length_]);
160 
161   static const float kFactorHeight = 10.f;
162   static const float kLowSlope = 1.f;
163   static const float kHighSlope = 0.3f;
164   for (size_t i = 0; i < complex_analysis_length_; ++i) {
165     mean_factor_[i] =
166         kFactorHeight /
167             (1.f + std::exp(kLowSlope * static_cast<int>(i - kMinVoiceBin))) +
168         kFactorHeight /
169             (1.f + std::exp(kHighSlope * static_cast<int>(kMaxVoiceBin - i)));
170   }
171   detector_smoothed_ = 0.f;
172   keypress_counter_ = 0;
173   chunks_since_keypress_ = 0;
174   detection_enabled_ = false;
175   suppression_enabled_ = false;
176   use_hard_restoration_ = false;
177   chunks_since_voice_change_ = 0;
178   seed_ = 182;
179   using_reference_ = false;
180 }
181 
Suppress(float * data,size_t data_length,int num_channels,const float * detection_data,size_t detection_length,const float * reference_data,size_t reference_length,float voice_probability,bool key_pressed)182 float TransientSuppressorImpl::Suppress(float* data,
183                                         size_t data_length,
184                                         int num_channels,
185                                         const float* detection_data,
186                                         size_t detection_length,
187                                         const float* reference_data,
188                                         size_t reference_length,
189                                         float voice_probability,
190                                         bool key_pressed) {
191   if (!data || data_length != data_length_ || num_channels != num_channels_ ||
192       detection_length != detection_length_ || voice_probability < 0 ||
193       voice_probability > 1) {
194     // The audio is not modified, so the voice probability is returned as is
195     // (delay not applied).
196     return voice_probability;
197   }
198 
199   UpdateKeypress(key_pressed);
200   UpdateBuffers(data);
201 
202   if (detection_enabled_) {
203     UpdateRestoration(voice_probability);
204 
205     if (!detection_data) {
206       // Use the input data  of the first channel if special detection data is
207       // not supplied.
208       detection_data = &in_buffer_[buffer_delay_];
209     }
210 
211     float detector_result = detector_->Detect(detection_data, detection_length,
212                                               reference_data, reference_length);
213     if (detector_result < 0) {
214       // The audio is not modified, so the voice probability is returned as is
215       // (delay not applied).
216       return voice_probability;
217     }
218 
219     using_reference_ = detector_->using_reference();
220 
221     // `detector_smoothed_` follows the `detector_result` when this last one is
222     // increasing, but has an exponential decaying tail to be able to suppress
223     // the ringing of keyclicks.
224     float smooth_factor = using_reference_ ? 0.6 : 0.1;
225     detector_smoothed_ = detector_result >= detector_smoothed_
226                              ? detector_result
227                              : smooth_factor * detector_smoothed_ +
228                                    (1 - smooth_factor) * detector_result;
229 
230     for (int i = 0; i < num_channels_; ++i) {
231       Suppress(&in_buffer_[i * analysis_length_],
232                &spectral_mean_[i * complex_analysis_length_],
233                &out_buffer_[i * analysis_length_]);
234     }
235   }
236 
237   // If the suppression isn't enabled, we use the in buffer to delay the signal
238   // appropriately. This also gives time for the out buffer to be refreshed with
239   // new data between detection and suppression getting enabled.
240   for (int i = 0; i < num_channels_; ++i) {
241     memcpy(&data[i * data_length_],
242            suppression_enabled_ ? &out_buffer_[i * analysis_length_]
243                                 : &in_buffer_[i * analysis_length_],
244            data_length_ * sizeof(*data));
245   }
246 
247   // The audio has been modified, return the delayed voice probability.
248   return voice_probability_delay_unit_.Delay(voice_probability);
249 }
250 
251 // This should only be called when detection is enabled. UpdateBuffers() must
252 // have been called. At return, `out_buffer_` will be filled with the
253 // processed output.
Suppress(float * in_ptr,float * spectral_mean,float * out_ptr)254 void TransientSuppressorImpl::Suppress(float* in_ptr,
255                                        float* spectral_mean,
256                                        float* out_ptr) {
257   // Go to frequency domain.
258   for (size_t i = 0; i < analysis_length_; ++i) {
259     // TODO(aluebs): Rename windows
260     fft_buffer_[i] = in_ptr[i] * window_[i];
261   }
262 
263   WebRtc_rdft(analysis_length_, 1, fft_buffer_.get(), ip_.get(), wfft_.get());
264 
265   // Since WebRtc_rdft puts R[n/2] in fft_buffer_[1], we move it to the end
266   // for convenience.
267   fft_buffer_[analysis_length_] = fft_buffer_[1];
268   fft_buffer_[analysis_length_ + 1] = 0.f;
269   fft_buffer_[1] = 0.f;
270 
271   for (size_t i = 0; i < complex_analysis_length_; ++i) {
272     magnitudes_[i] =
273         ComplexMagnitude(fft_buffer_[i * 2], fft_buffer_[i * 2 + 1]);
274   }
275   // Restore audio if necessary.
276   if (suppression_enabled_) {
277     if (use_hard_restoration_) {
278       HardRestoration(spectral_mean);
279     } else {
280       SoftRestoration(spectral_mean);
281     }
282   }
283 
284   // Update the spectral mean.
285   for (size_t i = 0; i < complex_analysis_length_; ++i) {
286     spectral_mean[i] = (1 - kMeanIIRCoefficient) * spectral_mean[i] +
287                        kMeanIIRCoefficient * magnitudes_[i];
288   }
289 
290   // Back to time domain.
291   // Put R[n/2] back in fft_buffer_[1].
292   fft_buffer_[1] = fft_buffer_[analysis_length_];
293 
294   WebRtc_rdft(analysis_length_, -1, fft_buffer_.get(), ip_.get(), wfft_.get());
295   const float fft_scaling = 2.f / analysis_length_;
296 
297   for (size_t i = 0; i < analysis_length_; ++i) {
298     out_ptr[i] += fft_buffer_[i] * window_[i] * fft_scaling;
299   }
300 }
301 
UpdateKeypress(bool key_pressed)302 void TransientSuppressorImpl::UpdateKeypress(bool key_pressed) {
303   const int kKeypressPenalty = 1000 / ts::kChunkSizeMs;
304   const int kIsTypingThreshold = 1000 / ts::kChunkSizeMs;
305   const int kChunksUntilNotTyping = 4000 / ts::kChunkSizeMs;  // 4 seconds.
306 
307   if (key_pressed) {
308     keypress_counter_ += kKeypressPenalty;
309     chunks_since_keypress_ = 0;
310     detection_enabled_ = true;
311   }
312   keypress_counter_ = std::max(0, keypress_counter_ - 1);
313 
314   if (keypress_counter_ > kIsTypingThreshold) {
315     if (!suppression_enabled_) {
316       RTC_LOG(LS_INFO) << "[ts] Transient suppression is now enabled.";
317     }
318     suppression_enabled_ = true;
319     keypress_counter_ = 0;
320   }
321 
322   if (detection_enabled_ && ++chunks_since_keypress_ > kChunksUntilNotTyping) {
323     if (suppression_enabled_) {
324       RTC_LOG(LS_INFO) << "[ts] Transient suppression is now disabled.";
325     }
326     detection_enabled_ = false;
327     suppression_enabled_ = false;
328     keypress_counter_ = 0;
329   }
330 }
331 
UpdateRestoration(float voice_probability)332 void TransientSuppressorImpl::UpdateRestoration(float voice_probability) {
333   bool not_voiced;
334   switch (vad_mode_) {
335     case TransientSuppressor::VadMode::kDefault: {
336       constexpr float kVoiceThreshold = 0.02f;
337       not_voiced = voice_probability < kVoiceThreshold;
338       break;
339     }
340     case TransientSuppressor::VadMode::kRnnVad: {
341       constexpr float kVoiceThreshold = 0.7f;
342       not_voiced = voice_probability < kVoiceThreshold;
343       break;
344     }
345     case TransientSuppressor::VadMode::kNoVad:
346       // Always assume that voice is detected.
347       not_voiced = false;
348       break;
349   }
350 
351   if (not_voiced == use_hard_restoration_) {
352     chunks_since_voice_change_ = 0;
353   } else {
354     ++chunks_since_voice_change_;
355 
356     // Number of 10 ms frames to wait to transition to and from hard
357     // restoration.
358     constexpr int kHardRestorationOffsetDelay = 3;
359     constexpr int kHardRestorationOnsetDelay = 80;
360 
361     if ((use_hard_restoration_ &&
362          chunks_since_voice_change_ > kHardRestorationOffsetDelay) ||
363         (!use_hard_restoration_ &&
364          chunks_since_voice_change_ > kHardRestorationOnsetDelay)) {
365       use_hard_restoration_ = not_voiced;
366       chunks_since_voice_change_ = 0;
367     }
368   }
369 }
370 
371 // Shift buffers to make way for new data. Must be called after
372 // `detection_enabled_` is updated by UpdateKeypress().
UpdateBuffers(float * data)373 void TransientSuppressorImpl::UpdateBuffers(float* data) {
374   // TODO(aluebs): Change to ring buffer.
375   memmove(in_buffer_.get(), &in_buffer_[data_length_],
376           (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
377               sizeof(in_buffer_[0]));
378   // Copy new chunk to buffer.
379   for (int i = 0; i < num_channels_; ++i) {
380     memcpy(&in_buffer_[buffer_delay_ + i * analysis_length_],
381            &data[i * data_length_], data_length_ * sizeof(*data));
382   }
383   if (detection_enabled_) {
384     // Shift previous chunk in out buffer.
385     memmove(out_buffer_.get(), &out_buffer_[data_length_],
386             (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
387                 sizeof(out_buffer_[0]));
388     // Initialize new chunk in out buffer.
389     for (int i = 0; i < num_channels_; ++i) {
390       memset(&out_buffer_[buffer_delay_ + i * analysis_length_], 0,
391              data_length_ * sizeof(out_buffer_[0]));
392     }
393   }
394 }
395 
396 // Restores the unvoiced signal if a click is present.
397 // Attenuates by a certain factor every peak in the `fft_buffer_` that exceeds
398 // the spectral mean. The attenuation depends on `detector_smoothed_`.
399 // If a restoration takes place, the `magnitudes_` are updated to the new value.
HardRestoration(float * spectral_mean)400 void TransientSuppressorImpl::HardRestoration(float* spectral_mean) {
401   const float detector_result =
402       1.f - std::pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f);
403   // To restore, we get the peaks in the spectrum. If higher than the previous
404   // spectral mean we adjust them.
405   for (size_t i = 0; i < complex_analysis_length_; ++i) {
406     if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0) {
407       // RandU() generates values on [0, int16::max()]
408       const float phase = 2 * ts::kPi * WebRtcSpl_RandU(&seed_) /
409                           std::numeric_limits<int16_t>::max();
410       const float scaled_mean = detector_result * spectral_mean[i];
411 
412       fft_buffer_[i * 2] = (1 - detector_result) * fft_buffer_[i * 2] +
413                            scaled_mean * cosf(phase);
414       fft_buffer_[i * 2 + 1] = (1 - detector_result) * fft_buffer_[i * 2 + 1] +
415                                scaled_mean * sinf(phase);
416       magnitudes_[i] = magnitudes_[i] -
417                        detector_result * (magnitudes_[i] - spectral_mean[i]);
418     }
419   }
420 }
421 
422 // Restores the voiced signal if a click is present.
423 // Attenuates by a certain factor every peak in the `fft_buffer_` that exceeds
424 // the spectral mean and that is lower than some function of the current block
425 // frequency mean. The attenuation depends on `detector_smoothed_`.
426 // If a restoration takes place, the `magnitudes_` are updated to the new value.
SoftRestoration(float * spectral_mean)427 void TransientSuppressorImpl::SoftRestoration(float* spectral_mean) {
428   // Get the spectral magnitude mean of the current block.
429   float block_frequency_mean = 0;
430   for (size_t i = kMinVoiceBin; i < kMaxVoiceBin; ++i) {
431     block_frequency_mean += magnitudes_[i];
432   }
433   block_frequency_mean /= (kMaxVoiceBin - kMinVoiceBin);
434 
435   // To restore, we get the peaks in the spectrum. If higher than the
436   // previous spectral mean and lower than a factor of the block mean
437   // we adjust them. The factor is a double sigmoid that has a minimum in the
438   // voice frequency range (300Hz - 3kHz).
439   for (size_t i = 0; i < complex_analysis_length_; ++i) {
440     if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0 &&
441         (using_reference_ ||
442          magnitudes_[i] < block_frequency_mean * mean_factor_[i])) {
443       const float new_magnitude =
444           magnitudes_[i] -
445           detector_smoothed_ * (magnitudes_[i] - spectral_mean[i]);
446       const float magnitude_ratio = new_magnitude / magnitudes_[i];
447 
448       fft_buffer_[i * 2] *= magnitude_ratio;
449       fft_buffer_[i * 2 + 1] *= magnitude_ratio;
450       magnitudes_[i] = new_magnitude;
451     }
452   }
453 }
454 
455 }  // namespace webrtc
456