xref: /aosp_15_r20/external/armnn/samples/SpeechRecognition/src/Wav2LetterMFCC.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "Wav2LetterMFCC.hpp"
6*89c4ff92SAndroid Build Coastguard Worker #include "MathUtils.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <cfloat>
9*89c4ff92SAndroid Build Coastguard Worker 
ApplyMelFilterBank(std::vector<float> & fftVec,std::vector<std::vector<float>> & melFilterBank,std::vector<uint32_t> & filterBankFilterFirst,std::vector<uint32_t> & filterBankFilterLast,std::vector<float> & melEnergies)10*89c4ff92SAndroid Build Coastguard Worker bool Wav2LetterMFCC::ApplyMelFilterBank(
11*89c4ff92SAndroid Build Coastguard Worker         std::vector<float>&                 fftVec,
12*89c4ff92SAndroid Build Coastguard Worker         std::vector<std::vector<float>>&    melFilterBank,
13*89c4ff92SAndroid Build Coastguard Worker         std::vector<uint32_t>&               filterBankFilterFirst,
14*89c4ff92SAndroid Build Coastguard Worker         std::vector<uint32_t>&               filterBankFilterLast,
15*89c4ff92SAndroid Build Coastguard Worker         std::vector<float>&                 melEnergies)
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker     const size_t numBanks = melEnergies.size();
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker     if (numBanks != filterBankFilterFirst.size() ||
20*89c4ff92SAndroid Build Coastguard Worker             numBanks != filterBankFilterLast.size())
21*89c4ff92SAndroid Build Coastguard Worker     {
22*89c4ff92SAndroid Build Coastguard Worker         printf("Unexpected filter bank lengths\n");
23*89c4ff92SAndroid Build Coastguard Worker         return false;
24*89c4ff92SAndroid Build Coastguard Worker     }
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker     for (size_t bin = 0; bin < numBanks; ++bin)
27*89c4ff92SAndroid Build Coastguard Worker     {
28*89c4ff92SAndroid Build Coastguard Worker         auto filterBankIter = melFilterBank[bin].begin();
29*89c4ff92SAndroid Build Coastguard Worker         auto end = melFilterBank[bin].end();
30*89c4ff92SAndroid Build Coastguard Worker         // Avoid log of zero at later stages, same value used in librosa.
31*89c4ff92SAndroid Build Coastguard Worker         // The number was used during our default wav2letter model training.
32*89c4ff92SAndroid Build Coastguard Worker         float melEnergy = 1e-10;
33*89c4ff92SAndroid Build Coastguard Worker         const uint32_t firstIndex = filterBankFilterFirst[bin];
34*89c4ff92SAndroid Build Coastguard Worker         const uint32_t lastIndex = std::min<uint32_t>(filterBankFilterLast[bin], fftVec.size() - 1);
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker         for (uint32_t i = firstIndex; i <= lastIndex && filterBankIter != end; ++i)
37*89c4ff92SAndroid Build Coastguard Worker         {
38*89c4ff92SAndroid Build Coastguard Worker             melEnergy += (*filterBankIter++ * fftVec[i]);
39*89c4ff92SAndroid Build Coastguard Worker         }
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker         melEnergies[bin] = melEnergy;
42*89c4ff92SAndroid Build Coastguard Worker     }
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     return true;
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker 
ConvertToLogarithmicScale(std::vector<float> & melEnergies)47*89c4ff92SAndroid Build Coastguard Worker void Wav2LetterMFCC::ConvertToLogarithmicScale(std::vector<float>& melEnergies)
48*89c4ff92SAndroid Build Coastguard Worker {
49*89c4ff92SAndroid Build Coastguard Worker     float maxMelEnergy = -FLT_MAX;
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker     // Container for natural logarithms of mel energies.
52*89c4ff92SAndroid Build Coastguard Worker     std::vector <float> vecLogEnergies(melEnergies.size(), 0.f);
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     // Because we are taking natural logs, we need to multiply by log10(e).
55*89c4ff92SAndroid Build Coastguard Worker     // Also, for wav2letter model, we scale our log10 values by 10.
56*89c4ff92SAndroid Build Coastguard Worker     constexpr float multiplier = 10.0 *  // Default scalar.
57*89c4ff92SAndroid Build Coastguard Worker                                   0.4342944819032518;  // log10f(std::exp(1.0))
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     // Take log of the whole vector.
60*89c4ff92SAndroid Build Coastguard Worker     MathUtils::VecLogarithmF32(melEnergies, vecLogEnergies);
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker     // Scale the log values and get the max.
63*89c4ff92SAndroid Build Coastguard Worker     for (auto iterM = melEnergies.begin(), iterL = vecLogEnergies.begin();
64*89c4ff92SAndroid Build Coastguard Worker               iterM != melEnergies.end() && iterL != vecLogEnergies.end(); ++iterM, ++iterL)
65*89c4ff92SAndroid Build Coastguard Worker     {
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker         *iterM = *iterL * multiplier;
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker         // Save the max mel energy.
70*89c4ff92SAndroid Build Coastguard Worker         if (*iterM > maxMelEnergy)
71*89c4ff92SAndroid Build Coastguard Worker         {
72*89c4ff92SAndroid Build Coastguard Worker             maxMelEnergy = *iterM;
73*89c4ff92SAndroid Build Coastguard Worker         }
74*89c4ff92SAndroid Build Coastguard Worker     }
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     // Clamp the mel energies.
77*89c4ff92SAndroid Build Coastguard Worker     constexpr float maxDb = 80.0;
78*89c4ff92SAndroid Build Coastguard Worker     const float clampLevelLowdB = maxMelEnergy - maxDb;
79*89c4ff92SAndroid Build Coastguard Worker     for (float& melEnergy : melEnergies)
80*89c4ff92SAndroid Build Coastguard Worker     {
81*89c4ff92SAndroid Build Coastguard Worker         melEnergy = std::max(melEnergy, clampLevelLowdB);
82*89c4ff92SAndroid Build Coastguard Worker     }
83*89c4ff92SAndroid Build Coastguard Worker }
84*89c4ff92SAndroid Build Coastguard Worker 
CreateDCTMatrix(const int32_t inputLength,const int32_t coefficientCount)85*89c4ff92SAndroid Build Coastguard Worker std::vector<float> Wav2LetterMFCC::CreateDCTMatrix(
86*89c4ff92SAndroid Build Coastguard Worker                                     const int32_t inputLength,
87*89c4ff92SAndroid Build Coastguard Worker                                     const int32_t coefficientCount)
88*89c4ff92SAndroid Build Coastguard Worker {
89*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> dctMatix(inputLength * coefficientCount);
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     // Orthonormal normalization.
92*89c4ff92SAndroid Build Coastguard Worker     const float normalizerK0 = 2 * sqrtf(1.0f /
93*89c4ff92SAndroid Build Coastguard Worker                                     static_cast<float>(4 * inputLength));
94*89c4ff92SAndroid Build Coastguard Worker     const float normalizer = 2 * sqrtf(1.0f /
95*89c4ff92SAndroid Build Coastguard Worker                                     static_cast<float>(2 * inputLength));
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker     const float angleIncr = M_PI / inputLength;
98*89c4ff92SAndroid Build Coastguard Worker     float angle = angleIncr;  // We start using it at k = 1 loop.
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker     // First row of DCT will use normalizer K0.
101*89c4ff92SAndroid Build Coastguard Worker     for (int32_t n = 0; n < inputLength; ++n)
102*89c4ff92SAndroid Build Coastguard Worker     {
103*89c4ff92SAndroid Build Coastguard Worker         dctMatix[n] = normalizerK0;  // cos(0) = 1
104*89c4ff92SAndroid Build Coastguard Worker     }
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker     // Second row (index = 1) onwards, we use standard normalizer.
107*89c4ff92SAndroid Build Coastguard Worker     for (int32_t k = 1, m = inputLength; k < coefficientCount; ++k, m += inputLength)
108*89c4ff92SAndroid Build Coastguard Worker     {
109*89c4ff92SAndroid Build Coastguard Worker         for (int32_t n = 0; n < inputLength; ++n)
110*89c4ff92SAndroid Build Coastguard Worker         {
111*89c4ff92SAndroid Build Coastguard Worker             dctMatix[m+n] = normalizer * cosf((n + 0.5f) * angle);
112*89c4ff92SAndroid Build Coastguard Worker         }
113*89c4ff92SAndroid Build Coastguard Worker         angle += angleIncr;
114*89c4ff92SAndroid Build Coastguard Worker     }
115*89c4ff92SAndroid Build Coastguard Worker     return dctMatix;
116*89c4ff92SAndroid Build Coastguard Worker }
117*89c4ff92SAndroid Build Coastguard Worker 
GetMelFilterBankNormaliser(const float & leftMel,const float & rightMel,const bool useHTKMethod)118*89c4ff92SAndroid Build Coastguard Worker float Wav2LetterMFCC::GetMelFilterBankNormaliser(
119*89c4ff92SAndroid Build Coastguard Worker                                 const float&    leftMel,
120*89c4ff92SAndroid Build Coastguard Worker                                 const float&    rightMel,
121*89c4ff92SAndroid Build Coastguard Worker                                 const bool      useHTKMethod)
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker     // Slaney normalization for mel weights.
124*89c4ff92SAndroid Build Coastguard Worker     return (2.0f / (MFCC::InverseMelScale(rightMel, useHTKMethod) -
125*89c4ff92SAndroid Build Coastguard Worker             MFCC::InverseMelScale(leftMel, useHTKMethod)));
126*89c4ff92SAndroid Build Coastguard Worker }
127