xref: /aosp_15_r20/external/libopus/dnn/nndsp.h (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li /* Copyright (c) 2023 Amazon
2*a58d3d2aSXin Li    Written by Jan Buethe */
3*a58d3d2aSXin Li /*
4*a58d3d2aSXin Li    Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li    modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li    are met:
7*a58d3d2aSXin Li 
8*a58d3d2aSXin Li    - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li 
11*a58d3d2aSXin Li    - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li    documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li 
15*a58d3d2aSXin Li    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
19*a58d3d2aSXin Li    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li */
27*a58d3d2aSXin Li 
28*a58d3d2aSXin Li #ifndef NNDSP_H
29*a58d3d2aSXin Li #define NNDSP_H
30*a58d3d2aSXin Li 
31*a58d3d2aSXin Li #include "opus_types.h"
32*a58d3d2aSXin Li #include "nnet.h"
33*a58d3d2aSXin Li #include <string.h>
34*a58d3d2aSXin Li 
35*a58d3d2aSXin Li 
36*a58d3d2aSXin Li #define ADACONV_MAX_KERNEL_SIZE 16
37*a58d3d2aSXin Li #define ADACONV_MAX_INPUT_CHANNELS 2
38*a58d3d2aSXin Li #define ADACONV_MAX_OUTPUT_CHANNELS 2
39*a58d3d2aSXin Li #define ADACONV_MAX_FRAME_SIZE 80
40*a58d3d2aSXin Li #define ADACONV_MAX_OVERLAP_SIZE 40
41*a58d3d2aSXin Li 
42*a58d3d2aSXin Li #define ADACOMB_MAX_LAG 300
43*a58d3d2aSXin Li #define ADACOMB_MAX_KERNEL_SIZE 16
44*a58d3d2aSXin Li #define ADACOMB_MAX_FRAME_SIZE 80
45*a58d3d2aSXin Li #define ADACOMB_MAX_OVERLAP_SIZE 40
46*a58d3d2aSXin Li 
47*a58d3d2aSXin Li #define ADASHAPE_MAX_INPUT_DIM 512
48*a58d3d2aSXin Li #define ADASHAPE_MAX_FRAME_SIZE 160
49*a58d3d2aSXin Li 
50*a58d3d2aSXin Li /*#define DEBUG_NNDSP*/
51*a58d3d2aSXin Li #ifdef DEBUG_NNDSP
52*a58d3d2aSXin Li #include <stdio.h>
53*a58d3d2aSXin Li #endif
54*a58d3d2aSXin Li 
55*a58d3d2aSXin Li 
56*a58d3d2aSXin Li void print_float_vector(const char* name, const float *vec, int length);
57*a58d3d2aSXin Li 
58*a58d3d2aSXin Li typedef struct {
59*a58d3d2aSXin Li     float history[ADACONV_MAX_KERNEL_SIZE * ADACONV_MAX_INPUT_CHANNELS];
60*a58d3d2aSXin Li     float last_kernel[ADACONV_MAX_KERNEL_SIZE * ADACONV_MAX_INPUT_CHANNELS * ADACONV_MAX_OUTPUT_CHANNELS];
61*a58d3d2aSXin Li     float last_gain;
62*a58d3d2aSXin Li } AdaConvState;
63*a58d3d2aSXin Li 
64*a58d3d2aSXin Li 
65*a58d3d2aSXin Li typedef struct {
66*a58d3d2aSXin Li     float history[ADACOMB_MAX_KERNEL_SIZE + ADACOMB_MAX_LAG];
67*a58d3d2aSXin Li     float last_kernel[ADACOMB_MAX_KERNEL_SIZE];
68*a58d3d2aSXin Li     float last_global_gain;
69*a58d3d2aSXin Li     int last_pitch_lag;
70*a58d3d2aSXin Li } AdaCombState;
71*a58d3d2aSXin Li 
72*a58d3d2aSXin Li 
73*a58d3d2aSXin Li typedef struct {
74*a58d3d2aSXin Li     float conv_alpha1f_state[ADASHAPE_MAX_INPUT_DIM];
75*a58d3d2aSXin Li     float conv_alpha1t_state[ADASHAPE_MAX_INPUT_DIM];
76*a58d3d2aSXin Li     float conv_alpha2_state[ADASHAPE_MAX_FRAME_SIZE];
77*a58d3d2aSXin Li } AdaShapeState;
78*a58d3d2aSXin Li 
79*a58d3d2aSXin Li void init_adaconv_state(AdaConvState *hAdaConv);
80*a58d3d2aSXin Li 
81*a58d3d2aSXin Li void init_adacomb_state(AdaCombState *hAdaComb);
82*a58d3d2aSXin Li 
83*a58d3d2aSXin Li void init_adashape_state(AdaShapeState *hAdaShape);
84*a58d3d2aSXin Li 
85*a58d3d2aSXin Li void compute_overlap_window(float *window, int overlap_size);
86*a58d3d2aSXin Li 
87*a58d3d2aSXin Li void adaconv_process_frame(
88*a58d3d2aSXin Li     AdaConvState* hAdaConv,
89*a58d3d2aSXin Li     float *x_out,
90*a58d3d2aSXin Li     const float *x_in,
91*a58d3d2aSXin Li     const float *features,
92*a58d3d2aSXin Li     const LinearLayer *kernel_layer,
93*a58d3d2aSXin Li     const LinearLayer *gain_layer,
94*a58d3d2aSXin Li     int feature_dim, /* not strictly necessary */
95*a58d3d2aSXin Li     int frame_size,
96*a58d3d2aSXin Li     int overlap_size,
97*a58d3d2aSXin Li     int in_channels,
98*a58d3d2aSXin Li     int out_channels,
99*a58d3d2aSXin Li     int kernel_size,
100*a58d3d2aSXin Li     int left_padding,
101*a58d3d2aSXin Li     float filter_gain_a,
102*a58d3d2aSXin Li     float filter_gain_b,
103*a58d3d2aSXin Li     float shape_gain,
104*a58d3d2aSXin Li     float *window,
105*a58d3d2aSXin Li     int arch
106*a58d3d2aSXin Li );
107*a58d3d2aSXin Li 
108*a58d3d2aSXin Li void adacomb_process_frame(
109*a58d3d2aSXin Li     AdaCombState* hAdaComb,
110*a58d3d2aSXin Li     float *x_out,
111*a58d3d2aSXin Li     const float *x_in,
112*a58d3d2aSXin Li     const float *features,
113*a58d3d2aSXin Li     const LinearLayer *kernel_layer,
114*a58d3d2aSXin Li     const LinearLayer *gain_layer,
115*a58d3d2aSXin Li     const LinearLayer *global_gain_layer,
116*a58d3d2aSXin Li     int pitch_lag,
117*a58d3d2aSXin Li     int feature_dim,
118*a58d3d2aSXin Li     int frame_size,
119*a58d3d2aSXin Li     int overlap_size,
120*a58d3d2aSXin Li     int kernel_size,
121*a58d3d2aSXin Li     int left_padding,
122*a58d3d2aSXin Li     float filter_gain_a,
123*a58d3d2aSXin Li     float filter_gain_b,
124*a58d3d2aSXin Li     float log_gain_limit,
125*a58d3d2aSXin Li     float *window,
126*a58d3d2aSXin Li     int arch
127*a58d3d2aSXin Li );
128*a58d3d2aSXin Li 
129*a58d3d2aSXin Li void adashape_process_frame(
130*a58d3d2aSXin Li     AdaShapeState *hAdaShape,
131*a58d3d2aSXin Li     float *x_out,
132*a58d3d2aSXin Li     const float *x_in,
133*a58d3d2aSXin Li     const float *features,
134*a58d3d2aSXin Li     const LinearLayer *alpha1f,
135*a58d3d2aSXin Li     const LinearLayer *alpha1t,
136*a58d3d2aSXin Li     const LinearLayer *alpha2,
137*a58d3d2aSXin Li     int feature_dim,
138*a58d3d2aSXin Li     int frame_size,
139*a58d3d2aSXin Li     int avg_pool_k,
140*a58d3d2aSXin Li     int arch
141*a58d3d2aSXin Li );
142*a58d3d2aSXin Li 
143*a58d3d2aSXin Li #endif
144