xref: /aosp_15_r20/external/webrtc/modules/audio_processing/transient/transient_suppressor_impl.h (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_IMPL_H_
12 #define MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_IMPL_H_
13 
14 #include <stddef.h>
15 #include <stdint.h>
16 
17 #include <memory>
18 
19 #include "modules/audio_processing/transient/transient_suppressor.h"
20 #include "modules/audio_processing/transient/voice_probability_delay_unit.h"
21 #include "rtc_base/gtest_prod_util.h"
22 
23 namespace webrtc {
24 
25 class TransientDetector;
26 
27 // Detects transients in an audio stream and suppress them using a simple
28 // restoration algorithm that attenuates unexpected spikes in the spectrum.
29 class TransientSuppressorImpl : public TransientSuppressor {
30  public:
31   TransientSuppressorImpl(VadMode vad_mode,
32                           int sample_rate_hz,
33                           int detector_rate_hz,
34                           int num_channels);
35   ~TransientSuppressorImpl() override;
36 
37   void Initialize(int sample_rate_hz,
38                   int detector_rate_hz,
39                   int num_channels) override;
40 
41   float Suppress(float* data,
42                  size_t data_length,
43                  int num_channels,
44                  const float* detection_data,
45                  size_t detection_length,
46                  const float* reference_data,
47                  size_t reference_length,
48                  float voice_probability,
49                  bool key_pressed) override;
50 
51  private:
52   FRIEND_TEST_ALL_PREFIXES(TransientSuppressorVadModeParametrization,
53                            TypingDetectionLogicWorksAsExpectedForMono);
54   void Suppress(float* in_ptr, float* spectral_mean, float* out_ptr);
55 
56   void UpdateKeypress(bool key_pressed);
57   void UpdateRestoration(float voice_probability);
58 
59   void UpdateBuffers(float* data);
60 
61   void HardRestoration(float* spectral_mean);
62   void SoftRestoration(float* spectral_mean);
63 
64   const VadMode vad_mode_;
65   VoiceProbabilityDelayUnit voice_probability_delay_unit_;
66 
67   std::unique_ptr<TransientDetector> detector_;
68 
69   bool analyzed_audio_is_silent_;
70 
71   size_t data_length_;
72   size_t detection_length_;
73   size_t analysis_length_;
74   size_t buffer_delay_;
75   size_t complex_analysis_length_;
76   int num_channels_;
77   // Input buffer where the original samples are stored.
78   std::unique_ptr<float[]> in_buffer_;
79   std::unique_ptr<float[]> detection_buffer_;
80   // Output buffer where the restored samples are stored.
81   std::unique_ptr<float[]> out_buffer_;
82 
83   // Arrays for fft.
84   std::unique_ptr<size_t[]> ip_;
85   std::unique_ptr<float[]> wfft_;
86 
87   std::unique_ptr<float[]> spectral_mean_;
88 
89   // Stores the data for the fft.
90   std::unique_ptr<float[]> fft_buffer_;
91 
92   std::unique_ptr<float[]> magnitudes_;
93 
94   const float* window_;
95 
96   std::unique_ptr<float[]> mean_factor_;
97 
98   float detector_smoothed_;
99 
100   int keypress_counter_;
101   int chunks_since_keypress_;
102   bool detection_enabled_;
103   bool suppression_enabled_;
104 
105   bool use_hard_restoration_;
106   int chunks_since_voice_change_;
107 
108   uint32_t seed_;
109 
110   bool using_reference_;
111 };
112 
113 }  // namespace webrtc
114 
115 #endif  // MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_IMPL_H_
116