1 /*
2 * Copyright (C) 2022 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // #define LOG_NDEBUG 0
18 #define LOG_TAG "audio_utils_mel_aggregator_tests"
19
20 #include <audio_utils/MelAggregator.h>
21
22 #include <gtest/gtest.h>
23 #include <gmock/gmock.h>
24
25 namespace android::audio_utils {
26 namespace {
27
28 constexpr int32_t kTestPortId = 1;
29 constexpr float kFloatError = 0.1f;
30 constexpr float kMelFloatError = 0.0001f;
31
32 /** Value used for CSD calculation. 3 MELs with this value will cause a change of 1% in CSD. */
33 constexpr float kCustomMelDbA = 107.f;
34
35 using ::testing::ElementsAre;
36 using ::testing::Pointwise;
37 using ::testing::FloatNear;
38
TEST(MelAggregatorTest,ResetAggregator)39 TEST(MelAggregatorTest, ResetAggregator) {
40 MelAggregator aggregator{100};
41
42 aggregator.aggregateAndAddNewMelRecord(MelRecord(1, {10.f, 10.f}, 0));
43 aggregator.reset(1.f, {CsdRecord(1, 1, 1.f, 1.f)});
44
45 EXPECT_EQ(aggregator.getCachedMelRecordsSize(), size_t{0});
46 EXPECT_EQ(aggregator.getCsd(), 1.f);
47 EXPECT_EQ(aggregator.getCsdRecordsSize(), size_t{1});
48 }
49
TEST(MelAggregatorTest,AggregateValuesFromDifferentStreams)50 TEST(MelAggregatorTest, AggregateValuesFromDifferentStreams) {
51 MelAggregator aggregator{/* csdWindowSeconds */ 100};
52
53 aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {10.f, 10.f},
54 /* timestamp */0));
55 aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {10.f, 10.f},
56 /* timestamp */0));
57
58 ASSERT_EQ(aggregator.getCachedMelRecordsSize(), size_t{1});
59 aggregator.foreachCachedMel([](const MelRecord &record) {
60 EXPECT_EQ(record.portId, kTestPortId);
61 EXPECT_THAT(record.mels, Pointwise(FloatNear(kFloatError), {13.f, 13.f}));
62 });
63 }
64
TEST(MelAggregatorTest,AggregateWithOlderValues)65 TEST(MelAggregatorTest, AggregateWithOlderValues) {
66 MelAggregator aggregator{/* csdWindowSeconds */ 100};
67
68 aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {1.f, 1.f},
69 /* timestamp */1));
70 // second mel array contains values that are older than the first entry
71 aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {2.f, 2.f, 2.f},
72 /* timestamp */0));
73
74 ASSERT_EQ(aggregator.getCachedMelRecordsSize(), size_t{1});
75 aggregator.foreachCachedMel([](const MelRecord &record) {
76 EXPECT_EQ(record.portId, kTestPortId);
77 EXPECT_THAT(record.mels, Pointwise(FloatNear(kFloatError), {2.f, 4.5f, 4.5f}));
78 });
79 }
80
TEST(MelAggregatorTest,AggregateWithNewerValues)81 TEST(MelAggregatorTest, AggregateWithNewerValues) {
82 MelAggregator aggregator{/* csdWindowSeconds */ 100};
83
84 aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {1.f, 1.f},
85 /* timestamp */1));
86 // second mel array contains values that are older than the first entry
87 aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {2.f, 2.f},
88 /* timestamp */2));
89
90 ASSERT_EQ(aggregator.getCachedMelRecordsSize(), size_t{1});
91 aggregator.foreachCachedMel([](const MelRecord &record) {
92 EXPECT_EQ(record.portId, kTestPortId);
93 EXPECT_THAT(record.mels, Pointwise(FloatNear(kFloatError), {1.f, 4.5f, 2.f}));
94 });
95 }
96
TEST(MelAggregatorTest,AggregateWithNonOverlappingValues)97 TEST(MelAggregatorTest, AggregateWithNonOverlappingValues) {
98 MelAggregator aggregator{/* csdWindowSeconds */ 100};
99
100 aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {1.f, 1.f},
101 /* timestamp */0));
102 // second mel array contains values that are older than the first entry
103 aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {1.f, 1.f},
104 /* timestamp */2));
105
106 ASSERT_EQ(aggregator.getCachedMelRecordsSize(), size_t{2});
107 aggregator.foreachCachedMel([](const MelRecord &record) {
108 EXPECT_EQ(record.portId, kTestPortId);
109 EXPECT_THAT(record.mels, Pointwise(FloatNear(kFloatError), {1.f, 1.f}));
110 });
111 }
112
TEST(MelAggregatorTest,CheckMelIntervalSplit)113 TEST(MelAggregatorTest, CheckMelIntervalSplit) {
114 MelAggregator aggregator{/* csdWindowSeconds */ 100};
115
116 aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {3.f, 3.f}, /* timestamp */1));
117 aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {3.f, 3.f, 3.f, 3.f},
118 /* timestamp */0));
119
120 ASSERT_EQ(aggregator.getCachedMelRecordsSize(), size_t{1});
121
122 aggregator.foreachCachedMel([](const MelRecord &record) {
123 EXPECT_EQ(record.portId, kTestPortId);
124 EXPECT_THAT(record.mels, Pointwise(FloatNear(kFloatError), {3.f, 6.f, 6.f, 3.f}));
125 });
126 }
127
TEST(MelAggregatorTest,CsdRollingWindowDiscardsOldElements)128 TEST(MelAggregatorTest, CsdRollingWindowDiscardsOldElements) {
129 MelAggregator aggregator{/* csdWindowSeconds */ 3};
130
131 aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId,
132 std::vector<float>(3, kCustomMelDbA),
133 /* timestamp */0));
134 float csdValue = aggregator.getCsd();
135 auto records = aggregator.aggregateAndAddNewMelRecord(
136 MelRecord(kTestPortId, std::vector<float>(3, kCustomMelDbA), /* timestamp */3));
137
138 EXPECT_EQ(records.size(), size_t{2}); // new record and record to remove
139 EXPECT_TRUE(records[0].value * records[1].value < 0.f);
140 EXPECT_EQ(csdValue, aggregator.getCsd());
141 EXPECT_EQ(aggregator.getCsdRecordsSize(), size_t{1});
142 }
143
TEST(MelAggregatorTest,CsdReaches100PercWith107dB)144 TEST(MelAggregatorTest, CsdReaches100PercWith107dB) {
145 MelAggregator aggregator{/* csdWindowSeconds */ 300};
146
147 // 287s of 107dB should produce at least 100% CSD
148 auto records = aggregator.aggregateAndAddNewMelRecord(
149 MelRecord(kTestPortId, std::vector<float>(288, kCustomMelDbA), /* timestamp */0));
150
151 // each record should have a CSD value between 1% and 2%
152 EXPECT_GE(records.size(), size_t{50});
153 EXPECT_GE(aggregator.getCsd(), 1.f);
154 }
155
TEST(MelAggregatorTest,CsdReaches100PercWith80dB)156 TEST(MelAggregatorTest, CsdReaches100PercWith80dB) {
157 constexpr int64_t seconds40h = 40*3600;
158 MelAggregator aggregator{seconds40h};
159
160 // 40h of 80dB should produce (near) exactly 100% CSD
161 auto records = aggregator.aggregateAndAddNewMelRecord(
162 MelRecord(kTestPortId,
163 std::vector<float>(seconds40h, 80.0f),
164 /* timestamp */0));
165
166 // each record should have a CSD value between 1% and 2%
167 EXPECT_GE(records.size(), size_t{50});
168 EXPECT_NEAR(aggregator.getCsd(), 1.f, kMelFloatError);
169 }
170
171 } // namespace
172 } // namespace android
173