xref: /aosp_15_r20/external/libaom/av1/encoder/x86/cnn_avx2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1*77c1e3ccSAndroid Build Coastguard Worker /*
2*77c1e3ccSAndroid Build Coastguard Worker  * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3*77c1e3ccSAndroid Build Coastguard Worker  *
4*77c1e3ccSAndroid Build Coastguard Worker  * This source code is subject to the terms of the BSD 2 Clause License and
5*77c1e3ccSAndroid Build Coastguard Worker  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6*77c1e3ccSAndroid Build Coastguard Worker  * was not distributed with this source code in the LICENSE file, you can
7*77c1e3ccSAndroid Build Coastguard Worker  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8*77c1e3ccSAndroid Build Coastguard Worker  * Media Patent License 1.0 was not distributed with this source code in the
9*77c1e3ccSAndroid Build Coastguard Worker  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10*77c1e3ccSAndroid Build Coastguard Worker  */
11*77c1e3ccSAndroid Build Coastguard Worker 
12*77c1e3ccSAndroid Build Coastguard Worker #include <assert.h>
13*77c1e3ccSAndroid Build Coastguard Worker #include <immintrin.h>
14*77c1e3ccSAndroid Build Coastguard Worker #include <math.h>
15*77c1e3ccSAndroid Build Coastguard Worker 
16*77c1e3ccSAndroid Build Coastguard Worker #include "aom_dsp/aom_dsp_common.h"
17*77c1e3ccSAndroid Build Coastguard Worker #include "av1/common/av1_common_int.h"
18*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/cnn.h"
19*77c1e3ccSAndroid Build Coastguard Worker 
20*77c1e3ccSAndroid Build Coastguard Worker // This mask rearranges source pixels in the order shown below.
21*77c1e3ccSAndroid Build Coastguard Worker // shuffle_src_layer0[0][8]: applied on source pixels 0 to 7.
22*77c1e3ccSAndroid Build Coastguard Worker // shuffle_src_layer0[1][8]: applied on source pixels 7 to 14.
23*77c1e3ccSAndroid Build Coastguard Worker // This shuffling is needed to process 3 5x5 blocks which need
24*77c1e3ccSAndroid Build Coastguard Worker // source pixels in the following order.
25*77c1e3ccSAndroid Build Coastguard Worker // 1st 5x5 block: source pixels needed are 0 to 4,
26*77c1e3ccSAndroid Build Coastguard Worker // 2nd 5x5 block: source pixels needed are 4 to 8,
27*77c1e3ccSAndroid Build Coastguard Worker // 3rd 5x5 block: source pixels needed are 8 to 12.
28*77c1e3ccSAndroid Build Coastguard Worker // Source pixels are loaded like mentioned below.
29*77c1e3ccSAndroid Build Coastguard Worker // load_src0 : 0, 1, 2, 3, 4, 5, 6, 7
30*77c1e3ccSAndroid Build Coastguard Worker // load_src1 : 7, 8, 9, 10, 11, 12, 13, 14
31*77c1e3ccSAndroid Build Coastguard Worker // After applying masks, source bytes will be in the order:
32*77c1e3ccSAndroid Build Coastguard Worker // load_src0 : 0, 1, 2, 3, 4, 4, 5, 6
33*77c1e3ccSAndroid Build Coastguard Worker //             consists 5 pixels needed for 1st 5x5 block and
34*77c1e3ccSAndroid Build Coastguard Worker //             first 3 pixels needed for 2nd 5x5 block.
35*77c1e3ccSAndroid Build Coastguard Worker // load_src1 : 7, 8, 8, 9, 10, 11, 12, x
36*77c1e3ccSAndroid Build Coastguard Worker //             consists last 2 pixels needed for 2nd 5x5 block and
37*77c1e3ccSAndroid Build Coastguard Worker //             5 pixels needed for 3rd 5x5 block.
38*77c1e3ccSAndroid Build Coastguard Worker DECLARE_ALIGNED(32, static const uint32_t,
39*77c1e3ccSAndroid Build Coastguard Worker                 shuffle_src_layer0[2][8]) = { { 0, 1, 2, 3, 4, 4, 5, 6 },
40*77c1e3ccSAndroid Build Coastguard Worker                                               { 0, 1, 1, 2, 3, 4, 5, 0 } };
41*77c1e3ccSAndroid Build Coastguard Worker 
42*77c1e3ccSAndroid Build Coastguard Worker // This mask rearrange the weights to match shuffled source pixels order.
43*77c1e3ccSAndroid Build Coastguard Worker DECLARE_ALIGNED(32, static const uint32_t,
44*77c1e3ccSAndroid Build Coastguard Worker                 shuffle_weight_layer0[2][8]) = { { 0, 1, 2, 3, 4, 0, 1, 2 },
45*77c1e3ccSAndroid Build Coastguard Worker                                                  { 3, 4, 0, 1, 2, 3, 4, 0 } };
46*77c1e3ccSAndroid Build Coastguard Worker 
47*77c1e3ccSAndroid Build Coastguard Worker // Shuffle mask used to rearrange weights corresponding to layer 1 and layer 2.
48*77c1e3ccSAndroid Build Coastguard Worker // For layer 1 and layer 2, convolution happens at 2x2 as filter_width and
49*77c1e3ccSAndroid Build Coastguard Worker // filter_height are equal to 2. So rearranging the weights in the
50*77c1e3ccSAndroid Build Coastguard Worker // order shown below to match source pixels. Basically this mask replicates
51*77c1e3ccSAndroid Build Coastguard Worker // the weights across the width of 2.
52*77c1e3ccSAndroid Build Coastguard Worker DECLARE_ALIGNED(32, static const uint32_t,
53*77c1e3ccSAndroid Build Coastguard Worker                 shuffle_weight_layer_1_and_2[2][8]) = {
54*77c1e3ccSAndroid Build Coastguard Worker   { 0, 1, 0, 1, 0, 1, 0, 1 }, { 2, 3, 2, 3, 2, 3, 2, 3 }
55*77c1e3ccSAndroid Build Coastguard Worker };
56*77c1e3ccSAndroid Build Coastguard Worker 
57*77c1e3ccSAndroid Build Coastguard Worker // After the stages of multiplication and accumulation, the output values
58*77c1e3ccSAndroid Build Coastguard Worker // in the register will be jumbled. In order to store register into
59*77c1e3ccSAndroid Build Coastguard Worker // output buffer in a proper way, the following mask is applied on output
60*77c1e3ccSAndroid Build Coastguard Worker // register.
61*77c1e3ccSAndroid Build Coastguard Worker DECLARE_ALIGNED(32, static const uint32_t,
62*77c1e3ccSAndroid Build Coastguard Worker                 shuffle_output_layer_1_and_2[8]) = { 0, 1, 4, 5, 2, 3, 6, 7 };
63*77c1e3ccSAndroid Build Coastguard Worker 
64*77c1e3ccSAndroid Build Coastguard Worker // Load weights needed for layer 0 (for 5x5 block processing),
65*77c1e3ccSAndroid Build Coastguard Worker // and fill the registers appropriately to match source pixel mapping.
prepare_weights_for_5x5_convolve(const float * layer_config_weights,int off,float weight[5][8],const int cstep,__m256 * shuffle_weight,const __m256i weight_mask_0,const __m256i weight_mask_1)66*77c1e3ccSAndroid Build Coastguard Worker static inline void prepare_weights_for_5x5_convolve(
67*77c1e3ccSAndroid Build Coastguard Worker     const float *layer_config_weights, int off, float weight[5][8],
68*77c1e3ccSAndroid Build Coastguard Worker     const int cstep, __m256 *shuffle_weight, const __m256i weight_mask_0,
69*77c1e3ccSAndroid Build Coastguard Worker     const __m256i weight_mask_1) {
70*77c1e3ccSAndroid Build Coastguard Worker   for (int row = 0; row < 5; ++row) {
71*77c1e3ccSAndroid Build Coastguard Worker     for (int col = 0; col < 5; ++col) {
72*77c1e3ccSAndroid Build Coastguard Worker       weight[row][col] = layer_config_weights[off];
73*77c1e3ccSAndroid Build Coastguard Worker       off += cstep;
74*77c1e3ccSAndroid Build Coastguard Worker     }
75*77c1e3ccSAndroid Build Coastguard Worker   }
76*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[0] = _mm256_loadu_ps(weight[0]);
77*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[1] = _mm256_loadu_ps(weight[1]);
78*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[2] = _mm256_loadu_ps(weight[2]);
79*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[3] = _mm256_loadu_ps(weight[3]);
80*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[4] = _mm256_loadu_ps(weight[4]);
81*77c1e3ccSAndroid Build Coastguard Worker 
82*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[0] =
83*77c1e3ccSAndroid Build Coastguard Worker       _mm256_permutevar8x32_ps(shuffle_weight[0], weight_mask_0);
84*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[1] =
85*77c1e3ccSAndroid Build Coastguard Worker       _mm256_permutevar8x32_ps(shuffle_weight[1], weight_mask_0);
86*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[2] =
87*77c1e3ccSAndroid Build Coastguard Worker       _mm256_permutevar8x32_ps(shuffle_weight[2], weight_mask_0);
88*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[3] =
89*77c1e3ccSAndroid Build Coastguard Worker       _mm256_permutevar8x32_ps(shuffle_weight[3], weight_mask_0);
90*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[4] =
91*77c1e3ccSAndroid Build Coastguard Worker       _mm256_permutevar8x32_ps(shuffle_weight[4], weight_mask_0);
92*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[5] =
93*77c1e3ccSAndroid Build Coastguard Worker       _mm256_permutevar8x32_ps(shuffle_weight[0], weight_mask_1);
94*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[6] =
95*77c1e3ccSAndroid Build Coastguard Worker       _mm256_permutevar8x32_ps(shuffle_weight[1], weight_mask_1);
96*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[7] =
97*77c1e3ccSAndroid Build Coastguard Worker       _mm256_permutevar8x32_ps(shuffle_weight[2], weight_mask_1);
98*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[8] =
99*77c1e3ccSAndroid Build Coastguard Worker       _mm256_permutevar8x32_ps(shuffle_weight[3], weight_mask_1);
100*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[9] =
101*77c1e3ccSAndroid Build Coastguard Worker       _mm256_permutevar8x32_ps(shuffle_weight[4], weight_mask_1);
102*77c1e3ccSAndroid Build Coastguard Worker }
103*77c1e3ccSAndroid Build Coastguard Worker 
104*77c1e3ccSAndroid Build Coastguard Worker // For each row, loads source pixels 0 to 7(load_src_0), 7 to 14(load_src_1) and
105*77c1e3ccSAndroid Build Coastguard Worker // arranges them appropriately to process 3 blocks.
106*77c1e3ccSAndroid Build Coastguard Worker #define PERFORM_CONVOLVE_FOR_3_5X5_BLOCKS()                            \
107*77c1e3ccSAndroid Build Coastguard Worker   do {                                                                 \
108*77c1e3ccSAndroid Build Coastguard Worker     for (int row = 0; row < 5; row++) {                                \
109*77c1e3ccSAndroid Build Coastguard Worker       load_src_0 = _mm256_loadu_ps(input_ptr);                         \
110*77c1e3ccSAndroid Build Coastguard Worker       load_src_1 = _mm256_loadu_ps(input_ptr + 7);                     \
111*77c1e3ccSAndroid Build Coastguard Worker       load_src_0 = _mm256_permutevar8x32_ps(load_src_0, block0_1);     \
112*77c1e3ccSAndroid Build Coastguard Worker       load_src_1 = _mm256_permutevar8x32_ps(load_src_1, block1_2);     \
113*77c1e3ccSAndroid Build Coastguard Worker       load_src_0 = _mm256_mul_ps(load_src_0, shuffle_weight[0 + row]); \
114*77c1e3ccSAndroid Build Coastguard Worker       load_src_1 = _mm256_mul_ps(load_src_1, shuffle_weight[5 + row]); \
115*77c1e3ccSAndroid Build Coastguard Worker       accum_src_0 = _mm256_add_ps(load_src_0, accum_src_0);            \
116*77c1e3ccSAndroid Build Coastguard Worker       accum_src_1 = _mm256_add_ps(load_src_1, accum_src_1);            \
117*77c1e3ccSAndroid Build Coastguard Worker       input_ptr += in_stride;                                          \
118*77c1e3ccSAndroid Build Coastguard Worker     }                                                                  \
119*77c1e3ccSAndroid Build Coastguard Worker   } while (0)
120*77c1e3ccSAndroid Build Coastguard Worker 
121*77c1e3ccSAndroid Build Coastguard Worker // Load masks needed for shuffling of output and weights.
load_shuffle_masks_for_2x2_convolve(__m256i * output_mask,__m256i * weight_mask)122*77c1e3ccSAndroid Build Coastguard Worker static inline void load_shuffle_masks_for_2x2_convolve(__m256i *output_mask,
123*77c1e3ccSAndroid Build Coastguard Worker                                                        __m256i *weight_mask) {
124*77c1e3ccSAndroid Build Coastguard Worker   // Load shuffle buffer needed to sort the output.
125*77c1e3ccSAndroid Build Coastguard Worker   *output_mask =
126*77c1e3ccSAndroid Build Coastguard Worker       _mm256_load_si256((const __m256i *)shuffle_output_layer_1_and_2);
127*77c1e3ccSAndroid Build Coastguard Worker 
128*77c1e3ccSAndroid Build Coastguard Worker   // Load shuffle buffers needed for weight.
129*77c1e3ccSAndroid Build Coastguard Worker   weight_mask[0] =
130*77c1e3ccSAndroid Build Coastguard Worker       _mm256_load_si256((const __m256i *)shuffle_weight_layer_1_and_2[0]);
131*77c1e3ccSAndroid Build Coastguard Worker   weight_mask[1] =
132*77c1e3ccSAndroid Build Coastguard Worker       _mm256_load_si256((const __m256i *)shuffle_weight_layer_1_and_2[1]);
133*77c1e3ccSAndroid Build Coastguard Worker }
134*77c1e3ccSAndroid Build Coastguard Worker 
135*77c1e3ccSAndroid Build Coastguard Worker // Load weights needed for layer 1 and 2 (for 2x2 block processing),
136*77c1e3ccSAndroid Build Coastguard Worker // and fill the registers appropriately to match source pixel mapping.
prepare_weights_for_2x2_convolve(const float * layer_config_weights,int off,const int cstep,__m256 * shuffle_weight,__m256i * weight_mask)137*77c1e3ccSAndroid Build Coastguard Worker static inline void prepare_weights_for_2x2_convolve(
138*77c1e3ccSAndroid Build Coastguard Worker     const float *layer_config_weights, int off, const int cstep,
139*77c1e3ccSAndroid Build Coastguard Worker     __m256 *shuffle_weight, __m256i *weight_mask) {
140*77c1e3ccSAndroid Build Coastguard Worker   // Weights needed for 2x2 block.
141*77c1e3ccSAndroid Build Coastguard Worker   float weight[4] = { 0 };
142*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < 4; ++i) {
143*77c1e3ccSAndroid Build Coastguard Worker     weight[i] = layer_config_weights[off];
144*77c1e3ccSAndroid Build Coastguard Worker     off += cstep;
145*77c1e3ccSAndroid Build Coastguard Worker   }
146*77c1e3ccSAndroid Build Coastguard Worker 
147*77c1e3ccSAndroid Build Coastguard Worker   const __m256 weight_vec = _mm256_castps128_ps256(_mm_loadu_ps(weight));
148*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[0] = _mm256_permutevar8x32_ps(weight_vec, weight_mask[0]);
149*77c1e3ccSAndroid Build Coastguard Worker   shuffle_weight[1] = _mm256_permutevar8x32_ps(weight_vec, weight_mask[1]);
150*77c1e3ccSAndroid Build Coastguard Worker }
151*77c1e3ccSAndroid Build Coastguard Worker 
152*77c1e3ccSAndroid Build Coastguard Worker // Do convolution of one 5x5 block.
153*77c1e3ccSAndroid Build Coastguard Worker #define PERFORM_CONVOLVE_FOR_1_5X5_BLOCK(w, accum0, in_stride)           \
154*77c1e3ccSAndroid Build Coastguard Worker   do {                                                                   \
155*77c1e3ccSAndroid Build Coastguard Worker     __m128 load_src[5];                                                  \
156*77c1e3ccSAndroid Build Coastguard Worker     load_src[0] = _mm_loadu_ps(input_ptr);                               \
157*77c1e3ccSAndroid Build Coastguard Worker     last_column_sum += input_ptr[4] * weight[0][4];                      \
158*77c1e3ccSAndroid Build Coastguard Worker     input_ptr += in_stride;                                              \
159*77c1e3ccSAndroid Build Coastguard Worker     load_src[1] = _mm_loadu_ps(input_ptr);                               \
160*77c1e3ccSAndroid Build Coastguard Worker     last_column_sum += input_ptr[4] * weight[1][4];                      \
161*77c1e3ccSAndroid Build Coastguard Worker     input_ptr += in_stride;                                              \
162*77c1e3ccSAndroid Build Coastguard Worker     load_src[2] = _mm_loadu_ps(input_ptr);                               \
163*77c1e3ccSAndroid Build Coastguard Worker     last_column_sum += input_ptr[4] * weight[2][4];                      \
164*77c1e3ccSAndroid Build Coastguard Worker     input_ptr += in_stride;                                              \
165*77c1e3ccSAndroid Build Coastguard Worker     load_src[3] = _mm_loadu_ps(input_ptr);                               \
166*77c1e3ccSAndroid Build Coastguard Worker     last_column_sum += input_ptr[4] * weight[3][4];                      \
167*77c1e3ccSAndroid Build Coastguard Worker     input_ptr += in_stride;                                              \
168*77c1e3ccSAndroid Build Coastguard Worker     load_src[4] = _mm_loadu_ps(input_ptr);                               \
169*77c1e3ccSAndroid Build Coastguard Worker     last_column_sum += input_ptr[4] * weight[4][4];                      \
170*77c1e3ccSAndroid Build Coastguard Worker                                                                          \
171*77c1e3ccSAndroid Build Coastguard Worker     load_src[0] = _mm_mul_ps(load_src[0], _mm256_castps256_ps128(w[0])); \
172*77c1e3ccSAndroid Build Coastguard Worker     load_src[1] = _mm_mul_ps(load_src[1], _mm256_castps256_ps128(w[1])); \
173*77c1e3ccSAndroid Build Coastguard Worker     load_src[2] = _mm_mul_ps(load_src[2], _mm256_castps256_ps128(w[2])); \
174*77c1e3ccSAndroid Build Coastguard Worker     load_src[3] = _mm_mul_ps(load_src[3], _mm256_castps256_ps128(w[3])); \
175*77c1e3ccSAndroid Build Coastguard Worker     load_src[4] = _mm_mul_ps(load_src[4], _mm256_castps256_ps128(w[4])); \
176*77c1e3ccSAndroid Build Coastguard Worker                                                                          \
177*77c1e3ccSAndroid Build Coastguard Worker     accum0 = _mm_add_ps(load_src[0], accum0);                            \
178*77c1e3ccSAndroid Build Coastguard Worker     load_src[1] = _mm_add_ps(load_src[1], load_src[2]);                  \
179*77c1e3ccSAndroid Build Coastguard Worker     load_src[3] = _mm_add_ps(load_src[3], load_src[4]);                  \
180*77c1e3ccSAndroid Build Coastguard Worker     load_src[1] = _mm_add_ps(load_src[1], load_src[3]);                  \
181*77c1e3ccSAndroid Build Coastguard Worker     accum0 = _mm_add_ps(accum0, load_src[1]);                            \
182*77c1e3ccSAndroid Build Coastguard Worker   } while (0)
183*77c1e3ccSAndroid Build Coastguard Worker 
184*77c1e3ccSAndroid Build Coastguard Worker // Do convolution on 8 horizontal 2x2 blocks.
perform_convolve_for_8h_2x2_blocks(const float * input_ptr,int in_stride,__m256 * weight,__m256 * out_accum,__m256i shuffle_output_mask)185*77c1e3ccSAndroid Build Coastguard Worker static inline void perform_convolve_for_8h_2x2_blocks(
186*77c1e3ccSAndroid Build Coastguard Worker     const float *input_ptr, int in_stride, __m256 *weight, __m256 *out_accum,
187*77c1e3ccSAndroid Build Coastguard Worker     __m256i shuffle_output_mask) {
188*77c1e3ccSAndroid Build Coastguard Worker   __m256 load_src[4];
189*77c1e3ccSAndroid Build Coastguard Worker   // Load input into source registers.
190*77c1e3ccSAndroid Build Coastguard Worker   load_src[0] = _mm256_loadu_ps(input_ptr);
191*77c1e3ccSAndroid Build Coastguard Worker   load_src[1] = _mm256_loadu_ps(input_ptr + 8);
192*77c1e3ccSAndroid Build Coastguard Worker   load_src[2] = _mm256_loadu_ps(input_ptr + in_stride);
193*77c1e3ccSAndroid Build Coastguard Worker   load_src[3] = _mm256_loadu_ps(input_ptr + in_stride + 8);
194*77c1e3ccSAndroid Build Coastguard Worker 
195*77c1e3ccSAndroid Build Coastguard Worker   // Multiply the loaded input with corresponding weights.
196*77c1e3ccSAndroid Build Coastguard Worker   load_src[0] = _mm256_mul_ps(load_src[0], weight[0]);
197*77c1e3ccSAndroid Build Coastguard Worker   load_src[1] = _mm256_mul_ps(load_src[1], weight[0]);
198*77c1e3ccSAndroid Build Coastguard Worker   load_src[2] = _mm256_mul_ps(load_src[2], weight[1]);
199*77c1e3ccSAndroid Build Coastguard Worker   load_src[3] = _mm256_mul_ps(load_src[3], weight[1]);
200*77c1e3ccSAndroid Build Coastguard Worker 
201*77c1e3ccSAndroid Build Coastguard Worker   // Accumulate across 2x2 blocks.
202*77c1e3ccSAndroid Build Coastguard Worker   load_src[0] = _mm256_add_ps(load_src[0], load_src[2]);
203*77c1e3ccSAndroid Build Coastguard Worker   load_src[1] = _mm256_add_ps(load_src[1], load_src[3]);
204*77c1e3ccSAndroid Build Coastguard Worker   load_src[0] = _mm256_hadd_ps(load_src[0], load_src[1]);
205*77c1e3ccSAndroid Build Coastguard Worker 
206*77c1e3ccSAndroid Build Coastguard Worker   // Sort the output in order to store into output buffer.
207*77c1e3ccSAndroid Build Coastguard Worker   load_src[0] = _mm256_permutevar8x32_ps(load_src[0], shuffle_output_mask);
208*77c1e3ccSAndroid Build Coastguard Worker   *out_accum = _mm256_add_ps(*out_accum, load_src[0]);
209*77c1e3ccSAndroid Build Coastguard Worker }
210*77c1e3ccSAndroid Build Coastguard Worker 
211*77c1e3ccSAndroid Build Coastguard Worker // Do convolution on 8 (4 horizontal x 2 vertical) 2x2 blocks.
perform_convolve_for_4hx2v_2x2_blocks(const float * input_ptr,int in_stride,__m256 * weight,__m256 * out_accum,__m256i shuffle_output_mask)212*77c1e3ccSAndroid Build Coastguard Worker static inline void perform_convolve_for_4hx2v_2x2_blocks(
213*77c1e3ccSAndroid Build Coastguard Worker     const float *input_ptr, int in_stride, __m256 *weight, __m256 *out_accum,
214*77c1e3ccSAndroid Build Coastguard Worker     __m256i shuffle_output_mask) {
215*77c1e3ccSAndroid Build Coastguard Worker   __m256 load_src[4];
216*77c1e3ccSAndroid Build Coastguard Worker   // Load input into source registers.
217*77c1e3ccSAndroid Build Coastguard Worker   load_src[0] = _mm256_loadu_ps(input_ptr);
218*77c1e3ccSAndroid Build Coastguard Worker   load_src[1] = _mm256_loadu_ps(input_ptr + in_stride);
219*77c1e3ccSAndroid Build Coastguard Worker   load_src[2] = _mm256_loadu_ps(input_ptr + (in_stride * 2));
220*77c1e3ccSAndroid Build Coastguard Worker   load_src[3] = _mm256_loadu_ps(input_ptr + (in_stride * 3));
221*77c1e3ccSAndroid Build Coastguard Worker 
222*77c1e3ccSAndroid Build Coastguard Worker   // Multiply the loaded input with corresponding weights.
223*77c1e3ccSAndroid Build Coastguard Worker   load_src[0] = _mm256_mul_ps(load_src[0], weight[0]);
224*77c1e3ccSAndroid Build Coastguard Worker   load_src[1] = _mm256_mul_ps(load_src[1], weight[1]);
225*77c1e3ccSAndroid Build Coastguard Worker   load_src[2] = _mm256_mul_ps(load_src[2], weight[0]);
226*77c1e3ccSAndroid Build Coastguard Worker   load_src[3] = _mm256_mul_ps(load_src[3], weight[1]);
227*77c1e3ccSAndroid Build Coastguard Worker 
228*77c1e3ccSAndroid Build Coastguard Worker   // Accumulate across 2x2 blocks.
229*77c1e3ccSAndroid Build Coastguard Worker   load_src[0] = _mm256_add_ps(load_src[0], load_src[1]);
230*77c1e3ccSAndroid Build Coastguard Worker   load_src[2] = _mm256_add_ps(load_src[2], load_src[3]);
231*77c1e3ccSAndroid Build Coastguard Worker   load_src[0] = _mm256_hadd_ps(load_src[0], load_src[2]);
232*77c1e3ccSAndroid Build Coastguard Worker 
233*77c1e3ccSAndroid Build Coastguard Worker   // Sort the output in order to store into output buffer.
234*77c1e3ccSAndroid Build Coastguard Worker   load_src[0] = _mm256_permutevar8x32_ps(load_src[0], shuffle_output_mask);
235*77c1e3ccSAndroid Build Coastguard Worker   *out_accum = _mm256_add_ps(*out_accum, load_src[0]);
236*77c1e3ccSAndroid Build Coastguard Worker }
237*77c1e3ccSAndroid Build Coastguard Worker 
238*77c1e3ccSAndroid Build Coastguard Worker // AVX2 variant of av1_cnn_convolve_no_maxpool_padding_valid_c(), when
239*77c1e3ccSAndroid Build Coastguard Worker // filter_width and filter_height are equal to 5.
240*77c1e3ccSAndroid Build Coastguard Worker // CNN convolve parsing is based on av1_intra_mode_cnn_partition_cnn_config.
241*77c1e3ccSAndroid Build Coastguard Worker // Based on the configuration set for each layer, the current encoder
242*77c1e3ccSAndroid Build Coastguard Worker // always chooses the case of no_maxpool_padding_valid.
243*77c1e3ccSAndroid Build Coastguard Worker // And also for layer 0 convolution happens at 5x5 level as the
244*77c1e3ccSAndroid Build Coastguard Worker // filter_width and filter_height are set as 5.
cnn_convolve_no_maxpool_padding_valid_5x5_avx2(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,const int cstep,const int channel_step)245*77c1e3ccSAndroid Build Coastguard Worker static void cnn_convolve_no_maxpool_padding_valid_5x5_avx2(
246*77c1e3ccSAndroid Build Coastguard Worker     const float **input, int in_width, int in_height, int in_stride,
247*77c1e3ccSAndroid Build Coastguard Worker     const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
248*77c1e3ccSAndroid Build Coastguard Worker     int start_idx, const int cstep, const int channel_step) {
249*77c1e3ccSAndroid Build Coastguard Worker   const int kFilterWidth = 5;
250*77c1e3ccSAndroid Build Coastguard Worker   const int kFilterHeight = 5;
251*77c1e3ccSAndroid Build Coastguard Worker   const int kSkipWidth = 4;
252*77c1e3ccSAndroid Build Coastguard Worker   const int kSkipHeight = 4;
253*77c1e3ccSAndroid Build Coastguard Worker   assert(layer_config->filter_width == kFilterWidth &&
254*77c1e3ccSAndroid Build Coastguard Worker          layer_config->filter_height == kFilterHeight);
255*77c1e3ccSAndroid Build Coastguard Worker   assert(layer_config->skip_width == kSkipWidth &&
256*77c1e3ccSAndroid Build Coastguard Worker          layer_config->skip_height == kSkipHeight);
257*77c1e3ccSAndroid Build Coastguard Worker 
258*77c1e3ccSAndroid Build Coastguard Worker   // Load shuffle buffers needed for source.
259*77c1e3ccSAndroid Build Coastguard Worker   const __m256i block0_1 =
260*77c1e3ccSAndroid Build Coastguard Worker       _mm256_load_si256((const __m256i *)shuffle_src_layer0[0]);
261*77c1e3ccSAndroid Build Coastguard Worker   const __m256i block1_2 =
262*77c1e3ccSAndroid Build Coastguard Worker       _mm256_load_si256((const __m256i *)shuffle_src_layer0[1]);
263*77c1e3ccSAndroid Build Coastguard Worker 
264*77c1e3ccSAndroid Build Coastguard Worker   // Load shuffle buffers needed for weight.
265*77c1e3ccSAndroid Build Coastguard Worker   const __m256i weight_mask_0 =
266*77c1e3ccSAndroid Build Coastguard Worker       _mm256_load_si256((const __m256i *)shuffle_weight_layer0[0]);
267*77c1e3ccSAndroid Build Coastguard Worker   const __m256i weight_mask_1 =
268*77c1e3ccSAndroid Build Coastguard Worker       _mm256_load_si256((const __m256i *)shuffle_weight_layer0[1]);
269*77c1e3ccSAndroid Build Coastguard Worker 
270*77c1e3ccSAndroid Build Coastguard Worker   // Width needs to be moved to go to next iteration of processing 3 5x5 blocks.
271*77c1e3ccSAndroid Build Coastguard Worker   const int kSkipWidthForNextIter = kSkipWidth * 3;
272*77c1e3ccSAndroid Build Coastguard Worker 
273*77c1e3ccSAndroid Build Coastguard Worker   // Minimum width required to process 3 5x5 blocks at a time.
274*77c1e3ccSAndroid Build Coastguard Worker   // min width (for processing 3 5x5 block) = 2*skip_width + filter_width
275*77c1e3ccSAndroid Build Coastguard Worker   // Here, skip_width specifies how much width we should move while processing
276*77c1e3ccSAndroid Build Coastguard Worker   // next block convolution and filter_width specifies for how many pixels
277*77c1e3ccSAndroid Build Coastguard Worker   // filter needs to be applied.
278*77c1e3ccSAndroid Build Coastguard Worker   const int kMinWidthFor3_5x5Blocks = (kSkipWidth * 2) + kFilterWidth;
279*77c1e3ccSAndroid Build Coastguard Worker   for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
280*77c1e3ccSAndroid Build Coastguard Worker     const float out_ch_bias = layer_config->bias[i];
281*77c1e3ccSAndroid Build Coastguard Worker     for (int k = 0; k < layer_config->in_channels; ++k) {
282*77c1e3ccSAndroid Build Coastguard Worker       __m256 shuffle_weight[10];
283*77c1e3ccSAndroid Build Coastguard Worker 
284*77c1e3ccSAndroid Build Coastguard Worker       // Weights needed are 5x5, for SIMD purpose made this array as 5x8.
285*77c1e3ccSAndroid Build Coastguard Worker       float weight[5][8] = { { 0 } };
286*77c1e3ccSAndroid Build Coastguard Worker       int off = k * layer_config->out_channels + i;
287*77c1e3ccSAndroid Build Coastguard Worker 
288*77c1e3ccSAndroid Build Coastguard Worker       // In layer 0, the convolution process happens at 5x5.
289*77c1e3ccSAndroid Build Coastguard Worker       // The weights needed for 5x5 block are same across the in-channels,
290*77c1e3ccSAndroid Build Coastguard Worker       // which is why the load of weights happens once for each in-channel.
291*77c1e3ccSAndroid Build Coastguard Worker       prepare_weights_for_5x5_convolve(layer_config->weights, off, weight,
292*77c1e3ccSAndroid Build Coastguard Worker                                        cstep, shuffle_weight, weight_mask_0,
293*77c1e3ccSAndroid Build Coastguard Worker                                        weight_mask_1);
294*77c1e3ccSAndroid Build Coastguard Worker 
295*77c1e3ccSAndroid Build Coastguard Worker       for (int h = 0, u = 0; h < in_height - kFilterHeight + 1;
296*77c1e3ccSAndroid Build Coastguard Worker            h += kSkipHeight, ++u) {
297*77c1e3ccSAndroid Build Coastguard Worker         const int out_h = u * out_stride;
298*77c1e3ccSAndroid Build Coastguard Worker         int v = 0;
299*77c1e3ccSAndroid Build Coastguard Worker         int w = 0;
300*77c1e3ccSAndroid Build Coastguard Worker         int rem_width = in_width;
301*77c1e3ccSAndroid Build Coastguard Worker         // Processing 3 5x5 blocks at a time, if sufficient width is present.
302*77c1e3ccSAndroid Build Coastguard Worker         while (rem_width >= kMinWidthFor3_5x5Blocks) {
303*77c1e3ccSAndroid Build Coastguard Worker           __m256 load_src_0, load_src_1;
304*77c1e3ccSAndroid Build Coastguard Worker           __m256 accum_src_0 = _mm256_setzero_ps();
305*77c1e3ccSAndroid Build Coastguard Worker           __m256 accum_src_1 = _mm256_setzero_ps();
306*77c1e3ccSAndroid Build Coastguard Worker           const float *input_ptr = &input[k][h * in_stride + w];
307*77c1e3ccSAndroid Build Coastguard Worker           PERFORM_CONVOLVE_FOR_3_5X5_BLOCKS();
308*77c1e3ccSAndroid Build Coastguard Worker 
309*77c1e3ccSAndroid Build Coastguard Worker           // Accumulate across column.
310*77c1e3ccSAndroid Build Coastguard Worker           __m256 accum = _mm256_hadd_ps(accum_src_0, accum_src_1);
311*77c1e3ccSAndroid Build Coastguard Worker           __m128 tmp_reg_0 = _mm256_extractf128_ps(accum_src_0, 1);
312*77c1e3ccSAndroid Build Coastguard Worker           __m128 tmp_reg_1 = _mm256_extractf128_ps(accum_src_1, 1);
313*77c1e3ccSAndroid Build Coastguard Worker 
314*77c1e3ccSAndroid Build Coastguard Worker           __m128 accum_l = _mm256_castps256_ps128(accum);
315*77c1e3ccSAndroid Build Coastguard Worker           __m128 accum_h = _mm256_extractf128_ps(accum, 1);
316*77c1e3ccSAndroid Build Coastguard Worker 
317*77c1e3ccSAndroid Build Coastguard Worker           __m128 tmp_reg_2 = _mm_add_ps(accum_l, tmp_reg_0);
318*77c1e3ccSAndroid Build Coastguard Worker           __m128 tmp_reg_3 = _mm_add_ps(tmp_reg_0, accum_h);
319*77c1e3ccSAndroid Build Coastguard Worker           __m128 tmp_reg_4 = _mm_add_ps(tmp_reg_1, accum_h);
320*77c1e3ccSAndroid Build Coastguard Worker 
321*77c1e3ccSAndroid Build Coastguard Worker           // 1st 5x5 block output.
322*77c1e3ccSAndroid Build Coastguard Worker           output[i][out_h + v] =
323*77c1e3ccSAndroid Build Coastguard Worker               out_ch_bias + _mm_cvtss_f32(tmp_reg_2) +
324*77c1e3ccSAndroid Build Coastguard Worker               _mm_cvtss_f32(_mm_shuffle_ps(accum_l, accum_l, 1));
325*77c1e3ccSAndroid Build Coastguard Worker 
326*77c1e3ccSAndroid Build Coastguard Worker           // 2nd 5x5 block output.
327*77c1e3ccSAndroid Build Coastguard Worker           output[i][out_h + v + 1] =
328*77c1e3ccSAndroid Build Coastguard Worker               out_ch_bias +
329*77c1e3ccSAndroid Build Coastguard Worker               _mm_cvtss_f32(_mm_shuffle_ps(tmp_reg_3, tmp_reg_3, 1)) +
330*77c1e3ccSAndroid Build Coastguard Worker               _mm_cvtss_f32(_mm_shuffle_ps(accum_l, accum_l, 2));
331*77c1e3ccSAndroid Build Coastguard Worker 
332*77c1e3ccSAndroid Build Coastguard Worker           // 3rd 5x5 block output.
333*77c1e3ccSAndroid Build Coastguard Worker           output[i][out_h + v + 2] =
334*77c1e3ccSAndroid Build Coastguard Worker               out_ch_bias +
335*77c1e3ccSAndroid Build Coastguard Worker               _mm_cvtss_f32(_mm_shuffle_ps(tmp_reg_4, tmp_reg_4, 2)) +
336*77c1e3ccSAndroid Build Coastguard Worker               _mm_cvtss_f32(_mm_shuffle_ps(accum_l, accum_l, 3));
337*77c1e3ccSAndroid Build Coastguard Worker 
338*77c1e3ccSAndroid Build Coastguard Worker           v += 3;
339*77c1e3ccSAndroid Build Coastguard Worker           w += kSkipWidthForNextIter;
340*77c1e3ccSAndroid Build Coastguard Worker           rem_width -= kSkipWidthForNextIter;
341*77c1e3ccSAndroid Build Coastguard Worker         }
342*77c1e3ccSAndroid Build Coastguard Worker 
343*77c1e3ccSAndroid Build Coastguard Worker         // Process remaining blocks as single 5x5 block at a time.
344*77c1e3ccSAndroid Build Coastguard Worker         while (rem_width >= kFilterWidth) {
345*77c1e3ccSAndroid Build Coastguard Worker           float last_column_sum = 0;
346*77c1e3ccSAndroid Build Coastguard Worker           __m128 accum = _mm_setzero_ps();
347*77c1e3ccSAndroid Build Coastguard Worker           const float *input_ptr = &input[k][h * in_stride + w];
348*77c1e3ccSAndroid Build Coastguard Worker           PERFORM_CONVOLVE_FOR_1_5X5_BLOCK(shuffle_weight, accum, in_stride);
349*77c1e3ccSAndroid Build Coastguard Worker 
350*77c1e3ccSAndroid Build Coastguard Worker           // Accumulate across column.
351*77c1e3ccSAndroid Build Coastguard Worker           accum = _mm_hadd_ps(accum, accum);
352*77c1e3ccSAndroid Build Coastguard Worker           output[i][out_h + v] = out_ch_bias + last_column_sum +
353*77c1e3ccSAndroid Build Coastguard Worker                                  _mm_cvtss_f32(accum) +
354*77c1e3ccSAndroid Build Coastguard Worker                                  _mm_cvtss_f32(_mm_shuffle_ps(accum, accum, 1));
355*77c1e3ccSAndroid Build Coastguard Worker 
356*77c1e3ccSAndroid Build Coastguard Worker           v += 1;
357*77c1e3ccSAndroid Build Coastguard Worker           w += kSkipWidth;
358*77c1e3ccSAndroid Build Coastguard Worker           rem_width -= kSkipWidth;
359*77c1e3ccSAndroid Build Coastguard Worker         }
360*77c1e3ccSAndroid Build Coastguard Worker       }
361*77c1e3ccSAndroid Build Coastguard Worker     }
362*77c1e3ccSAndroid Build Coastguard Worker   }
363*77c1e3ccSAndroid Build Coastguard Worker }
364*77c1e3ccSAndroid Build Coastguard Worker 
365*77c1e3ccSAndroid Build Coastguard Worker // AVX2 implementation for layer 1.
cnn_convolve_no_maxpool_padding_valid_layer1_avx2(const float ** input,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,const int cstep,const int channel_step)366*77c1e3ccSAndroid Build Coastguard Worker static inline void cnn_convolve_no_maxpool_padding_valid_layer1_avx2(
367*77c1e3ccSAndroid Build Coastguard Worker     const float **input, int in_stride,
368*77c1e3ccSAndroid Build Coastguard Worker     const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
369*77c1e3ccSAndroid Build Coastguard Worker     int start_idx, const int cstep, const int channel_step) {
370*77c1e3ccSAndroid Build Coastguard Worker   __m256i weight_mask[2];
371*77c1e3ccSAndroid Build Coastguard Worker   __m256i shuffle_output_mask;
372*77c1e3ccSAndroid Build Coastguard Worker   load_shuffle_masks_for_2x2_convolve(&shuffle_output_mask, weight_mask);
373*77c1e3ccSAndroid Build Coastguard Worker 
374*77c1e3ccSAndroid Build Coastguard Worker   const int kInHeight = 16;
375*77c1e3ccSAndroid Build Coastguard Worker   const int kFilterHeight = 2;
376*77c1e3ccSAndroid Build Coastguard Worker   const int kSkipHeight = 2;
377*77c1e3ccSAndroid Build Coastguard Worker   for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
378*77c1e3ccSAndroid Build Coastguard Worker     __m256 bias_reg = _mm256_set1_ps(layer_config->bias[i]);
379*77c1e3ccSAndroid Build Coastguard Worker     // out_accum registers are used to store the 2x2 convolve outputs
380*77c1e3ccSAndroid Build Coastguard Worker     // (calculated over input block size), which are accumulated across the
381*77c1e3ccSAndroid Build Coastguard Worker     // in_channels. As per the design, each iteration of for loop processes 8
382*77c1e3ccSAndroid Build Coastguard Worker     // (horizontal) 2x2 blocks and stores in corresponding out_accum register
383*77c1e3ccSAndroid Build Coastguard Worker     // (as input size is 16x16, a total of 64 2x2 blocks are present and 8
384*77c1e3ccSAndroid Build Coastguard Worker     // out_accum registers are enough to store the outputs).
385*77c1e3ccSAndroid Build Coastguard Worker     // Hence for loops corresponding to 'j' and 'h', below, run over the number
386*77c1e3ccSAndroid Build Coastguard Worker     // of out_accum registers.
387*77c1e3ccSAndroid Build Coastguard Worker     __m256 out_accum[8];
388*77c1e3ccSAndroid Build Coastguard Worker     for (int j = 0; j < 8; ++j) out_accum[j] = bias_reg;
389*77c1e3ccSAndroid Build Coastguard Worker     for (int k = 0; k < layer_config->in_channels; ++k) {
390*77c1e3ccSAndroid Build Coastguard Worker       __m256 shuffle_weight[2];
391*77c1e3ccSAndroid Build Coastguard Worker       int off = k * layer_config->out_channels + i;
392*77c1e3ccSAndroid Build Coastguard Worker       // In layer 1, the convolution process happens at 2x2.
393*77c1e3ccSAndroid Build Coastguard Worker       // The weights needed for 2x2 block are same across the in-channels,
394*77c1e3ccSAndroid Build Coastguard Worker       // which is why the load of weights happens once for each in-channel.
395*77c1e3ccSAndroid Build Coastguard Worker       prepare_weights_for_2x2_convolve(layer_config->weights, off, cstep,
396*77c1e3ccSAndroid Build Coastguard Worker                                        shuffle_weight, weight_mask);
397*77c1e3ccSAndroid Build Coastguard Worker 
398*77c1e3ccSAndroid Build Coastguard Worker       for (int h = 0, u = 0; h < kInHeight - kFilterHeight + 1;
399*77c1e3ccSAndroid Build Coastguard Worker            h += kSkipHeight, ++u) {
400*77c1e3ccSAndroid Build Coastguard Worker         const float *input_ptr = &input[k][h * in_stride];
401*77c1e3ccSAndroid Build Coastguard Worker         perform_convolve_for_8h_2x2_blocks(input_ptr, in_stride, shuffle_weight,
402*77c1e3ccSAndroid Build Coastguard Worker                                            &out_accum[u], shuffle_output_mask);
403*77c1e3ccSAndroid Build Coastguard Worker       }
404*77c1e3ccSAndroid Build Coastguard Worker     }
405*77c1e3ccSAndroid Build Coastguard Worker     // Store output of layer 1.
406*77c1e3ccSAndroid Build Coastguard Worker     for (int j = 0; j < 8; ++j) {
407*77c1e3ccSAndroid Build Coastguard Worker       _mm256_storeu_ps(&output[i][j * out_stride], out_accum[j]);
408*77c1e3ccSAndroid Build Coastguard Worker     }
409*77c1e3ccSAndroid Build Coastguard Worker   }
410*77c1e3ccSAndroid Build Coastguard Worker }
411*77c1e3ccSAndroid Build Coastguard Worker 
412*77c1e3ccSAndroid Build Coastguard Worker // AVX2 implementation for layer 2.
cnn_convolve_no_maxpool_padding_valid_layer2_avx2(const float ** input,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,const int cstep,const int channel_step)413*77c1e3ccSAndroid Build Coastguard Worker static inline void cnn_convolve_no_maxpool_padding_valid_layer2_avx2(
414*77c1e3ccSAndroid Build Coastguard Worker     const float **input, int in_stride,
415*77c1e3ccSAndroid Build Coastguard Worker     const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
416*77c1e3ccSAndroid Build Coastguard Worker     int start_idx, const int cstep, const int channel_step) {
417*77c1e3ccSAndroid Build Coastguard Worker   __m256i weight_mask[2];
418*77c1e3ccSAndroid Build Coastguard Worker   __m256i shuffle_output_mask;
419*77c1e3ccSAndroid Build Coastguard Worker   load_shuffle_masks_for_2x2_convolve(&shuffle_output_mask, weight_mask);
420*77c1e3ccSAndroid Build Coastguard Worker 
421*77c1e3ccSAndroid Build Coastguard Worker   const int kInHeight = 8;
422*77c1e3ccSAndroid Build Coastguard Worker   const int kFilterHeight = 2;
423*77c1e3ccSAndroid Build Coastguard Worker   const int kSkipHeight = 2;
424*77c1e3ccSAndroid Build Coastguard Worker   for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
425*77c1e3ccSAndroid Build Coastguard Worker     __m256 bias_reg = _mm256_set1_ps(layer_config->bias[i]);
426*77c1e3ccSAndroid Build Coastguard Worker     // out_accum registers are used to store the 2x2 convolve outputs
427*77c1e3ccSAndroid Build Coastguard Worker     // (calculated over input block size), which are accumulated across the
428*77c1e3ccSAndroid Build Coastguard Worker     // in_channels. As per the design, each iteration of for loop processes 8
429*77c1e3ccSAndroid Build Coastguard Worker     // (4 horizontal x 2 vertical) 2x2 blocks and stores in corresponding
430*77c1e3ccSAndroid Build Coastguard Worker     // out_accum register (as input size is 8x8, a total of 16 2x2 blocks are
431*77c1e3ccSAndroid Build Coastguard Worker     // present and 2 out_accum registers are enough to store the outputs).
432*77c1e3ccSAndroid Build Coastguard Worker     // Hence for loops corresponding to 'j' and 'h', below, run over the number
433*77c1e3ccSAndroid Build Coastguard Worker     // of out_accum registers.
434*77c1e3ccSAndroid Build Coastguard Worker     __m256 out_accum[2];
435*77c1e3ccSAndroid Build Coastguard Worker 
436*77c1e3ccSAndroid Build Coastguard Worker     // Height needs to be moved to go to next iteration of processing
437*77c1e3ccSAndroid Build Coastguard Worker     // while processing 2 2x2 blocks vertically.
438*77c1e3ccSAndroid Build Coastguard Worker     const int kSkipHeightForNextIter = kSkipHeight * 2;
439*77c1e3ccSAndroid Build Coastguard Worker     for (int j = 0; j < 2; ++j) out_accum[j] = bias_reg;
440*77c1e3ccSAndroid Build Coastguard Worker     for (int k = 0; k < layer_config->in_channels; ++k) {
441*77c1e3ccSAndroid Build Coastguard Worker       __m256 shuffle_weight[2];
442*77c1e3ccSAndroid Build Coastguard Worker       int off = k * layer_config->out_channels + i;
443*77c1e3ccSAndroid Build Coastguard Worker       // In layer 2, the convolution process happens at 2x2.
444*77c1e3ccSAndroid Build Coastguard Worker       // The weights needed for 2x2 block are same across the in-channels,
445*77c1e3ccSAndroid Build Coastguard Worker       // which is why the load of weights happens once for each in-channel.
446*77c1e3ccSAndroid Build Coastguard Worker       prepare_weights_for_2x2_convolve(layer_config->weights, off, cstep,
447*77c1e3ccSAndroid Build Coastguard Worker                                        shuffle_weight, weight_mask);
448*77c1e3ccSAndroid Build Coastguard Worker 
449*77c1e3ccSAndroid Build Coastguard Worker       for (int h = 0, u = 0; h < kInHeight - kFilterHeight + 1;
450*77c1e3ccSAndroid Build Coastguard Worker            h += kSkipHeightForNextIter, ++u) {
451*77c1e3ccSAndroid Build Coastguard Worker         const float *input_ptr = &input[k][h * in_stride];
452*77c1e3ccSAndroid Build Coastguard Worker         perform_convolve_for_4hx2v_2x2_blocks(input_ptr, in_stride,
453*77c1e3ccSAndroid Build Coastguard Worker                                               shuffle_weight, &out_accum[u],
454*77c1e3ccSAndroid Build Coastguard Worker                                               shuffle_output_mask);
455*77c1e3ccSAndroid Build Coastguard Worker       }
456*77c1e3ccSAndroid Build Coastguard Worker     }
457*77c1e3ccSAndroid Build Coastguard Worker     // Store output of layer 2.
458*77c1e3ccSAndroid Build Coastguard Worker     for (int j = 0; j < 2; ++j) {
459*77c1e3ccSAndroid Build Coastguard Worker       _mm256_storeu_ps(&output[i][j * out_stride * 2], out_accum[j]);
460*77c1e3ccSAndroid Build Coastguard Worker     }
461*77c1e3ccSAndroid Build Coastguard Worker   }
462*77c1e3ccSAndroid Build Coastguard Worker }
463*77c1e3ccSAndroid Build Coastguard Worker 
464*77c1e3ccSAndroid Build Coastguard Worker // AVX2 variant of av1_cnn_convolve_no_maxpool_padding_valid_c(), when
465*77c1e3ccSAndroid Build Coastguard Worker // filter_width and filter_height are equal to 2.
466*77c1e3ccSAndroid Build Coastguard Worker // As per the layer config set by av1_intra_mode_cnn_partition_cnn_config,
467*77c1e3ccSAndroid Build Coastguard Worker // the filter_width and filter_height are equal to 2 for layer >= 1. So
468*77c1e3ccSAndroid Build Coastguard Worker // convolution happens at 2x2 for layer >= 1.
cnn_convolve_no_maxpool_padding_valid_2x2_avx2(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * const layer_config,float ** output,int out_stride,int start_idx,const int cstep,const int channel_step)469*77c1e3ccSAndroid Build Coastguard Worker static void cnn_convolve_no_maxpool_padding_valid_2x2_avx2(
470*77c1e3ccSAndroid Build Coastguard Worker     const float **input, int in_width, int in_height, int in_stride,
471*77c1e3ccSAndroid Build Coastguard Worker     const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
472*77c1e3ccSAndroid Build Coastguard Worker     int start_idx, const int cstep, const int channel_step) {
473*77c1e3ccSAndroid Build Coastguard Worker   assert(layer_config->filter_width == 2 && layer_config->filter_height == 2);
474*77c1e3ccSAndroid Build Coastguard Worker   assert(layer_config->skip_width == 2 && layer_config->skip_height == 2);
475*77c1e3ccSAndroid Build Coastguard Worker 
476*77c1e3ccSAndroid Build Coastguard Worker   if (in_width == 16 && in_height == 16) {
477*77c1e3ccSAndroid Build Coastguard Worker     // This case of in_width and in_height equal to 16 corresponds to layer 1.
478*77c1e3ccSAndroid Build Coastguard Worker     // The output size of this layer is 8x8.
479*77c1e3ccSAndroid Build Coastguard Worker     cnn_convolve_no_maxpool_padding_valid_layer1_avx2(
480*77c1e3ccSAndroid Build Coastguard Worker         input, in_stride, layer_config, output, out_stride, start_idx, cstep,
481*77c1e3ccSAndroid Build Coastguard Worker         channel_step);
482*77c1e3ccSAndroid Build Coastguard Worker   } else if (in_width == 8 && in_height == 8) {
483*77c1e3ccSAndroid Build Coastguard Worker     // This case of in_width and in_height equal to 8 corresponds to layer 2.
484*77c1e3ccSAndroid Build Coastguard Worker     // The output size of this layer is 4x4.
485*77c1e3ccSAndroid Build Coastguard Worker     cnn_convolve_no_maxpool_padding_valid_layer2_avx2(
486*77c1e3ccSAndroid Build Coastguard Worker         input, in_stride, layer_config, output, out_stride, start_idx, cstep,
487*77c1e3ccSAndroid Build Coastguard Worker         channel_step);
488*77c1e3ccSAndroid Build Coastguard Worker   } else {
489*77c1e3ccSAndroid Build Coastguard Worker     // For layer equal to 3 and 4, the input is of size 4x4 and 2x2
490*77c1e3ccSAndroid Build Coastguard Worker     // respectively. Implementing SIMD for these cases might not be optimal,
491*77c1e3ccSAndroid Build Coastguard Worker     // which is why we call C path for layer >= 3.
492*77c1e3ccSAndroid Build Coastguard Worker     av1_cnn_convolve_no_maxpool_padding_valid_c(
493*77c1e3ccSAndroid Build Coastguard Worker         input, in_width, in_height, in_stride, layer_config, output, out_stride,
494*77c1e3ccSAndroid Build Coastguard Worker         start_idx, cstep, channel_step);
495*77c1e3ccSAndroid Build Coastguard Worker   }
496*77c1e3ccSAndroid Build Coastguard Worker }
497*77c1e3ccSAndroid Build Coastguard Worker 
498*77c1e3ccSAndroid Build Coastguard Worker // AVX2 variant of av1_cnn_convolve_no_maxpool_padding_valid_c().
499*77c1e3ccSAndroid Build Coastguard Worker // As per the current encoder, av1_cnn_convolve function gets called for
500*77c1e3ccSAndroid Build Coastguard Worker // block size equal to 64x64. av1_cnn_convolve() uses layer config values
501*77c1e3ccSAndroid Build Coastguard Worker // set by av1_intra_mode_cnn_partition_cnn_config. The following are a few
502*77c1e3ccSAndroid Build Coastguard Worker // details related to each layer's config parameters.
503*77c1e3ccSAndroid Build Coastguard Worker // Layer_Number in_size out_size filter_wd filter_ht skip_wd skip_ht
504*77c1e3ccSAndroid Build Coastguard Worker //     0         64x64    16x16      5         5         4       4
505*77c1e3ccSAndroid Build Coastguard Worker //     1         16x16    8x8        2         2         2       2
506*77c1e3ccSAndroid Build Coastguard Worker //     2         8x8      4x4        2         2         2       2
507*77c1e3ccSAndroid Build Coastguard Worker //     3         4x4      2x2        2         2         2       2
508*77c1e3ccSAndroid Build Coastguard Worker //     4         2x2      1x1        2         2         2       2
509*77c1e3ccSAndroid Build Coastguard Worker // Here,
510*77c1e3ccSAndroid Build Coastguard Worker // filter_wd = filter_width and filter_ht = filter_height,
511*77c1e3ccSAndroid Build Coastguard Worker // skip_wd = skip_width and skip_ht = skip_height.
av1_cnn_convolve_no_maxpool_padding_valid_avx2(const float ** input,int in_width,int in_height,int in_stride,const CNN_LAYER_CONFIG * layer_config,float ** output,int out_stride,int start_idx,int cstep,int channel_step)512*77c1e3ccSAndroid Build Coastguard Worker void av1_cnn_convolve_no_maxpool_padding_valid_avx2(
513*77c1e3ccSAndroid Build Coastguard Worker     const float **input, int in_width, int in_height, int in_stride,
514*77c1e3ccSAndroid Build Coastguard Worker     const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride,
515*77c1e3ccSAndroid Build Coastguard Worker     int start_idx, int cstep, int channel_step) {
516*77c1e3ccSAndroid Build Coastguard Worker   if (layer_config->filter_width == 5 && layer_config->filter_height == 5 &&
517*77c1e3ccSAndroid Build Coastguard Worker       layer_config->skip_width == 4 && layer_config->skip_height == 4) {
518*77c1e3ccSAndroid Build Coastguard Worker     cnn_convolve_no_maxpool_padding_valid_5x5_avx2(
519*77c1e3ccSAndroid Build Coastguard Worker         input, in_width, in_height, in_stride, layer_config, output, out_stride,
520*77c1e3ccSAndroid Build Coastguard Worker         start_idx, cstep, channel_step);
521*77c1e3ccSAndroid Build Coastguard Worker   } else if (layer_config->filter_width == 2 &&
522*77c1e3ccSAndroid Build Coastguard Worker              layer_config->filter_height == 2 &&
523*77c1e3ccSAndroid Build Coastguard Worker              layer_config->skip_width == 2 && layer_config->skip_height == 2) {
524*77c1e3ccSAndroid Build Coastguard Worker     cnn_convolve_no_maxpool_padding_valid_2x2_avx2(
525*77c1e3ccSAndroid Build Coastguard Worker         input, in_width, in_height, in_stride, layer_config, output, out_stride,
526*77c1e3ccSAndroid Build Coastguard Worker         start_idx, cstep, channel_step);
527*77c1e3ccSAndroid Build Coastguard Worker   } else {
528*77c1e3ccSAndroid Build Coastguard Worker     av1_cnn_convolve_no_maxpool_padding_valid_c(
529*77c1e3ccSAndroid Build Coastguard Worker         input, in_width, in_height, in_stride, layer_config, output, out_stride,
530*77c1e3ccSAndroid Build Coastguard Worker         start_idx, cstep, channel_step);
531*77c1e3ccSAndroid Build Coastguard Worker   }
532*77c1e3ccSAndroid Build Coastguard Worker }
533