xref: /aosp_15_r20/external/webrtc/modules/audio_processing/agc2/rnn_vad/test_utils.h (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
13 
14 #include <array>
15 #include <fstream>
16 #include <memory>
17 #include <string>
18 
19 #include "absl/strings/string_view.h"
20 #include "api/array_view.h"
21 #include "modules/audio_processing/agc2/rnn_vad/common.h"
22 #include "rtc_base/checks.h"
23 #include "rtc_base/numerics/safe_compare.h"
24 
25 namespace webrtc {
26 namespace rnn_vad {
27 
28 constexpr float kFloatMin = std::numeric_limits<float>::min();
29 
30 // Fails for every pair from two equally sized rtc::ArrayView<float> views such
31 // that the values in the pair do not match.
32 void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
33                            rtc::ArrayView<const float> computed);
34 
35 // Fails for every pair from two equally sized rtc::ArrayView<float> views such
36 // that their absolute error is above a given threshold.
37 void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
38                         rtc::ArrayView<const float> computed,
39                         float tolerance);
40 
41 // File reader interface.
42 class FileReader {
43  public:
44   virtual ~FileReader() = default;
45   // Number of values in the file.
46   virtual int size() const = 0;
47   // Reads `dst.size()` float values into `dst`, advances the internal file
48   // position according to the number of read bytes and returns true if the
49   // values are correctly read. If the number of remaining bytes in the file is
50   // not sufficient to read `dst.size()` float values, `dst` is partially
51   // modified and false is returned.
52   virtual bool ReadChunk(rtc::ArrayView<float> dst) = 0;
53   // Reads a single float value, advances the internal file position according
54   // to the number of read bytes and returns true if the value is correctly
55   // read. If the number of remaining bytes in the file is not sufficient to
56   // read one float, `dst` is not modified and false is returned.
57   virtual bool ReadValue(float& dst) = 0;
58   // Advances the internal file position by `hop` float values.
59   virtual void SeekForward(int hop) = 0;
60   // Resets the internal file position to BOF.
61   virtual void SeekBeginning() = 0;
62 };
63 
64 // File reader for files that contain `num_chunks` chunks with size equal to
65 // `chunk_size`.
66 struct ChunksFileReader {
67   const int chunk_size;
68   const int num_chunks;
69   std::unique_ptr<FileReader> reader;
70 };
71 
72 // Creates a reader for the PCM S16 samples file.
73 std::unique_ptr<FileReader> CreatePcmSamplesReader();
74 
75 // Creates a reader for the 24 kHz pitch buffer test data.
76 ChunksFileReader CreatePitchBuffer24kHzReader();
77 
78 // Creates a reader for the LP residual and pitch information test data.
79 ChunksFileReader CreateLpResidualAndPitchInfoReader();
80 
81 // Creates a reader for the sequence of GRU input vectors.
82 std::unique_ptr<FileReader> CreateGruInputReader();
83 
84 // Creates a reader for the VAD probabilities test data.
85 std::unique_ptr<FileReader> CreateVadProbsReader();
86 
87 // Class to retrieve a test pitch buffer content and the expected output for the
88 // analysis steps.
89 class PitchTestData {
90  public:
91   PitchTestData();
92   ~PitchTestData();
PitchBuffer24kHzView()93   rtc::ArrayView<const float, kBufSize24kHz> PitchBuffer24kHzView() const {
94     return pitch_buffer_24k_;
95   }
SquareEnergies24kHzView()96   rtc::ArrayView<const float, kRefineNumLags24kHz> SquareEnergies24kHzView()
97       const {
98     return square_energies_24k_;
99   }
AutoCorrelation12kHzView()100   rtc::ArrayView<const float, kNumLags12kHz> AutoCorrelation12kHzView() const {
101     return auto_correlation_12k_;
102   }
103 
104  private:
105   std::array<float, kBufSize24kHz> pitch_buffer_24k_;
106   std::array<float, kRefineNumLags24kHz> square_energies_24k_;
107   std::array<float, kNumLags12kHz> auto_correlation_12k_;
108 };
109 
110 // Writer for binary files.
111 class FileWriter {
112  public:
FileWriter(absl::string_view file_path)113   explicit FileWriter(absl::string_view file_path)
114       : os_(std::string(file_path), std::ios::binary) {}
115   FileWriter(const FileWriter&) = delete;
116   FileWriter& operator=(const FileWriter&) = delete;
117   ~FileWriter() = default;
WriteChunk(rtc::ArrayView<const float> value)118   void WriteChunk(rtc::ArrayView<const float> value) {
119     const std::streamsize bytes_to_write = value.size() * sizeof(float);
120     os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
121   }
122 
123  private:
124   std::ofstream os_;
125 };
126 
127 }  // namespace rnn_vad
128 }  // namespace webrtc
129 
130 #endif  // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
131