xref: /aosp_15_r20/external/libaom/aom_dsp/x86/avg_intrin_avx2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 
14 #include "config/aom_dsp_rtcd.h"
15 #include "aom/aom_integer.h"
16 #include "aom_dsp/x86/bitdepth_conversion_avx2.h"
17 #include "aom_dsp/x86/synonyms_avx2.h"
18 #include "aom_ports/mem.h"
19 
sign_extend_16bit_to_32bit_avx2(__m256i in,__m256i zero,__m256i * out_lo,__m256i * out_hi)20 static inline void sign_extend_16bit_to_32bit_avx2(__m256i in, __m256i zero,
21                                                    __m256i *out_lo,
22                                                    __m256i *out_hi) {
23   const __m256i sign_bits = _mm256_cmpgt_epi16(zero, in);
24   *out_lo = _mm256_unpacklo_epi16(in, sign_bits);
25   *out_hi = _mm256_unpackhi_epi16(in, sign_bits);
26 }
27 
hadamard_col8x2_avx2(__m256i * in,int iter)28 static void hadamard_col8x2_avx2(__m256i *in, int iter) {
29   __m256i a0 = in[0];
30   __m256i a1 = in[1];
31   __m256i a2 = in[2];
32   __m256i a3 = in[3];
33   __m256i a4 = in[4];
34   __m256i a5 = in[5];
35   __m256i a6 = in[6];
36   __m256i a7 = in[7];
37 
38   __m256i b0 = _mm256_add_epi16(a0, a1);
39   __m256i b1 = _mm256_sub_epi16(a0, a1);
40   __m256i b2 = _mm256_add_epi16(a2, a3);
41   __m256i b3 = _mm256_sub_epi16(a2, a3);
42   __m256i b4 = _mm256_add_epi16(a4, a5);
43   __m256i b5 = _mm256_sub_epi16(a4, a5);
44   __m256i b6 = _mm256_add_epi16(a6, a7);
45   __m256i b7 = _mm256_sub_epi16(a6, a7);
46 
47   a0 = _mm256_add_epi16(b0, b2);
48   a1 = _mm256_add_epi16(b1, b3);
49   a2 = _mm256_sub_epi16(b0, b2);
50   a3 = _mm256_sub_epi16(b1, b3);
51   a4 = _mm256_add_epi16(b4, b6);
52   a5 = _mm256_add_epi16(b5, b7);
53   a6 = _mm256_sub_epi16(b4, b6);
54   a7 = _mm256_sub_epi16(b5, b7);
55 
56   if (iter == 0) {
57     b0 = _mm256_add_epi16(a0, a4);
58     b7 = _mm256_add_epi16(a1, a5);
59     b3 = _mm256_add_epi16(a2, a6);
60     b4 = _mm256_add_epi16(a3, a7);
61     b2 = _mm256_sub_epi16(a0, a4);
62     b6 = _mm256_sub_epi16(a1, a5);
63     b1 = _mm256_sub_epi16(a2, a6);
64     b5 = _mm256_sub_epi16(a3, a7);
65 
66     a0 = _mm256_unpacklo_epi16(b0, b1);
67     a1 = _mm256_unpacklo_epi16(b2, b3);
68     a2 = _mm256_unpackhi_epi16(b0, b1);
69     a3 = _mm256_unpackhi_epi16(b2, b3);
70     a4 = _mm256_unpacklo_epi16(b4, b5);
71     a5 = _mm256_unpacklo_epi16(b6, b7);
72     a6 = _mm256_unpackhi_epi16(b4, b5);
73     a7 = _mm256_unpackhi_epi16(b6, b7);
74 
75     b0 = _mm256_unpacklo_epi32(a0, a1);
76     b1 = _mm256_unpacklo_epi32(a4, a5);
77     b2 = _mm256_unpackhi_epi32(a0, a1);
78     b3 = _mm256_unpackhi_epi32(a4, a5);
79     b4 = _mm256_unpacklo_epi32(a2, a3);
80     b5 = _mm256_unpacklo_epi32(a6, a7);
81     b6 = _mm256_unpackhi_epi32(a2, a3);
82     b7 = _mm256_unpackhi_epi32(a6, a7);
83 
84     in[0] = _mm256_unpacklo_epi64(b0, b1);
85     in[1] = _mm256_unpackhi_epi64(b0, b1);
86     in[2] = _mm256_unpacklo_epi64(b2, b3);
87     in[3] = _mm256_unpackhi_epi64(b2, b3);
88     in[4] = _mm256_unpacklo_epi64(b4, b5);
89     in[5] = _mm256_unpackhi_epi64(b4, b5);
90     in[6] = _mm256_unpacklo_epi64(b6, b7);
91     in[7] = _mm256_unpackhi_epi64(b6, b7);
92   } else {
93     in[0] = _mm256_add_epi16(a0, a4);
94     in[7] = _mm256_add_epi16(a1, a5);
95     in[3] = _mm256_add_epi16(a2, a6);
96     in[4] = _mm256_add_epi16(a3, a7);
97     in[2] = _mm256_sub_epi16(a0, a4);
98     in[6] = _mm256_sub_epi16(a1, a5);
99     in[1] = _mm256_sub_epi16(a2, a6);
100     in[5] = _mm256_sub_epi16(a3, a7);
101   }
102 }
103 
aom_hadamard_lp_8x8_dual_avx2(const int16_t * src_diff,ptrdiff_t src_stride,int16_t * coeff)104 void aom_hadamard_lp_8x8_dual_avx2(const int16_t *src_diff,
105                                    ptrdiff_t src_stride, int16_t *coeff) {
106   __m256i src[8];
107   src[0] = _mm256_loadu_si256((const __m256i *)src_diff);
108   src[1] = _mm256_loadu_si256((const __m256i *)(src_diff += src_stride));
109   src[2] = _mm256_loadu_si256((const __m256i *)(src_diff += src_stride));
110   src[3] = _mm256_loadu_si256((const __m256i *)(src_diff += src_stride));
111   src[4] = _mm256_loadu_si256((const __m256i *)(src_diff += src_stride));
112   src[5] = _mm256_loadu_si256((const __m256i *)(src_diff += src_stride));
113   src[6] = _mm256_loadu_si256((const __m256i *)(src_diff += src_stride));
114   src[7] = _mm256_loadu_si256((const __m256i *)(src_diff + src_stride));
115 
116   hadamard_col8x2_avx2(src, 0);
117   hadamard_col8x2_avx2(src, 1);
118 
119   _mm256_storeu_si256((__m256i *)coeff,
120                       _mm256_permute2x128_si256(src[0], src[1], 0x20));
121   coeff += 16;
122   _mm256_storeu_si256((__m256i *)coeff,
123                       _mm256_permute2x128_si256(src[2], src[3], 0x20));
124   coeff += 16;
125   _mm256_storeu_si256((__m256i *)coeff,
126                       _mm256_permute2x128_si256(src[4], src[5], 0x20));
127   coeff += 16;
128   _mm256_storeu_si256((__m256i *)coeff,
129                       _mm256_permute2x128_si256(src[6], src[7], 0x20));
130   coeff += 16;
131   _mm256_storeu_si256((__m256i *)coeff,
132                       _mm256_permute2x128_si256(src[0], src[1], 0x31));
133   coeff += 16;
134   _mm256_storeu_si256((__m256i *)coeff,
135                       _mm256_permute2x128_si256(src[2], src[3], 0x31));
136   coeff += 16;
137   _mm256_storeu_si256((__m256i *)coeff,
138                       _mm256_permute2x128_si256(src[4], src[5], 0x31));
139   coeff += 16;
140   _mm256_storeu_si256((__m256i *)coeff,
141                       _mm256_permute2x128_si256(src[6], src[7], 0x31));
142 }
143 
hadamard_16x16_avx2(const int16_t * src_diff,ptrdiff_t src_stride,tran_low_t * coeff,int is_final)144 static inline void hadamard_16x16_avx2(const int16_t *src_diff,
145                                        ptrdiff_t src_stride, tran_low_t *coeff,
146                                        int is_final) {
147   DECLARE_ALIGNED(32, int16_t, temp_coeff[16 * 16]);
148   int16_t *t_coeff = temp_coeff;
149   int16_t *coeff16 = (int16_t *)coeff;
150   int idx;
151   for (idx = 0; idx < 2; ++idx) {
152     const int16_t *src_ptr = src_diff + idx * 8 * src_stride;
153     aom_hadamard_lp_8x8_dual_avx2(src_ptr, src_stride,
154                                   t_coeff + (idx * 64 * 2));
155   }
156 
157   for (idx = 0; idx < 64; idx += 16) {
158     const __m256i coeff0 = _mm256_loadu_si256((const __m256i *)t_coeff);
159     const __m256i coeff1 = _mm256_loadu_si256((const __m256i *)(t_coeff + 64));
160     const __m256i coeff2 = _mm256_loadu_si256((const __m256i *)(t_coeff + 128));
161     const __m256i coeff3 = _mm256_loadu_si256((const __m256i *)(t_coeff + 192));
162 
163     __m256i b0 = _mm256_add_epi16(coeff0, coeff1);
164     __m256i b1 = _mm256_sub_epi16(coeff0, coeff1);
165     __m256i b2 = _mm256_add_epi16(coeff2, coeff3);
166     __m256i b3 = _mm256_sub_epi16(coeff2, coeff3);
167 
168     b0 = _mm256_srai_epi16(b0, 1);
169     b1 = _mm256_srai_epi16(b1, 1);
170     b2 = _mm256_srai_epi16(b2, 1);
171     b3 = _mm256_srai_epi16(b3, 1);
172     if (is_final) {
173       store_tran_low(_mm256_add_epi16(b0, b2), coeff);
174       store_tran_low(_mm256_add_epi16(b1, b3), coeff + 64);
175       store_tran_low(_mm256_sub_epi16(b0, b2), coeff + 128);
176       store_tran_low(_mm256_sub_epi16(b1, b3), coeff + 192);
177       coeff += 16;
178     } else {
179       _mm256_storeu_si256((__m256i *)coeff16, _mm256_add_epi16(b0, b2));
180       _mm256_storeu_si256((__m256i *)(coeff16 + 64), _mm256_add_epi16(b1, b3));
181       _mm256_storeu_si256((__m256i *)(coeff16 + 128), _mm256_sub_epi16(b0, b2));
182       _mm256_storeu_si256((__m256i *)(coeff16 + 192), _mm256_sub_epi16(b1, b3));
183       coeff16 += 16;
184     }
185     t_coeff += 16;
186   }
187 }
188 
aom_hadamard_16x16_avx2(const int16_t * src_diff,ptrdiff_t src_stride,tran_low_t * coeff)189 void aom_hadamard_16x16_avx2(const int16_t *src_diff, ptrdiff_t src_stride,
190                              tran_low_t *coeff) {
191   hadamard_16x16_avx2(src_diff, src_stride, coeff, 1);
192 }
193 
aom_hadamard_lp_16x16_avx2(const int16_t * src_diff,ptrdiff_t src_stride,int16_t * coeff)194 void aom_hadamard_lp_16x16_avx2(const int16_t *src_diff, ptrdiff_t src_stride,
195                                 int16_t *coeff) {
196   int16_t *t_coeff = coeff;
197   for (int idx = 0; idx < 2; ++idx) {
198     const int16_t *src_ptr = src_diff + idx * 8 * src_stride;
199     aom_hadamard_lp_8x8_dual_avx2(src_ptr, src_stride,
200                                   t_coeff + (idx * 64 * 2));
201   }
202 
203   for (int idx = 0; idx < 64; idx += 16) {
204     const __m256i coeff0 = _mm256_loadu_si256((const __m256i *)t_coeff);
205     const __m256i coeff1 = _mm256_loadu_si256((const __m256i *)(t_coeff + 64));
206     const __m256i coeff2 = _mm256_loadu_si256((const __m256i *)(t_coeff + 128));
207     const __m256i coeff3 = _mm256_loadu_si256((const __m256i *)(t_coeff + 192));
208 
209     __m256i b0 = _mm256_add_epi16(coeff0, coeff1);
210     __m256i b1 = _mm256_sub_epi16(coeff0, coeff1);
211     __m256i b2 = _mm256_add_epi16(coeff2, coeff3);
212     __m256i b3 = _mm256_sub_epi16(coeff2, coeff3);
213 
214     b0 = _mm256_srai_epi16(b0, 1);
215     b1 = _mm256_srai_epi16(b1, 1);
216     b2 = _mm256_srai_epi16(b2, 1);
217     b3 = _mm256_srai_epi16(b3, 1);
218     _mm256_storeu_si256((__m256i *)coeff, _mm256_add_epi16(b0, b2));
219     _mm256_storeu_si256((__m256i *)(coeff + 64), _mm256_add_epi16(b1, b3));
220     _mm256_storeu_si256((__m256i *)(coeff + 128), _mm256_sub_epi16(b0, b2));
221     _mm256_storeu_si256((__m256i *)(coeff + 192), _mm256_sub_epi16(b1, b3));
222     coeff += 16;
223     t_coeff += 16;
224   }
225 }
226 
aom_hadamard_32x32_avx2(const int16_t * src_diff,ptrdiff_t src_stride,tran_low_t * coeff)227 void aom_hadamard_32x32_avx2(const int16_t *src_diff, ptrdiff_t src_stride,
228                              tran_low_t *coeff) {
229   // For high bitdepths, it is unnecessary to store_tran_low
230   // (mult/unpack/store), then load_tran_low (load/pack) the same memory in the
231   // next stage.  Output to an intermediate buffer first, then store_tran_low()
232   // in the final stage.
233   DECLARE_ALIGNED(32, int16_t, temp_coeff[32 * 32]);
234   int16_t *t_coeff = temp_coeff;
235   int idx;
236   __m256i coeff0_lo, coeff1_lo, coeff2_lo, coeff3_lo, b0_lo, b1_lo, b2_lo,
237       b3_lo;
238   __m256i coeff0_hi, coeff1_hi, coeff2_hi, coeff3_hi, b0_hi, b1_hi, b2_hi,
239       b3_hi;
240   __m256i b0, b1, b2, b3;
241   const __m256i zero = _mm256_setzero_si256();
242   for (idx = 0; idx < 4; ++idx) {
243     // src_diff: 9 bit, dynamic range [-255, 255]
244     const int16_t *src_ptr =
245         src_diff + (idx >> 1) * 16 * src_stride + (idx & 0x01) * 16;
246     hadamard_16x16_avx2(src_ptr, src_stride,
247                         (tran_low_t *)(t_coeff + idx * 256), 0);
248   }
249 
250   for (idx = 0; idx < 256; idx += 16) {
251     const __m256i coeff0 = _mm256_loadu_si256((const __m256i *)t_coeff);
252     const __m256i coeff1 = _mm256_loadu_si256((const __m256i *)(t_coeff + 256));
253     const __m256i coeff2 = _mm256_loadu_si256((const __m256i *)(t_coeff + 512));
254     const __m256i coeff3 = _mm256_loadu_si256((const __m256i *)(t_coeff + 768));
255 
256     // Sign extend 16 bit to 32 bit.
257     sign_extend_16bit_to_32bit_avx2(coeff0, zero, &coeff0_lo, &coeff0_hi);
258     sign_extend_16bit_to_32bit_avx2(coeff1, zero, &coeff1_lo, &coeff1_hi);
259     sign_extend_16bit_to_32bit_avx2(coeff2, zero, &coeff2_lo, &coeff2_hi);
260     sign_extend_16bit_to_32bit_avx2(coeff3, zero, &coeff3_lo, &coeff3_hi);
261 
262     b0_lo = _mm256_add_epi32(coeff0_lo, coeff1_lo);
263     b0_hi = _mm256_add_epi32(coeff0_hi, coeff1_hi);
264 
265     b1_lo = _mm256_sub_epi32(coeff0_lo, coeff1_lo);
266     b1_hi = _mm256_sub_epi32(coeff0_hi, coeff1_hi);
267 
268     b2_lo = _mm256_add_epi32(coeff2_lo, coeff3_lo);
269     b2_hi = _mm256_add_epi32(coeff2_hi, coeff3_hi);
270 
271     b3_lo = _mm256_sub_epi32(coeff2_lo, coeff3_lo);
272     b3_hi = _mm256_sub_epi32(coeff2_hi, coeff3_hi);
273 
274     b0_lo = _mm256_srai_epi32(b0_lo, 2);
275     b1_lo = _mm256_srai_epi32(b1_lo, 2);
276     b2_lo = _mm256_srai_epi32(b2_lo, 2);
277     b3_lo = _mm256_srai_epi32(b3_lo, 2);
278 
279     b0_hi = _mm256_srai_epi32(b0_hi, 2);
280     b1_hi = _mm256_srai_epi32(b1_hi, 2);
281     b2_hi = _mm256_srai_epi32(b2_hi, 2);
282     b3_hi = _mm256_srai_epi32(b3_hi, 2);
283 
284     b0 = _mm256_packs_epi32(b0_lo, b0_hi);
285     b1 = _mm256_packs_epi32(b1_lo, b1_hi);
286     b2 = _mm256_packs_epi32(b2_lo, b2_hi);
287     b3 = _mm256_packs_epi32(b3_lo, b3_hi);
288 
289     store_tran_low(_mm256_add_epi16(b0, b2), coeff);
290     store_tran_low(_mm256_add_epi16(b1, b3), coeff + 256);
291     store_tran_low(_mm256_sub_epi16(b0, b2), coeff + 512);
292     store_tran_low(_mm256_sub_epi16(b1, b3), coeff + 768);
293 
294     coeff += 16;
295     t_coeff += 16;
296   }
297 }
298 
299 #if CONFIG_AV1_HIGHBITDEPTH
highbd_hadamard_col8_avx2(__m256i * in,int iter)300 static void highbd_hadamard_col8_avx2(__m256i *in, int iter) {
301   __m256i a0 = in[0];
302   __m256i a1 = in[1];
303   __m256i a2 = in[2];
304   __m256i a3 = in[3];
305   __m256i a4 = in[4];
306   __m256i a5 = in[5];
307   __m256i a6 = in[6];
308   __m256i a7 = in[7];
309 
310   __m256i b0 = _mm256_add_epi32(a0, a1);
311   __m256i b1 = _mm256_sub_epi32(a0, a1);
312   __m256i b2 = _mm256_add_epi32(a2, a3);
313   __m256i b3 = _mm256_sub_epi32(a2, a3);
314   __m256i b4 = _mm256_add_epi32(a4, a5);
315   __m256i b5 = _mm256_sub_epi32(a4, a5);
316   __m256i b6 = _mm256_add_epi32(a6, a7);
317   __m256i b7 = _mm256_sub_epi32(a6, a7);
318 
319   a0 = _mm256_add_epi32(b0, b2);
320   a1 = _mm256_add_epi32(b1, b3);
321   a2 = _mm256_sub_epi32(b0, b2);
322   a3 = _mm256_sub_epi32(b1, b3);
323   a4 = _mm256_add_epi32(b4, b6);
324   a5 = _mm256_add_epi32(b5, b7);
325   a6 = _mm256_sub_epi32(b4, b6);
326   a7 = _mm256_sub_epi32(b5, b7);
327 
328   if (iter == 0) {
329     b0 = _mm256_add_epi32(a0, a4);
330     b7 = _mm256_add_epi32(a1, a5);
331     b3 = _mm256_add_epi32(a2, a6);
332     b4 = _mm256_add_epi32(a3, a7);
333     b2 = _mm256_sub_epi32(a0, a4);
334     b6 = _mm256_sub_epi32(a1, a5);
335     b1 = _mm256_sub_epi32(a2, a6);
336     b5 = _mm256_sub_epi32(a3, a7);
337 
338     a0 = _mm256_unpacklo_epi32(b0, b1);
339     a1 = _mm256_unpacklo_epi32(b2, b3);
340     a2 = _mm256_unpackhi_epi32(b0, b1);
341     a3 = _mm256_unpackhi_epi32(b2, b3);
342     a4 = _mm256_unpacklo_epi32(b4, b5);
343     a5 = _mm256_unpacklo_epi32(b6, b7);
344     a6 = _mm256_unpackhi_epi32(b4, b5);
345     a7 = _mm256_unpackhi_epi32(b6, b7);
346 
347     b0 = _mm256_unpacklo_epi64(a0, a1);
348     b1 = _mm256_unpacklo_epi64(a4, a5);
349     b2 = _mm256_unpackhi_epi64(a0, a1);
350     b3 = _mm256_unpackhi_epi64(a4, a5);
351     b4 = _mm256_unpacklo_epi64(a2, a3);
352     b5 = _mm256_unpacklo_epi64(a6, a7);
353     b6 = _mm256_unpackhi_epi64(a2, a3);
354     b7 = _mm256_unpackhi_epi64(a6, a7);
355 
356     in[0] = _mm256_permute2x128_si256(b0, b1, 0x20);
357     in[1] = _mm256_permute2x128_si256(b0, b1, 0x31);
358     in[2] = _mm256_permute2x128_si256(b2, b3, 0x20);
359     in[3] = _mm256_permute2x128_si256(b2, b3, 0x31);
360     in[4] = _mm256_permute2x128_si256(b4, b5, 0x20);
361     in[5] = _mm256_permute2x128_si256(b4, b5, 0x31);
362     in[6] = _mm256_permute2x128_si256(b6, b7, 0x20);
363     in[7] = _mm256_permute2x128_si256(b6, b7, 0x31);
364   } else {
365     in[0] = _mm256_add_epi32(a0, a4);
366     in[7] = _mm256_add_epi32(a1, a5);
367     in[3] = _mm256_add_epi32(a2, a6);
368     in[4] = _mm256_add_epi32(a3, a7);
369     in[2] = _mm256_sub_epi32(a0, a4);
370     in[6] = _mm256_sub_epi32(a1, a5);
371     in[1] = _mm256_sub_epi32(a2, a6);
372     in[5] = _mm256_sub_epi32(a3, a7);
373   }
374 }
375 
aom_highbd_hadamard_8x8_avx2(const int16_t * src_diff,ptrdiff_t src_stride,tran_low_t * coeff)376 void aom_highbd_hadamard_8x8_avx2(const int16_t *src_diff, ptrdiff_t src_stride,
377                                   tran_low_t *coeff) {
378   __m128i src16[8];
379   __m256i src32[8];
380 
381   src16[0] = _mm_loadu_si128((const __m128i *)src_diff);
382   src16[1] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
383   src16[2] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
384   src16[3] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
385   src16[4] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
386   src16[5] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
387   src16[6] = _mm_loadu_si128((const __m128i *)(src_diff += src_stride));
388   src16[7] = _mm_loadu_si128((const __m128i *)(src_diff + src_stride));
389 
390   src32[0] = _mm256_cvtepi16_epi32(src16[0]);
391   src32[1] = _mm256_cvtepi16_epi32(src16[1]);
392   src32[2] = _mm256_cvtepi16_epi32(src16[2]);
393   src32[3] = _mm256_cvtepi16_epi32(src16[3]);
394   src32[4] = _mm256_cvtepi16_epi32(src16[4]);
395   src32[5] = _mm256_cvtepi16_epi32(src16[5]);
396   src32[6] = _mm256_cvtepi16_epi32(src16[6]);
397   src32[7] = _mm256_cvtepi16_epi32(src16[7]);
398 
399   highbd_hadamard_col8_avx2(src32, 0);
400   highbd_hadamard_col8_avx2(src32, 1);
401 
402   _mm256_storeu_si256((__m256i *)coeff, src32[0]);
403   coeff += 8;
404   _mm256_storeu_si256((__m256i *)coeff, src32[1]);
405   coeff += 8;
406   _mm256_storeu_si256((__m256i *)coeff, src32[2]);
407   coeff += 8;
408   _mm256_storeu_si256((__m256i *)coeff, src32[3]);
409   coeff += 8;
410   _mm256_storeu_si256((__m256i *)coeff, src32[4]);
411   coeff += 8;
412   _mm256_storeu_si256((__m256i *)coeff, src32[5]);
413   coeff += 8;
414   _mm256_storeu_si256((__m256i *)coeff, src32[6]);
415   coeff += 8;
416   _mm256_storeu_si256((__m256i *)coeff, src32[7]);
417 }
418 
aom_highbd_hadamard_16x16_avx2(const int16_t * src_diff,ptrdiff_t src_stride,tran_low_t * coeff)419 void aom_highbd_hadamard_16x16_avx2(const int16_t *src_diff,
420                                     ptrdiff_t src_stride, tran_low_t *coeff) {
421   int idx;
422   tran_low_t *t_coeff = coeff;
423   for (idx = 0; idx < 4; ++idx) {
424     const int16_t *src_ptr =
425         src_diff + (idx >> 1) * 8 * src_stride + (idx & 0x01) * 8;
426     aom_highbd_hadamard_8x8_avx2(src_ptr, src_stride, t_coeff + idx * 64);
427   }
428 
429   for (idx = 0; idx < 64; idx += 8) {
430     __m256i coeff0 = _mm256_loadu_si256((const __m256i *)t_coeff);
431     __m256i coeff1 = _mm256_loadu_si256((const __m256i *)(t_coeff + 64));
432     __m256i coeff2 = _mm256_loadu_si256((const __m256i *)(t_coeff + 128));
433     __m256i coeff3 = _mm256_loadu_si256((const __m256i *)(t_coeff + 192));
434 
435     __m256i b0 = _mm256_add_epi32(coeff0, coeff1);
436     __m256i b1 = _mm256_sub_epi32(coeff0, coeff1);
437     __m256i b2 = _mm256_add_epi32(coeff2, coeff3);
438     __m256i b3 = _mm256_sub_epi32(coeff2, coeff3);
439 
440     b0 = _mm256_srai_epi32(b0, 1);
441     b1 = _mm256_srai_epi32(b1, 1);
442     b2 = _mm256_srai_epi32(b2, 1);
443     b3 = _mm256_srai_epi32(b3, 1);
444 
445     coeff0 = _mm256_add_epi32(b0, b2);
446     coeff1 = _mm256_add_epi32(b1, b3);
447     coeff2 = _mm256_sub_epi32(b0, b2);
448     coeff3 = _mm256_sub_epi32(b1, b3);
449 
450     _mm256_storeu_si256((__m256i *)coeff, coeff0);
451     _mm256_storeu_si256((__m256i *)(coeff + 64), coeff1);
452     _mm256_storeu_si256((__m256i *)(coeff + 128), coeff2);
453     _mm256_storeu_si256((__m256i *)(coeff + 192), coeff3);
454 
455     coeff += 8;
456     t_coeff += 8;
457   }
458 }
459 
aom_highbd_hadamard_32x32_avx2(const int16_t * src_diff,ptrdiff_t src_stride,tran_low_t * coeff)460 void aom_highbd_hadamard_32x32_avx2(const int16_t *src_diff,
461                                     ptrdiff_t src_stride, tran_low_t *coeff) {
462   int idx;
463   tran_low_t *t_coeff = coeff;
464   for (idx = 0; idx < 4; ++idx) {
465     const int16_t *src_ptr =
466         src_diff + (idx >> 1) * 16 * src_stride + (idx & 0x01) * 16;
467     aom_highbd_hadamard_16x16_avx2(src_ptr, src_stride, t_coeff + idx * 256);
468   }
469 
470   for (idx = 0; idx < 256; idx += 8) {
471     __m256i coeff0 = _mm256_loadu_si256((const __m256i *)t_coeff);
472     __m256i coeff1 = _mm256_loadu_si256((const __m256i *)(t_coeff + 256));
473     __m256i coeff2 = _mm256_loadu_si256((const __m256i *)(t_coeff + 512));
474     __m256i coeff3 = _mm256_loadu_si256((const __m256i *)(t_coeff + 768));
475 
476     __m256i b0 = _mm256_add_epi32(coeff0, coeff1);
477     __m256i b1 = _mm256_sub_epi32(coeff0, coeff1);
478     __m256i b2 = _mm256_add_epi32(coeff2, coeff3);
479     __m256i b3 = _mm256_sub_epi32(coeff2, coeff3);
480 
481     b0 = _mm256_srai_epi32(b0, 2);
482     b1 = _mm256_srai_epi32(b1, 2);
483     b2 = _mm256_srai_epi32(b2, 2);
484     b3 = _mm256_srai_epi32(b3, 2);
485 
486     coeff0 = _mm256_add_epi32(b0, b2);
487     coeff1 = _mm256_add_epi32(b1, b3);
488     coeff2 = _mm256_sub_epi32(b0, b2);
489     coeff3 = _mm256_sub_epi32(b1, b3);
490 
491     _mm256_storeu_si256((__m256i *)coeff, coeff0);
492     _mm256_storeu_si256((__m256i *)(coeff + 256), coeff1);
493     _mm256_storeu_si256((__m256i *)(coeff + 512), coeff2);
494     _mm256_storeu_si256((__m256i *)(coeff + 768), coeff3);
495 
496     coeff += 8;
497     t_coeff += 8;
498   }
499 }
500 #endif  // CONFIG_AV1_HIGHBITDEPTH
501 
aom_satd_avx2(const tran_low_t * coeff,int length)502 int aom_satd_avx2(const tran_low_t *coeff, int length) {
503   __m256i accum = _mm256_setzero_si256();
504   int i;
505 
506   for (i = 0; i < length; i += 8, coeff += 8) {
507     const __m256i src_line = _mm256_loadu_si256((const __m256i *)coeff);
508     const __m256i abs = _mm256_abs_epi32(src_line);
509     accum = _mm256_add_epi32(accum, abs);
510   }
511 
512   {  // 32 bit horizontal add
513     const __m256i a = _mm256_srli_si256(accum, 8);
514     const __m256i b = _mm256_add_epi32(accum, a);
515     const __m256i c = _mm256_srli_epi64(b, 32);
516     const __m256i d = _mm256_add_epi32(b, c);
517     const __m128i accum_128 = _mm_add_epi32(_mm256_castsi256_si128(d),
518                                             _mm256_extractf128_si256(d, 1));
519     return _mm_cvtsi128_si32(accum_128);
520   }
521 }
522 
aom_satd_lp_avx2(const int16_t * coeff,int length)523 int aom_satd_lp_avx2(const int16_t *coeff, int length) {
524   const __m256i one = _mm256_set1_epi16(1);
525   __m256i accum = _mm256_setzero_si256();
526 
527   for (int i = 0; i < length; i += 16) {
528     const __m256i src_line = _mm256_loadu_si256((const __m256i *)coeff);
529     const __m256i abs = _mm256_abs_epi16(src_line);
530     const __m256i sum = _mm256_madd_epi16(abs, one);
531     accum = _mm256_add_epi32(accum, sum);
532     coeff += 16;
533   }
534 
535   {  // 32 bit horizontal add
536     const __m256i a = _mm256_srli_si256(accum, 8);
537     const __m256i b = _mm256_add_epi32(accum, a);
538     const __m256i c = _mm256_srli_epi64(b, 32);
539     const __m256i d = _mm256_add_epi32(b, c);
540     const __m128i accum_128 = _mm_add_epi32(_mm256_castsi256_si128(d),
541                                             _mm256_extractf128_si256(d, 1));
542     return _mm_cvtsi128_si32(accum_128);
543   }
544 }
545 
aom_avg_8x8_quad_avx2(const uint8_t * s,int p,int x16_idx,int y16_idx,int * avg)546 void aom_avg_8x8_quad_avx2(const uint8_t *s, int p, int x16_idx, int y16_idx,
547                            int *avg) {
548   const uint8_t *s_y0 = s + y16_idx * p + x16_idx;
549   const uint8_t *s_y1 = s_y0 + 8 * p;
550   __m256i sum0, sum1, s0, s1, s2, s3, u0;
551   u0 = _mm256_setzero_si256();
552   s0 = _mm256_sad_epu8(yy_loadu2_128(s_y1, s_y0), u0);
553   s1 = _mm256_sad_epu8(yy_loadu2_128(s_y1 + p, s_y0 + p), u0);
554   s2 = _mm256_sad_epu8(yy_loadu2_128(s_y1 + 2 * p, s_y0 + 2 * p), u0);
555   s3 = _mm256_sad_epu8(yy_loadu2_128(s_y1 + 3 * p, s_y0 + 3 * p), u0);
556   sum0 = _mm256_add_epi16(s0, s1);
557   sum1 = _mm256_add_epi16(s2, s3);
558   s0 = _mm256_sad_epu8(yy_loadu2_128(s_y1 + 4 * p, s_y0 + 4 * p), u0);
559   s1 = _mm256_sad_epu8(yy_loadu2_128(s_y1 + 5 * p, s_y0 + 5 * p), u0);
560   s2 = _mm256_sad_epu8(yy_loadu2_128(s_y1 + 6 * p, s_y0 + 6 * p), u0);
561   s3 = _mm256_sad_epu8(yy_loadu2_128(s_y1 + 7 * p, s_y0 + 7 * p), u0);
562   sum0 = _mm256_add_epi16(sum0, _mm256_add_epi16(s0, s1));
563   sum1 = _mm256_add_epi16(sum1, _mm256_add_epi16(s2, s3));
564   sum0 = _mm256_add_epi16(sum0, sum1);
565 
566   // (avg + 32) >> 6
567   __m256i rounding = _mm256_set1_epi32(32);
568   sum0 = _mm256_add_epi32(sum0, rounding);
569   sum0 = _mm256_srli_epi32(sum0, 6);
570   __m128i lo = _mm256_castsi256_si128(sum0);
571   __m128i hi = _mm256_extracti128_si256(sum0, 1);
572   avg[0] = _mm_cvtsi128_si32(lo);
573   avg[1] = _mm_extract_epi32(lo, 2);
574   avg[2] = _mm_cvtsi128_si32(hi);
575   avg[3] = _mm_extract_epi32(hi, 2);
576 }
577 
aom_int_pro_row_avx2(int16_t * hbuf,const uint8_t * ref,const int ref_stride,const int width,const int height,int norm_factor)578 void aom_int_pro_row_avx2(int16_t *hbuf, const uint8_t *ref,
579                           const int ref_stride, const int width,
580                           const int height, int norm_factor) {
581   // SIMD implementation assumes width and height to be multiple of 16 and 2
582   // respectively. For any odd width or height, SIMD support needs to be added.
583   assert(width % 16 == 0 && height % 2 == 0);
584 
585   if (width % 32 == 0) {
586     const __m256i zero = _mm256_setzero_si256();
587     for (int wd = 0; wd < width; wd += 32) {
588       const uint8_t *ref_tmp = ref + wd;
589       int16_t *hbuf_tmp = hbuf + wd;
590       __m256i s0 = zero;
591       __m256i s1 = zero;
592       int idx = 0;
593       do {
594         __m256i src_line = _mm256_loadu_si256((const __m256i *)ref_tmp);
595         __m256i t0 = _mm256_unpacklo_epi8(src_line, zero);
596         __m256i t1 = _mm256_unpackhi_epi8(src_line, zero);
597         s0 = _mm256_add_epi16(s0, t0);
598         s1 = _mm256_add_epi16(s1, t1);
599         ref_tmp += ref_stride;
600 
601         src_line = _mm256_loadu_si256((const __m256i *)ref_tmp);
602         t0 = _mm256_unpacklo_epi8(src_line, zero);
603         t1 = _mm256_unpackhi_epi8(src_line, zero);
604         s0 = _mm256_add_epi16(s0, t0);
605         s1 = _mm256_add_epi16(s1, t1);
606         ref_tmp += ref_stride;
607         idx += 2;
608       } while (idx < height);
609       s0 = _mm256_srai_epi16(s0, norm_factor);
610       s1 = _mm256_srai_epi16(s1, norm_factor);
611       _mm_storeu_si128((__m128i *)(hbuf_tmp), _mm256_castsi256_si128(s0));
612       _mm_storeu_si128((__m128i *)(hbuf_tmp + 8), _mm256_castsi256_si128(s1));
613       _mm_storeu_si128((__m128i *)(hbuf_tmp + 16),
614                        _mm256_extractf128_si256(s0, 1));
615       _mm_storeu_si128((__m128i *)(hbuf_tmp + 24),
616                        _mm256_extractf128_si256(s1, 1));
617     }
618   } else if (width % 16 == 0) {
619     aom_int_pro_row_sse2(hbuf, ref, ref_stride, width, height, norm_factor);
620   }
621 }
622 
load_from_src_buf(const uint8_t * ref1,__m256i * src,const int stride)623 static inline void load_from_src_buf(const uint8_t *ref1, __m256i *src,
624                                      const int stride) {
625   src[0] = _mm256_loadu_si256((const __m256i *)ref1);
626   src[1] = _mm256_loadu_si256((const __m256i *)(ref1 + stride));
627   src[2] = _mm256_loadu_si256((const __m256i *)(ref1 + (2 * stride)));
628   src[3] = _mm256_loadu_si256((const __m256i *)(ref1 + (3 * stride)));
629 }
630 
631 #define CALC_TOT_SAD_AND_STORE                                                \
632   /* r00 r10 x x r01 r11 x x | r02 r12 x x r03 r13 x x */                     \
633   const __m256i r01 = _mm256_add_epi16(_mm256_slli_si256(r1, 2), r0);         \
634   /* r00 r10 r20 x r01 r11 r21 x | r02 r12 r22 x r03 r13 r23 x */             \
635   const __m256i r012 = _mm256_add_epi16(_mm256_slli_si256(r2, 4), r01);       \
636   /* r00 r10 r20 r30 r01 r11 r21 r31 | r02 r12 r22 r32 r03 r13 r23 r33 */     \
637   const __m256i result0 = _mm256_add_epi16(_mm256_slli_si256(r3, 6), r012);   \
638                                                                               \
639   const __m128i results0 = _mm_add_epi16(                                     \
640       _mm256_castsi256_si128(result0), _mm256_extractf128_si256(result0, 1)); \
641   const __m128i results1 =                                                    \
642       _mm_add_epi16(results0, _mm_srli_si128(results0, 8));                   \
643   _mm_storel_epi64((__m128i *)vbuf, _mm_srli_epi16(results1, norm_factor));
644 
aom_int_pro_col_16wd_avx2(int16_t * vbuf,const uint8_t * ref,const int ref_stride,const int height,int norm_factor)645 static inline void aom_int_pro_col_16wd_avx2(int16_t *vbuf, const uint8_t *ref,
646                                              const int ref_stride,
647                                              const int height,
648                                              int norm_factor) {
649   const __m256i zero = _mm256_setzero_si256();
650   int ht = 0;
651   // Post sad operation, the data is present in lower 16-bit of each 64-bit lane
652   // and higher 16-bits are Zero. Here, we are processing 8 rows at a time to
653   // utilize the higher 16-bits efficiently.
654   do {
655     __m256i src_00 =
656         _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(ref)));
657     src_00 = _mm256_inserti128_si256(
658         src_00, _mm_loadu_si128((const __m128i *)(ref + ref_stride * 4)), 1);
659     __m256i src_01 = _mm256_castsi128_si256(
660         _mm_loadu_si128((const __m128i *)(ref + ref_stride * 1)));
661     src_01 = _mm256_inserti128_si256(
662         src_01, _mm_loadu_si128((const __m128i *)(ref + ref_stride * 5)), 1);
663     __m256i src_10 = _mm256_castsi128_si256(
664         _mm_loadu_si128((const __m128i *)(ref + ref_stride * 2)));
665     src_10 = _mm256_inserti128_si256(
666         src_10, _mm_loadu_si128((const __m128i *)(ref + ref_stride * 6)), 1);
667     __m256i src_11 = _mm256_castsi128_si256(
668         _mm_loadu_si128((const __m128i *)(ref + ref_stride * 3)));
669     src_11 = _mm256_inserti128_si256(
670         src_11, _mm_loadu_si128((const __m128i *)(ref + ref_stride * 7)), 1);
671 
672     // s00 x x x s01 x x x | s40 x x x s41 x x x
673     const __m256i s0 = _mm256_sad_epu8(src_00, zero);
674     // s10 x x x s11 x x x | s50 x x x s51 x x x
675     const __m256i s1 = _mm256_sad_epu8(src_01, zero);
676     // s20 x x x s21 x x x | s60 x x x s61 x x x
677     const __m256i s2 = _mm256_sad_epu8(src_10, zero);
678     // s30 x x x s31 x x x | s70 x x x s71 x x x
679     const __m256i s3 = _mm256_sad_epu8(src_11, zero);
680 
681     // s00 s10 x x x x x x | s40 s50 x x x x x x
682     const __m256i s0_lo = _mm256_unpacklo_epi16(s0, s1);
683     // s01 s11 x x x x x x | s41 s51 x x x x x x
684     const __m256i s0_hi = _mm256_unpackhi_epi16(s0, s1);
685     // s20 s30 x x x x x x | s60 s70 x x x x x x
686     const __m256i s1_lo = _mm256_unpacklo_epi16(s2, s3);
687     // s21 s31 x x x x x x | s61 s71 x x x x x x
688     const __m256i s1_hi = _mm256_unpackhi_epi16(s2, s3);
689 
690     // s0 s1 x x x x x x | s4 s5 x x x x x x
691     const __m256i s0_add = _mm256_add_epi16(s0_lo, s0_hi);
692     // s2 s3 x x x x x x | s6 s7 x x x x x x
693     const __m256i s1_add = _mm256_add_epi16(s1_lo, s1_hi);
694 
695     // s1 s1 s2 s3 s4 s5 s6 s7
696     const __m128i results = _mm256_castsi256_si128(
697         _mm256_permute4x64_epi64(_mm256_unpacklo_epi32(s0_add, s1_add), 0x08));
698     _mm_storeu_si128((__m128i *)vbuf, _mm_srli_epi16(results, norm_factor));
699     vbuf += 8;
700     ref += (ref_stride << 3);
701     ht += 8;
702   } while (ht < height);
703 }
704 
aom_int_pro_col_avx2(int16_t * vbuf,const uint8_t * ref,const int ref_stride,const int width,const int height,int norm_factor)705 void aom_int_pro_col_avx2(int16_t *vbuf, const uint8_t *ref,
706                           const int ref_stride, const int width,
707                           const int height, int norm_factor) {
708   assert(width % 16 == 0);
709   if (width == 128) {
710     const __m256i zero = _mm256_setzero_si256();
711     for (int ht = 0; ht < height; ht += 4) {
712       __m256i src[16];
713       // Load source data.
714       load_from_src_buf(ref, &src[0], ref_stride);
715       load_from_src_buf(ref + 32, &src[4], ref_stride);
716       load_from_src_buf(ref + 64, &src[8], ref_stride);
717       load_from_src_buf(ref + 96, &src[12], ref_stride);
718 
719       // Row0 output: r00 x x x r01 x x x | r02 x x x r03 x x x
720       const __m256i s0 = _mm256_add_epi16(_mm256_sad_epu8(src[0], zero),
721                                           _mm256_sad_epu8(src[4], zero));
722       const __m256i s1 = _mm256_add_epi16(_mm256_sad_epu8(src[8], zero),
723                                           _mm256_sad_epu8(src[12], zero));
724       const __m256i r0 = _mm256_add_epi16(s0, s1);
725       // Row1 output: r10 x x x r11 x x x | r12 x x x r13 x x x
726       const __m256i s2 = _mm256_add_epi16(_mm256_sad_epu8(src[1], zero),
727                                           _mm256_sad_epu8(src[5], zero));
728       const __m256i s3 = _mm256_add_epi16(_mm256_sad_epu8(src[9], zero),
729                                           _mm256_sad_epu8(src[13], zero));
730       const __m256i r1 = _mm256_add_epi16(s2, s3);
731       // Row2 output: r20 x x x r21 x x x | r22 x x x r23 x x x
732       const __m256i s4 = _mm256_add_epi16(_mm256_sad_epu8(src[2], zero),
733                                           _mm256_sad_epu8(src[6], zero));
734       const __m256i s5 = _mm256_add_epi16(_mm256_sad_epu8(src[10], zero),
735                                           _mm256_sad_epu8(src[14], zero));
736       const __m256i r2 = _mm256_add_epi16(s4, s5);
737       // Row3 output: r30 x x x r31 x x x | r32 x x x r33 x x x
738       const __m256i s6 = _mm256_add_epi16(_mm256_sad_epu8(src[3], zero),
739                                           _mm256_sad_epu8(src[7], zero));
740       const __m256i s7 = _mm256_add_epi16(_mm256_sad_epu8(src[11], zero),
741                                           _mm256_sad_epu8(src[15], zero));
742       const __m256i r3 = _mm256_add_epi16(s6, s7);
743 
744       CALC_TOT_SAD_AND_STORE
745       vbuf += 4;
746       ref += ref_stride << 2;
747     }
748   } else if (width == 64) {
749     const __m256i zero = _mm256_setzero_si256();
750     for (int ht = 0; ht < height; ht += 4) {
751       __m256i src[8];
752       // Load source data.
753       load_from_src_buf(ref, &src[0], ref_stride);
754       load_from_src_buf(ref + 32, &src[4], ref_stride);
755 
756       // Row0 output: r00 x x x r01 x x x | r02 x x x r03 x x x
757       const __m256i s0 = _mm256_sad_epu8(src[0], zero);
758       const __m256i s1 = _mm256_sad_epu8(src[4], zero);
759       const __m256i r0 = _mm256_add_epi16(s0, s1);
760       // Row1 output: r10 x x x r11 x x x | r12 x x x r13 x x x
761       const __m256i s2 = _mm256_sad_epu8(src[1], zero);
762       const __m256i s3 = _mm256_sad_epu8(src[5], zero);
763       const __m256i r1 = _mm256_add_epi16(s2, s3);
764       // Row2 output: r20 x x x r21 x x x | r22 x x x r23 x x x
765       const __m256i s4 = _mm256_sad_epu8(src[2], zero);
766       const __m256i s5 = _mm256_sad_epu8(src[6], zero);
767       const __m256i r2 = _mm256_add_epi16(s4, s5);
768       // Row3 output: r30 x x x r31 x x x | r32 x x x r33 x x x
769       const __m256i s6 = _mm256_sad_epu8(src[3], zero);
770       const __m256i s7 = _mm256_sad_epu8(src[7], zero);
771       const __m256i r3 = _mm256_add_epi16(s6, s7);
772 
773       CALC_TOT_SAD_AND_STORE
774       vbuf += 4;
775       ref += ref_stride << 2;
776     }
777   } else if (width == 32) {
778     assert(height % 2 == 0);
779     const __m256i zero = _mm256_setzero_si256();
780     for (int ht = 0; ht < height; ht += 4) {
781       __m256i src[4];
782       // Load source data.
783       load_from_src_buf(ref, &src[0], ref_stride);
784 
785       // s00 x x x s01 x x x s02 x x x s03 x x x
786       const __m256i r0 = _mm256_sad_epu8(src[0], zero);
787       // s10 x x x s11 x x x s12 x x x s13 x x x
788       const __m256i r1 = _mm256_sad_epu8(src[1], zero);
789       // s20 x x x s21 x x x s22 x x x s23 x x x
790       const __m256i r2 = _mm256_sad_epu8(src[2], zero);
791       // s30 x x x s31 x x x s32 x x x s33 x x x
792       const __m256i r3 = _mm256_sad_epu8(src[3], zero);
793 
794       CALC_TOT_SAD_AND_STORE
795       vbuf += 4;
796       ref += ref_stride << 2;
797     }
798   } else if (width == 16) {
799     aom_int_pro_col_16wd_avx2(vbuf, ref, ref_stride, height, norm_factor);
800   }
801 }
802 
calc_vector_mean_sse_64wd(const int16_t * ref,const int16_t * src,__m256i * mean,__m256i * sse)803 static inline void calc_vector_mean_sse_64wd(const int16_t *ref,
804                                              const int16_t *src, __m256i *mean,
805                                              __m256i *sse) {
806   const __m256i src_line0 = _mm256_loadu_si256((const __m256i *)src);
807   const __m256i src_line1 = _mm256_loadu_si256((const __m256i *)(src + 16));
808   const __m256i src_line2 = _mm256_loadu_si256((const __m256i *)(src + 32));
809   const __m256i src_line3 = _mm256_loadu_si256((const __m256i *)(src + 48));
810   const __m256i ref_line0 = _mm256_loadu_si256((const __m256i *)ref);
811   const __m256i ref_line1 = _mm256_loadu_si256((const __m256i *)(ref + 16));
812   const __m256i ref_line2 = _mm256_loadu_si256((const __m256i *)(ref + 32));
813   const __m256i ref_line3 = _mm256_loadu_si256((const __m256i *)(ref + 48));
814 
815   const __m256i diff0 = _mm256_sub_epi16(ref_line0, src_line0);
816   const __m256i diff1 = _mm256_sub_epi16(ref_line1, src_line1);
817   const __m256i diff2 = _mm256_sub_epi16(ref_line2, src_line2);
818   const __m256i diff3 = _mm256_sub_epi16(ref_line3, src_line3);
819   const __m256i diff_sqr0 = _mm256_madd_epi16(diff0, diff0);
820   const __m256i diff_sqr1 = _mm256_madd_epi16(diff1, diff1);
821   const __m256i diff_sqr2 = _mm256_madd_epi16(diff2, diff2);
822   const __m256i diff_sqr3 = _mm256_madd_epi16(diff3, diff3);
823 
824   *mean = _mm256_add_epi16(*mean, _mm256_add_epi16(diff0, diff1));
825   *mean = _mm256_add_epi16(*mean, diff2);
826   *mean = _mm256_add_epi16(*mean, diff3);
827   *sse = _mm256_add_epi32(*sse, _mm256_add_epi32(diff_sqr0, diff_sqr1));
828   *sse = _mm256_add_epi32(*sse, diff_sqr2);
829   *sse = _mm256_add_epi32(*sse, diff_sqr3);
830 }
831 
832 #define CALC_VAR_FROM_MEAN_SSE(mean, sse)                                    \
833   {                                                                          \
834     mean = _mm256_madd_epi16(mean, _mm256_set1_epi16(1));                    \
835     mean = _mm256_hadd_epi32(mean, sse);                                     \
836     mean = _mm256_add_epi32(mean, _mm256_bsrli_epi128(mean, 4));             \
837     const __m128i result = _mm_add_epi32(_mm256_castsi256_si128(mean),       \
838                                          _mm256_extractf128_si256(mean, 1)); \
839     /*(mean * mean): dynamic range 31 bits.*/                                \
840     const int mean_int = _mm_extract_epi32(result, 0);                       \
841     const int sse_int = _mm_extract_epi32(result, 2);                        \
842     const unsigned int mean_abs = abs(mean_int);                             \
843     var = sse_int - ((mean_abs * mean_abs) >> (bwl + 2));                    \
844   }
845 
846 // ref: [0 - 510]
847 // src: [0 - 510]
848 // bwl: {2, 3, 4, 5}
aom_vector_var_avx2(const int16_t * ref,const int16_t * src,int bwl)849 int aom_vector_var_avx2(const int16_t *ref, const int16_t *src, int bwl) {
850   const int width = 4 << bwl;
851   assert(width % 16 == 0 && width <= 128);
852   int var = 0;
853 
854   // Instead of having a loop over width 16, considered loop unrolling to avoid
855   // some addition operations.
856   if (width == 128) {
857     __m256i mean = _mm256_setzero_si256();
858     __m256i sse = _mm256_setzero_si256();
859 
860     calc_vector_mean_sse_64wd(src, ref, &mean, &sse);
861     calc_vector_mean_sse_64wd(src + 64, ref + 64, &mean, &sse);
862     CALC_VAR_FROM_MEAN_SSE(mean, sse)
863   } else if (width == 64) {
864     __m256i mean = _mm256_setzero_si256();
865     __m256i sse = _mm256_setzero_si256();
866 
867     calc_vector_mean_sse_64wd(src, ref, &mean, &sse);
868     CALC_VAR_FROM_MEAN_SSE(mean, sse)
869   } else if (width == 32) {
870     const __m256i src_line0 = _mm256_loadu_si256((const __m256i *)src);
871     const __m256i ref_line0 = _mm256_loadu_si256((const __m256i *)ref);
872     const __m256i src_line1 = _mm256_loadu_si256((const __m256i *)(src + 16));
873     const __m256i ref_line1 = _mm256_loadu_si256((const __m256i *)(ref + 16));
874 
875     const __m256i diff0 = _mm256_sub_epi16(ref_line0, src_line0);
876     const __m256i diff1 = _mm256_sub_epi16(ref_line1, src_line1);
877     const __m256i diff_sqr0 = _mm256_madd_epi16(diff0, diff0);
878     const __m256i diff_sqr1 = _mm256_madd_epi16(diff1, diff1);
879     const __m256i sse = _mm256_add_epi32(diff_sqr0, diff_sqr1);
880     __m256i mean = _mm256_add_epi16(diff0, diff1);
881 
882     CALC_VAR_FROM_MEAN_SSE(mean, sse)
883   } else if (width == 16) {
884     const __m256i src_line = _mm256_loadu_si256((const __m256i *)src);
885     const __m256i ref_line = _mm256_loadu_si256((const __m256i *)ref);
886     __m256i mean = _mm256_sub_epi16(ref_line, src_line);
887     const __m256i sse = _mm256_madd_epi16(mean, mean);
888 
889     CALC_VAR_FROM_MEAN_SSE(mean, sse)
890   }
891   return var;
892 }
893