xref: /aosp_15_r20/external/libvpx/vpx_dsp/x86/highbd_sad_avx2.c (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  *  Copyright (c) 2022 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #include <immintrin.h>
11 #include "./vpx_dsp_rtcd.h"
12 #include "vpx/vpx_integer.h"
13 
calc_final(const __m256i sums_32)14 static VPX_FORCE_INLINE unsigned int calc_final(const __m256i sums_32) {
15   const __m256i t0 = _mm256_add_epi32(sums_32, _mm256_srli_si256(sums_32, 8));
16   const __m256i t1 = _mm256_add_epi32(t0, _mm256_srli_si256(t0, 4));
17   const __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(t1),
18                                     _mm256_extractf128_si256(t1, 1));
19   return (unsigned int)_mm_cvtsi128_si32(sum);
20 }
21 
highbd_sad64xH(__m256i * sums_16,const uint16_t * src,int src_stride,uint16_t * ref,int ref_stride,int height)22 static VPX_FORCE_INLINE void highbd_sad64xH(__m256i *sums_16,
23                                             const uint16_t *src, int src_stride,
24                                             uint16_t *ref, int ref_stride,
25                                             int height) {
26   int i;
27   for (i = 0; i < height; ++i) {
28     // load src and all ref[]
29     const __m256i s0 = _mm256_load_si256((const __m256i *)src);
30     const __m256i s1 = _mm256_load_si256((const __m256i *)(src + 16));
31     const __m256i s2 = _mm256_load_si256((const __m256i *)(src + 32));
32     const __m256i s3 = _mm256_load_si256((const __m256i *)(src + 48));
33     const __m256i r0 = _mm256_loadu_si256((const __m256i *)ref);
34     const __m256i r1 = _mm256_loadu_si256((const __m256i *)(ref + 16));
35     const __m256i r2 = _mm256_loadu_si256((const __m256i *)(ref + 32));
36     const __m256i r3 = _mm256_loadu_si256((const __m256i *)(ref + 48));
37     // absolute differences between every ref[] to src
38     const __m256i abs_diff0 = _mm256_abs_epi16(_mm256_sub_epi16(r0, s0));
39     const __m256i abs_diff1 = _mm256_abs_epi16(_mm256_sub_epi16(r1, s1));
40     const __m256i abs_diff2 = _mm256_abs_epi16(_mm256_sub_epi16(r2, s2));
41     const __m256i abs_diff3 = _mm256_abs_epi16(_mm256_sub_epi16(r3, s3));
42     // sum every abs diff
43     *sums_16 =
44         _mm256_add_epi16(*sums_16, _mm256_add_epi16(abs_diff0, abs_diff1));
45     *sums_16 =
46         _mm256_add_epi16(*sums_16, _mm256_add_epi16(abs_diff2, abs_diff3));
47 
48     src += src_stride;
49     ref += ref_stride;
50   }
51 }
52 
highbd_sad64xN_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int n)53 static VPX_FORCE_INLINE unsigned int highbd_sad64xN_avx2(const uint8_t *src_ptr,
54                                                          int src_stride,
55                                                          const uint8_t *ref_ptr,
56                                                          int ref_stride,
57                                                          int n) {
58   const uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);
59   uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);
60   __m256i sums_32 = _mm256_setzero_si256();
61   int i;
62 
63   for (i = 0; i < (n / 2); ++i) {
64     __m256i sums_16 = _mm256_setzero_si256();
65 
66     highbd_sad64xH(&sums_16, src, src_stride, ref, ref_stride, 2);
67 
68     /* sums_16 will outrange after 2 rows, so add current sums_16 to
69      * sums_32*/
70     sums_32 = _mm256_add_epi32(
71         sums_32,
72         _mm256_add_epi32(
73             _mm256_cvtepu16_epi32(_mm256_castsi256_si128(sums_16)),
74             _mm256_cvtepu16_epi32(_mm256_extractf128_si256(sums_16, 1))));
75 
76     src += src_stride << 1;
77     ref += ref_stride << 1;
78   }
79   return calc_final(sums_32);
80 }
81 
82 #define HIGHBD_SAD64XN(n)                                                      \
83   unsigned int vpx_highbd_sad64x##n##_avx2(const uint8_t *src, int src_stride, \
84                                            const uint8_t *ref,                 \
85                                            int ref_stride) {                   \
86     return highbd_sad64xN_avx2(src, src_stride, ref, ref_stride, n);           \
87   }
88 
89 #define HIGHBD_SADSKIP64xN(n)                                                \
90   unsigned int vpx_highbd_sad_skip_64x##n##_avx2(                            \
91       const uint8_t *src, int src_stride, const uint8_t *ref,                \
92       int ref_stride) {                                                      \
93     return 2 * highbd_sad64xN_avx2(src, 2 * src_stride, ref, 2 * ref_stride, \
94                                    n / 2);                                   \
95   }
96 
highbd_sad32xH(__m256i * sums_16,const uint16_t * src,int src_stride,uint16_t * ref,int ref_stride,int height)97 static VPX_FORCE_INLINE void highbd_sad32xH(__m256i *sums_16,
98                                             const uint16_t *src, int src_stride,
99                                             uint16_t *ref, int ref_stride,
100                                             int height) {
101   int i;
102   for (i = 0; i < height; ++i) {
103     // load src and all ref[]
104     const __m256i s0 = _mm256_load_si256((const __m256i *)src);
105     const __m256i s1 = _mm256_load_si256((const __m256i *)(src + 16));
106     const __m256i r0 = _mm256_loadu_si256((const __m256i *)ref);
107     const __m256i r1 = _mm256_loadu_si256((const __m256i *)(ref + 16));
108     // absolute differences between every ref[] to src
109     const __m256i abs_diff0 = _mm256_abs_epi16(_mm256_sub_epi16(r0, s0));
110     const __m256i abs_diff1 = _mm256_abs_epi16(_mm256_sub_epi16(r1, s1));
111     // sum every abs diff
112     *sums_16 = _mm256_add_epi16(*sums_16, abs_diff0);
113     *sums_16 = _mm256_add_epi16(*sums_16, abs_diff1);
114 
115     src += src_stride;
116     ref += ref_stride;
117   }
118 }
119 
highbd_sad32xN_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int n)120 static VPX_FORCE_INLINE unsigned int highbd_sad32xN_avx2(const uint8_t *src_ptr,
121                                                          int src_stride,
122                                                          const uint8_t *ref_ptr,
123                                                          int ref_stride,
124                                                          int n) {
125   const uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);
126   uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);
127   __m256i sums_32 = _mm256_setzero_si256();
128   int i;
129 
130   for (i = 0; i < (n / 8); ++i) {
131     __m256i sums_16 = _mm256_setzero_si256();
132 
133     highbd_sad32xH(&sums_16, src, src_stride, ref, ref_stride, 8);
134 
135     /* sums_16 will outrange after 8 rows, so add current sums_16 to
136      * sums_32*/
137     sums_32 = _mm256_add_epi32(
138         sums_32,
139         _mm256_add_epi32(
140             _mm256_cvtepu16_epi32(_mm256_castsi256_si128(sums_16)),
141             _mm256_cvtepu16_epi32(_mm256_extractf128_si256(sums_16, 1))));
142 
143     src += src_stride << 3;
144     ref += ref_stride << 3;
145   }
146   return calc_final(sums_32);
147 }
148 
149 #define HIGHBD_SAD32XN(n)                                                      \
150   unsigned int vpx_highbd_sad32x##n##_avx2(const uint8_t *src, int src_stride, \
151                                            const uint8_t *ref,                 \
152                                            int ref_stride) {                   \
153     return highbd_sad32xN_avx2(src, src_stride, ref, ref_stride, n);           \
154   }
155 
156 #define HIGHBD_SADSKIP32xN(n)                                                \
157   unsigned int vpx_highbd_sad_skip_32x##n##_avx2(                            \
158       const uint8_t *src, int src_stride, const uint8_t *ref,                \
159       int ref_stride) {                                                      \
160     return 2 * highbd_sad32xN_avx2(src, 2 * src_stride, ref, 2 * ref_stride, \
161                                    n / 2);                                   \
162   }
163 
highbd_sad16xH(__m256i * sums_16,const uint16_t * src,int src_stride,uint16_t * ref,int ref_stride,int height)164 static VPX_FORCE_INLINE void highbd_sad16xH(__m256i *sums_16,
165                                             const uint16_t *src, int src_stride,
166                                             uint16_t *ref, int ref_stride,
167                                             int height) {
168   int i;
169   for (i = 0; i < height; i += 2) {
170     // load src and all ref[]
171     const __m256i s0 = _mm256_load_si256((const __m256i *)src);
172     const __m256i s1 = _mm256_load_si256((const __m256i *)(src + src_stride));
173     const __m256i r0 = _mm256_loadu_si256((const __m256i *)ref);
174     const __m256i r1 = _mm256_loadu_si256((const __m256i *)(ref + ref_stride));
175     // absolute differences between every ref[] to src
176     const __m256i abs_diff0 = _mm256_abs_epi16(_mm256_sub_epi16(r0, s0));
177     const __m256i abs_diff1 = _mm256_abs_epi16(_mm256_sub_epi16(r1, s1));
178     // sum every abs diff
179     *sums_16 = _mm256_add_epi16(*sums_16, abs_diff0);
180     *sums_16 = _mm256_add_epi16(*sums_16, abs_diff1);
181 
182     src += src_stride << 1;
183     ref += ref_stride << 1;
184   }
185 }
186 
highbd_sad16xN_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int n)187 static VPX_FORCE_INLINE unsigned int highbd_sad16xN_avx2(const uint8_t *src_ptr,
188                                                          int src_stride,
189                                                          const uint8_t *ref_ptr,
190                                                          int ref_stride,
191                                                          int n) {
192   const uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);
193   uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);
194   __m256i sums_32 = _mm256_setzero_si256();
195   const int height = VPXMIN(16, n);
196   const int num_iters = n / height;
197   int i;
198 
199   for (i = 0; i < num_iters; ++i) {
200     __m256i sums_16 = _mm256_setzero_si256();
201 
202     highbd_sad16xH(&sums_16, src, src_stride, ref, ref_stride, height);
203 
204     // sums_16 will outrange after 16 rows, so add current sums_16 to sums_32
205     sums_32 = _mm256_add_epi32(
206         sums_32,
207         _mm256_add_epi32(
208             _mm256_cvtepu16_epi32(_mm256_castsi256_si128(sums_16)),
209             _mm256_cvtepu16_epi32(_mm256_extractf128_si256(sums_16, 1))));
210 
211     src += src_stride << 4;
212     ref += ref_stride << 4;
213   }
214   return calc_final(sums_32);
215 }
216 
217 #define HIGHBD_SAD16XN(n)                                                      \
218   unsigned int vpx_highbd_sad16x##n##_avx2(const uint8_t *src, int src_stride, \
219                                            const uint8_t *ref,                 \
220                                            int ref_stride) {                   \
221     return highbd_sad16xN_avx2(src, src_stride, ref, ref_stride, n);           \
222   }
223 
224 #define HIGHBD_SADSKIP16xN(n)                                                \
225   unsigned int vpx_highbd_sad_skip_16x##n##_avx2(                            \
226       const uint8_t *src, int src_stride, const uint8_t *ref,                \
227       int ref_stride) {                                                      \
228     return 2 * highbd_sad16xN_avx2(src, 2 * src_stride, ref, 2 * ref_stride, \
229                                    n / 2);                                   \
230   }
231 
vpx_highbd_sad16x16_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)232 unsigned int vpx_highbd_sad16x16_avx2(const uint8_t *src_ptr, int src_stride,
233                                       const uint8_t *ref_ptr, int ref_stride) {
234   const uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);
235   uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);
236   __m256i sums_16 = _mm256_setzero_si256();
237 
238   highbd_sad16xH(&sums_16, src, src_stride, ref, ref_stride, 16);
239 
240   {
241     const __m256i sums_32 = _mm256_add_epi32(
242         _mm256_cvtepu16_epi32(_mm256_castsi256_si128(sums_16)),
243         _mm256_cvtepu16_epi32(_mm256_extractf128_si256(sums_16, 1)));
244     return calc_final(sums_32);
245   }
246 }
247 
vpx_highbd_sad16x8_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride)248 unsigned int vpx_highbd_sad16x8_avx2(const uint8_t *src_ptr, int src_stride,
249                                      const uint8_t *ref_ptr, int ref_stride) {
250   const uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);
251   uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);
252   __m256i sums_16 = _mm256_setzero_si256();
253 
254   highbd_sad16xH(&sums_16, src, src_stride, ref, ref_stride, 8);
255 
256   {
257     const __m256i sums_32 = _mm256_add_epi32(
258         _mm256_cvtepu16_epi32(_mm256_castsi256_si128(sums_16)),
259         _mm256_cvtepu16_epi32(_mm256_extractf128_si256(sums_16, 1)));
260     return calc_final(sums_32);
261   }
262 }
263 
264 // clang-format off
265 HIGHBD_SAD64XN(64)
266 HIGHBD_SADSKIP64xN(64)
267 HIGHBD_SAD64XN(32)
268 HIGHBD_SADSKIP64xN(32)
269 HIGHBD_SAD32XN(64)
270 HIGHBD_SADSKIP32xN(64)
271 HIGHBD_SAD32XN(32)
272 HIGHBD_SADSKIP32xN(32)
273 HIGHBD_SAD32XN(16)
274 HIGHBD_SADSKIP32xN(16)
275 HIGHBD_SAD16XN(32)
276 HIGHBD_SADSKIP16xN(32)
277 HIGHBD_SADSKIP16xN(16)
278 HIGHBD_SADSKIP16xN(8)
279 //clang-format on
280 
281 // AVG -------------------------------------------------------------------------
highbd_sad64xH_avg(__m256i * sums_16,const uint16_t * src,int src_stride,uint16_t * ref,int ref_stride,uint16_t * sec,int height)282 static VPX_FORCE_INLINE void highbd_sad64xH_avg(__m256i *sums_16,
283                                                 const uint16_t *src,
284                                                 int src_stride, uint16_t *ref,
285                                                 int ref_stride, uint16_t *sec,
286                                                 int height) {
287   int i;
288   for (i = 0; i < height; ++i) {
289     // load src and all ref[]
290     const __m256i s0 = _mm256_load_si256((const __m256i *)src);
291     const __m256i s1 = _mm256_load_si256((const __m256i *)(src + 16));
292     const __m256i s2 = _mm256_load_si256((const __m256i *)(src + 32));
293     const __m256i s3 = _mm256_load_si256((const __m256i *)(src + 48));
294     const __m256i r0 = _mm256_loadu_si256((const __m256i *)ref);
295     const __m256i r1 = _mm256_loadu_si256((const __m256i *)(ref + 16));
296     const __m256i r2 = _mm256_loadu_si256((const __m256i *)(ref + 32));
297     const __m256i r3 = _mm256_loadu_si256((const __m256i *)(ref + 48));
298     const __m256i x0 = _mm256_loadu_si256((const __m256i *)sec);
299     const __m256i x1 = _mm256_loadu_si256((const __m256i *)(sec + 16));
300     const __m256i x2 = _mm256_loadu_si256((const __m256i *)(sec + 32));
301     const __m256i x3 = _mm256_loadu_si256((const __m256i *)(sec + 48));
302     const __m256i avg0 = _mm256_avg_epu16(r0, x0);
303     const __m256i avg1 = _mm256_avg_epu16(r1, x1);
304     const __m256i avg2 = _mm256_avg_epu16(r2, x2);
305     const __m256i avg3 = _mm256_avg_epu16(r3, x3);
306     // absolute differences between every ref/pred avg to src
307     const __m256i abs_diff0 = _mm256_abs_epi16(_mm256_sub_epi16(avg0, s0));
308     const __m256i abs_diff1 = _mm256_abs_epi16(_mm256_sub_epi16(avg1, s1));
309     const __m256i abs_diff2 = _mm256_abs_epi16(_mm256_sub_epi16(avg2, s2));
310     const __m256i abs_diff3 = _mm256_abs_epi16(_mm256_sub_epi16(avg3, s3));
311     // sum every abs diff
312     *sums_16 =
313         _mm256_add_epi16(*sums_16, _mm256_add_epi16(abs_diff0, abs_diff1));
314     *sums_16 =
315         _mm256_add_epi16(*sums_16, _mm256_add_epi16(abs_diff2, abs_diff3));
316 
317     src += src_stride;
318     ref += ref_stride;
319     sec += 64;
320   }
321 }
322 
323 #define HIGHBD_SAD64XN_AVG(n)                                                 \
324   unsigned int vpx_highbd_sad64x##n##_avg_avx2(                               \
325       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,         \
326       int ref_stride, const uint8_t *second_pred) {                           \
327     const uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                       \
328     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                             \
329     uint16_t *sec = CONVERT_TO_SHORTPTR(second_pred);                         \
330     __m256i sums_32 = _mm256_setzero_si256();                                 \
331     int i;                                                                    \
332                                                                               \
333     for (i = 0; i < (n / 2); ++i) {                                           \
334       __m256i sums_16 = _mm256_setzero_si256();                               \
335                                                                               \
336       highbd_sad64xH_avg(&sums_16, src, src_stride, ref, ref_stride, sec, 2); \
337                                                                               \
338       /* sums_16 will outrange after 2 rows, so add current sums_16 to        \
339        * sums_32*/                                                            \
340       sums_32 = _mm256_add_epi32(                                             \
341           sums_32,                                                            \
342           _mm256_add_epi32(                                                   \
343               _mm256_cvtepu16_epi32(_mm256_castsi256_si128(sums_16)),         \
344               _mm256_cvtepu16_epi32(_mm256_extractf128_si256(sums_16, 1))));  \
345                                                                               \
346       src += src_stride << 1;                                                 \
347       ref += ref_stride << 1;                                                 \
348       sec += 64 << 1;                                                         \
349     }                                                                         \
350     return calc_final(sums_32);                                               \
351   }
352 
353 // 64x64
354 HIGHBD_SAD64XN_AVG(64)
355 
356 // 64x32
357 HIGHBD_SAD64XN_AVG(32)
358 
highbd_sad32xH_avg(__m256i * sums_16,const uint16_t * src,int src_stride,uint16_t * ref,int ref_stride,uint16_t * sec,int height)359 static VPX_FORCE_INLINE void highbd_sad32xH_avg(__m256i *sums_16,
360                                                 const uint16_t *src,
361                                                 int src_stride, uint16_t *ref,
362                                                 int ref_stride, uint16_t *sec,
363                                                 int height) {
364   int i;
365   for (i = 0; i < height; ++i) {
366     // load src and all ref[]
367     const __m256i s0 = _mm256_load_si256((const __m256i *)src);
368     const __m256i s1 = _mm256_load_si256((const __m256i *)(src + 16));
369     const __m256i r0 = _mm256_loadu_si256((const __m256i *)ref);
370     const __m256i r1 = _mm256_loadu_si256((const __m256i *)(ref + 16));
371     const __m256i x0 = _mm256_loadu_si256((const __m256i *)sec);
372     const __m256i x1 = _mm256_loadu_si256((const __m256i *)(sec + 16));
373     const __m256i avg0 = _mm256_avg_epu16(r0, x0);
374     const __m256i avg1 = _mm256_avg_epu16(r1, x1);
375     // absolute differences between every ref/pred avg to src
376     const __m256i abs_diff0 = _mm256_abs_epi16(_mm256_sub_epi16(avg0, s0));
377     const __m256i abs_diff1 = _mm256_abs_epi16(_mm256_sub_epi16(avg1, s1));
378     // sum every abs diff
379     *sums_16 = _mm256_add_epi16(*sums_16, abs_diff0);
380     *sums_16 = _mm256_add_epi16(*sums_16, abs_diff1);
381 
382     src += src_stride;
383     ref += ref_stride;
384     sec += 32;
385   }
386 }
387 
388 #define HIGHBD_SAD32XN_AVG(n)                                                 \
389   unsigned int vpx_highbd_sad32x##n##_avg_avx2(                               \
390       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,         \
391       int ref_stride, const uint8_t *second_pred) {                           \
392     const uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                       \
393     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                             \
394     uint16_t *sec = CONVERT_TO_SHORTPTR(second_pred);                         \
395     __m256i sums_32 = _mm256_setzero_si256();                                 \
396     int i;                                                                    \
397                                                                               \
398     for (i = 0; i < (n / 8); ++i) {                                           \
399       __m256i sums_16 = _mm256_setzero_si256();                               \
400                                                                               \
401       highbd_sad32xH_avg(&sums_16, src, src_stride, ref, ref_stride, sec, 8); \
402                                                                               \
403       /* sums_16 will outrange after 8 rows, so add current sums_16 to        \
404        * sums_32*/                                                            \
405       sums_32 = _mm256_add_epi32(                                             \
406           sums_32,                                                            \
407           _mm256_add_epi32(                                                   \
408               _mm256_cvtepu16_epi32(_mm256_castsi256_si128(sums_16)),         \
409               _mm256_cvtepu16_epi32(_mm256_extractf128_si256(sums_16, 1))));  \
410                                                                               \
411       src += src_stride << 3;                                                 \
412       ref += ref_stride << 3;                                                 \
413       sec += 32 << 3;                                                         \
414     }                                                                         \
415     return calc_final(sums_32);                                               \
416   }
417 
418 // 32x64
419 HIGHBD_SAD32XN_AVG(64)
420 
421 // 32x32
422 HIGHBD_SAD32XN_AVG(32)
423 
424 // 32x16
425 HIGHBD_SAD32XN_AVG(16)
426 
highbd_sad16xH_avg(__m256i * sums_16,const uint16_t * src,int src_stride,uint16_t * ref,int ref_stride,uint16_t * sec,int height)427 static VPX_FORCE_INLINE void highbd_sad16xH_avg(__m256i *sums_16,
428                                                 const uint16_t *src,
429                                                 int src_stride, uint16_t *ref,
430                                                 int ref_stride, uint16_t *sec,
431                                                 int height) {
432   int i;
433   for (i = 0; i < height; i += 2) {
434     // load src and all ref[]
435     const __m256i s0 = _mm256_load_si256((const __m256i *)src);
436     const __m256i s1 = _mm256_load_si256((const __m256i *)(src + src_stride));
437     const __m256i r0 = _mm256_loadu_si256((const __m256i *)ref);
438     const __m256i r1 = _mm256_loadu_si256((const __m256i *)(ref + ref_stride));
439     const __m256i x0 = _mm256_loadu_si256((const __m256i *)sec);
440     const __m256i x1 = _mm256_loadu_si256((const __m256i *)(sec + 16));
441     const __m256i avg0 = _mm256_avg_epu16(r0, x0);
442     const __m256i avg1 = _mm256_avg_epu16(r1, x1);
443     // absolute differences between every ref[] to src
444     const __m256i abs_diff0 = _mm256_abs_epi16(_mm256_sub_epi16(avg0, s0));
445     const __m256i abs_diff1 = _mm256_abs_epi16(_mm256_sub_epi16(avg1, s1));
446     // sum every abs diff
447     *sums_16 = _mm256_add_epi16(*sums_16, abs_diff0);
448     *sums_16 = _mm256_add_epi16(*sums_16, abs_diff1);
449 
450     src += src_stride << 1;
451     ref += ref_stride << 1;
452     sec += 32;
453   }
454 }
455 
vpx_highbd_sad16x32_avg_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,const uint8_t * second_pred)456 unsigned int vpx_highbd_sad16x32_avg_avx2(const uint8_t *src_ptr,
457                                           int src_stride,
458                                           const uint8_t *ref_ptr,
459                                           int ref_stride,
460                                           const uint8_t *second_pred) {
461   const uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);
462   uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);
463   uint16_t *sec = CONVERT_TO_SHORTPTR(second_pred);
464   __m256i sums_32 = _mm256_setzero_si256();
465   int i;
466 
467   for (i = 0; i < 2; ++i) {
468     __m256i sums_16 = _mm256_setzero_si256();
469 
470     highbd_sad16xH_avg(&sums_16, src, src_stride, ref, ref_stride, sec, 16);
471 
472     // sums_16 will outrange after 16 rows, so add current sums_16 to sums_32
473     sums_32 = _mm256_add_epi32(
474         sums_32,
475         _mm256_add_epi32(
476             _mm256_cvtepu16_epi32(_mm256_castsi256_si128(sums_16)),
477             _mm256_cvtepu16_epi32(_mm256_extractf128_si256(sums_16, 1))));
478 
479     src += src_stride << 4;
480     ref += ref_stride << 4;
481     sec += 16 << 4;
482   }
483   return calc_final(sums_32);
484 }
485 
vpx_highbd_sad16x16_avg_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,const uint8_t * second_pred)486 unsigned int vpx_highbd_sad16x16_avg_avx2(const uint8_t *src_ptr,
487                                           int src_stride,
488                                           const uint8_t *ref_ptr,
489                                           int ref_stride,
490                                           const uint8_t *second_pred) {
491   const uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);
492   uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);
493   uint16_t *sec = CONVERT_TO_SHORTPTR(second_pred);
494   __m256i sums_16 = _mm256_setzero_si256();
495 
496   highbd_sad16xH_avg(&sums_16, src, src_stride, ref, ref_stride, sec, 16);
497 
498   {
499     const __m256i sums_32 = _mm256_add_epi32(
500         _mm256_cvtepu16_epi32(_mm256_castsi256_si128(sums_16)),
501         _mm256_cvtepu16_epi32(_mm256_extractf128_si256(sums_16, 1)));
502     return calc_final(sums_32);
503   }
504 }
505 
vpx_highbd_sad16x8_avg_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,const uint8_t * second_pred)506 unsigned int vpx_highbd_sad16x8_avg_avx2(const uint8_t *src_ptr, int src_stride,
507                                          const uint8_t *ref_ptr, int ref_stride,
508                                          const uint8_t *second_pred) {
509   const uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);
510   uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);
511   uint16_t *sec = CONVERT_TO_SHORTPTR(second_pred);
512   __m256i sums_16 = _mm256_setzero_si256();
513 
514   highbd_sad16xH_avg(&sums_16, src, src_stride, ref, ref_stride, sec, 8);
515 
516   {
517     const __m256i sums_32 = _mm256_add_epi32(
518         _mm256_cvtepu16_epi32(_mm256_castsi256_si128(sums_16)),
519         _mm256_cvtepu16_epi32(_mm256_extractf128_si256(sums_16, 1)));
520     return calc_final(sums_32);
521   }
522 }
523