xref: /btstack/3rd-party/lc3-google/src/lc3.c (revision 2281ada7921e55288a45180ce35a34d9af2af65c)
1 /******************************************************************************
2  *
3  *  Copyright 2021 Google, Inc.
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at:
8  *
9  *  http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  ******************************************************************************/
18 
19 #include <lc3.h>
20 
21 #include "common.h"
22 #include "bits.h"
23 
24 #include "attdet.h"
25 #include "bwdet.h"
26 #include "ltpf.h"
27 #include "mdct.h"
28 #include "energy.h"
29 #include "sns.h"
30 #include "tns.h"
31 #include "spec.h"
32 #include "plc.h"
33 
34 
35 /**
36  * Frame side data
37  */
38 
39 struct side_data {
40     enum lc3_bandwidth bw;
41     bool pitch_present;
42     lc3_ltpf_data_t ltpf;
43     lc3_sns_data_t sns;
44     lc3_tns_data_t tns;
45     lc3_spec_side_t spec;
46 };
47 
48 
49 /* ----------------------------------------------------------------------------
50  *  General
51  * -------------------------------------------------------------------------- */
52 
53 /**
54  * Resolve frame duration in us
55  * us              Frame duration in us
56  * return          Frame duration identifier, or LC3_NUM_DT
57  */
58 static enum lc3_dt resolve_dt(int us)
59 {
60     return us ==  7500 ? LC3_DT_7M5 :
61            us == 10000 ? LC3_DT_10M : LC3_NUM_DT;
62 }
63 
64 /**
65  * Resolve samplerate in Hz
66  * hz              Samplerate in Hz
67  * return          Sample rate identifier, or LC3_NUM_SRATE
68  */
69 static enum lc3_srate resolve_sr(int hz)
70 {
71     return hz ==  8000 ? LC3_SRATE_8K  : hz == 16000 ? LC3_SRATE_16K :
72            hz == 24000 ? LC3_SRATE_24K : hz == 32000 ? LC3_SRATE_32K :
73            hz == 48000 ? LC3_SRATE_48K : LC3_NUM_SRATE;
74 }
75 
76 /**
77  * Return the number of PCM samples in a frame
78  */
79 int lc3_frame_samples(int dt_us, int sr_hz)
80 {
81     enum lc3_dt dt = resolve_dt(dt_us);
82     enum lc3_srate sr = resolve_sr(sr_hz);
83 
84     if (dt >= LC3_NUM_DT || sr >= LC3_NUM_SRATE)
85         return -1;
86 
87     return LC3_NS(dt, sr);
88 }
89 
90 /**
91  * Return the size of frames, from bitrate
92  */
93 int lc3_frame_bytes(int dt_us, int bitrate)
94 {
95     if (resolve_dt(dt_us) >= LC3_NUM_DT)
96         return -1;
97 
98     if (bitrate < LC3_MIN_BITRATE)
99         return LC3_MIN_FRAME_BYTES;
100 
101     if (bitrate > LC3_MAX_BITRATE)
102         return LC3_MAX_FRAME_BYTES;
103 
104     int nbytes = ((unsigned)bitrate * dt_us) / (1000*1000*8);
105 
106     return LC3_CLIP(nbytes, LC3_MIN_FRAME_BYTES, LC3_MAX_FRAME_BYTES);
107 }
108 
109 /**
110  * Resolve the bitrate, from the size of frames
111  */
112 int lc3_resolve_bitrate(int dt_us, int nbytes)
113 {
114     if (resolve_dt(dt_us) >= LC3_NUM_DT)
115         return -1;
116 
117     if (nbytes < LC3_MIN_FRAME_BYTES)
118         return LC3_MIN_BITRATE;
119 
120     if (nbytes > LC3_MAX_FRAME_BYTES)
121         return LC3_MAX_BITRATE;
122 
123     int bitrate = ((unsigned)nbytes * (1000*1000*8) + dt_us/2) / dt_us;
124 
125     return LC3_CLIP(bitrate, LC3_MIN_BITRATE, LC3_MAX_BITRATE);
126 }
127 
128 /**
129  * Return algorithmic delay, as a number of samples
130  */
131 int lc3_delay_samples(int dt_us, int sr_hz)
132 {
133     enum lc3_dt dt = resolve_dt(dt_us);
134     enum lc3_srate sr = resolve_sr(sr_hz);
135 
136     if (dt >= LC3_NUM_DT || sr >= LC3_NUM_SRATE)
137         return -1;
138 
139     return (dt == LC3_DT_7M5 ? 8 : 5) * (LC3_SRATE_KHZ(sr) / 2);
140 }
141 
142 
143 /* ----------------------------------------------------------------------------
144  *  Encoder
145  * -------------------------------------------------------------------------- */
146 
147 /**
148  * Input PCM Samples from signed 16 bits
149  * encoder         Encoder state
150  * pcm, stride     Input PCM samples, and count between two consecutives
151  */
152 static void load_s16(
153     struct lc3_encoder *encoder, const void *_pcm, int stride)
154 {
155     const int16_t *pcm = _pcm;
156 
157     enum lc3_dt dt = encoder->dt;
158     enum lc3_srate sr = encoder->sr_pcm;
159     float *xs = encoder->xs;
160     int ns = LC3_NS(dt, sr);
161 
162     for (int i = 0; i < ns; i++)
163         xs[i] = pcm[i*stride];
164 }
165 
166 /**
167  * Input PCM Samples from signed 24 bits
168  * encoder         Encoder state
169  * pcm, stride     Input PCM samples, and count between two consecutives
170  */
171 static void load_s24(
172     struct lc3_encoder *encoder, const void *_pcm, int stride)
173 {
174     const int32_t *pcm = _pcm;
175 
176     enum lc3_dt dt = encoder->dt;
177     enum lc3_srate sr = encoder->sr_pcm;
178     float *xs = encoder->xs;
179     int ns = LC3_NS(dt, sr);
180 
181     for (int i = 0; i < ns; i++)
182         xs[i] = ldexpf(pcm[i*stride], -8);
183 }
184 
185 /**
186  * Frame Analysis
187  * encoder         Encoder state
188  * nbytes          Size in bytes of the frame
189  * side, xq        Return frame data
190  */
191 static void analyze(struct lc3_encoder *encoder,
192     int nbytes, struct side_data *side, int16_t *xq)
193 {
194     enum lc3_dt dt = encoder->dt;
195     enum lc3_srate sr = encoder->sr;
196     enum lc3_srate sr_pcm = encoder->sr_pcm;
197     int ns = LC3_NS(dt, sr_pcm);
198     int nd = LC3_ND(dt, sr_pcm);
199 
200     float *xs = encoder->xs;
201     float *xf = encoder->xf;
202 
203     /* --- Temporal --- */
204 
205     bool att = lc3_attdet_run(dt, sr_pcm, nbytes, &encoder->attdet, xs);
206 
207     side->pitch_present =
208         lc3_ltpf_analyse(dt, sr_pcm, &encoder->ltpf, xs, &side->ltpf);
209 
210     /* --- Spectral --- */
211 
212     float e[LC3_NUM_BANDS];
213 
214     lc3_mdct_forward(dt, sr_pcm, sr, xs, xf);
215     memmove(xs - nd, xs + ns-nd, nd * sizeof(float));
216 
217     bool nn_flag = lc3_energy_compute(dt, sr, xf, e);
218     if (nn_flag)
219         lc3_ltpf_disable(&side->ltpf);
220 
221     side->bw = lc3_bwdet_run(dt, sr, e);
222 
223     lc3_sns_analyze(dt, sr, e, att, &side->sns, xf, xf);
224 
225     lc3_tns_analyze(dt, side->bw, nn_flag, nbytes, &side->tns, xf);
226 
227     lc3_spec_analyze(dt, sr,
228         nbytes, side->pitch_present, &side->tns,
229         &encoder->spec, xf, xq, &side->spec);
230 }
231 
232 /**
233  * Encode bitstream
234  * encoder         Encoder state
235  * side, xq        The frame data
236  * nbytes          Target size of the frame (20 to 400)
237  * buffer          Output bitstream buffer of `nbytes` size
238  */
239 static void encode(struct lc3_encoder *encoder,
240     const struct side_data *side, int16_t *xq, int nbytes, void *buffer)
241 {
242     enum lc3_dt dt = encoder->dt;
243     enum lc3_srate sr = encoder->sr;
244     enum lc3_bandwidth bw = side->bw;
245     float *xf = encoder->xf;
246 
247     lc3_bits_t bits;
248 
249     lc3_setup_bits(&bits, LC3_BITS_MODE_WRITE, buffer, nbytes);
250 
251     lc3_bwdet_put_bw(&bits, sr, bw);
252 
253     lc3_spec_put_side(&bits, dt, sr, &side->spec);
254 
255     lc3_tns_put_data(&bits, &side->tns);
256 
257     lc3_put_bit(&bits, side->pitch_present);
258 
259     lc3_sns_put_data(&bits, &side->sns);
260 
261     if (side->pitch_present)
262         lc3_ltpf_put_data(&bits, &side->ltpf);
263 
264     lc3_spec_encode(&bits,
265         dt, sr, bw, nbytes, xq, &side->spec, xf);
266 
267     lc3_flush_bits(&bits);
268 }
269 
270 /**
271  * Return size needed for an encoder
272  */
273 unsigned lc3_encoder_size(int dt_us, int sr_hz)
274 {
275     if (resolve_dt(dt_us) >= LC3_NUM_DT ||
276         resolve_sr(sr_hz) >= LC3_NUM_SRATE)
277         return 0;
278 
279     return sizeof(struct lc3_encoder) +
280         LC3_ENCODER_BUFFER_COUNT(dt_us, sr_hz) * sizeof(float);
281 }
282 
283 /**
284  * Setup encoder
285  */
286 struct lc3_encoder *lc3_setup_encoder(
287     int dt_us, int sr_hz, int sr_pcm_hz, void *mem)
288 {
289     if (sr_pcm_hz <= 0)
290         sr_pcm_hz = sr_hz;
291 
292     enum lc3_dt dt = resolve_dt(dt_us);
293     enum lc3_srate sr = resolve_sr(sr_hz);
294     enum lc3_srate sr_pcm = resolve_sr(sr_pcm_hz);
295 
296     if (dt >= LC3_NUM_DT || sr_pcm >= LC3_NUM_SRATE || sr > sr_pcm || !mem)
297         return NULL;
298 
299     struct lc3_encoder *encoder = mem;
300     int ns = LC3_NS(dt, sr_pcm);
301     int nd = LC3_ND(dt, sr_pcm);
302 
303     *encoder = (struct lc3_encoder){
304         .dt = dt, .sr = sr,
305         .sr_pcm = sr_pcm,
306         .xs = encoder->s + nd,
307         .xf = encoder->s + nd+ns,
308     };
309 
310     memset(encoder->s, 0,
311         LC3_ENCODER_BUFFER_COUNT(dt_us, sr_pcm_hz) * sizeof(float));
312 
313     return encoder;
314 }
315 
316 /**
317  * Encode a frame
318  */
319 int lc3_encode(struct lc3_encoder *encoder, enum lc3_pcm_format fmt,
320     const void *pcm, int stride, int nbytes, void *out)
321 {
322     static void (* const load[])(struct lc3_encoder *, const void *, int) = {
323         [LC3_PCM_FORMAT_S16] = load_s16,
324         [LC3_PCM_FORMAT_S24] = load_s24,
325     };
326 
327     /* --- Check parameters --- */
328 
329     if (!encoder || nbytes < LC3_MIN_FRAME_BYTES
330                  || nbytes > LC3_MAX_FRAME_BYTES)
331         return -1;
332 
333     /* --- Processing --- */
334 
335     struct side_data side;
336     int16_t xq[LC3_NE(encoder->dt, encoder->sr)];
337 
338     load[fmt](encoder, pcm, stride);
339 
340     analyze(encoder, nbytes, &side, xq);
341 
342     encode(encoder, &side, xq, nbytes, out);
343 
344     return 0;
345 }
346 
347 
348 /* ----------------------------------------------------------------------------
349  *  Decoder
350  * -------------------------------------------------------------------------- */
351 
352 /**
353  * Output PCM Samples to signed 16 bits
354  * decoder         Decoder state
355  * pcm, stride     Output PCM samples, and count between two consecutives
356  */
357 static void store_s16(
358     struct lc3_decoder *decoder, void *_pcm, int stride)
359 {
360     int16_t *pcm = _pcm;
361 
362     enum lc3_dt dt = decoder->dt;
363     enum lc3_srate sr = decoder->sr_pcm;
364     float *xs = decoder->xs;
365     int ns = LC3_NS(dt, sr);
366 
367     for ( ; ns > 0; ns--, xs++, pcm += stride) {
368         int s = *xs >= 0 ? (int)(*xs + 0.5f) : (int)(*xs - 0.5f);
369         *pcm = LC3_CLIP(s, INT16_MIN, INT16_MAX);
370     }
371 }
372 
373 /**
374  * Output PCM Samples to signed 24 bits
375  * decoder         Decoder state
376  * pcm, stride     Output PCM samples, and count between two consecutives
377  */
378 static void store_s24(
379     struct lc3_decoder *decoder, void *_pcm, int stride)
380 {
381     int32_t *pcm = _pcm;
382     const int32_t int24_max =  (1 << 23) - 1;
383     const int32_t int24_min = -(1 << 23);
384 
385     enum lc3_dt dt = decoder->dt;
386     enum lc3_srate sr = decoder->sr_pcm;
387     float *xs = decoder->xs;
388     int ns = LC3_NS(dt, sr);
389 
390     for ( ; ns > 0; ns--, xs++, pcm += stride) {
391         int32_t s = *xs >= 0 ? (int32_t)(ldexpf(*xs, 8) + 0.5f)
392                              : (int32_t)(ldexpf(*xs, 8) - 0.5f);
393         *pcm = LC3_CLIP(s, int24_min, int24_max);
394     }
395 }
396 
397 /**
398  * Decode bitstream
399  * decoder         Decoder state
400  * data, nbytes    Input bitstream buffer
401  * side            Return the side data
402  * return          0: Ok  < 0: Bitsream error detected
403  */
404 static int decode(struct lc3_decoder *decoder,
405     const void *data, int nbytes, struct side_data *side)
406 {
407     enum lc3_dt dt = decoder->dt;
408     enum lc3_srate sr = decoder->sr;
409     float *xf = decoder->xs;
410     int ns = LC3_NS(dt, sr);
411     int ne = LC3_NE(dt, sr);
412 
413     lc3_bits_t bits;
414     int ret = 0;
415 
416     lc3_setup_bits(&bits, LC3_BITS_MODE_READ, (void *)data, nbytes);
417 
418     if ((ret = lc3_bwdet_get_bw(&bits, sr, &side->bw)) < 0)
419         return ret;
420 
421     if ((ret = lc3_spec_get_side(&bits, dt, sr, &side->spec)) < 0)
422         return ret;
423 
424     lc3_tns_get_data(&bits, dt, side->bw, nbytes, &side->tns);
425 
426     side->pitch_present = lc3_get_bit(&bits);
427 
428     if ((ret = lc3_sns_get_data(&bits, &side->sns)) < 0)
429         return ret;
430 
431     if (side->pitch_present)
432         lc3_ltpf_get_data(&bits, &side->ltpf);
433 
434     if ((ret = lc3_spec_decode(&bits, dt, sr,
435                     side->bw, nbytes, &side->spec, xf)) < 0)
436         return ret;
437 
438     memset(xf + ne, 0, (ns - ne) * sizeof(float));
439 
440     return lc3_check_bits(&bits);
441 }
442 
443 /**
444  * Frame synthesis
445  * decoder         Decoder state
446  * side            Frame data, NULL performs PLC
447  * nbytes          Size in bytes of the frame
448  */
449 static void synthesize(struct lc3_decoder *decoder,
450     const struct side_data *side, int nbytes)
451 {
452     enum lc3_dt dt = decoder->dt;
453     enum lc3_srate sr = decoder->sr;
454     enum lc3_srate sr_pcm = decoder->sr_pcm;
455     int ns = LC3_NS(dt, sr_pcm);
456     int ne = LC3_NE(dt, sr);
457     int nh = LC3_NH(sr_pcm);
458 
459     float *xf = decoder->xs;
460     float *xg = decoder->xg;
461     float *xd = decoder->xd;
462     float *xs = xf;
463 
464     if (side) {
465         enum lc3_bandwidth bw = side->bw;
466 
467         lc3_plc_suspend(&decoder->plc);
468 
469         lc3_tns_synthesize(dt, bw, &side->tns, xf);
470 
471         lc3_sns_synthesize(dt, sr, &side->sns, xf, xg);
472 
473         lc3_mdct_inverse(dt, sr_pcm, sr, xg, xd, xs);
474 
475     } else {
476         lc3_plc_synthesize(dt, sr, &decoder->plc, xg, xf);
477 
478         memset(xf + ne, 0, (ns - ne) * sizeof(float));
479 
480         lc3_mdct_inverse(dt, sr_pcm, sr, xf, xd, xs);
481     }
482 
483     lc3_ltpf_synthesize(dt, sr_pcm, nbytes, &decoder->ltpf,
484         side && side->pitch_present ? &side->ltpf : NULL, xs);
485 
486     memmove(xs - nh, xs - nh+ns, nh * sizeof(*xs));
487 }
488 
489 /**
490  * Return size needed for a decoder
491  */
492 unsigned lc3_decoder_size(int dt_us, int sr_hz)
493 {
494     if (resolve_dt(dt_us) >= LC3_NUM_DT ||
495         resolve_sr(sr_hz) >= LC3_NUM_SRATE)
496         return 0;
497 
498     return sizeof(struct lc3_decoder) +
499         LC3_DECODER_BUFFER_COUNT(dt_us, sr_hz) * sizeof(float);
500 }
501 
502 /**
503  * Setup decoder
504  */
505 struct lc3_decoder *lc3_setup_decoder(
506     int dt_us, int sr_hz, int sr_pcm_hz, void *mem)
507 {
508     if (sr_pcm_hz <= 0)
509         sr_pcm_hz = sr_hz;
510 
511     enum lc3_dt dt = resolve_dt(dt_us);
512     enum lc3_srate sr = resolve_sr(sr_hz);
513     enum lc3_srate sr_pcm = resolve_sr(sr_pcm_hz);
514 
515     if (dt >= LC3_NUM_DT || sr_pcm >= LC3_NUM_SRATE || sr > sr_pcm || !mem)
516         return NULL;
517 
518     struct lc3_decoder *decoder = mem;
519     int nh = LC3_NH(sr_pcm);
520     int ns = LC3_NS(dt, sr_pcm);
521     int nd = LC3_ND(dt, sr_pcm);
522 
523     *decoder = (struct lc3_decoder){
524         .dt = dt, .sr = sr,
525         .sr_pcm = sr_pcm,
526         .xs = decoder->s + nh,
527         .xd = decoder->s + nh+ns,
528         .xg = decoder->s + nh+ns+nd,
529     };
530 
531     lc3_plc_reset(&decoder->plc);
532 
533     memset(decoder->s, 0,
534         LC3_DECODER_BUFFER_COUNT(dt_us, sr_pcm_hz) * sizeof(float));
535 
536     return decoder;
537 }
538 
539 /**
540  * Decode a frame
541  */
542 int lc3_decode(struct lc3_decoder *decoder, const void *in, int nbytes,
543     enum lc3_pcm_format fmt, void *pcm, int stride)
544 {
545     static void (* const store[])(struct lc3_decoder *, void *, int) = {
546         [LC3_PCM_FORMAT_S16] = store_s16,
547         [LC3_PCM_FORMAT_S24] = store_s24,
548     };
549 
550     /* --- Check parameters --- */
551 
552     if (!decoder)
553         return -1;
554 
555     if (in && (nbytes < LC3_MIN_FRAME_BYTES ||
556                nbytes > LC3_MAX_FRAME_BYTES   ))
557         return -1;
558 
559     /* --- Processing --- */
560 
561     struct side_data side;
562 
563     int ret = !in || (decode(decoder, in, nbytes, &side) < 0);
564 
565     synthesize(decoder, ret ? NULL : &side, nbytes);
566 
567     store[fmt](decoder, pcm, stride);
568 
569     return ret;
570 }
571