xref: /aosp_15_r20/external/webrtc/modules/audio_processing/transient/voice_probability_delay_unit.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2022 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/transient/voice_probability_delay_unit.h"
12 
13 #include <array>
14 
15 #include "rtc_base/checks.h"
16 
17 namespace webrtc {
18 
VoiceProbabilityDelayUnit(int delay_num_samples,int sample_rate_hz)19 VoiceProbabilityDelayUnit::VoiceProbabilityDelayUnit(int delay_num_samples,
20                                                      int sample_rate_hz) {
21   Initialize(delay_num_samples, sample_rate_hz);
22 }
23 
Initialize(int delay_num_samples,int sample_rate_hz)24 void VoiceProbabilityDelayUnit::Initialize(int delay_num_samples,
25                                            int sample_rate_hz) {
26   RTC_DCHECK_GE(delay_num_samples, 0);
27   RTC_DCHECK_LE(delay_num_samples, sample_rate_hz / 50)
28       << "The implementation does not support delays greater than 20 ms.";
29   int frame_size = rtc::CheckedDivExact(sample_rate_hz, 100);  // 10 ms.
30   if (delay_num_samples <= frame_size) {
31     weights_[0] = 0.0f;
32     weights_[1] = static_cast<float>(delay_num_samples) / frame_size;
33     weights_[2] =
34         static_cast<float>(frame_size - delay_num_samples) / frame_size;
35   } else {
36     delay_num_samples -= frame_size;
37     weights_[0] = static_cast<float>(delay_num_samples) / frame_size;
38     weights_[1] =
39         static_cast<float>(frame_size - delay_num_samples) / frame_size;
40     weights_[2] = 0.0f;
41   }
42 
43   // Resets the delay unit.
44   last_probabilities_.fill(0.0f);
45 }
46 
Delay(float voice_probability)47 float VoiceProbabilityDelayUnit::Delay(float voice_probability) {
48   float weighted_probability = weights_[0] * last_probabilities_[0] +
49                                weights_[1] * last_probabilities_[1] +
50                                weights_[2] * voice_probability;
51   last_probabilities_[0] = last_probabilities_[1];
52   last_probabilities_[1] = voice_probability;
53   return weighted_probability;
54 }
55 
56 }  // namespace webrtc
57