xref: /aosp_15_r20/system/media/audio_utils/tests/mel_aggregator_tests.cpp (revision b9df5ad1c9ac98a7fefaac271a55f7ae3db05414)
1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // #define LOG_NDEBUG 0
18 #define LOG_TAG "audio_utils_mel_aggregator_tests"
19 
20 #include <audio_utils/MelAggregator.h>
21 
22 #include <gtest/gtest.h>
23 #include <gmock/gmock.h>
24 
25 namespace android::audio_utils {
26 namespace {
27 
28 constexpr int32_t kTestPortId = 1;
29 constexpr float kFloatError = 0.1f;
30 constexpr float kMelFloatError = 0.0001f;
31 
32 /** Value used for CSD calculation. 3 MELs with this value will cause a change of 1% in CSD. */
33 constexpr float kCustomMelDbA = 107.f;
34 
35 using ::testing::ElementsAre;
36 using ::testing::Pointwise;
37 using ::testing::FloatNear;
38 
TEST(MelAggregatorTest,ResetAggregator)39 TEST(MelAggregatorTest, ResetAggregator) {
40     MelAggregator aggregator{100};
41 
42     aggregator.aggregateAndAddNewMelRecord(MelRecord(1, {10.f, 10.f}, 0));
43     aggregator.reset(1.f, {CsdRecord(1, 1, 1.f, 1.f)});
44 
45     EXPECT_EQ(aggregator.getCachedMelRecordsSize(), size_t{0});
46     EXPECT_EQ(aggregator.getCsd(), 1.f);
47     EXPECT_EQ(aggregator.getCsdRecordsSize(), size_t{1});
48 }
49 
TEST(MelAggregatorTest,AggregateValuesFromDifferentStreams)50 TEST(MelAggregatorTest, AggregateValuesFromDifferentStreams) {
51     MelAggregator aggregator{/* csdWindowSeconds */ 100};
52 
53     aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {10.f, 10.f},
54                                                      /* timestamp */0));
55     aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {10.f, 10.f},
56                                                      /* timestamp */0));
57 
58     ASSERT_EQ(aggregator.getCachedMelRecordsSize(), size_t{1});
59     aggregator.foreachCachedMel([](const MelRecord &record) {
60         EXPECT_EQ(record.portId, kTestPortId);
61         EXPECT_THAT(record.mels, Pointwise(FloatNear(kFloatError), {13.f, 13.f}));
62     });
63 }
64 
TEST(MelAggregatorTest,AggregateWithOlderValues)65 TEST(MelAggregatorTest, AggregateWithOlderValues) {
66     MelAggregator aggregator{/* csdWindowSeconds */ 100};
67 
68     aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {1.f, 1.f},
69                                                      /* timestamp */1));
70     // second mel array contains values that are older than the first entry
71     aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {2.f, 2.f, 2.f},
72                                                      /* timestamp */0));
73 
74     ASSERT_EQ(aggregator.getCachedMelRecordsSize(), size_t{1});
75     aggregator.foreachCachedMel([](const MelRecord &record) {
76         EXPECT_EQ(record.portId, kTestPortId);
77         EXPECT_THAT(record.mels, Pointwise(FloatNear(kFloatError), {2.f, 4.5f, 4.5f}));
78     });
79 }
80 
TEST(MelAggregatorTest,AggregateWithNewerValues)81 TEST(MelAggregatorTest, AggregateWithNewerValues) {
82     MelAggregator aggregator{/* csdWindowSeconds */ 100};
83 
84     aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {1.f, 1.f},
85                                                      /* timestamp */1));
86     // second mel array contains values that are older than the first entry
87     aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {2.f, 2.f},
88                                                      /* timestamp */2));
89 
90     ASSERT_EQ(aggregator.getCachedMelRecordsSize(), size_t{1});
91     aggregator.foreachCachedMel([](const MelRecord &record) {
92         EXPECT_EQ(record.portId, kTestPortId);
93         EXPECT_THAT(record.mels, Pointwise(FloatNear(kFloatError), {1.f, 4.5f, 2.f}));
94     });
95 }
96 
TEST(MelAggregatorTest,AggregateWithNonOverlappingValues)97 TEST(MelAggregatorTest, AggregateWithNonOverlappingValues) {
98     MelAggregator aggregator{/* csdWindowSeconds */ 100};
99 
100     aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {1.f, 1.f},
101                                                      /* timestamp */0));
102     // second mel array contains values that are older than the first entry
103     aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {1.f, 1.f},
104                                                      /* timestamp */2));
105 
106     ASSERT_EQ(aggregator.getCachedMelRecordsSize(), size_t{2});
107     aggregator.foreachCachedMel([](const MelRecord &record) {
108         EXPECT_EQ(record.portId, kTestPortId);
109         EXPECT_THAT(record.mels, Pointwise(FloatNear(kFloatError), {1.f, 1.f}));
110     });
111 }
112 
TEST(MelAggregatorTest,CheckMelIntervalSplit)113 TEST(MelAggregatorTest, CheckMelIntervalSplit) {
114     MelAggregator aggregator{/* csdWindowSeconds */ 100};
115 
116     aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {3.f, 3.f}, /* timestamp */1));
117     aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId, {3.f, 3.f, 3.f, 3.f},
118                                                      /* timestamp */0));
119 
120     ASSERT_EQ(aggregator.getCachedMelRecordsSize(), size_t{1});
121 
122     aggregator.foreachCachedMel([](const MelRecord &record) {
123         EXPECT_EQ(record.portId, kTestPortId);
124         EXPECT_THAT(record.mels, Pointwise(FloatNear(kFloatError), {3.f, 6.f, 6.f, 3.f}));
125     });
126 }
127 
TEST(MelAggregatorTest,CsdRollingWindowDiscardsOldElements)128 TEST(MelAggregatorTest, CsdRollingWindowDiscardsOldElements) {
129     MelAggregator aggregator{/* csdWindowSeconds */ 3};
130 
131     aggregator.aggregateAndAddNewMelRecord(MelRecord(kTestPortId,
132                                                      std::vector<float>(3, kCustomMelDbA),
133                                                      /* timestamp */0));
134     float csdValue = aggregator.getCsd();
135     auto records = aggregator.aggregateAndAddNewMelRecord(
136         MelRecord(kTestPortId, std::vector<float>(3, kCustomMelDbA), /* timestamp */3));
137 
138     EXPECT_EQ(records.size(), size_t{2});  // new record and record to remove
139     EXPECT_TRUE(records[0].value * records[1].value < 0.f);
140     EXPECT_EQ(csdValue, aggregator.getCsd());
141     EXPECT_EQ(aggregator.getCsdRecordsSize(), size_t{1});
142 }
143 
TEST(MelAggregatorTest,CsdReaches100PercWith107dB)144 TEST(MelAggregatorTest, CsdReaches100PercWith107dB) {
145     MelAggregator aggregator{/* csdWindowSeconds */ 300};
146 
147     // 287s of 107dB should produce at least 100% CSD
148     auto records = aggregator.aggregateAndAddNewMelRecord(
149         MelRecord(kTestPortId, std::vector<float>(288, kCustomMelDbA), /* timestamp */0));
150 
151     // each record should have a CSD value between 1% and 2%
152     EXPECT_GE(records.size(), size_t{50});
153     EXPECT_GE(aggregator.getCsd(), 1.f);
154 }
155 
TEST(MelAggregatorTest,CsdReaches100PercWith80dB)156 TEST(MelAggregatorTest, CsdReaches100PercWith80dB) {
157     constexpr int64_t seconds40h = 40*3600;
158     MelAggregator aggregator{seconds40h};
159 
160     // 40h of 80dB should produce (near) exactly 100% CSD
161     auto records = aggregator.aggregateAndAddNewMelRecord(
162         MelRecord(kTestPortId,
163                   std::vector<float>(seconds40h, 80.0f),
164             /* timestamp */0));
165 
166     // each record should have a CSD value between 1% and 2%
167     EXPECT_GE(records.size(), size_t{50});
168     EXPECT_NEAR(aggregator.getCsd(), 1.f, kMelFloatError);
169 }
170 
171 }  // namespace
172 }  // namespace android
173