xref: /btstack/3rd-party/lc3-google/src/lc3.c (revision 4930cef6e21e6da2d7571b9259c7f0fb8bed3d01)
1 /******************************************************************************
2  *
3  *  Copyright 2022 Google LLC
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at:
8  *
9  *  http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  ******************************************************************************/
18 
19 #include <lc3.h>
20 
21 #include "common.h"
22 #include "bits.h"
23 
24 #include "attdet.h"
25 #include "bwdet.h"
26 #include "ltpf.h"
27 #include "mdct.h"
28 #include "energy.h"
29 #include "sns.h"
30 #include "tns.h"
31 #include "spec.h"
32 #include "plc.h"
33 
34 
35 /**
36  * Frame side data
37  */
38 
39 struct side_data {
40     enum lc3_bandwidth bw;
41     bool pitch_present;
42     lc3_ltpf_data_t ltpf;
43     lc3_sns_data_t sns;
44     lc3_tns_data_t tns;
45     lc3_spec_side_t spec;
46 };
47 
48 
49 /* ----------------------------------------------------------------------------
50  *  General
51  * -------------------------------------------------------------------------- */
52 
53 /**
54  * Resolve frame duration in us
55  * us              Frame duration in us
56  * return          Frame duration identifier, or LC3_NUM_DT
57  */
58 static enum lc3_dt resolve_dt(int us)
59 {
60     return us ==  7500 ? LC3_DT_7M5 :
61            us == 10000 ? LC3_DT_10M : LC3_NUM_DT;
62 }
63 
64 /**
65  * Resolve samplerate in Hz
66  * hz              Samplerate in Hz
67  * return          Sample rate identifier, or LC3_NUM_SRATE
68  */
69 static enum lc3_srate resolve_sr(int hz)
70 {
71     return hz ==  8000 ? LC3_SRATE_8K  : hz == 16000 ? LC3_SRATE_16K :
72            hz == 24000 ? LC3_SRATE_24K : hz == 32000 ? LC3_SRATE_32K :
73            hz == 48000 ? LC3_SRATE_48K : LC3_NUM_SRATE;
74 }
75 
76 /**
77  * Return the number of PCM samples in a frame
78  */
79 int lc3_frame_samples(int dt_us, int sr_hz)
80 {
81     enum lc3_dt dt = resolve_dt(dt_us);
82     enum lc3_srate sr = resolve_sr(sr_hz);
83 
84     if (dt >= LC3_NUM_DT || sr >= LC3_NUM_SRATE)
85         return -1;
86 
87     return LC3_NS(dt, sr);
88 }
89 
90 /**
91  * Return the size of frames, from bitrate
92  */
93 int lc3_frame_bytes(int dt_us, int bitrate)
94 {
95     if (resolve_dt(dt_us) >= LC3_NUM_DT)
96         return -1;
97 
98     if (bitrate < LC3_MIN_BITRATE)
99         return LC3_MIN_FRAME_BYTES;
100 
101     if (bitrate > LC3_MAX_BITRATE)
102         return LC3_MAX_FRAME_BYTES;
103 
104     int nbytes = ((unsigned)bitrate * dt_us) / (1000*1000*8);
105 
106     return LC3_CLIP(nbytes, LC3_MIN_FRAME_BYTES, LC3_MAX_FRAME_BYTES);
107 }
108 
109 /**
110  * Resolve the bitrate, from the size of frames
111  */
112 int lc3_resolve_bitrate(int dt_us, int nbytes)
113 {
114     if (resolve_dt(dt_us) >= LC3_NUM_DT)
115         return -1;
116 
117     if (nbytes < LC3_MIN_FRAME_BYTES)
118         return LC3_MIN_BITRATE;
119 
120     if (nbytes > LC3_MAX_FRAME_BYTES)
121         return LC3_MAX_BITRATE;
122 
123     int bitrate = ((unsigned)nbytes * (1000*1000*8) + dt_us/2) / dt_us;
124 
125     return LC3_CLIP(bitrate, LC3_MIN_BITRATE, LC3_MAX_BITRATE);
126 }
127 
128 /**
129  * Return algorithmic delay, as a number of samples
130  */
131 int lc3_delay_samples(int dt_us, int sr_hz)
132 {
133     enum lc3_dt dt = resolve_dt(dt_us);
134     enum lc3_srate sr = resolve_sr(sr_hz);
135 
136     if (dt >= LC3_NUM_DT || sr >= LC3_NUM_SRATE)
137         return -1;
138 
139     return (dt == LC3_DT_7M5 ? 8 : 5) * (LC3_SRATE_KHZ(sr) / 2);
140 }
141 
142 
143 /* ----------------------------------------------------------------------------
144  *  Encoder
145  * -------------------------------------------------------------------------- */
146 
147 /**
148  * Input PCM Samples from signed 16 bits
149  * encoder         Encoder state
150  * pcm, stride     Input PCM samples, and count between two consecutives
151  */
152 static void load_s16(
153     struct lc3_encoder *encoder, const void *_pcm, int stride)
154 {
155     const int16_t *pcm = _pcm;
156 
157     enum lc3_dt dt = encoder->dt;
158     enum lc3_srate sr = encoder->sr_pcm;
159 
160     int16_t *xt = encoder->xt;
161     float *xs = encoder->xs;
162     int ns = LC3_NS(dt, sr);
163 
164     for (int i = 0; i < ns; i++) {
165         int16_t in = pcm[i*stride];
166         xt[i] = in, xs[i] = in;
167     }
168 }
169 
170 /**
171  * Input PCM Samples from signed 24 bits
172  * encoder         Encoder state
173  * pcm, stride     Input PCM samples, and count between two consecutives
174  */
175 static void load_s24(
176     struct lc3_encoder *encoder, const void *_pcm, int stride)
177 {
178     const int32_t *pcm = _pcm;
179 
180     enum lc3_dt dt = encoder->dt;
181     enum lc3_srate sr = encoder->sr_pcm;
182 
183     int16_t *xt = encoder->xt;
184     float *xs = encoder->xs;
185     int ns = LC3_NS(dt, sr);
186 
187     for (int i = 0; i < ns; i++) {
188         int32_t in = pcm[i*stride];
189 
190         xt[i] = in >> 8;
191         xs[i] = ldexpf(in, -8);
192     }
193 }
194 
195 /**
196  * Frame Analysis
197  * encoder         Encoder state
198  * nbytes          Size in bytes of the frame
199  * side, xq        Return frame data
200  */
201 static void analyze(struct lc3_encoder *encoder,
202     int nbytes, struct side_data *side, uint16_t *xq)
203 {
204     enum lc3_dt dt = encoder->dt;
205     enum lc3_srate sr = encoder->sr;
206     enum lc3_srate sr_pcm = encoder->sr_pcm;
207     int ns = LC3_NS(dt, sr_pcm);
208     int nt = LC3_NT(sr_pcm);
209 
210     int16_t *xt = encoder->xt;
211     float *xs = encoder->xs;
212     float *xd = encoder->xd;
213     float *xf = xs;
214 
215     /* --- Temporal --- */
216 
217     bool att = lc3_attdet_run(dt, sr_pcm, nbytes, &encoder->attdet, xt);
218 
219     side->pitch_present =
220         lc3_ltpf_analyse(dt, sr_pcm, &encoder->ltpf, xt, &side->ltpf);
221 
222     memmove(xt - nt, xt + (ns-nt), nt * sizeof(*xt));
223 
224     /* --- Spectral --- */
225 
226     float e[LC3_NUM_BANDS];
227 
228     lc3_mdct_forward(dt, sr_pcm, sr, xs, xd, xf);
229 
230     bool nn_flag = lc3_energy_compute(dt, sr, xf, e);
231     if (nn_flag)
232         lc3_ltpf_disable(&side->ltpf);
233 
234     side->bw = lc3_bwdet_run(dt, sr, e);
235 
236     lc3_sns_analyze(dt, sr, e, att, &side->sns, xf, xf);
237 
238     lc3_tns_analyze(dt, side->bw, nn_flag, nbytes, &side->tns, xf);
239 
240     lc3_spec_analyze(dt, sr,
241         nbytes, side->pitch_present, &side->tns,
242         &encoder->spec, xf, xq, &side->spec);
243 }
244 
245 /**
246  * Encode bitstream
247  * encoder         Encoder state
248  * side, xq        The frame data
249  * nbytes          Target size of the frame (20 to 400)
250  * buffer          Output bitstream buffer of `nbytes` size
251  */
252 static void encode(struct lc3_encoder *encoder,
253     const struct side_data *side, uint16_t *xq, int nbytes, void *buffer)
254 {
255     enum lc3_dt dt = encoder->dt;
256     enum lc3_srate sr = encoder->sr;
257     enum lc3_bandwidth bw = side->bw;
258     float *xf = encoder->xs;
259 
260     lc3_bits_t bits;
261 
262     lc3_setup_bits(&bits, LC3_BITS_MODE_WRITE, buffer, nbytes);
263 
264     lc3_bwdet_put_bw(&bits, sr, bw);
265 
266     lc3_spec_put_side(&bits, dt, sr, &side->spec);
267 
268     lc3_tns_put_data(&bits, &side->tns);
269 
270     lc3_put_bit(&bits, side->pitch_present);
271 
272     lc3_sns_put_data(&bits, &side->sns);
273 
274     if (side->pitch_present)
275         lc3_ltpf_put_data(&bits, &side->ltpf);
276 
277     lc3_spec_encode(&bits,
278         dt, sr, bw, nbytes, xq, &side->spec, xf);
279 
280     lc3_flush_bits(&bits);
281 }
282 
283 /**
284  * Return size needed for an encoder
285  */
286 unsigned lc3_encoder_size(int dt_us, int sr_hz)
287 {
288     if (resolve_dt(dt_us) >= LC3_NUM_DT ||
289         resolve_sr(sr_hz) >= LC3_NUM_SRATE)
290         return 0;
291 
292     return sizeof(struct lc3_encoder) +
293         LC3_ENCODER_BUFFER_COUNT(dt_us, sr_hz) * sizeof(float);
294 }
295 
296 /**
297  * Setup encoder
298  */
299 struct lc3_encoder *lc3_setup_encoder(
300     int dt_us, int sr_hz, int sr_pcm_hz, void *mem)
301 {
302     if (sr_pcm_hz <= 0)
303         sr_pcm_hz = sr_hz;
304 
305     enum lc3_dt dt = resolve_dt(dt_us);
306     enum lc3_srate sr = resolve_sr(sr_hz);
307     enum lc3_srate sr_pcm = resolve_sr(sr_pcm_hz);
308 
309     if (dt >= LC3_NUM_DT || sr_pcm >= LC3_NUM_SRATE || sr > sr_pcm || !mem)
310         return NULL;
311 
312     struct lc3_encoder *encoder = mem;
313     int ns = LC3_NS(dt, sr_pcm);
314     int nt = LC3_NT(sr_pcm);
315 
316     *encoder = (struct lc3_encoder){
317         .dt = dt, .sr = sr,
318         .sr_pcm = sr_pcm,
319 
320         .xt = (int16_t *)encoder->s + nt,
321         .xs = encoder->s + (nt+ns)/2,
322         .xd = encoder->s + (nt+ns)/2 + ns,
323     };
324 
325     memset(encoder->s, 0,
326         LC3_ENCODER_BUFFER_COUNT(dt_us, sr_pcm_hz) * sizeof(float));
327 
328     return encoder;
329 }
330 
331 /**
332  * Encode a frame
333  */
334 int lc3_encode(struct lc3_encoder *encoder, enum lc3_pcm_format fmt,
335     const void *pcm, int stride, int nbytes, void *out)
336 {
337     static void (* const load[])(struct lc3_encoder *, const void *, int) = {
338         [LC3_PCM_FORMAT_S16] = load_s16,
339         [LC3_PCM_FORMAT_S24] = load_s24,
340     };
341 
342     /* --- Check parameters --- */
343 
344     if (!encoder || nbytes < LC3_MIN_FRAME_BYTES
345                  || nbytes > LC3_MAX_FRAME_BYTES)
346         return -1;
347 
348     /* --- Processing --- */
349 
350     struct side_data side;
351     uint16_t xq[LC3_NE(encoder->dt, encoder->sr)];
352 
353     load[fmt](encoder, pcm, stride);
354 
355     analyze(encoder, nbytes, &side, xq);
356 
357     encode(encoder, &side, xq, nbytes, out);
358 
359     return 0;
360 }
361 
362 
363 /* ----------------------------------------------------------------------------
364  *  Decoder
365  * -------------------------------------------------------------------------- */
366 
367 /**
368  * Output PCM Samples to signed 16 bits
369  * decoder         Decoder state
370  * pcm, stride     Output PCM samples, and count between two consecutives
371  */
372 static void store_s16(
373     struct lc3_decoder *decoder, void *_pcm, int stride)
374 {
375     int16_t *pcm = _pcm;
376 
377     enum lc3_dt dt = decoder->dt;
378     enum lc3_srate sr = decoder->sr_pcm;
379 
380     float *xs = decoder->xs;
381     int ns = LC3_NS(dt, sr);
382 
383     for ( ; ns > 0; ns--, xs++, pcm += stride) {
384         int32_t s = *xs >= 0 ? (int)(*xs + 0.5f) : (int)(*xs - 0.5f);
385         *pcm = LC3_SAT16(s);
386     }
387 }
388 
389 /**
390  * Output PCM Samples to signed 24 bits
391  * decoder         Decoder state
392  * pcm, stride     Output PCM samples, and count between two consecutives
393  */
394 static void store_s24(
395     struct lc3_decoder *decoder, void *_pcm, int stride)
396 {
397     int32_t *pcm = _pcm;
398 
399     enum lc3_dt dt = decoder->dt;
400     enum lc3_srate sr = decoder->sr_pcm;
401 
402     float *xs = decoder->xs;
403     int ns = LC3_NS(dt, sr);
404 
405     for ( ; ns > 0; ns--, xs++, pcm += stride) {
406         int32_t s = *xs >= 0 ? (int32_t)(ldexpf(*xs, 8) + 0.5f)
407                              : (int32_t)(ldexpf(*xs, 8) - 0.5f);
408         *pcm = LC3_SAT24(s);
409     }
410 }
411 
412 /**
413  * Decode bitstream
414  * decoder         Decoder state
415  * data, nbytes    Input bitstream buffer
416  * side            Return the side data
417  * return          0: Ok  < 0: Bitsream error detected
418  */
419 static int decode(struct lc3_decoder *decoder,
420     const void *data, int nbytes, struct side_data *side)
421 {
422     enum lc3_dt dt = decoder->dt;
423     enum lc3_srate sr = decoder->sr;
424 
425     float *xf = decoder->xs;
426     int ns = LC3_NS(dt, sr);
427     int ne = LC3_NE(dt, sr);
428 
429     lc3_bits_t bits;
430     int ret = 0;
431 
432     lc3_setup_bits(&bits, LC3_BITS_MODE_READ, (void *)data, nbytes);
433 
434     if ((ret = lc3_bwdet_get_bw(&bits, sr, &side->bw)) < 0)
435         return ret;
436 
437     if ((ret = lc3_spec_get_side(&bits, dt, sr, &side->spec)) < 0)
438         return ret;
439 
440     lc3_tns_get_data(&bits, dt, side->bw, nbytes, &side->tns);
441 
442     side->pitch_present = lc3_get_bit(&bits);
443 
444     if ((ret = lc3_sns_get_data(&bits, &side->sns)) < 0)
445         return ret;
446 
447     if (side->pitch_present)
448         lc3_ltpf_get_data(&bits, &side->ltpf);
449 
450     if ((ret = lc3_spec_decode(&bits, dt, sr,
451                     side->bw, nbytes, &side->spec, xf)) < 0)
452         return ret;
453 
454     memset(xf + ne, 0, (ns - ne) * sizeof(float));
455 
456     return lc3_check_bits(&bits);
457 }
458 
459 /**
460  * Frame synthesis
461  * decoder         Decoder state
462  * side            Frame data, NULL performs PLC
463  * nbytes          Size in bytes of the frame
464  */
465 static void synthesize(struct lc3_decoder *decoder,
466     const struct side_data *side, int nbytes)
467 {
468     enum lc3_dt dt = decoder->dt;
469     enum lc3_srate sr = decoder->sr;
470     enum lc3_srate sr_pcm = decoder->sr_pcm;
471 
472     float *xf = decoder->xs;
473     int ns = LC3_NS(dt, sr_pcm);
474     int ne = LC3_NE(dt, sr);
475 
476     float *xg = decoder->xg;
477     float *xd = decoder->xd;
478     float *xs = xf;
479 
480     if (side) {
481         enum lc3_bandwidth bw = side->bw;
482 
483         lc3_plc_suspend(&decoder->plc);
484 
485         lc3_tns_synthesize(dt, bw, &side->tns, xf);
486 
487         lc3_sns_synthesize(dt, sr, &side->sns, xf, xg);
488 
489         lc3_mdct_inverse(dt, sr_pcm, sr, xg, xd, xs);
490 
491     } else {
492         lc3_plc_synthesize(dt, sr, &decoder->plc, xg, xf);
493 
494         memset(xf + ne, 0, (ns - ne) * sizeof(float));
495 
496         lc3_mdct_inverse(dt, sr_pcm, sr, xf, xd, xs);
497     }
498 
499     lc3_ltpf_synthesize(dt, sr_pcm, nbytes, &decoder->ltpf,
500         side && side->pitch_present ? &side->ltpf : NULL, decoder->xh, xs);
501 }
502 
503 /**
504  * Update decoder state on decoding completion
505  * decoder         Decoder state
506  */
507 static void complete(struct lc3_decoder *decoder)
508 {
509     enum lc3_dt dt = decoder->dt;
510     enum lc3_srate sr_pcm = decoder->sr_pcm;
511     int nh = LC3_NH(dt, sr_pcm);
512     int ns = LC3_NS(dt, sr_pcm);
513 
514     decoder->xs = decoder->xs - decoder->xh < nh - ns ?
515         decoder->xs + ns : decoder->xh;
516 }
517 
518 /**
519  * Return size needed for a decoder
520  */
521 unsigned lc3_decoder_size(int dt_us, int sr_hz)
522 {
523     if (resolve_dt(dt_us) >= LC3_NUM_DT ||
524         resolve_sr(sr_hz) >= LC3_NUM_SRATE)
525         return 0;
526 
527     return sizeof(struct lc3_decoder) +
528         LC3_DECODER_BUFFER_COUNT(dt_us, sr_hz) * sizeof(float);
529 }
530 
531 /**
532  * Setup decoder
533  */
534 struct lc3_decoder *lc3_setup_decoder(
535     int dt_us, int sr_hz, int sr_pcm_hz, void *mem)
536 {
537     if (sr_pcm_hz <= 0)
538         sr_pcm_hz = sr_hz;
539 
540     enum lc3_dt dt = resolve_dt(dt_us);
541     enum lc3_srate sr = resolve_sr(sr_hz);
542     enum lc3_srate sr_pcm = resolve_sr(sr_pcm_hz);
543 
544     if (dt >= LC3_NUM_DT || sr_pcm >= LC3_NUM_SRATE || sr > sr_pcm || !mem)
545         return NULL;
546 
547     struct lc3_decoder *decoder = mem;
548     int nh = LC3_NH(dt, sr_pcm);
549     int ns = LC3_NS(dt, sr_pcm);
550     int nd = LC3_ND(dt, sr_pcm);
551 
552     *decoder = (struct lc3_decoder){
553         .dt = dt, .sr = sr,
554         .sr_pcm = sr_pcm,
555 
556         .xh = decoder->s,
557         .xs = decoder->s + nh-ns,
558         .xd = decoder->s + nh,
559         .xg = decoder->s + nh+nd,
560     };
561 
562     lc3_plc_reset(&decoder->plc);
563 
564     memset(decoder->s, 0,
565         LC3_DECODER_BUFFER_COUNT(dt_us, sr_pcm_hz) * sizeof(float));
566 
567     return decoder;
568 }
569 
570 /**
571  * Decode a frame
572  */
573 int lc3_decode(struct lc3_decoder *decoder, const void *in, int nbytes,
574     enum lc3_pcm_format fmt, void *pcm, int stride)
575 {
576     static void (* const store[])(struct lc3_decoder *, void *, int) = {
577         [LC3_PCM_FORMAT_S16] = store_s16,
578         [LC3_PCM_FORMAT_S24] = store_s24,
579     };
580 
581     /* --- Check parameters --- */
582 
583     if (!decoder)
584         return -1;
585 
586     if (in && (nbytes < LC3_MIN_FRAME_BYTES ||
587                nbytes > LC3_MAX_FRAME_BYTES   ))
588         return -1;
589 
590     /* --- Processing --- */
591 
592     struct side_data side;
593 
594     int ret = !in || (decode(decoder, in, nbytes, &side) < 0);
595 
596     synthesize(decoder, ret ? NULL : &side, nbytes);
597 
598     store[fmt](decoder, pcm, stride);
599 
600     complete(decoder);
601 
602     return ret;
603 }
604