1 /*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
12
13 #include <algorithm>
14 #include <fstream>
15 #include <memory>
16 #include <string>
17 #include <type_traits>
18 #include <vector>
19
20 #include "absl/strings/string_view.h"
21 #include "rtc_base/checks.h"
22 #include "rtc_base/numerics/safe_compare.h"
23 #include "test/gtest.h"
24 #include "test/testsupport/file_utils.h"
25
26 namespace webrtc {
27 namespace rnn_vad {
28 namespace {
29
30 // File reader for binary files that contain a sequence of values with
31 // arithmetic type `T`. The values of type `T` that are read are cast to float.
32 template <typename T>
33 class FloatFileReader : public FileReader {
34 public:
35 static_assert(std::is_arithmetic<T>::value, "");
FloatFileReader(absl::string_view filename)36 explicit FloatFileReader(absl::string_view filename)
37 : is_(std::string(filename), std::ios::binary | std::ios::ate),
38 size_(is_.tellg() / sizeof(T)) {
39 RTC_CHECK(is_);
40 SeekBeginning();
41 }
42 FloatFileReader(const FloatFileReader&) = delete;
43 FloatFileReader& operator=(const FloatFileReader&) = delete;
44 ~FloatFileReader() = default;
45
size() const46 int size() const override { return size_; }
ReadChunk(rtc::ArrayView<float> dst)47 bool ReadChunk(rtc::ArrayView<float> dst) override {
48 const std::streamsize bytes_to_read = dst.size() * sizeof(T);
49 if (std::is_same<T, float>::value) {
50 is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
51 } else {
52 buffer_.resize(dst.size());
53 is_.read(reinterpret_cast<char*>(buffer_.data()), bytes_to_read);
54 std::transform(buffer_.begin(), buffer_.end(), dst.begin(),
55 [](const T& v) -> float { return static_cast<float>(v); });
56 }
57 return is_.gcount() == bytes_to_read;
58 }
ReadValue(float & dst)59 bool ReadValue(float& dst) override { return ReadChunk({&dst, 1}); }
SeekForward(int hop)60 void SeekForward(int hop) override { is_.seekg(hop * sizeof(T), is_.cur); }
SeekBeginning()61 void SeekBeginning() override { is_.seekg(0, is_.beg); }
62
63 private:
64 std::ifstream is_;
65 const int size_;
66 std::vector<T> buffer_;
67 };
68
69 } // namespace
70
71 using webrtc::test::ResourcePath;
72
ExpectEqualFloatArray(rtc::ArrayView<const float> expected,rtc::ArrayView<const float> computed)73 void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
74 rtc::ArrayView<const float> computed) {
75 ASSERT_EQ(expected.size(), computed.size());
76 for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
77 SCOPED_TRACE(i);
78 EXPECT_FLOAT_EQ(expected[i], computed[i]);
79 }
80 }
81
ExpectNearAbsolute(rtc::ArrayView<const float> expected,rtc::ArrayView<const float> computed,float tolerance)82 void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
83 rtc::ArrayView<const float> computed,
84 float tolerance) {
85 ASSERT_EQ(expected.size(), computed.size());
86 for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
87 SCOPED_TRACE(i);
88 EXPECT_NEAR(expected[i], computed[i], tolerance);
89 }
90 }
91
CreatePcmSamplesReader()92 std::unique_ptr<FileReader> CreatePcmSamplesReader() {
93 return std::make_unique<FloatFileReader<int16_t>>(
94 /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/samples",
95 "pcm"));
96 }
97
CreatePitchBuffer24kHzReader()98 ChunksFileReader CreatePitchBuffer24kHzReader() {
99 auto reader = std::make_unique<FloatFileReader<float>>(
100 /*filename=*/test::ResourcePath(
101 "audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"));
102 const int num_chunks = rtc::CheckedDivExact(reader->size(), kBufSize24kHz);
103 return {/*chunk_size=*/kBufSize24kHz, num_chunks, std::move(reader)};
104 }
105
CreateLpResidualAndPitchInfoReader()106 ChunksFileReader CreateLpResidualAndPitchInfoReader() {
107 constexpr int kPitchInfoSize = 2; // Pitch period and strength.
108 constexpr int kChunkSize = kBufSize24kHz + kPitchInfoSize;
109 auto reader = std::make_unique<FloatFileReader<float>>(
110 /*filename=*/test::ResourcePath(
111 "audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"));
112 const int num_chunks = rtc::CheckedDivExact(reader->size(), kChunkSize);
113 return {kChunkSize, num_chunks, std::move(reader)};
114 }
115
CreateGruInputReader()116 std::unique_ptr<FileReader> CreateGruInputReader() {
117 return std::make_unique<FloatFileReader<float>>(
118 /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/gru_in",
119 "dat"));
120 }
121
CreateVadProbsReader()122 std::unique_ptr<FileReader> CreateVadProbsReader() {
123 return std::make_unique<FloatFileReader<float>>(
124 /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob",
125 "dat"));
126 }
127
PitchTestData()128 PitchTestData::PitchTestData() {
129 FloatFileReader<float> reader(
130 /*filename=*/ResourcePath(
131 "audio_processing/agc2/rnn_vad/pitch_search_int", "dat"));
132 reader.ReadChunk(pitch_buffer_24k_);
133 reader.ReadChunk(square_energies_24k_);
134 reader.ReadChunk(auto_correlation_12k_);
135 // Reverse the order of the squared energy values.
136 // Required after the WebRTC CL 191703 which switched to forward computation.
137 std::reverse(square_energies_24k_.begin(), square_energies_24k_.end());
138 }
139
140 PitchTestData::~PitchTestData() = default;
141
142 } // namespace rnn_vad
143 } // namespace webrtc
144