1*a58d3d2aSXin Li /* Copyright (c) 2022 Amazon
2*a58d3d2aSXin Li Written by Jan Buethe */
3*a58d3d2aSXin Li /*
4*a58d3d2aSXin Li Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li are met:
7*a58d3d2aSXin Li
8*a58d3d2aSXin Li - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li
11*a58d3d2aSXin Li - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li
15*a58d3d2aSXin Li THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
19*a58d3d2aSXin Li OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li */
27*a58d3d2aSXin Li
28*a58d3d2aSXin Li #include <math.h>
29*a58d3d2aSXin Li
30*a58d3d2aSXin Li #ifdef HAVE_CONFIG_H
31*a58d3d2aSXin Li #include "config.h"
32*a58d3d2aSXin Li #endif
33*a58d3d2aSXin Li
34*a58d3d2aSXin Li
35*a58d3d2aSXin Li #include "dred_rdovae_enc.h"
36*a58d3d2aSXin Li #include "os_support.h"
37*a58d3d2aSXin Li #include "dred_rdovae_constants.h"
38*a58d3d2aSXin Li
conv1_cond_init(float * mem,int len,int dilation,int * init)39*a58d3d2aSXin Li static void conv1_cond_init(float *mem, int len, int dilation, int *init)
40*a58d3d2aSXin Li {
41*a58d3d2aSXin Li if (!*init) {
42*a58d3d2aSXin Li int i;
43*a58d3d2aSXin Li for (i=0;i<dilation;i++) OPUS_CLEAR(&mem[i*len], len);
44*a58d3d2aSXin Li }
45*a58d3d2aSXin Li *init = 1;
46*a58d3d2aSXin Li }
47*a58d3d2aSXin Li
dred_rdovae_encode_dframe(RDOVAEEncState * enc_state,const RDOVAEEnc * model,float * latents,float * initial_state,const float * input,int arch)48*a58d3d2aSXin Li void dred_rdovae_encode_dframe(
49*a58d3d2aSXin Li RDOVAEEncState *enc_state, /* io: encoder state */
50*a58d3d2aSXin Li const RDOVAEEnc *model,
51*a58d3d2aSXin Li float *latents, /* o: latent vector */
52*a58d3d2aSXin Li float *initial_state, /* o: initial state */
53*a58d3d2aSXin Li const float *input, /* i: double feature frame (concatenated) */
54*a58d3d2aSXin Li int arch
55*a58d3d2aSXin Li )
56*a58d3d2aSXin Li {
57*a58d3d2aSXin Li float padded_latents[DRED_PADDED_LATENT_DIM];
58*a58d3d2aSXin Li float padded_state[DRED_PADDED_STATE_DIM];
59*a58d3d2aSXin Li float buffer[ENC_DENSE1_OUT_SIZE + ENC_GRU1_OUT_SIZE + ENC_GRU2_OUT_SIZE + ENC_GRU3_OUT_SIZE + ENC_GRU4_OUT_SIZE + ENC_GRU5_OUT_SIZE
60*a58d3d2aSXin Li + ENC_CONV1_OUT_SIZE + ENC_CONV2_OUT_SIZE + ENC_CONV3_OUT_SIZE + ENC_CONV4_OUT_SIZE + ENC_CONV5_OUT_SIZE];
61*a58d3d2aSXin Li float state_hidden[GDENSE1_OUT_SIZE];
62*a58d3d2aSXin Li int output_index = 0;
63*a58d3d2aSXin Li
64*a58d3d2aSXin Li /* run encoder stack and concatenate output in buffer*/
65*a58d3d2aSXin Li compute_generic_dense(&model->enc_dense1, &buffer[output_index], input, ACTIVATION_TANH, arch);
66*a58d3d2aSXin Li output_index += ENC_DENSE1_OUT_SIZE;
67*a58d3d2aSXin Li
68*a58d3d2aSXin Li compute_generic_gru(&model->enc_gru1_input, &model->enc_gru1_recurrent, enc_state->gru1_state, buffer, arch);
69*a58d3d2aSXin Li OPUS_COPY(&buffer[output_index], enc_state->gru1_state, ENC_GRU1_OUT_SIZE);
70*a58d3d2aSXin Li output_index += ENC_GRU1_OUT_SIZE;
71*a58d3d2aSXin Li conv1_cond_init(enc_state->conv1_state, output_index, 1, &enc_state->initialized);
72*a58d3d2aSXin Li compute_generic_conv1d(&model->enc_conv1, &buffer[output_index], enc_state->conv1_state, buffer, output_index, ACTIVATION_TANH, arch);
73*a58d3d2aSXin Li output_index += ENC_CONV1_OUT_SIZE;
74*a58d3d2aSXin Li
75*a58d3d2aSXin Li compute_generic_gru(&model->enc_gru2_input, &model->enc_gru2_recurrent, enc_state->gru2_state, buffer, arch);
76*a58d3d2aSXin Li OPUS_COPY(&buffer[output_index], enc_state->gru2_state, ENC_GRU2_OUT_SIZE);
77*a58d3d2aSXin Li output_index += ENC_GRU2_OUT_SIZE;
78*a58d3d2aSXin Li conv1_cond_init(enc_state->conv2_state, output_index, 2, &enc_state->initialized);
79*a58d3d2aSXin Li compute_generic_conv1d_dilation(&model->enc_conv2, &buffer[output_index], enc_state->conv2_state, buffer, output_index, 2, ACTIVATION_TANH, arch);
80*a58d3d2aSXin Li output_index += ENC_CONV2_OUT_SIZE;
81*a58d3d2aSXin Li
82*a58d3d2aSXin Li compute_generic_gru(&model->enc_gru3_input, &model->enc_gru3_recurrent, enc_state->gru3_state, buffer, arch);
83*a58d3d2aSXin Li OPUS_COPY(&buffer[output_index], enc_state->gru3_state, ENC_GRU3_OUT_SIZE);
84*a58d3d2aSXin Li output_index += ENC_GRU3_OUT_SIZE;
85*a58d3d2aSXin Li conv1_cond_init(enc_state->conv3_state, output_index, 2, &enc_state->initialized);
86*a58d3d2aSXin Li compute_generic_conv1d_dilation(&model->enc_conv3, &buffer[output_index], enc_state->conv3_state, buffer, output_index, 2, ACTIVATION_TANH, arch);
87*a58d3d2aSXin Li output_index += ENC_CONV3_OUT_SIZE;
88*a58d3d2aSXin Li
89*a58d3d2aSXin Li compute_generic_gru(&model->enc_gru4_input, &model->enc_gru4_recurrent, enc_state->gru4_state, buffer, arch);
90*a58d3d2aSXin Li OPUS_COPY(&buffer[output_index], enc_state->gru4_state, ENC_GRU4_OUT_SIZE);
91*a58d3d2aSXin Li output_index += ENC_GRU4_OUT_SIZE;
92*a58d3d2aSXin Li conv1_cond_init(enc_state->conv4_state, output_index, 2, &enc_state->initialized);
93*a58d3d2aSXin Li compute_generic_conv1d_dilation(&model->enc_conv4, &buffer[output_index], enc_state->conv4_state, buffer, output_index, 2, ACTIVATION_TANH, arch);
94*a58d3d2aSXin Li output_index += ENC_CONV4_OUT_SIZE;
95*a58d3d2aSXin Li
96*a58d3d2aSXin Li compute_generic_gru(&model->enc_gru5_input, &model->enc_gru5_recurrent, enc_state->gru5_state, buffer, arch);
97*a58d3d2aSXin Li OPUS_COPY(&buffer[output_index], enc_state->gru5_state, ENC_GRU5_OUT_SIZE);
98*a58d3d2aSXin Li output_index += ENC_GRU5_OUT_SIZE;
99*a58d3d2aSXin Li conv1_cond_init(enc_state->conv5_state, output_index, 2, &enc_state->initialized);
100*a58d3d2aSXin Li compute_generic_conv1d_dilation(&model->enc_conv5, &buffer[output_index], enc_state->conv5_state, buffer, output_index, 2, ACTIVATION_TANH, arch);
101*a58d3d2aSXin Li output_index += ENC_CONV5_OUT_SIZE;
102*a58d3d2aSXin Li
103*a58d3d2aSXin Li compute_generic_dense(&model->enc_zdense, padded_latents, buffer, ACTIVATION_LINEAR, arch);
104*a58d3d2aSXin Li OPUS_COPY(latents, padded_latents, DRED_LATENT_DIM);
105*a58d3d2aSXin Li
106*a58d3d2aSXin Li /* next, calculate initial state */
107*a58d3d2aSXin Li compute_generic_dense(&model->gdense1, state_hidden, buffer, ACTIVATION_TANH, arch);
108*a58d3d2aSXin Li compute_generic_dense(&model->gdense2, padded_state, state_hidden, ACTIVATION_LINEAR, arch);
109*a58d3d2aSXin Li OPUS_COPY(initial_state, padded_state, DRED_STATE_DIM);
110*a58d3d2aSXin Li }
111