xref: /aosp_15_r20/external/webrtc/modules/audio_processing/agc2/rnn_vad/test_utils.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
12 
13 #include <algorithm>
14 #include <fstream>
15 #include <memory>
16 #include <string>
17 #include <type_traits>
18 #include <vector>
19 
20 #include "absl/strings/string_view.h"
21 #include "rtc_base/checks.h"
22 #include "rtc_base/numerics/safe_compare.h"
23 #include "test/gtest.h"
24 #include "test/testsupport/file_utils.h"
25 
26 namespace webrtc {
27 namespace rnn_vad {
28 namespace {
29 
30 // File reader for binary files that contain a sequence of values with
31 // arithmetic type `T`. The values of type `T` that are read are cast to float.
32 template <typename T>
33 class FloatFileReader : public FileReader {
34  public:
35   static_assert(std::is_arithmetic<T>::value, "");
FloatFileReader(absl::string_view filename)36   explicit FloatFileReader(absl::string_view filename)
37       : is_(std::string(filename), std::ios::binary | std::ios::ate),
38         size_(is_.tellg() / sizeof(T)) {
39     RTC_CHECK(is_);
40     SeekBeginning();
41   }
42   FloatFileReader(const FloatFileReader&) = delete;
43   FloatFileReader& operator=(const FloatFileReader&) = delete;
44   ~FloatFileReader() = default;
45 
size() const46   int size() const override { return size_; }
ReadChunk(rtc::ArrayView<float> dst)47   bool ReadChunk(rtc::ArrayView<float> dst) override {
48     const std::streamsize bytes_to_read = dst.size() * sizeof(T);
49     if (std::is_same<T, float>::value) {
50       is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
51     } else {
52       buffer_.resize(dst.size());
53       is_.read(reinterpret_cast<char*>(buffer_.data()), bytes_to_read);
54       std::transform(buffer_.begin(), buffer_.end(), dst.begin(),
55                      [](const T& v) -> float { return static_cast<float>(v); });
56     }
57     return is_.gcount() == bytes_to_read;
58   }
ReadValue(float & dst)59   bool ReadValue(float& dst) override { return ReadChunk({&dst, 1}); }
SeekForward(int hop)60   void SeekForward(int hop) override { is_.seekg(hop * sizeof(T), is_.cur); }
SeekBeginning()61   void SeekBeginning() override { is_.seekg(0, is_.beg); }
62 
63  private:
64   std::ifstream is_;
65   const int size_;
66   std::vector<T> buffer_;
67 };
68 
69 }  // namespace
70 
71 using webrtc::test::ResourcePath;
72 
ExpectEqualFloatArray(rtc::ArrayView<const float> expected,rtc::ArrayView<const float> computed)73 void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
74                            rtc::ArrayView<const float> computed) {
75   ASSERT_EQ(expected.size(), computed.size());
76   for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
77     SCOPED_TRACE(i);
78     EXPECT_FLOAT_EQ(expected[i], computed[i]);
79   }
80 }
81 
ExpectNearAbsolute(rtc::ArrayView<const float> expected,rtc::ArrayView<const float> computed,float tolerance)82 void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
83                         rtc::ArrayView<const float> computed,
84                         float tolerance) {
85   ASSERT_EQ(expected.size(), computed.size());
86   for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
87     SCOPED_TRACE(i);
88     EXPECT_NEAR(expected[i], computed[i], tolerance);
89   }
90 }
91 
CreatePcmSamplesReader()92 std::unique_ptr<FileReader> CreatePcmSamplesReader() {
93   return std::make_unique<FloatFileReader<int16_t>>(
94       /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/samples",
95                                       "pcm"));
96 }
97 
CreatePitchBuffer24kHzReader()98 ChunksFileReader CreatePitchBuffer24kHzReader() {
99   auto reader = std::make_unique<FloatFileReader<float>>(
100       /*filename=*/test::ResourcePath(
101           "audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"));
102   const int num_chunks = rtc::CheckedDivExact(reader->size(), kBufSize24kHz);
103   return {/*chunk_size=*/kBufSize24kHz, num_chunks, std::move(reader)};
104 }
105 
CreateLpResidualAndPitchInfoReader()106 ChunksFileReader CreateLpResidualAndPitchInfoReader() {
107   constexpr int kPitchInfoSize = 2;  // Pitch period and strength.
108   constexpr int kChunkSize = kBufSize24kHz + kPitchInfoSize;
109   auto reader = std::make_unique<FloatFileReader<float>>(
110       /*filename=*/test::ResourcePath(
111           "audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"));
112   const int num_chunks = rtc::CheckedDivExact(reader->size(), kChunkSize);
113   return {kChunkSize, num_chunks, std::move(reader)};
114 }
115 
CreateGruInputReader()116 std::unique_ptr<FileReader> CreateGruInputReader() {
117   return std::make_unique<FloatFileReader<float>>(
118       /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/gru_in",
119                                       "dat"));
120 }
121 
CreateVadProbsReader()122 std::unique_ptr<FileReader> CreateVadProbsReader() {
123   return std::make_unique<FloatFileReader<float>>(
124       /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob",
125                                       "dat"));
126 }
127 
PitchTestData()128 PitchTestData::PitchTestData() {
129   FloatFileReader<float> reader(
130       /*filename=*/ResourcePath(
131           "audio_processing/agc2/rnn_vad/pitch_search_int", "dat"));
132   reader.ReadChunk(pitch_buffer_24k_);
133   reader.ReadChunk(square_energies_24k_);
134   reader.ReadChunk(auto_correlation_12k_);
135   // Reverse the order of the squared energy values.
136   // Required after the WebRTC CL 191703 which switched to forward computation.
137   std::reverse(square_energies_24k_.begin(), square_energies_24k_.end());
138 }
139 
140 PitchTestData::~PitchTestData() = default;
141 
142 }  // namespace rnn_vad
143 }  // namespace webrtc
144