xref: /aosp_15_r20/external/libaom/av1/encoder/x86/ml_avx2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1*77c1e3ccSAndroid Build Coastguard Worker /*
2*77c1e3ccSAndroid Build Coastguard Worker  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3*77c1e3ccSAndroid Build Coastguard Worker  *
4*77c1e3ccSAndroid Build Coastguard Worker  * This source code is subject to the terms of the BSD 2 Clause License and
5*77c1e3ccSAndroid Build Coastguard Worker  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6*77c1e3ccSAndroid Build Coastguard Worker  * was not distributed with this source code in the LICENSE file, you can
7*77c1e3ccSAndroid Build Coastguard Worker  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8*77c1e3ccSAndroid Build Coastguard Worker  * Media Patent License 1.0 was not distributed with this source code in the
9*77c1e3ccSAndroid Build Coastguard Worker  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10*77c1e3ccSAndroid Build Coastguard Worker  */
11*77c1e3ccSAndroid Build Coastguard Worker 
12*77c1e3ccSAndroid Build Coastguard Worker #include <stdbool.h>
13*77c1e3ccSAndroid Build Coastguard Worker #include <assert.h>
14*77c1e3ccSAndroid Build Coastguard Worker #include <immintrin.h>
15*77c1e3ccSAndroid Build Coastguard Worker 
16*77c1e3ccSAndroid Build Coastguard Worker #include "config/av1_rtcd.h"
17*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/ml.h"
18*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/x86/ml_sse3.h"
19*77c1e3ccSAndroid Build Coastguard Worker 
20*77c1e3ccSAndroid Build Coastguard Worker #define CALC_OUTPUT_FOR_2ROWS                                               \
21*77c1e3ccSAndroid Build Coastguard Worker   const int index = weight_idx + (2 * i * tot_num_inputs);                  \
22*77c1e3ccSAndroid Build Coastguard Worker   const __m256 weight0 = _mm256_loadu_ps(&weights[index]);                  \
23*77c1e3ccSAndroid Build Coastguard Worker   const __m256 weight1 = _mm256_loadu_ps(&weights[index + tot_num_inputs]); \
24*77c1e3ccSAndroid Build Coastguard Worker   const __m256 mul0 = _mm256_mul_ps(inputs256, weight0);                    \
25*77c1e3ccSAndroid Build Coastguard Worker   const __m256 mul1 = _mm256_mul_ps(inputs256, weight1);                    \
26*77c1e3ccSAndroid Build Coastguard Worker   hadd[i] = _mm256_hadd_ps(mul0, mul1);
27*77c1e3ccSAndroid Build Coastguard Worker 
nn_propagate_8to1(const float * const inputs,const float * const weights,const float * const bias,int num_inputs_to_process,int tot_num_inputs,int num_outputs,float * const output_nodes,int is_clip_required)28*77c1e3ccSAndroid Build Coastguard Worker static inline void nn_propagate_8to1(
29*77c1e3ccSAndroid Build Coastguard Worker     const float *const inputs, const float *const weights,
30*77c1e3ccSAndroid Build Coastguard Worker     const float *const bias, int num_inputs_to_process, int tot_num_inputs,
31*77c1e3ccSAndroid Build Coastguard Worker     int num_outputs, float *const output_nodes, int is_clip_required) {
32*77c1e3ccSAndroid Build Coastguard Worker   // Process one output row at a time.
33*77c1e3ccSAndroid Build Coastguard Worker   for (int out = 0; out < num_outputs; out++) {
34*77c1e3ccSAndroid Build Coastguard Worker     __m256 in_result = _mm256_setzero_ps();
35*77c1e3ccSAndroid Build Coastguard Worker     float bias_val = bias[out];
36*77c1e3ccSAndroid Build Coastguard Worker     for (int in = 0; in < num_inputs_to_process; in += 8) {
37*77c1e3ccSAndroid Build Coastguard Worker       const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]);
38*77c1e3ccSAndroid Build Coastguard Worker       const int weight_idx = in + (out * tot_num_inputs);
39*77c1e3ccSAndroid Build Coastguard Worker       const __m256 weight0 = _mm256_loadu_ps(&weights[weight_idx]);
40*77c1e3ccSAndroid Build Coastguard Worker       const __m256 mul0 = _mm256_mul_ps(inputs256, weight0);
41*77c1e3ccSAndroid Build Coastguard Worker       in_result = _mm256_add_ps(in_result, mul0);
42*77c1e3ccSAndroid Build Coastguard Worker     }
43*77c1e3ccSAndroid Build Coastguard Worker     const __m128 low_128 = _mm256_castps256_ps128(in_result);
44*77c1e3ccSAndroid Build Coastguard Worker     const __m128 high_128 = _mm256_extractf128_ps(in_result, 1);
45*77c1e3ccSAndroid Build Coastguard Worker     const __m128 sum_par_0 = _mm_add_ps(low_128, high_128);
46*77c1e3ccSAndroid Build Coastguard Worker     const __m128 sum_par_1 = _mm_hadd_ps(sum_par_0, sum_par_0);
47*77c1e3ccSAndroid Build Coastguard Worker     const __m128 sum_tot =
48*77c1e3ccSAndroid Build Coastguard Worker         _mm_add_ps(_mm_shuffle_ps(sum_par_1, sum_par_1, 0x99), sum_par_1);
49*77c1e3ccSAndroid Build Coastguard Worker 
50*77c1e3ccSAndroid Build Coastguard Worker     bias_val += _mm_cvtss_f32(sum_tot);
51*77c1e3ccSAndroid Build Coastguard Worker     if (is_clip_required) bias_val = AOMMAX(bias_val, 0);
52*77c1e3ccSAndroid Build Coastguard Worker     output_nodes[out] = bias_val;
53*77c1e3ccSAndroid Build Coastguard Worker   }
54*77c1e3ccSAndroid Build Coastguard Worker }
55*77c1e3ccSAndroid Build Coastguard Worker 
nn_propagate_8to4(const float * const inputs,const float * const weights,const float * const bias,int num_inputs_to_process,int tot_num_inputs,int num_outputs,float * const output_nodes,int is_clip_required)56*77c1e3ccSAndroid Build Coastguard Worker static inline void nn_propagate_8to4(
57*77c1e3ccSAndroid Build Coastguard Worker     const float *const inputs, const float *const weights,
58*77c1e3ccSAndroid Build Coastguard Worker     const float *const bias, int num_inputs_to_process, int tot_num_inputs,
59*77c1e3ccSAndroid Build Coastguard Worker     int num_outputs, float *const output_nodes, int is_clip_required) {
60*77c1e3ccSAndroid Build Coastguard Worker   __m256 hadd[2];
61*77c1e3ccSAndroid Build Coastguard Worker   for (int out = 0; out < num_outputs; out += 4) {
62*77c1e3ccSAndroid Build Coastguard Worker     __m128 bias_reg = _mm_loadu_ps(&bias[out]);
63*77c1e3ccSAndroid Build Coastguard Worker     __m128 in_result = _mm_setzero_ps();
64*77c1e3ccSAndroid Build Coastguard Worker     for (int in = 0; in < num_inputs_to_process; in += 8) {
65*77c1e3ccSAndroid Build Coastguard Worker       const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]);
66*77c1e3ccSAndroid Build Coastguard Worker       const int weight_idx = in + (out * tot_num_inputs);
67*77c1e3ccSAndroid Build Coastguard Worker       // Process two output row at a time.
68*77c1e3ccSAndroid Build Coastguard Worker       for (int i = 0; i < 2; i++) {
69*77c1e3ccSAndroid Build Coastguard Worker         CALC_OUTPUT_FOR_2ROWS
70*77c1e3ccSAndroid Build Coastguard Worker       }
71*77c1e3ccSAndroid Build Coastguard Worker 
72*77c1e3ccSAndroid Build Coastguard Worker       const __m256 sum_par = _mm256_hadd_ps(hadd[0], hadd[1]);
73*77c1e3ccSAndroid Build Coastguard Worker       const __m128 low_128 = _mm256_castps256_ps128(sum_par);
74*77c1e3ccSAndroid Build Coastguard Worker       const __m128 high_128 = _mm256_extractf128_ps(sum_par, 1);
75*77c1e3ccSAndroid Build Coastguard Worker       const __m128 result = _mm_add_ps(low_128, high_128);
76*77c1e3ccSAndroid Build Coastguard Worker 
77*77c1e3ccSAndroid Build Coastguard Worker       in_result = _mm_add_ps(in_result, result);
78*77c1e3ccSAndroid Build Coastguard Worker     }
79*77c1e3ccSAndroid Build Coastguard Worker 
80*77c1e3ccSAndroid Build Coastguard Worker     in_result = _mm_add_ps(in_result, bias_reg);
81*77c1e3ccSAndroid Build Coastguard Worker     if (is_clip_required) in_result = _mm_max_ps(in_result, _mm_setzero_ps());
82*77c1e3ccSAndroid Build Coastguard Worker     _mm_storeu_ps(&output_nodes[out], in_result);
83*77c1e3ccSAndroid Build Coastguard Worker   }
84*77c1e3ccSAndroid Build Coastguard Worker }
85*77c1e3ccSAndroid Build Coastguard Worker 
nn_propagate_8to8(const float * const inputs,const float * const weights,const float * const bias,int num_inputs_to_process,int tot_num_inputs,int num_outputs,float * const output_nodes,int is_clip_required)86*77c1e3ccSAndroid Build Coastguard Worker static inline void nn_propagate_8to8(
87*77c1e3ccSAndroid Build Coastguard Worker     const float *const inputs, const float *const weights,
88*77c1e3ccSAndroid Build Coastguard Worker     const float *const bias, int num_inputs_to_process, int tot_num_inputs,
89*77c1e3ccSAndroid Build Coastguard Worker     int num_outputs, float *const output_nodes, int is_clip_required) {
90*77c1e3ccSAndroid Build Coastguard Worker   __m256 hadd[4];
91*77c1e3ccSAndroid Build Coastguard Worker   for (int out = 0; out < num_outputs; out += 8) {
92*77c1e3ccSAndroid Build Coastguard Worker     __m256 bias_reg = _mm256_loadu_ps(&bias[out]);
93*77c1e3ccSAndroid Build Coastguard Worker     __m256 in_result = _mm256_setzero_ps();
94*77c1e3ccSAndroid Build Coastguard Worker     for (int in = 0; in < num_inputs_to_process; in += 8) {
95*77c1e3ccSAndroid Build Coastguard Worker       const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]);
96*77c1e3ccSAndroid Build Coastguard Worker       const int weight_idx = in + (out * tot_num_inputs);
97*77c1e3ccSAndroid Build Coastguard Worker       // Process two output rows at a time.
98*77c1e3ccSAndroid Build Coastguard Worker       for (int i = 0; i < 4; i++) {
99*77c1e3ccSAndroid Build Coastguard Worker         CALC_OUTPUT_FOR_2ROWS
100*77c1e3ccSAndroid Build Coastguard Worker       }
101*77c1e3ccSAndroid Build Coastguard Worker       const __m256 hh0 = _mm256_hadd_ps(hadd[0], hadd[1]);
102*77c1e3ccSAndroid Build Coastguard Worker       const __m256 hh1 = _mm256_hadd_ps(hadd[2], hadd[3]);
103*77c1e3ccSAndroid Build Coastguard Worker 
104*77c1e3ccSAndroid Build Coastguard Worker       __m256 ht_0 = _mm256_permute2f128_ps(hh0, hh1, 0x20);
105*77c1e3ccSAndroid Build Coastguard Worker       __m256 ht_1 = _mm256_permute2f128_ps(hh0, hh1, 0x31);
106*77c1e3ccSAndroid Build Coastguard Worker 
107*77c1e3ccSAndroid Build Coastguard Worker       __m256 result = _mm256_add_ps(ht_0, ht_1);
108*77c1e3ccSAndroid Build Coastguard Worker       in_result = _mm256_add_ps(in_result, result);
109*77c1e3ccSAndroid Build Coastguard Worker     }
110*77c1e3ccSAndroid Build Coastguard Worker     in_result = _mm256_add_ps(in_result, bias_reg);
111*77c1e3ccSAndroid Build Coastguard Worker     if (is_clip_required)
112*77c1e3ccSAndroid Build Coastguard Worker       in_result = _mm256_max_ps(in_result, _mm256_setzero_ps());
113*77c1e3ccSAndroid Build Coastguard Worker     _mm256_storeu_ps(&output_nodes[out], in_result);
114*77c1e3ccSAndroid Build Coastguard Worker   }
115*77c1e3ccSAndroid Build Coastguard Worker }
116*77c1e3ccSAndroid Build Coastguard Worker 
nn_propagate_input_multiple_of_8(const float * const inputs,const float * const weights,const float * const bias,int num_inputs_to_process,int tot_num_inputs,bool is_output_layer,int num_outputs,float * const output_nodes)117*77c1e3ccSAndroid Build Coastguard Worker static inline void nn_propagate_input_multiple_of_8(
118*77c1e3ccSAndroid Build Coastguard Worker     const float *const inputs, const float *const weights,
119*77c1e3ccSAndroid Build Coastguard Worker     const float *const bias, int num_inputs_to_process, int tot_num_inputs,
120*77c1e3ccSAndroid Build Coastguard Worker     bool is_output_layer, int num_outputs, float *const output_nodes) {
121*77c1e3ccSAndroid Build Coastguard Worker   // The saturation of output is considered for hidden layer which is not equal
122*77c1e3ccSAndroid Build Coastguard Worker   // to final hidden layer.
123*77c1e3ccSAndroid Build Coastguard Worker   const int is_clip_required =
124*77c1e3ccSAndroid Build Coastguard Worker       !is_output_layer && num_inputs_to_process == tot_num_inputs;
125*77c1e3ccSAndroid Build Coastguard Worker   if (num_outputs % 8 == 0) {
126*77c1e3ccSAndroid Build Coastguard Worker     nn_propagate_8to8(inputs, weights, bias, num_inputs_to_process,
127*77c1e3ccSAndroid Build Coastguard Worker                       tot_num_inputs, num_outputs, output_nodes,
128*77c1e3ccSAndroid Build Coastguard Worker                       is_clip_required);
129*77c1e3ccSAndroid Build Coastguard Worker   } else if (num_outputs % 4 == 0) {
130*77c1e3ccSAndroid Build Coastguard Worker     nn_propagate_8to4(inputs, weights, bias, num_inputs_to_process,
131*77c1e3ccSAndroid Build Coastguard Worker                       tot_num_inputs, num_outputs, output_nodes,
132*77c1e3ccSAndroid Build Coastguard Worker                       is_clip_required);
133*77c1e3ccSAndroid Build Coastguard Worker   } else {
134*77c1e3ccSAndroid Build Coastguard Worker     nn_propagate_8to1(inputs, weights, bias, num_inputs_to_process,
135*77c1e3ccSAndroid Build Coastguard Worker                       tot_num_inputs, num_outputs, output_nodes,
136*77c1e3ccSAndroid Build Coastguard Worker                       is_clip_required);
137*77c1e3ccSAndroid Build Coastguard Worker   }
138*77c1e3ccSAndroid Build Coastguard Worker }
139*77c1e3ccSAndroid Build Coastguard Worker 
av1_nn_predict_avx2(const float * input_nodes,const NN_CONFIG * const nn_config,int reduce_prec,float * const output)140*77c1e3ccSAndroid Build Coastguard Worker void av1_nn_predict_avx2(const float *input_nodes,
141*77c1e3ccSAndroid Build Coastguard Worker                          const NN_CONFIG *const nn_config, int reduce_prec,
142*77c1e3ccSAndroid Build Coastguard Worker                          float *const output) {
143*77c1e3ccSAndroid Build Coastguard Worker   float buf[2][NN_MAX_NODES_PER_LAYER];
144*77c1e3ccSAndroid Build Coastguard Worker   int buf_index = 0;
145*77c1e3ccSAndroid Build Coastguard Worker   int num_inputs = nn_config->num_inputs;
146*77c1e3ccSAndroid Build Coastguard Worker   assert(num_inputs > 0 && num_inputs <= NN_MAX_NODES_PER_LAYER);
147*77c1e3ccSAndroid Build Coastguard Worker 
148*77c1e3ccSAndroid Build Coastguard Worker   for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) {
149*77c1e3ccSAndroid Build Coastguard Worker     const float *layer_weights = nn_config->weights[layer];
150*77c1e3ccSAndroid Build Coastguard Worker     const float *layer_bias = nn_config->bias[layer];
151*77c1e3ccSAndroid Build Coastguard Worker     bool is_output_layer = layer == nn_config->num_hidden_layers;
152*77c1e3ccSAndroid Build Coastguard Worker     float *const output_nodes = is_output_layer ? output : &buf[buf_index][0];
153*77c1e3ccSAndroid Build Coastguard Worker     const int num_outputs = is_output_layer
154*77c1e3ccSAndroid Build Coastguard Worker                                 ? nn_config->num_outputs
155*77c1e3ccSAndroid Build Coastguard Worker                                 : nn_config->num_hidden_nodes[layer];
156*77c1e3ccSAndroid Build Coastguard Worker     assert(num_outputs > 0 && num_outputs <= NN_MAX_NODES_PER_LAYER);
157*77c1e3ccSAndroid Build Coastguard Worker 
158*77c1e3ccSAndroid Build Coastguard Worker     // Process input multiple of 8 using AVX2 intrinsic.
159*77c1e3ccSAndroid Build Coastguard Worker     if (num_inputs % 8 == 0) {
160*77c1e3ccSAndroid Build Coastguard Worker       nn_propagate_input_multiple_of_8(input_nodes, layer_weights, layer_bias,
161*77c1e3ccSAndroid Build Coastguard Worker                                        num_inputs, num_inputs, is_output_layer,
162*77c1e3ccSAndroid Build Coastguard Worker                                        num_outputs, output_nodes);
163*77c1e3ccSAndroid Build Coastguard Worker     } else {
164*77c1e3ccSAndroid Build Coastguard Worker       // When number of inputs is not multiple of 8, use hybrid approach of AVX2
165*77c1e3ccSAndroid Build Coastguard Worker       // and SSE3 based on the need.
166*77c1e3ccSAndroid Build Coastguard Worker       const int in_mul_8 = num_inputs / 8;
167*77c1e3ccSAndroid Build Coastguard Worker       const int num_inputs_to_process = in_mul_8 * 8;
168*77c1e3ccSAndroid Build Coastguard Worker       int bias_is_considered = 0;
169*77c1e3ccSAndroid Build Coastguard Worker       if (in_mul_8) {
170*77c1e3ccSAndroid Build Coastguard Worker         nn_propagate_input_multiple_of_8(
171*77c1e3ccSAndroid Build Coastguard Worker             input_nodes, layer_weights, layer_bias, num_inputs_to_process,
172*77c1e3ccSAndroid Build Coastguard Worker             num_inputs, is_output_layer, num_outputs, output_nodes);
173*77c1e3ccSAndroid Build Coastguard Worker         bias_is_considered = 1;
174*77c1e3ccSAndroid Build Coastguard Worker       }
175*77c1e3ccSAndroid Build Coastguard Worker 
176*77c1e3ccSAndroid Build Coastguard Worker       const float *out_temp = bias_is_considered ? output_nodes : layer_bias;
177*77c1e3ccSAndroid Build Coastguard Worker       const int input_remaining = num_inputs % 8;
178*77c1e3ccSAndroid Build Coastguard Worker       if (input_remaining % 4 == 0 && num_outputs % 8 == 0) {
179*77c1e3ccSAndroid Build Coastguard Worker         for (int out = 0; out < num_outputs; out += 8) {
180*77c1e3ccSAndroid Build Coastguard Worker           __m128 out_h = _mm_loadu_ps(&out_temp[out + 4]);
181*77c1e3ccSAndroid Build Coastguard Worker           __m128 out_l = _mm_loadu_ps(&out_temp[out]);
182*77c1e3ccSAndroid Build Coastguard Worker           for (int in = in_mul_8 * 8; in < num_inputs; in += 4) {
183*77c1e3ccSAndroid Build Coastguard Worker             av1_nn_propagate_4to8_sse3(&input_nodes[in],
184*77c1e3ccSAndroid Build Coastguard Worker                                        &layer_weights[out * num_inputs + in],
185*77c1e3ccSAndroid Build Coastguard Worker                                        &out_h, &out_l, num_inputs);
186*77c1e3ccSAndroid Build Coastguard Worker           }
187*77c1e3ccSAndroid Build Coastguard Worker           if (!is_output_layer) {
188*77c1e3ccSAndroid Build Coastguard Worker             const __m128 zero = _mm_setzero_ps();
189*77c1e3ccSAndroid Build Coastguard Worker             out_h = _mm_max_ps(out_h, zero);
190*77c1e3ccSAndroid Build Coastguard Worker             out_l = _mm_max_ps(out_l, zero);
191*77c1e3ccSAndroid Build Coastguard Worker           }
192*77c1e3ccSAndroid Build Coastguard Worker           _mm_storeu_ps(&output_nodes[out + 4], out_h);
193*77c1e3ccSAndroid Build Coastguard Worker           _mm_storeu_ps(&output_nodes[out], out_l);
194*77c1e3ccSAndroid Build Coastguard Worker         }
195*77c1e3ccSAndroid Build Coastguard Worker       } else if (input_remaining % 4 == 0 && num_outputs % 4 == 0) {
196*77c1e3ccSAndroid Build Coastguard Worker         for (int out = 0; out < num_outputs; out += 4) {
197*77c1e3ccSAndroid Build Coastguard Worker           __m128 outputs = _mm_loadu_ps(&out_temp[out]);
198*77c1e3ccSAndroid Build Coastguard Worker           for (int in = in_mul_8 * 8; in < num_inputs; in += 4) {
199*77c1e3ccSAndroid Build Coastguard Worker             av1_nn_propagate_4to4_sse3(&input_nodes[in],
200*77c1e3ccSAndroid Build Coastguard Worker                                        &layer_weights[out * num_inputs + in],
201*77c1e3ccSAndroid Build Coastguard Worker                                        &outputs, num_inputs);
202*77c1e3ccSAndroid Build Coastguard Worker           }
203*77c1e3ccSAndroid Build Coastguard Worker           if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps());
204*77c1e3ccSAndroid Build Coastguard Worker           _mm_storeu_ps(&output_nodes[out], outputs);
205*77c1e3ccSAndroid Build Coastguard Worker         }
206*77c1e3ccSAndroid Build Coastguard Worker       } else if (input_remaining % 4 == 0) {
207*77c1e3ccSAndroid Build Coastguard Worker         for (int out = 0; out < num_outputs; out++) {
208*77c1e3ccSAndroid Build Coastguard Worker           __m128 outputs = _mm_load1_ps(&out_temp[out]);
209*77c1e3ccSAndroid Build Coastguard Worker           for (int in = in_mul_8 * 8; in < num_inputs; in += 4) {
210*77c1e3ccSAndroid Build Coastguard Worker             av1_nn_propagate_4to1_sse3(&input_nodes[in],
211*77c1e3ccSAndroid Build Coastguard Worker                                        &layer_weights[out * num_inputs + in],
212*77c1e3ccSAndroid Build Coastguard Worker                                        &outputs);
213*77c1e3ccSAndroid Build Coastguard Worker           }
214*77c1e3ccSAndroid Build Coastguard Worker           if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps());
215*77c1e3ccSAndroid Build Coastguard Worker           output_nodes[out] = _mm_cvtss_f32(outputs);
216*77c1e3ccSAndroid Build Coastguard Worker         }
217*77c1e3ccSAndroid Build Coastguard Worker       } else {
218*77c1e3ccSAndroid Build Coastguard Worker         // Use SSE instructions for scalar operations to avoid the latency
219*77c1e3ccSAndroid Build Coastguard Worker         // of swapping between SIMD and FPU modes.
220*77c1e3ccSAndroid Build Coastguard Worker         for (int out = 0; out < num_outputs; out++) {
221*77c1e3ccSAndroid Build Coastguard Worker           __m128 outputs = _mm_load1_ps(&out_temp[out]);
222*77c1e3ccSAndroid Build Coastguard Worker           for (int in_node = in_mul_8 * 8; in_node < num_inputs; in_node++) {
223*77c1e3ccSAndroid Build Coastguard Worker             __m128 input = _mm_load1_ps(&input_nodes[in_node]);
224*77c1e3ccSAndroid Build Coastguard Worker             __m128 weight =
225*77c1e3ccSAndroid Build Coastguard Worker                 _mm_load1_ps(&layer_weights[num_inputs * out + in_node]);
226*77c1e3ccSAndroid Build Coastguard Worker             outputs = _mm_add_ps(outputs, _mm_mul_ps(input, weight));
227*77c1e3ccSAndroid Build Coastguard Worker           }
228*77c1e3ccSAndroid Build Coastguard Worker           if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps());
229*77c1e3ccSAndroid Build Coastguard Worker           output_nodes[out] = _mm_cvtss_f32(outputs);
230*77c1e3ccSAndroid Build Coastguard Worker         }
231*77c1e3ccSAndroid Build Coastguard Worker       }
232*77c1e3ccSAndroid Build Coastguard Worker     }
233*77c1e3ccSAndroid Build Coastguard Worker     // Before processing the next layer, treat the output of current layer as
234*77c1e3ccSAndroid Build Coastguard Worker     // input to next layer.
235*77c1e3ccSAndroid Build Coastguard Worker     input_nodes = output_nodes;
236*77c1e3ccSAndroid Build Coastguard Worker     num_inputs = num_outputs;
237*77c1e3ccSAndroid Build Coastguard Worker     buf_index = 1 - buf_index;
238*77c1e3ccSAndroid Build Coastguard Worker   }
239*77c1e3ccSAndroid Build Coastguard Worker   if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs);
240*77c1e3ccSAndroid Build Coastguard Worker }
241