1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "Wav2LetterMFCC.hpp"
6*89c4ff92SAndroid Build Coastguard Worker #include "MathUtils.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <cfloat>
9*89c4ff92SAndroid Build Coastguard Worker
ApplyMelFilterBank(std::vector<float> & fftVec,std::vector<std::vector<float>> & melFilterBank,std::vector<uint32_t> & filterBankFilterFirst,std::vector<uint32_t> & filterBankFilterLast,std::vector<float> & melEnergies)10*89c4ff92SAndroid Build Coastguard Worker bool Wav2LetterMFCC::ApplyMelFilterBank(
11*89c4ff92SAndroid Build Coastguard Worker std::vector<float>& fftVec,
12*89c4ff92SAndroid Build Coastguard Worker std::vector<std::vector<float>>& melFilterBank,
13*89c4ff92SAndroid Build Coastguard Worker std::vector<uint32_t>& filterBankFilterFirst,
14*89c4ff92SAndroid Build Coastguard Worker std::vector<uint32_t>& filterBankFilterLast,
15*89c4ff92SAndroid Build Coastguard Worker std::vector<float>& melEnergies)
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker const size_t numBanks = melEnergies.size();
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker if (numBanks != filterBankFilterFirst.size() ||
20*89c4ff92SAndroid Build Coastguard Worker numBanks != filterBankFilterLast.size())
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker printf("Unexpected filter bank lengths\n");
23*89c4ff92SAndroid Build Coastguard Worker return false;
24*89c4ff92SAndroid Build Coastguard Worker }
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker for (size_t bin = 0; bin < numBanks; ++bin)
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker auto filterBankIter = melFilterBank[bin].begin();
29*89c4ff92SAndroid Build Coastguard Worker auto end = melFilterBank[bin].end();
30*89c4ff92SAndroid Build Coastguard Worker // Avoid log of zero at later stages, same value used in librosa.
31*89c4ff92SAndroid Build Coastguard Worker // The number was used during our default wav2letter model training.
32*89c4ff92SAndroid Build Coastguard Worker float melEnergy = 1e-10;
33*89c4ff92SAndroid Build Coastguard Worker const uint32_t firstIndex = filterBankFilterFirst[bin];
34*89c4ff92SAndroid Build Coastguard Worker const uint32_t lastIndex = std::min<uint32_t>(filterBankFilterLast[bin], fftVec.size() - 1);
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = firstIndex; i <= lastIndex && filterBankIter != end; ++i)
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker melEnergy += (*filterBankIter++ * fftVec[i]);
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker
41*89c4ff92SAndroid Build Coastguard Worker melEnergies[bin] = melEnergy;
42*89c4ff92SAndroid Build Coastguard Worker }
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker return true;
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker
ConvertToLogarithmicScale(std::vector<float> & melEnergies)47*89c4ff92SAndroid Build Coastguard Worker void Wav2LetterMFCC::ConvertToLogarithmicScale(std::vector<float>& melEnergies)
48*89c4ff92SAndroid Build Coastguard Worker {
49*89c4ff92SAndroid Build Coastguard Worker float maxMelEnergy = -FLT_MAX;
50*89c4ff92SAndroid Build Coastguard Worker
51*89c4ff92SAndroid Build Coastguard Worker // Container for natural logarithms of mel energies.
52*89c4ff92SAndroid Build Coastguard Worker std::vector <float> vecLogEnergies(melEnergies.size(), 0.f);
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker // Because we are taking natural logs, we need to multiply by log10(e).
55*89c4ff92SAndroid Build Coastguard Worker // Also, for wav2letter model, we scale our log10 values by 10.
56*89c4ff92SAndroid Build Coastguard Worker constexpr float multiplier = 10.0 * // Default scalar.
57*89c4ff92SAndroid Build Coastguard Worker 0.4342944819032518; // log10f(std::exp(1.0))
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker // Take log of the whole vector.
60*89c4ff92SAndroid Build Coastguard Worker MathUtils::VecLogarithmF32(melEnergies, vecLogEnergies);
61*89c4ff92SAndroid Build Coastguard Worker
62*89c4ff92SAndroid Build Coastguard Worker // Scale the log values and get the max.
63*89c4ff92SAndroid Build Coastguard Worker for (auto iterM = melEnergies.begin(), iterL = vecLogEnergies.begin();
64*89c4ff92SAndroid Build Coastguard Worker iterM != melEnergies.end() && iterL != vecLogEnergies.end(); ++iterM, ++iterL)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker
67*89c4ff92SAndroid Build Coastguard Worker *iterM = *iterL * multiplier;
68*89c4ff92SAndroid Build Coastguard Worker
69*89c4ff92SAndroid Build Coastguard Worker // Save the max mel energy.
70*89c4ff92SAndroid Build Coastguard Worker if (*iterM > maxMelEnergy)
71*89c4ff92SAndroid Build Coastguard Worker {
72*89c4ff92SAndroid Build Coastguard Worker maxMelEnergy = *iterM;
73*89c4ff92SAndroid Build Coastguard Worker }
74*89c4ff92SAndroid Build Coastguard Worker }
75*89c4ff92SAndroid Build Coastguard Worker
76*89c4ff92SAndroid Build Coastguard Worker // Clamp the mel energies.
77*89c4ff92SAndroid Build Coastguard Worker constexpr float maxDb = 80.0;
78*89c4ff92SAndroid Build Coastguard Worker const float clampLevelLowdB = maxMelEnergy - maxDb;
79*89c4ff92SAndroid Build Coastguard Worker for (float& melEnergy : melEnergies)
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker melEnergy = std::max(melEnergy, clampLevelLowdB);
82*89c4ff92SAndroid Build Coastguard Worker }
83*89c4ff92SAndroid Build Coastguard Worker }
84*89c4ff92SAndroid Build Coastguard Worker
CreateDCTMatrix(const int32_t inputLength,const int32_t coefficientCount)85*89c4ff92SAndroid Build Coastguard Worker std::vector<float> Wav2LetterMFCC::CreateDCTMatrix(
86*89c4ff92SAndroid Build Coastguard Worker const int32_t inputLength,
87*89c4ff92SAndroid Build Coastguard Worker const int32_t coefficientCount)
88*89c4ff92SAndroid Build Coastguard Worker {
89*89c4ff92SAndroid Build Coastguard Worker std::vector<float> dctMatix(inputLength * coefficientCount);
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker // Orthonormal normalization.
92*89c4ff92SAndroid Build Coastguard Worker const float normalizerK0 = 2 * sqrtf(1.0f /
93*89c4ff92SAndroid Build Coastguard Worker static_cast<float>(4 * inputLength));
94*89c4ff92SAndroid Build Coastguard Worker const float normalizer = 2 * sqrtf(1.0f /
95*89c4ff92SAndroid Build Coastguard Worker static_cast<float>(2 * inputLength));
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker const float angleIncr = M_PI / inputLength;
98*89c4ff92SAndroid Build Coastguard Worker float angle = angleIncr; // We start using it at k = 1 loop.
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Worker // First row of DCT will use normalizer K0.
101*89c4ff92SAndroid Build Coastguard Worker for (int32_t n = 0; n < inputLength; ++n)
102*89c4ff92SAndroid Build Coastguard Worker {
103*89c4ff92SAndroid Build Coastguard Worker dctMatix[n] = normalizerK0; // cos(0) = 1
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker
106*89c4ff92SAndroid Build Coastguard Worker // Second row (index = 1) onwards, we use standard normalizer.
107*89c4ff92SAndroid Build Coastguard Worker for (int32_t k = 1, m = inputLength; k < coefficientCount; ++k, m += inputLength)
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker for (int32_t n = 0; n < inputLength; ++n)
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker dctMatix[m+n] = normalizer * cosf((n + 0.5f) * angle);
112*89c4ff92SAndroid Build Coastguard Worker }
113*89c4ff92SAndroid Build Coastguard Worker angle += angleIncr;
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker return dctMatix;
116*89c4ff92SAndroid Build Coastguard Worker }
117*89c4ff92SAndroid Build Coastguard Worker
GetMelFilterBankNormaliser(const float & leftMel,const float & rightMel,const bool useHTKMethod)118*89c4ff92SAndroid Build Coastguard Worker float Wav2LetterMFCC::GetMelFilterBankNormaliser(
119*89c4ff92SAndroid Build Coastguard Worker const float& leftMel,
120*89c4ff92SAndroid Build Coastguard Worker const float& rightMel,
121*89c4ff92SAndroid Build Coastguard Worker const bool useHTKMethod)
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker // Slaney normalization for mel weights.
124*89c4ff92SAndroid Build Coastguard Worker return (2.0f / (MFCC::InverseMelScale(rightMel, useHTKMethod) -
125*89c4ff92SAndroid Build Coastguard Worker MFCC::InverseMelScale(leftMel, useHTKMethod)));
126*89c4ff92SAndroid Build Coastguard Worker }
127