xref: /aosp_15_r20/external/libopus/dnn/dred_encoder.c (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1 /* Copyright (c) 2022 Amazon
2    Written by Jan Buethe */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7 
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10 
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14 
15    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
19    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27 
28 #ifdef HAVE_CONFIG_H
29 #include "config.h"
30 #endif
31 
32 #include <string.h>
33 
34 #if 0
35 #include <stdio.h>
36 #include <math.h>
37 #endif
38 
39 #include "dred_encoder.h"
40 #include "dred_coding.h"
41 #include "celt/entenc.h"
42 
43 #include "dred_decoder.h"
44 #include "float_cast.h"
45 #include "os_support.h"
46 #include "celt/laplace.h"
47 #include "dred_rdovae_stats_data.h"
48 
49 
DRED_rdovae_init_encoder(RDOVAEEncState * enc_state)50 static void DRED_rdovae_init_encoder(RDOVAEEncState *enc_state)
51 {
52     memset(enc_state, 0, sizeof(*enc_state));
53 }
54 
dred_encoder_load_model(DREDEnc * enc,const void * data,int len)55 int dred_encoder_load_model(DREDEnc* enc, const void *data, int len)
56 {
57     WeightArray *list;
58     int ret;
59     parse_weights(&list, data, len);
60     ret = init_rdovaeenc(&enc->model, list);
61     opus_free(list);
62     if (ret == 0) {
63       ret = lpcnet_encoder_load_model(&enc->lpcnet_enc_state, data, len);
64     }
65     if (ret == 0) enc->loaded = 1;
66     return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG;
67 }
68 
dred_encoder_reset(DREDEnc * enc)69 void dred_encoder_reset(DREDEnc* enc)
70 {
71     OPUS_CLEAR((char*)&enc->DREDENC_RESET_START,
72               sizeof(DREDEnc)-
73               ((char*)&enc->DREDENC_RESET_START - (char*)enc));
74     enc->input_buffer_fill = DRED_SILK_ENCODER_DELAY;
75     lpcnet_encoder_init(&enc->lpcnet_enc_state);
76     DRED_rdovae_init_encoder(&enc->rdovae_enc);
77 }
78 
dred_encoder_init(DREDEnc * enc,opus_int32 Fs,int channels)79 void dred_encoder_init(DREDEnc* enc, opus_int32 Fs, int channels)
80 {
81     enc->Fs = Fs;
82     enc->channels = channels;
83     enc->loaded = 0;
84 #ifndef USE_WEIGHTS_FILE
85     if (init_rdovaeenc(&enc->model, rdovaeenc_arrays) == 0) enc->loaded = 1;
86 #endif
87     dred_encoder_reset(enc);
88 }
89 
dred_process_frame(DREDEnc * enc,int arch)90 static void dred_process_frame(DREDEnc *enc, int arch)
91 {
92     float feature_buffer[2 * 36];
93     float input_buffer[2*DRED_NUM_FEATURES] = {0};
94 
95     celt_assert(enc->loaded);
96     /* shift latents buffer */
97     OPUS_MOVE(enc->latents_buffer + DRED_LATENT_DIM, enc->latents_buffer, (DRED_MAX_FRAMES - 1) * DRED_LATENT_DIM);
98     OPUS_MOVE(enc->state_buffer + DRED_STATE_DIM, enc->state_buffer, (DRED_MAX_FRAMES - 1) * DRED_STATE_DIM);
99 
100     /* calculate LPCNet features */
101     lpcnet_compute_single_frame_features_float(&enc->lpcnet_enc_state, enc->input_buffer, feature_buffer, arch);
102     lpcnet_compute_single_frame_features_float(&enc->lpcnet_enc_state, enc->input_buffer + DRED_FRAME_SIZE, feature_buffer + 36, arch);
103 
104     /* prepare input buffer (discard LPC coefficients) */
105     OPUS_COPY(input_buffer, feature_buffer, DRED_NUM_FEATURES);
106     OPUS_COPY(input_buffer + DRED_NUM_FEATURES, feature_buffer + 36, DRED_NUM_FEATURES);
107 
108     /* run RDOVAE encoder */
109     dred_rdovae_encode_dframe(&enc->rdovae_enc, &enc->model, enc->latents_buffer, enc->state_buffer, input_buffer, arch);
110     enc->latents_buffer_fill = IMIN(enc->latents_buffer_fill+1, DRED_NUM_REDUNDANCY_FRAMES);
111 }
112 
filter_df2t(const float * in,float * out,int len,float b0,const float * b,const float * a,int order,float * mem)113 void filter_df2t(const float *in, float *out, int len, float b0, const float *b, const float *a, int order, float *mem)
114 {
115     int i;
116     for (i=0;i<len;i++) {
117         int j;
118         float xi, yi, nyi;
119         xi = in[i];
120         yi = xi*b0 + mem[0];
121         nyi = -yi;
122         for (j=0;j<order;j++)
123         {
124            mem[j] = mem[j+1] + b[j]*xi + a[j]*nyi;
125         }
126         out[i] = yi;
127         /*fprintf(stdout, "%f\n", out[i]);*/
128     }
129 }
130 
131 #define MAX_DOWNMIX_BUFFER (960*2)
dred_convert_to_16k(DREDEnc * enc,const float * in,int in_len,float * out,int out_len)132 static void dred_convert_to_16k(DREDEnc *enc, const float *in, int in_len, float *out, int out_len)
133 {
134     float downmix[MAX_DOWNMIX_BUFFER];
135     int i;
136     int up;
137     celt_assert(enc->channels*in_len <= MAX_DOWNMIX_BUFFER);
138     celt_assert(in_len * (opus_int32)16000 == out_len * enc->Fs);
139     switch(enc->Fs) {
140         case 8000:
141             up = 2;
142             break;
143         case 12000:
144             up = 4;
145             break;
146         case 16000:
147             up = 1;
148             break;
149         case 24000:
150             up = 2;
151             break;
152         case 48000:
153             up = 1;
154             break;
155         default:
156             celt_assert(0);
157     }
158     OPUS_CLEAR(downmix, up*in_len);
159     if (enc->channels == 1) {
160         for (i=0;i<in_len;i++) downmix[up*i] = FLOAT2INT16(up*in[i]);
161     } else {
162         for (i=0;i<in_len;i++) downmix[up*i] = FLOAT2INT16(.5*up*(in[2*i]+in[2*i+1]));
163     }
164     if (enc->Fs == 16000) {
165         OPUS_COPY(out, downmix, out_len);
166     } else if (enc->Fs == 48000 || enc->Fs == 24000) {
167         /* ellip(7, .2, 70, 7750/24000) */
168 
169         static const float filter_b[8] = { 0.005873358047f,  0.012980854831f, 0.014531340042f,  0.014531340042f, 0.012980854831f,  0.005873358047f, 0.004523418224f, 0.f};
170         static const float filter_a[8] = {-3.878718597768f, 7.748834257468f, -9.653651699533f, 8.007342726666f, -4.379450178552f, 1.463182111810f, -0.231720677804f, 0.f};
171         float b0 = 0.004523418224f;
172         filter_df2t(downmix, downmix, up*in_len, b0, filter_b, filter_a, RESAMPLING_ORDER, enc->resample_mem);
173         for (i=0;i<out_len;i++) out[i] = downmix[3*i];
174     } else if (enc->Fs == 12000) {
175         /* ellip(7, .2, 70, 7750/24000) */
176         static const float filter_b[8] = {-0.001017101081f,  0.003673127243f,   0.001009165267f,  0.001009165267f,  0.003673127243f, -0.001017101081f,  0.002033596776f, 0.f};
177         static const float filter_a[8] = {-4.930414411612f, 11.291643096504f, -15.322037343815f, 13.216403930898f, -7.220409219553f,  2.310550142771f, -0.334338618782f, 0.f};
178         float b0 = 0.002033596776f;
179         filter_df2t(downmix, downmix, up*in_len, b0, filter_b, filter_a, RESAMPLING_ORDER, enc->resample_mem);
180         for (i=0;i<out_len;i++) out[i] = downmix[3*i];
181     } else if (enc->Fs == 8000) {
182         /* ellip(7, .2, 70, 3900/8000) */
183         static const float filter_b[8] = { 0.081670120929f, 0.180401598565f,  0.259391051971f, 0.259391051971f,  0.180401598565f, 0.081670120929f,  0.020109185709f, 0.f};
184         static const float filter_a[8] = {-1.393651933659f, 2.609789872676f, -2.403541968806f, 2.056814957331f, -1.148908574570f, 0.473001413788f, -0.110359852412f, 0.f};
185         float b0 = 0.020109185709f;
186         filter_df2t(downmix, out, out_len, b0, filter_b, filter_a, RESAMPLING_ORDER, enc->resample_mem);
187     } else {
188         celt_assert(0);
189     }
190 }
191 
dred_compute_latents(DREDEnc * enc,const float * pcm,int frame_size,int extra_delay,int arch)192 void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int extra_delay, int arch)
193 {
194     int curr_offset16k;
195     int frame_size16k = frame_size * 16000 / enc->Fs;
196     celt_assert(enc->loaded);
197     curr_offset16k = 40 + extra_delay*16000/enc->Fs - enc->input_buffer_fill;
198     enc->dred_offset = (int)floor((curr_offset16k+20.f)/40.f);
199     enc->latent_offset = 0;
200     while (frame_size16k > 0) {
201         int process_size16k;
202         int process_size;
203         process_size16k = IMIN(2*DRED_FRAME_SIZE, frame_size16k);
204         process_size = process_size16k * enc->Fs / 16000;
205         dred_convert_to_16k(enc, pcm, process_size, &enc->input_buffer[enc->input_buffer_fill], process_size16k);
206         enc->input_buffer_fill += process_size16k;
207         if (enc->input_buffer_fill >= 2*DRED_FRAME_SIZE)
208         {
209             curr_offset16k += 320;
210             dred_process_frame(enc, arch);
211             enc->input_buffer_fill -= 2*DRED_FRAME_SIZE;
212             OPUS_MOVE(&enc->input_buffer[0], &enc->input_buffer[2*DRED_FRAME_SIZE], enc->input_buffer_fill);
213             /* 15 ms (6*2.5 ms) is the ideal offset for DRED because it corresponds to our vocoder look-ahead. */
214             if (enc->dred_offset < 6) {
215                 enc->dred_offset += 8;
216             } else {
217                 enc->latent_offset++;
218             }
219         }
220 
221         pcm += process_size;
222         frame_size16k -= process_size16k;
223     }
224 }
225 
dred_encode_latents(ec_enc * enc,const float * x,const opus_uint8 * scale,const opus_uint8 * dzone,const opus_uint8 * r,const opus_uint8 * p0,int dim,int arch)226 static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint8 *scale, const opus_uint8 *dzone, const opus_uint8 *r, const opus_uint8 *p0, int dim, int arch) {
227     int i;
228     int q[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
229     float xq[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
230     float delta[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
231     float deadzone[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
232     float eps = .1f;
233     /* This is split into multiple loops (with temporary arrays) so that the compiler
234        can vectorize all of it, and so we can call the vector tanh(). */
235     for (i=0;i<dim;i++) {
236         delta[i] = dzone[i]*(1.f/256.f);
237         xq[i] = x[i]*scale[i]*(1.f/256.f);
238         deadzone[i] = xq[i]/(delta[i]+eps);
239     }
240     compute_activation(deadzone, deadzone, dim, ACTIVATION_TANH, arch);
241     for (i=0;i<dim;i++) {
242         xq[i] = xq[i] - delta[i]*deadzone[i];
243         q[i] = (int)floor(.5f+xq[i]);
244     }
245     for (i=0;i<dim;i++) {
246         /* Make the impossible actually impossible. */
247         if (r[i] == 0 || p0[i] == 255) q[i] = 0;
248         else ec_laplace_encode_p0(enc, q[i], p0[i]<<7, r[i]<<7);
249     }
250 }
251 
dred_voice_active(const unsigned char * activity_mem,int offset)252 static int dred_voice_active(const unsigned char *activity_mem, int offset) {
253     int i;
254     for (i=0;i<16;i++) {
255         if (activity_mem[8*offset + i] == 1) return 1;
256     }
257     return 0;
258 }
259 
dred_encode_silk_frame(DREDEnc * enc,unsigned char * buf,int max_chunks,int max_bytes,int q0,int dQ,int qmax,unsigned char * activity_mem,int arch)260 int dred_encode_silk_frame(DREDEnc *enc, unsigned char *buf, int max_chunks, int max_bytes, int q0, int dQ, int qmax, unsigned char *activity_mem, int arch) {
261     ec_enc ec_encoder;
262 
263     int q_level;
264     int i;
265     int offset;
266     int ec_buffer_fill;
267     int state_qoffset;
268     ec_enc ec_bak;
269     int prev_active=0;
270     int latent_offset;
271     int extra_dred_offset=0;
272     int dred_encoded=0;
273     int delayed_dred=0;
274     int total_offset;
275 
276     latent_offset = enc->latent_offset;
277     /* Delaying new DRED data when just out of silence because we already have the
278        main Opus payload for that frame. */
279     if (activity_mem[0] && enc->last_extra_dred_offset>0) {
280         latent_offset = enc->last_extra_dred_offset;
281         delayed_dred = 1;
282         enc->last_extra_dred_offset = 0;
283     }
284     while (latent_offset < enc->latents_buffer_fill && !dred_voice_active(activity_mem, latent_offset)) {
285        latent_offset++;
286        extra_dred_offset++;
287     }
288     if (!delayed_dred) enc->last_extra_dred_offset = extra_dred_offset;
289 
290     /* entropy coding of state and latents */
291     ec_enc_init(&ec_encoder, buf, max_bytes);
292     ec_enc_uint(&ec_encoder, q0, 16);
293     ec_enc_uint(&ec_encoder, dQ, 8);
294     total_offset = 16 - (enc->dred_offset - extra_dred_offset*8);
295     celt_assert(total_offset>=0);
296     if (total_offset > 31) {
297        ec_enc_uint(&ec_encoder, 1, 2);
298        ec_enc_uint(&ec_encoder, total_offset>>5, 256);
299        ec_enc_uint(&ec_encoder, total_offset&31, 32);
300     } else {
301        ec_enc_uint(&ec_encoder, 0, 2);
302        ec_enc_uint(&ec_encoder, total_offset, 32);
303     }
304     celt_assert(qmax >= q0);
305     if (q0 < 14 && dQ > 0) {
306       int nvals;
307       /* If you want to use qmax == q0, you should have set dQ = 0. */
308       celt_assert(qmax > q0);
309       nvals = 15 - (q0 + 1);
310       ec_encode(&ec_encoder, qmax >= 15 ? 0 : nvals + qmax - (q0 + 1),
311         qmax >= 15 ? nvals : nvals + qmax - q0, 2*nvals);
312     }
313     state_qoffset = q0*DRED_STATE_DIM;
314     dred_encode_latents(
315         &ec_encoder,
316         &enc->state_buffer[latent_offset*DRED_STATE_DIM],
317         dred_state_quant_scales_q8 + state_qoffset,
318         dred_state_dead_zone_q8 + state_qoffset,
319         dred_state_r_q8 + state_qoffset,
320         dred_state_p0_q8 + state_qoffset,
321         DRED_STATE_DIM,
322         arch);
323     if (ec_tell(&ec_encoder) > 8*max_bytes) {
324       return 0;
325     }
326     ec_bak = ec_encoder;
327     for (i = 0; i < IMIN(2*max_chunks, enc->latents_buffer_fill-latent_offset-1); i += 2)
328     {
329         int active;
330         q_level = compute_quantizer(q0, dQ, qmax, i/2);
331         offset = q_level * DRED_LATENT_DIM;
332 
333         dred_encode_latents(
334             &ec_encoder,
335             enc->latents_buffer + (i+latent_offset) * DRED_LATENT_DIM,
336             dred_latent_quant_scales_q8 + offset,
337             dred_latent_dead_zone_q8 + offset,
338             dred_latent_r_q8 + offset,
339             dred_latent_p0_q8 + offset,
340             DRED_LATENT_DIM,
341             arch
342         );
343         if (ec_tell(&ec_encoder) > 8*max_bytes) {
344           /* If we haven't been able to code one chunk, give up on DRED completely. */
345           if (i==0) return 0;
346           break;
347         }
348         active = dred_voice_active(activity_mem, i+latent_offset);
349         if (active || prev_active) {
350            ec_bak = ec_encoder;
351            dred_encoded = i+2;
352         }
353         prev_active = active;
354     }
355     /* Avoid sending empty DRED packets. */
356     if (dred_encoded==0 || (dred_encoded<=2 && extra_dred_offset)) return 0;
357     ec_encoder = ec_bak;
358 
359     ec_buffer_fill = (ec_tell(&ec_encoder)+7)/8;
360     ec_enc_shrink(&ec_encoder, ec_buffer_fill);
361     ec_enc_done(&ec_encoder);
362     return ec_buffer_fill;
363 }
364