xref: /aosp_15_r20/external/webrtc/modules/audio_coding/neteq/neteq_decoder_plc_unittest.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 // Test to verify correct operation when using the decoder-internal PLC.
12 
13 #include <memory>
14 #include <utility>
15 #include <vector>
16 
17 #include "absl/types/optional.h"
18 #include "modules/audio_coding/codecs/pcm16b/audio_encoder_pcm16b.h"
19 #include "modules/audio_coding/neteq/tools/audio_checksum.h"
20 #include "modules/audio_coding/neteq/tools/audio_sink.h"
21 #include "modules/audio_coding/neteq/tools/encode_neteq_input.h"
22 #include "modules/audio_coding/neteq/tools/fake_decode_from_file.h"
23 #include "modules/audio_coding/neteq/tools/input_audio_file.h"
24 #include "modules/audio_coding/neteq/tools/neteq_test.h"
25 #include "rtc_base/numerics/safe_conversions.h"
26 #include "test/audio_decoder_proxy_factory.h"
27 #include "test/gtest.h"
28 #include "test/testsupport/file_utils.h"
29 
30 namespace webrtc {
31 namespace test {
32 namespace {
33 
34 constexpr int kSampleRateHz = 32000;
35 constexpr int kRunTimeMs = 10000;
36 
37 // This class implements a fake decoder. The decoder will read audio from a file
38 // and present as output, both for regular decoding and for PLC.
39 class AudioDecoderPlc : public AudioDecoder {
40  public:
AudioDecoderPlc(std::unique_ptr<InputAudioFile> input,int sample_rate_hz)41   AudioDecoderPlc(std::unique_ptr<InputAudioFile> input, int sample_rate_hz)
42       : input_(std::move(input)), sample_rate_hz_(sample_rate_hz) {}
43 
Reset()44   void Reset() override {}
SampleRateHz() const45   int SampleRateHz() const override { return sample_rate_hz_; }
Channels() const46   size_t Channels() const override { return 1; }
DecodeInternal(const uint8_t *,size_t encoded_len,int sample_rate_hz,int16_t * decoded,SpeechType * speech_type)47   int DecodeInternal(const uint8_t* /*encoded*/,
48                      size_t encoded_len,
49                      int sample_rate_hz,
50                      int16_t* decoded,
51                      SpeechType* speech_type) override {
52     RTC_CHECK_GE(encoded_len / 2, 10 * sample_rate_hz_ / 1000);
53     RTC_CHECK_LE(encoded_len / 2, 2 * 10 * sample_rate_hz_ / 1000);
54     RTC_CHECK_EQ(sample_rate_hz, sample_rate_hz_);
55     RTC_CHECK(decoded);
56     RTC_CHECK(speech_type);
57     RTC_CHECK(input_->Read(encoded_len / 2, decoded));
58     *speech_type = kSpeech;
59     last_was_plc_ = false;
60     return encoded_len / 2;
61   }
62 
GeneratePlc(size_t requested_samples_per_channel,rtc::BufferT<int16_t> * concealment_audio)63   void GeneratePlc(size_t requested_samples_per_channel,
64                    rtc::BufferT<int16_t>* concealment_audio) override {
65     // Instead of generating random data for GeneratePlc we use the same data as
66     // the input, so we can check that we produce the same result independently
67     // of the losses.
68     RTC_DCHECK_EQ(requested_samples_per_channel, 10 * sample_rate_hz_ / 1000);
69 
70     // Must keep a local copy of this since DecodeInternal sets it to false.
71     const bool last_was_plc = last_was_plc_;
72 
73     std::vector<int16_t> decoded(5760);
74     SpeechType speech_type;
75     int dec_len = DecodeInternal(nullptr, 2 * 10 * sample_rate_hz_ / 1000,
76                                  sample_rate_hz_, decoded.data(), &speech_type);
77     concealment_audio->AppendData(decoded.data(), dec_len);
78     concealed_samples_ += rtc::checked_cast<size_t>(dec_len);
79 
80     if (!last_was_plc) {
81       ++concealment_events_;
82     }
83     last_was_plc_ = true;
84   }
85 
concealed_samples()86   size_t concealed_samples() { return concealed_samples_; }
concealment_events()87   size_t concealment_events() { return concealment_events_; }
88 
89  private:
90   const std::unique_ptr<InputAudioFile> input_;
91   const int sample_rate_hz_;
92   size_t concealed_samples_ = 0;
93   size_t concealment_events_ = 0;
94   bool last_was_plc_ = false;
95 };
96 
97 // An input sample generator which generates only zero-samples.
98 class ZeroSampleGenerator : public EncodeNetEqInput::Generator {
99  public:
Generate(size_t num_samples)100   rtc::ArrayView<const int16_t> Generate(size_t num_samples) override {
101     vec.resize(num_samples, 0);
102     rtc::ArrayView<const int16_t> view(vec);
103     RTC_DCHECK_EQ(view.size(), num_samples);
104     return view;
105   }
106 
107  private:
108   std::vector<int16_t> vec;
109 };
110 
111 // A NetEqInput which connects to another NetEqInput, but drops a number of
112 // consecutive packets on the way
113 class LossyInput : public NetEqInput {
114  public:
LossyInput(int loss_cadence,int burst_length,std::unique_ptr<NetEqInput> input)115   LossyInput(int loss_cadence,
116              int burst_length,
117              std::unique_ptr<NetEqInput> input)
118       : loss_cadence_(loss_cadence),
119         burst_length_(burst_length),
120         input_(std::move(input)) {}
121 
NextPacketTime() const122   absl::optional<int64_t> NextPacketTime() const override {
123     return input_->NextPacketTime();
124   }
125 
NextOutputEventTime() const126   absl::optional<int64_t> NextOutputEventTime() const override {
127     return input_->NextOutputEventTime();
128   }
129 
PopPacket()130   std::unique_ptr<PacketData> PopPacket() override {
131     if (loss_cadence_ != 0 && (++count_ % loss_cadence_) == 0) {
132       // Pop `burst_length_` packets to create the loss.
133       auto packet_to_return = input_->PopPacket();
134       for (int i = 0; i < burst_length_; i++) {
135         input_->PopPacket();
136       }
137       return packet_to_return;
138     }
139     return input_->PopPacket();
140   }
141 
AdvanceOutputEvent()142   void AdvanceOutputEvent() override { return input_->AdvanceOutputEvent(); }
143 
ended() const144   bool ended() const override { return input_->ended(); }
145 
NextHeader() const146   absl::optional<RTPHeader> NextHeader() const override {
147     return input_->NextHeader();
148   }
149 
150  private:
151   const int loss_cadence_;
152   const int burst_length_;
153   int count_ = 0;
154   const std::unique_ptr<NetEqInput> input_;
155 };
156 
157 class AudioChecksumWithOutput : public AudioChecksum {
158  public:
AudioChecksumWithOutput(std::string * output_str)159   explicit AudioChecksumWithOutput(std::string* output_str)
160       : output_str_(*output_str) {}
~AudioChecksumWithOutput()161   ~AudioChecksumWithOutput() { output_str_ = Finish(); }
162 
163  private:
164   std::string& output_str_;
165 };
166 
167 struct TestStatistics {
168   NetEqNetworkStatistics network;
169   NetEqLifetimeStatistics lifetime;
170 };
171 
RunTest(int loss_cadence,int burst_length,std::string * checksum)172 TestStatistics RunTest(int loss_cadence,
173                        int burst_length,
174                        std::string* checksum) {
175   NetEq::Config config;
176   config.for_test_no_time_stretching = true;
177 
178   // The input is mostly useless. It sends zero-samples to a PCM16b encoder,
179   // but the actual encoded samples will never be used by the decoder in the
180   // test. See below about the decoder.
181   auto generator = std::make_unique<ZeroSampleGenerator>();
182   constexpr int kPayloadType = 100;
183   AudioEncoderPcm16B::Config encoder_config;
184   encoder_config.sample_rate_hz = kSampleRateHz;
185   encoder_config.payload_type = kPayloadType;
186   auto encoder = std::make_unique<AudioEncoderPcm16B>(encoder_config);
187   auto input = std::make_unique<EncodeNetEqInput>(
188       std::move(generator), std::move(encoder), kRunTimeMs);
189   // Wrap the input in a loss function.
190   auto lossy_input = std::make_unique<LossyInput>(loss_cadence, burst_length,
191                                                   std::move(input));
192 
193   // Setting up decoders.
194   NetEqTest::DecoderMap decoders;
195   // Using a fake decoder which simply reads the output audio from a file.
196   auto input_file = std::make_unique<InputAudioFile>(
197       webrtc::test::ResourcePath("audio_coding/testfile32kHz", "pcm"));
198   AudioDecoderPlc dec(std::move(input_file), kSampleRateHz);
199   // Masquerading as a PCM16b decoder.
200   decoders.emplace(kPayloadType, SdpAudioFormat("l16", 32000, 1));
201 
202   // Output is simply a checksum calculator.
203   auto output = std::make_unique<AudioChecksumWithOutput>(checksum);
204 
205   // No callback objects.
206   NetEqTest::Callbacks callbacks;
207 
208   NetEqTest neteq_test(
209       config, /*decoder_factory=*/
210       rtc::make_ref_counted<test::AudioDecoderProxyFactory>(&dec),
211       /*codecs=*/decoders, /*text_log=*/nullptr, /*neteq_factory=*/nullptr,
212       /*input=*/std::move(lossy_input), std::move(output), callbacks);
213   EXPECT_LE(kRunTimeMs, neteq_test.Run());
214 
215   auto lifetime_stats = neteq_test.LifetimeStats();
216   EXPECT_EQ(dec.concealed_samples(), lifetime_stats.concealed_samples);
217   EXPECT_EQ(dec.concealment_events(), lifetime_stats.concealment_events);
218   return {neteq_test.SimulationStats(), neteq_test.LifetimeStats()};
219 }
220 }  // namespace
221 
222 // Check that some basic metrics are produced in the right direction. In
223 // particular, expand_rate should only increase if there are losses present. Our
224 // dummy decoder is designed such as the checksum should always be the same
225 // regardless of the losses given that calls are executed in the right order.
TEST(NetEqDecoderPlc,BasicMetrics)226 TEST(NetEqDecoderPlc, BasicMetrics) {
227   std::string checksum;
228 
229   // Drop 1 packet every 10 packets.
230   auto stats = RunTest(10, 1, &checksum);
231 
232   std::string checksum_no_loss;
233   auto stats_no_loss = RunTest(0, 0, &checksum_no_loss);
234 
235   EXPECT_EQ(checksum, checksum_no_loss);
236 
237   EXPECT_EQ(stats.network.preemptive_rate,
238             stats_no_loss.network.preemptive_rate);
239   EXPECT_EQ(stats.network.accelerate_rate,
240             stats_no_loss.network.accelerate_rate);
241   EXPECT_EQ(0, stats_no_loss.network.expand_rate);
242   EXPECT_GT(stats.network.expand_rate, 0);
243 }
244 
245 // Checks that interruptions are not counted in small losses but they are
246 // correctly counted in long interruptions.
TEST(NetEqDecoderPlc,CountInterruptions)247 TEST(NetEqDecoderPlc, CountInterruptions) {
248   std::string checksum;
249   std::string checksum_2;
250   std::string checksum_3;
251 
252   // Half of the packets lost but in short interruptions.
253   auto stats_no_interruptions = RunTest(1, 1, &checksum);
254   // One lost of 500 ms (250 packets).
255   auto stats_one_interruption = RunTest(200, 250, &checksum_2);
256   // Two losses of 250ms each (125 packets).
257   auto stats_two_interruptions = RunTest(125, 125, &checksum_3);
258 
259   EXPECT_EQ(checksum, checksum_2);
260   EXPECT_EQ(checksum, checksum_3);
261   EXPECT_GT(stats_no_interruptions.network.expand_rate, 0);
262   EXPECT_EQ(stats_no_interruptions.lifetime.total_interruption_duration_ms, 0);
263   EXPECT_EQ(stats_no_interruptions.lifetime.interruption_count, 0);
264 
265   EXPECT_GT(stats_one_interruption.network.expand_rate, 0);
266   EXPECT_EQ(stats_one_interruption.lifetime.total_interruption_duration_ms,
267             5000);
268   EXPECT_EQ(stats_one_interruption.lifetime.interruption_count, 1);
269 
270   EXPECT_GT(stats_two_interruptions.network.expand_rate, 0);
271   EXPECT_EQ(stats_two_interruptions.lifetime.total_interruption_duration_ms,
272             5000);
273   EXPECT_EQ(stats_two_interruptions.lifetime.interruption_count, 2);
274 }
275 
276 // Checks that small losses do not produce interruptions.
TEST(NetEqDecoderPlc,NoInterruptionsInSmallLosses)277 TEST(NetEqDecoderPlc, NoInterruptionsInSmallLosses) {
278   std::string checksum_1;
279   std::string checksum_4;
280 
281   auto stats_1 = RunTest(300, 1, &checksum_1);
282   auto stats_4 = RunTest(300, 4, &checksum_4);
283 
284   EXPECT_EQ(checksum_1, checksum_4);
285 
286   EXPECT_EQ(stats_1.lifetime.interruption_count, 0);
287   EXPECT_EQ(stats_1.lifetime.total_interruption_duration_ms, 0);
288   EXPECT_EQ(stats_1.lifetime.concealed_samples, 640u);  // 20ms of concealment.
289   EXPECT_EQ(stats_1.lifetime.concealment_events, 1u);   // in just one event.
290 
291   EXPECT_EQ(stats_4.lifetime.interruption_count, 0);
292   EXPECT_EQ(stats_4.lifetime.total_interruption_duration_ms, 0);
293   EXPECT_EQ(stats_4.lifetime.concealed_samples, 2560u);  // 80ms of concealment.
294   EXPECT_EQ(stats_4.lifetime.concealment_events, 1u);    // in just one event.
295 }
296 
297 // Checks that interruptions of different sizes report correct duration.
TEST(NetEqDecoderPlc,InterruptionsReportCorrectSize)298 TEST(NetEqDecoderPlc, InterruptionsReportCorrectSize) {
299   std::string checksum;
300 
301   for (int burst_length = 5; burst_length < 10; burst_length++) {
302     auto stats = RunTest(300, burst_length, &checksum);
303     auto duration = stats.lifetime.total_interruption_duration_ms;
304     if (burst_length < 8) {
305       EXPECT_EQ(duration, 0);
306     } else {
307       EXPECT_EQ(duration, burst_length * 20);
308     }
309   }
310 }
311 
312 }  // namespace test
313 }  // namespace webrtc
314