1 /*
2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_coding/neteq/tools/neteq_quality_test.h"
12
13 #include <stdio.h>
14
15 #include <cmath>
16
17 #include "absl/flags/flag.h"
18 #include "absl/strings/string_view.h"
19 #include "modules/audio_coding/neteq/default_neteq_factory.h"
20 #include "modules/audio_coding/neteq/tools/neteq_quality_test.h"
21 #include "modules/audio_coding/neteq/tools/output_audio_file.h"
22 #include "modules/audio_coding/neteq/tools/output_wav_file.h"
23 #include "modules/audio_coding/neteq/tools/resample_input_audio_file.h"
24 #include "rtc_base/checks.h"
25 #include "rtc_base/string_encode.h"
26 #include "system_wrappers/include/clock.h"
27 #include "test/testsupport/file_utils.h"
28
29 ABSL_FLAG(std::string,
30 in_filename,
31 "audio_coding/speech_mono_16kHz.pcm",
32 "Path of the input file (relative to the resources/ directory) for "
33 "input audio (specify sample rate with --input_sample_rate, "
34 "and channels with --channels).");
35
36 ABSL_FLAG(int, input_sample_rate, 16000, "Sample rate of input file in Hz.");
37
38 ABSL_FLAG(int, channels, 1, "Number of channels in input audio.");
39
40 ABSL_FLAG(std::string,
41 out_filename,
42 "neteq_quality_test_out.pcm",
43 "Name of output audio file, which will be saved in " +
44 ::webrtc::test::OutputPath());
45
46 ABSL_FLAG(
47 int,
48 runtime_ms,
49 10000,
50 "Simulated runtime (milliseconds). -1 will consume the complete file.");
51
52 ABSL_FLAG(int, packet_loss_rate, 10, "Percentile of packet loss.");
53
54 ABSL_FLAG(int,
55 random_loss_mode,
56 ::webrtc::test::kUniformLoss,
57 "Random loss mode: 0--no loss, 1--uniform loss, 2--Gilbert Elliot "
58 "loss, 3--fixed loss.");
59
60 ABSL_FLAG(int,
61 burst_length,
62 30,
63 "Burst length in milliseconds, only valid for Gilbert Elliot loss.");
64
65 ABSL_FLAG(float, drift_factor, 0.0, "Time drift factor.");
66
67 ABSL_FLAG(int,
68 preload_packets,
69 1,
70 "Preload the buffer with this many packets.");
71
72 ABSL_FLAG(std::string,
73 loss_events,
74 "",
75 "List of loss events time and duration separated by comma: "
76 "<first_event_time> <first_event_duration>, <second_event_time> "
77 "<second_event_duration>, ...");
78
79 namespace webrtc {
80 namespace test {
81
82 namespace {
83
CreateNetEq(const NetEq::Config & config,Clock * clock,const rtc::scoped_refptr<AudioDecoderFactory> & decoder_factory)84 std::unique_ptr<NetEq> CreateNetEq(
85 const NetEq::Config& config,
86 Clock* clock,
87 const rtc::scoped_refptr<AudioDecoderFactory>& decoder_factory) {
88 return DefaultNetEqFactory().CreateNetEq(config, decoder_factory, clock);
89 }
90
GetInFilenamePath(absl::string_view file_name)91 const std::string& GetInFilenamePath(absl::string_view file_name) {
92 std::vector<absl::string_view> name_parts = rtc::split(file_name, '.');
93 RTC_CHECK_EQ(name_parts.size(), 2);
94 static const std::string path =
95 ::webrtc::test::ResourcePath(name_parts[0], name_parts[1]);
96 return path;
97 }
98
GetOutFilenamePath(absl::string_view file_name)99 const std::string& GetOutFilenamePath(absl::string_view file_name) {
100 static const std::string path =
101 ::webrtc::test::OutputPath() + std::string(file_name);
102 return path;
103 }
104
105 } // namespace
106
107 const uint8_t kPayloadType = 95;
108 const int kOutputSizeMs = 10;
109 const int kInitSeed = 0x12345678;
110 const int kPacketLossTimeUnitMs = 10;
111
112 // Common validator for file names.
ValidateFilename(absl::string_view value,bool is_output)113 static bool ValidateFilename(absl::string_view value, bool is_output) {
114 if (!is_output) {
115 RTC_CHECK_NE(value.substr(value.find_last_of('.') + 1), "wav")
116 << "WAV file input is not supported";
117 }
118 FILE* fid = is_output ? fopen(std::string(value).c_str(), "wb")
119 : fopen(std::string(value).c_str(), "rb");
120 if (fid == nullptr)
121 return false;
122 fclose(fid);
123 return true;
124 }
125
126 // ProbTrans00Solver() is to calculate the transition probability from no-loss
127 // state to itself in a modified Gilbert Elliot packet loss model. The result is
128 // to achieve the target packet loss rate `loss_rate`, when a packet is not
129 // lost only if all `units` drawings within the duration of the packet result in
130 // no-loss.
ProbTrans00Solver(int units,double loss_rate,double prob_trans_10)131 static double ProbTrans00Solver(int units,
132 double loss_rate,
133 double prob_trans_10) {
134 if (units == 1)
135 return prob_trans_10 / (1.0f - loss_rate) - prob_trans_10;
136 // 0 == prob_trans_00 ^ (units - 1) + (1 - loss_rate) / prob_trans_10 *
137 // prob_trans_00 - (1 - loss_rate) * (1 + 1 / prob_trans_10).
138 // There is a unique solution between 0.0 and 1.0, due to the monotonicity and
139 // an opposite sign at 0.0 and 1.0.
140 // For simplicity, we reformulate the equation as
141 // f(x) = x ^ (units - 1) + a x + b.
142 // Its derivative is
143 // f'(x) = (units - 1) x ^ (units - 2) + a.
144 // The derivative is strictly greater than 0 when x is between 0 and 1.
145 // We use Newton's method to solve the equation, iteration is
146 // x(k+1) = x(k) - f(x) / f'(x);
147 const double kPrecision = 0.001f;
148 const int kIterations = 100;
149 const double a = (1.0f - loss_rate) / prob_trans_10;
150 const double b = (loss_rate - 1.0f) * (1.0f + 1.0f / prob_trans_10);
151 double x = 0.0; // Starting point;
152 double f = b;
153 double f_p;
154 int iter = 0;
155 while ((f >= kPrecision || f <= -kPrecision) && iter < kIterations) {
156 f_p = (units - 1.0f) * std::pow(x, units - 2) + a;
157 x -= f / f_p;
158 if (x > 1.0f) {
159 x = 1.0f;
160 } else if (x < 0.0f) {
161 x = 0.0f;
162 }
163 f = std::pow(x, units - 1) + a * x + b;
164 iter++;
165 }
166 return x;
167 }
168
NetEqQualityTest(int block_duration_ms,int in_sampling_khz,int out_sampling_khz,const SdpAudioFormat & format,const rtc::scoped_refptr<AudioDecoderFactory> & decoder_factory)169 NetEqQualityTest::NetEqQualityTest(
170 int block_duration_ms,
171 int in_sampling_khz,
172 int out_sampling_khz,
173 const SdpAudioFormat& format,
174 const rtc::scoped_refptr<AudioDecoderFactory>& decoder_factory)
175 : audio_format_(format),
176 channels_(absl::GetFlag(FLAGS_channels)),
177 decoded_time_ms_(0),
178 decodable_time_ms_(0),
179 drift_factor_(absl::GetFlag(FLAGS_drift_factor)),
180 packet_loss_rate_(absl::GetFlag(FLAGS_packet_loss_rate)),
181 block_duration_ms_(block_duration_ms),
182 in_sampling_khz_(in_sampling_khz),
183 out_sampling_khz_(out_sampling_khz),
184 in_size_samples_(
185 static_cast<size_t>(in_sampling_khz_ * block_duration_ms_)),
186 payload_size_bytes_(0),
187 max_payload_bytes_(0),
188 in_file_(new ResampleInputAudioFile(
189 GetInFilenamePath(absl::GetFlag(FLAGS_in_filename)),
190 absl::GetFlag(FLAGS_input_sample_rate),
191 in_sampling_khz * 1000,
192 absl::GetFlag(FLAGS_runtime_ms) > 0)),
193 rtp_generator_(
194 new RtpGenerator(in_sampling_khz_, 0, 0, decodable_time_ms_)),
195 total_payload_size_bytes_(0) {
196 // Flag validation
197 RTC_CHECK(ValidateFilename(
198 GetInFilenamePath(absl::GetFlag(FLAGS_in_filename)), false))
199 << "Invalid input filename.";
200
201 RTC_CHECK(absl::GetFlag(FLAGS_input_sample_rate) == 8000 ||
202 absl::GetFlag(FLAGS_input_sample_rate) == 16000 ||
203 absl::GetFlag(FLAGS_input_sample_rate) == 32000 ||
204 absl::GetFlag(FLAGS_input_sample_rate) == 48000)
205 << "Invalid sample rate should be 8000, 16000, 32000 or 48000 Hz.";
206
207 RTC_CHECK_EQ(absl::GetFlag(FLAGS_channels), 1)
208 << "Invalid number of channels, current support only 1.";
209
210 RTC_CHECK(ValidateFilename(
211 GetOutFilenamePath(absl::GetFlag(FLAGS_out_filename)), true))
212 << "Invalid output filename.";
213
214 RTC_CHECK(absl::GetFlag(FLAGS_packet_loss_rate) >= 0 &&
215 absl::GetFlag(FLAGS_packet_loss_rate) <= 100)
216 << "Invalid packet loss percentile, should be between 0 and 100.";
217
218 RTC_CHECK(absl::GetFlag(FLAGS_random_loss_mode) >= 0 &&
219 absl::GetFlag(FLAGS_random_loss_mode) < kLastLossMode)
220 << "Invalid random packet loss mode, should be between 0 and "
221 << kLastLossMode - 1 << ".";
222
223 RTC_CHECK_GE(absl::GetFlag(FLAGS_burst_length), kPacketLossTimeUnitMs)
224 << "Invalid burst length, should be greater than or equal to "
225 << kPacketLossTimeUnitMs << " ms.";
226
227 RTC_CHECK_GT(absl::GetFlag(FLAGS_drift_factor), -0.1)
228 << "Invalid drift factor, should be greater than -0.1.";
229
230 RTC_CHECK_GE(absl::GetFlag(FLAGS_preload_packets), 0)
231 << "Invalid number of packets to preload; must be non-negative.";
232
233 const std::string out_filename =
234 GetOutFilenamePath(absl::GetFlag(FLAGS_out_filename));
235 const std::string log_filename = out_filename + ".log";
236 log_file_.open(log_filename.c_str(), std::ofstream::out);
237 RTC_CHECK(log_file_.is_open());
238
239 if (out_filename.size() >= 4 &&
240 out_filename.substr(out_filename.size() - 4) == ".wav") {
241 // Open a wav file.
242 output_.reset(
243 new webrtc::test::OutputWavFile(out_filename, 1000 * out_sampling_khz));
244 } else {
245 // Open a pcm file.
246 output_.reset(new webrtc::test::OutputAudioFile(out_filename));
247 }
248
249 NetEq::Config config;
250 config.sample_rate_hz = out_sampling_khz_ * 1000;
251 neteq_ = CreateNetEq(config, Clock::GetRealTimeClock(), decoder_factory);
252 max_payload_bytes_ = in_size_samples_ * channels_ * sizeof(int16_t);
253 in_data_.reset(new int16_t[in_size_samples_ * channels_]);
254 }
255
~NetEqQualityTest()256 NetEqQualityTest::~NetEqQualityTest() {
257 log_file_.close();
258 }
259
Lost(int now_ms)260 bool NoLoss::Lost(int now_ms) {
261 return false;
262 }
263
UniformLoss(double loss_rate)264 UniformLoss::UniformLoss(double loss_rate) : loss_rate_(loss_rate) {}
265
Lost(int now_ms)266 bool UniformLoss::Lost(int now_ms) {
267 int drop_this = rand();
268 return (drop_this < loss_rate_ * RAND_MAX);
269 }
270
GilbertElliotLoss(double prob_trans_11,double prob_trans_01)271 GilbertElliotLoss::GilbertElliotLoss(double prob_trans_11, double prob_trans_01)
272 : prob_trans_11_(prob_trans_11),
273 prob_trans_01_(prob_trans_01),
274 lost_last_(false),
275 uniform_loss_model_(new UniformLoss(0)) {}
276
~GilbertElliotLoss()277 GilbertElliotLoss::~GilbertElliotLoss() {}
278
Lost(int now_ms)279 bool GilbertElliotLoss::Lost(int now_ms) {
280 // Simulate bursty channel (Gilbert model).
281 // (1st order) Markov chain model with memory of the previous/last
282 // packet state (lost or received).
283 if (lost_last_) {
284 // Previous packet was not received.
285 uniform_loss_model_->set_loss_rate(prob_trans_11_);
286 return lost_last_ = uniform_loss_model_->Lost(now_ms);
287 } else {
288 uniform_loss_model_->set_loss_rate(prob_trans_01_);
289 return lost_last_ = uniform_loss_model_->Lost(now_ms);
290 }
291 }
292
FixedLossModel(std::set<FixedLossEvent,FixedLossEventCmp> loss_events)293 FixedLossModel::FixedLossModel(
294 std::set<FixedLossEvent, FixedLossEventCmp> loss_events)
295 : loss_events_(loss_events) {
296 loss_events_it_ = loss_events_.begin();
297 }
298
~FixedLossModel()299 FixedLossModel::~FixedLossModel() {}
300
Lost(int now_ms)301 bool FixedLossModel::Lost(int now_ms) {
302 if (loss_events_it_ != loss_events_.end() &&
303 now_ms > loss_events_it_->start_ms) {
304 if (now_ms <= loss_events_it_->start_ms + loss_events_it_->duration_ms) {
305 return true;
306 } else {
307 ++loss_events_it_;
308 return false;
309 }
310 }
311 return false;
312 }
313
SetUp()314 void NetEqQualityTest::SetUp() {
315 ASSERT_TRUE(neteq_->RegisterPayloadType(kPayloadType, audio_format_));
316 rtp_generator_->set_drift_factor(drift_factor_);
317
318 int units = block_duration_ms_ / kPacketLossTimeUnitMs;
319 switch (absl::GetFlag(FLAGS_random_loss_mode)) {
320 case kUniformLoss: {
321 // `unit_loss_rate` is the packet loss rate for each unit time interval
322 // (kPacketLossTimeUnitMs). Since a packet loss event is generated if any
323 // of |block_duration_ms_ / kPacketLossTimeUnitMs| unit time intervals of
324 // a full packet duration is drawn with a loss, `unit_loss_rate` fulfills
325 // (1 - unit_loss_rate) ^ (block_duration_ms_ / kPacketLossTimeUnitMs) ==
326 // 1 - packet_loss_rate.
327 double unit_loss_rate =
328 (1.0 - std::pow(1.0 - 0.01 * packet_loss_rate_, 1.0 / units));
329 loss_model_.reset(new UniformLoss(unit_loss_rate));
330 break;
331 }
332 case kGilbertElliotLoss: {
333 // `FLAGS_burst_length` should be integer times of kPacketLossTimeUnitMs.
334 ASSERT_EQ(0, absl::GetFlag(FLAGS_burst_length) % kPacketLossTimeUnitMs);
335
336 // We do not allow 100 percent packet loss in Gilbert Elliot model, which
337 // makes no sense.
338 ASSERT_GT(100, packet_loss_rate_);
339
340 // To guarantee the overall packet loss rate, transition probabilities
341 // need to satisfy:
342 // pi_0 * (1 - prob_trans_01_) ^ units +
343 // pi_1 * prob_trans_10_ ^ (units - 1) == 1 - loss_rate
344 // pi_0 = prob_trans_10 / (prob_trans_10 + prob_trans_01_)
345 // is the stationary state probability of no-loss
346 // pi_1 = prob_trans_01_ / (prob_trans_10 + prob_trans_01_)
347 // is the stationary state probability of loss
348 // After a derivation prob_trans_00 should satisfy:
349 // prob_trans_00 ^ (units - 1) = (loss_rate - 1) / prob_trans_10 *
350 // prob_trans_00 + (1 - loss_rate) * (1 + 1 / prob_trans_10).
351 double loss_rate = 0.01f * packet_loss_rate_;
352 double prob_trans_10 =
353 1.0f * kPacketLossTimeUnitMs / absl::GetFlag(FLAGS_burst_length);
354 double prob_trans_00 = ProbTrans00Solver(units, loss_rate, prob_trans_10);
355 loss_model_.reset(
356 new GilbertElliotLoss(1.0f - prob_trans_10, 1.0f - prob_trans_00));
357 break;
358 }
359 case kFixedLoss: {
360 std::istringstream loss_events_stream(absl::GetFlag(FLAGS_loss_events));
361 std::string loss_event_string;
362 std::set<FixedLossEvent, FixedLossEventCmp> loss_events;
363 while (std::getline(loss_events_stream, loss_event_string, ',')) {
364 std::vector<int> loss_event_params;
365 std::istringstream loss_event_params_stream(loss_event_string);
366 std::copy(std::istream_iterator<int>(loss_event_params_stream),
367 std::istream_iterator<int>(),
368 std::back_inserter(loss_event_params));
369 RTC_CHECK_EQ(loss_event_params.size(), 2);
370 auto result = loss_events.insert(
371 FixedLossEvent(loss_event_params[0], loss_event_params[1]));
372 RTC_CHECK(result.second);
373 }
374 RTC_CHECK_GT(loss_events.size(), 0);
375 loss_model_.reset(new FixedLossModel(loss_events));
376 break;
377 }
378 default: {
379 loss_model_.reset(new NoLoss);
380 break;
381 }
382 }
383
384 // Make sure that the packet loss profile is same for all derived tests.
385 srand(kInitSeed);
386 }
387
Log()388 std::ofstream& NetEqQualityTest::Log() {
389 return log_file_;
390 }
391
PacketLost()392 bool NetEqQualityTest::PacketLost() {
393 int cycles = block_duration_ms_ / kPacketLossTimeUnitMs;
394
395 // The loop is to make sure that codecs with different block lengths share the
396 // same packet loss profile.
397 bool lost = false;
398 for (int idx = 0; idx < cycles; idx++) {
399 if (loss_model_->Lost(decoded_time_ms_)) {
400 // The packet will be lost if any of the drawings indicates a loss, but
401 // the loop has to go on to make sure that codecs with different block
402 // lengths keep the same pace.
403 lost = true;
404 }
405 }
406 return lost;
407 }
408
Transmit()409 int NetEqQualityTest::Transmit() {
410 int packet_input_time_ms = rtp_generator_->GetRtpHeader(
411 kPayloadType, in_size_samples_, &rtp_header_);
412 Log() << "Packet of size " << payload_size_bytes_ << " bytes, for frame at "
413 << packet_input_time_ms << " ms ";
414 if (payload_size_bytes_ > 0) {
415 if (!PacketLost()) {
416 int ret = neteq_->InsertPacket(
417 rtp_header_,
418 rtc::ArrayView<const uint8_t>(payload_.data(), payload_size_bytes_));
419 if (ret != NetEq::kOK)
420 return -1;
421 Log() << "was sent.";
422 } else {
423 Log() << "was lost.";
424 }
425 }
426 Log() << std::endl;
427 return packet_input_time_ms;
428 }
429
DecodeBlock()430 int NetEqQualityTest::DecodeBlock() {
431 bool muted;
432 int ret = neteq_->GetAudio(&out_frame_, &muted);
433 RTC_CHECK(!muted);
434
435 if (ret != NetEq::kOK) {
436 return -1;
437 } else {
438 RTC_DCHECK_EQ(out_frame_.num_channels_, channels_);
439 RTC_DCHECK_EQ(out_frame_.samples_per_channel_,
440 static_cast<size_t>(kOutputSizeMs * out_sampling_khz_));
441 RTC_CHECK(output_->WriteArray(
442 out_frame_.data(),
443 out_frame_.samples_per_channel_ * out_frame_.num_channels_));
444 return static_cast<int>(out_frame_.samples_per_channel_);
445 }
446 }
447
Simulate()448 void NetEqQualityTest::Simulate() {
449 int audio_size_samples;
450 bool end_of_input = false;
451 int runtime_ms = absl::GetFlag(FLAGS_runtime_ms) >= 0
452 ? absl::GetFlag(FLAGS_runtime_ms)
453 : INT_MAX;
454
455 while (!end_of_input && decoded_time_ms_ < runtime_ms) {
456 // Preload the buffer if needed.
457 while (decodable_time_ms_ -
458 absl::GetFlag(FLAGS_preload_packets) * block_duration_ms_ <
459 decoded_time_ms_) {
460 if (!in_file_->Read(in_size_samples_ * channels_, &in_data_[0])) {
461 end_of_input = true;
462 ASSERT_TRUE(end_of_input && absl::GetFlag(FLAGS_runtime_ms) < 0);
463 break;
464 }
465 payload_.Clear();
466 payload_size_bytes_ = EncodeBlock(&in_data_[0], in_size_samples_,
467 &payload_, max_payload_bytes_);
468 total_payload_size_bytes_ += payload_size_bytes_;
469 decodable_time_ms_ = Transmit() + block_duration_ms_;
470 }
471 audio_size_samples = DecodeBlock();
472 if (audio_size_samples > 0) {
473 decoded_time_ms_ += audio_size_samples / out_sampling_khz_;
474 }
475 }
476 Log() << "Average bit rate was "
477 << 8.0f * total_payload_size_bytes_ / absl::GetFlag(FLAGS_runtime_ms)
478 << " kbps" << std::endl;
479 }
480
481 } // namespace test
482 } // namespace webrtc
483