xref: /btstack/3rd-party/lc3-google/src/tns.c (revision 4a9eead824c50b40e12b6f72611a74a3f57a47f6)
1 /******************************************************************************
2  *
3  *  Copyright 2021 Google, Inc.
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at:
8  *
9  *  http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  ******************************************************************************/
18 
19 #include "tns.h"
20 #include "tables.h"
21 
22 
23 /* ----------------------------------------------------------------------------
24  *  Filter Coefficients
25  * -------------------------------------------------------------------------- */
26 
27 /**
28  * Resolve LPC Weighting indication according bitrate
29  * dt, nbytes      Duration and size of the frame
30  * return          True when LPC Weighting enabled
31  */
32 static bool resolve_lpc_weighting(enum lc3_dt dt, int nbytes)
33 {
34     return nbytes < (dt == LC3_DT_7M5 ? 360/8 : 480/8);
35 }
36 
37 /**
38  * Return dot product of 2 vectors
39  * a, b, n         The 2 vectors of size `n`
40  * return          sum( a[i] * b[i] ), i = [0..n-1]
41  */
42 static inline float dot(const float *a, const float *b, int n)
43 {
44     float v = 0;
45 
46     while (n--)
47         v += *(a++) * *(b++);
48 
49     return v;
50 }
51 
52 /**
53  * LPC Coefficients
54  * dt, bw          Duration and bandwidth of the frame
55  * x               Spectral coefficients
56  * gain, a         Output the prediction gains and LPC coefficients
57  */
58 static void compute_lpc_coeffs(enum lc3_dt dt, enum lc3_bandwidth bw,
59     const float *x, float *gain, float (*a)[9])
60 {
61     static const int sub_7m5_nb[]   = {  9, 26,  43,  60 };
62     static const int sub_7m5_wb[]   = {  9, 46,  83, 120 };
63     static const int sub_7m5_sswb[] = {  9, 66, 123, 180 };
64     static const int sub_7m5_swb[]  = {  9, 46,  82, 120, 159, 200, 240 };
65     static const int sub_7m5_fb[]   = {  9, 56, 103, 150, 200, 250, 300 };
66 
67     static const int sub_10m_nb[]   = { 12, 34,  57,  80 };
68     static const int sub_10m_wb[]   = { 12, 61, 110, 160 };
69     static const int sub_10m_sswb[] = { 12, 88, 164, 240 };
70     static const int sub_10m_swb[]  = { 12, 61, 110, 160, 213, 266, 320 };
71     static const int sub_10m_fb[]   = { 12, 74, 137, 200, 266, 333, 400 };
72 
73     /* --- Normalized autocorrelation --- */
74 
75     static const float lag_window[] = {
76         1.00000000e+00, 9.98028026e-01, 9.92135406e-01, 9.82391584e-01,
77         9.68910791e-01, 9.51849807e-01, 9.31404933e-01, 9.07808230e-01,
78         8.81323137e-01
79     };
80 
81     const int *sub = (const int * const [LC3_NUM_DT][LC3_NUM_SRATE]){
82         { sub_7m5_nb, sub_7m5_wb, sub_7m5_sswb, sub_7m5_swb, sub_7m5_fb },
83         { sub_10m_nb, sub_10m_wb, sub_10m_sswb, sub_10m_swb, sub_10m_fb },
84     }[dt][bw];
85 
86     int nfilters = 1 + (bw >= LC3_BANDWIDTH_SWB);
87 
88     const float *xs, *xe = x + *sub;
89     float r[2][9];
90 
91     for (int f = 0; f < nfilters; f++) {
92         float c[9][3];
93 
94         for (int s = 0; s < 3; s++) {
95             xs = xe, xe = x + *(++sub);
96 
97             for (int k = 0; k < 9; k++)
98                 c[k][s] = dot(xs, xs + k, (xe - xs) - k);
99         }
100 
101         float e0 = c[0][0], e1 = c[0][1], e2 = c[0][2];
102 
103         r[f][0] = 3;
104         for (int k = 1; k < 9; k++)
105             r[f][k] = e0 == 0 || e1 == 0 || e2 == 0 ? 0 :
106                 (c[k][0]/e0 + c[k][1]/e1 + c[k][2]/e2) * lag_window[k];
107     }
108 
109     /* --- Levinson-Durbin recursion --- */
110 
111     for (int f = 0; f < nfilters; f++) {
112         float *a0 = a[f], a1[9];
113         float err = r[f][0], rc;
114 
115         gain[f] = err;
116 
117         a0[0] = 1;
118         for (int k = 1; k < 9; ) {
119 
120             rc = -r[f][k];
121             for (int i = 1; i < k; i++)
122                 rc -= a0[i] * r[f][k-i];
123 
124             rc /= err;
125             err *= 1 - rc * rc;
126 
127             for (int i = 1; i < k; i++)
128                 a1[i] = a0[i] + rc * a0[k-i];
129             a1[k++] = rc;
130 
131             rc = -r[f][k];
132             for (int i = 1; i < k; i++)
133                 rc -= a1[i] * r[f][k-i];
134 
135             rc /= err;
136             err *= 1 - rc * rc;
137 
138             for (int i = 1; i < k; i++)
139                 a0[i] = a1[i] + rc * a1[k-i];
140             a0[k++] = rc;
141         }
142 
143         gain[f] /= err;
144     }
145 }
146 
147 /**
148  * LPC Weighting
149  * gain, a         Prediction gain and LPC coefficients, weighted as output
150  */
151 static void lpc_weighting(float pred_gain, float *a)
152 {
153     float gamma = 1. - (1. - 0.85) * (2. - pred_gain) / (2. - 1.5), g = 1;
154     for (int i = 1; i < 9; i++)
155         a[i] *= (g *= gamma);
156 }
157 
158 /**
159  * LPC reflection
160  * a               LPC coefficients
161  * rc              Output refelection coefficients
162  */
163 static void lpc_reflection(const float *a, float *rc)
164 {
165     float e, b[2][7], *b0, *b1;
166 
167     rc[7] = a[1+7];
168     e = 1 - rc[7] * rc[7];
169 
170     b1 = b[1];
171     for (int i = 0; i < 7; i++)
172         b1[i] = (a[1+i] - rc[7] * a[7-i]) / e;
173 
174     for (int k = 6; k > 0; k--) {
175         b0 = b1, b1 = b[k & 1];
176 
177         rc[k] = b0[k];
178         e = 1 - rc[k] * rc[k];
179 
180         for (int i = 0; i < k; i++)
181             b1[i] = (b0[i] - rc[k] * b0[k-1-i]) / e;
182     }
183 
184     rc[0] = b1[0];
185 }
186 
187 /**
188  * Quantization of RC coefficients
189  * rc              Refelection coefficients
190  * rc_order        Return order of coefficients
191  * rc_i            Return quantized coefficients
192  */
193 static void quantize_rc(const float *rc, int *rc_order, int *rc_q)
194 {
195     /* Quantization table, sin(delta * (i + 0.5)), delta = Pi / 17 */
196 
197     static float q_thr[] = {
198         9.22683595e-02, 2.73662990e-01, 4.45738356e-01, 6.02634636e-01,
199         7.39008917e-01, 8.50217136e-01, 9.32472229e-01, 9.82973100e-01
200     };
201 
202     *rc_order = 8;
203 
204     for (int i = 0; i < 8; i++) {
205         float rc_m = fabsf(rc[i]);
206 
207         rc_q[i] = 4 * (rc_m >= q_thr[4]);
208         for (int j = 0; j < 4 && rc_m >= q_thr[rc_q[i]]; j++, rc_q[i]++);
209 
210         if (rc[i] < 0)
211             rc_q[i] = -rc_q[i];
212 
213         *rc_order = rc_q[i] != 0 ? 8 : *rc_order - 1;
214     }
215 }
216 
217 /**
218  * Unquantization of RC coefficients
219  * rc_q            Quantized coefficients
220  * rc_order        Order of coefficients
221  * rc              Return refelection coefficients
222  */
223 static void unquantize_rc(const int *rc_q, int rc_order, float rc[8])
224 {
225     /* Quantization table, sin(delta * i), delta = Pi / 17 */
226 
227     static float q_inv[] = {
228         0.00000000e+00, 1.83749517e-01, 3.61241664e-01, 5.26432173e-01,
229         6.73695641e-01, 7.98017215e-01, 8.95163302e-01, 9.61825645e-01,
230         9.95734176e-01
231     };
232 
233     int i;
234 
235     for (i = 0; i < rc_order; i++) {
236         float rc_m = q_inv[LC3_ABS(rc_q[i])];
237         rc[i] = rc_q[i] < 0 ? -rc_m : rc_m;
238     }
239 }
240 
241 
242 /* ----------------------------------------------------------------------------
243  *  Filtering
244  * -------------------------------------------------------------------------- */
245 
246 /**
247  * Forward filtering
248  * dt, bw          Duration and bandwidth of the frame
249  * rc_order, rc    Order of coefficients, and coefficients
250  * x               Spectral coefficients, filtered as output
251  */
252 static void forward_filtering(
253     enum lc3_dt dt, enum lc3_bandwidth bw,
254     const int rc_order[2], const float rc[2][8], float *x)
255 {
256     int nfilters = 1 + (bw >= LC3_BANDWIDTH_SWB);
257     int nf = LC3_NE(dt, bw) >> (nfilters - 1);
258     int i0, ie = 3*(3 + dt);
259 
260     float s[8] = { 0 };
261 
262     for (int f = 0; f < nfilters; f++) {
263 
264         i0 = ie;
265         ie = nf * (1 + f);
266 
267         if (!rc_order[f])
268             continue;
269 
270         for (int i = i0; i < ie; i++) {
271             float xi = x[i];
272             float s0, s1 = xi;
273 
274             for (int k = 0; k < rc_order[f]; k++) {
275                 s0 = s[k];
276                 s[k] = s1;
277 
278                 s1  = rc[f][k] * xi + s0;
279                 xi += rc[f][k] * s0;
280             }
281 
282             x[i] = xi;
283         }
284     }
285 }
286 
287 /**
288  * Inverse filtering
289  * dt, bw          Duration and bandwidth of the frame
290  * rc_order, rc    Order of coefficients, and unquantized coefficients
291  * x               Spectral coefficients, filtered as output
292  */
293 static void inverse_filtering(
294     enum lc3_dt dt, enum lc3_bandwidth bw,
295     const int rc_order[2], const float rc[2][8], float *x)
296 {
297     int nfilters = 1 + (bw >= LC3_BANDWIDTH_SWB);
298     int nf = LC3_NE(dt, bw) >> (nfilters - 1);
299     int i0, ie = 3*(3 + dt);
300 
301     float s[8] = { 0 };
302 
303     for (int f = 0; f < nfilters; f++) {
304 
305         i0 = ie;
306         ie = nf * (1 + f);
307 
308         if (!rc_order[f])
309             continue;
310 
311         for (int i = i0; i < ie; i++) {
312             float xi = x[i];
313 
314             xi -= s[7] * rc[f][7];
315             for (int k = 6; k >= 0; k--) {
316                 xi -= s[k] * rc[f][k];
317                 s[k+1] = s[k] + rc[f][k] * xi;
318             }
319             s[0] = xi;
320             x[i] = xi;
321         }
322 
323         for (int k = 7; k >= rc_order[f]; k--)
324             s[k] = 0;
325     }
326 }
327 
328 
329 /* ----------------------------------------------------------------------------
330  *  Interface
331  * -------------------------------------------------------------------------- */
332 
333 /**
334  * TNS analysis
335  */
336 void lc3_tns_analyze(enum lc3_dt dt, enum lc3_bandwidth bw,
337     bool nn_flag, int nbytes, struct lc3_tns_data *data, float *x)
338 {
339     /* Processing steps :
340      * - Determine the LPC (Linear Predictive Coding) Coefficients
341      * - Check is the filtering is disabled
342      * - The coefficients are weighted on low bitrates and predicition gain
343      * - Convert to reflection coefficients and quantize
344      * - Finally filter the spectral coefficients */
345 
346     float pred_gain[2], a[2][9];
347     float rc[2][8];
348 
349     data->nfilters = 1 + (bw >= LC3_BANDWIDTH_SWB);
350     data->lpc_weighting = resolve_lpc_weighting(dt, nbytes);
351 
352     compute_lpc_coeffs(dt, bw, x, pred_gain, a);
353 
354     for (int f = 0; f < data->nfilters; f++) {
355 
356         data->rc_order[f] = 0;
357         if (nn_flag || pred_gain[f] <= 1.5)
358             continue;
359 
360         if (data->lpc_weighting && pred_gain[f] < 2)
361             lpc_weighting(pred_gain[f], a[f]);
362 
363         lpc_reflection(a[f], rc[f]);
364 
365         quantize_rc(rc[f], &data->rc_order[f], data->rc[f]);
366         unquantize_rc(data->rc[f], data->rc_order[f], rc[f]);
367     }
368 
369     forward_filtering(dt, bw, data->rc_order, rc, x);
370 }
371 
372 /**
373  * TNS synthesis
374  */
375 void lc3_tns_synthesize(enum lc3_dt dt, enum lc3_bandwidth bw,
376     const struct lc3_tns_data *data, float *x)
377 {
378     float rc[2][8] = { };
379 
380     for (int f = 0; f < data->nfilters; f++)
381         if (data->rc_order[f])
382             unquantize_rc(data->rc[f], data->rc_order[f], rc[f]);
383 
384     inverse_filtering(dt, bw, data->rc_order, rc, x);
385 }
386 
387 /**
388  * Bit consumption of bitstream data
389  */
390 int lc3_tns_get_nbits(const struct lc3_tns_data *data)
391 {
392     int nbits = 0;
393 
394     for (int f = 0; f < data->nfilters; f++) {
395 
396         int nbits_2048 = 2048;
397         int rc_order = data->rc_order[f];
398 
399         nbits_2048 += rc_order > 0 ? lc3_tns_order_bits
400             [data->lpc_weighting][rc_order-1] : 0;
401 
402         for (int i = 0; i < rc_order; i++)
403             nbits_2048 += lc3_tns_coeffs_bits[i][8 + data->rc[f][i]];
404 
405         nbits += (nbits_2048 + (1 << 11) - 1) >> 11;
406     }
407 
408     return nbits;
409 }
410 
411 /**
412  * Put bitstream data
413  */
414 void lc3_tns_put_data(lc3_bits_t *bits, const struct lc3_tns_data *data)
415 {
416     for (int f = 0; f < data->nfilters; f++) {
417         int rc_order = data->rc_order[f];
418 
419         lc3_put_bits(bits, rc_order > 0, 1);
420         if (rc_order <= 0)
421             continue;
422 
423         lc3_put_symbol(bits,
424             lc3_tns_order_models + data->lpc_weighting, rc_order-1);
425 
426         for (int i = 0; i < rc_order; i++)
427             lc3_put_symbol(bits,
428                 lc3_tns_coeffs_models + i, 8 + data->rc[f][i]);
429     }
430 }
431 
432 /**
433  * Get bitstream data
434  */
435 void lc3_tns_get_data(lc3_bits_t *bits,
436     enum lc3_dt dt, enum lc3_bandwidth bw, int nbytes, lc3_tns_data_t *data)
437 {
438     data->nfilters = 1 + (bw >= LC3_BANDWIDTH_SWB);
439     data->lpc_weighting = resolve_lpc_weighting(dt, nbytes);
440 
441     for (int f = 0; f < data->nfilters; f++) {
442 
443         data->rc_order[f] = lc3_get_bit(bits);
444         if (!data->rc_order[f])
445             continue;
446 
447         data->rc_order[f] += lc3_get_symbol(bits,
448             lc3_tns_order_models + data->lpc_weighting);
449 
450         for (int i = 0; i < data->rc_order[f]; i++)
451             data->rc[f][i] = (int)lc3_get_symbol(bits,
452                 lc3_tns_coeffs_models + i) - 8;
453     }
454 }
455