xref: /aosp_15_r20/external/libaom/aom_dsp/x86/highbd_sad_avx2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/x86/synonyms_avx2.h"
19 #include "aom_ports/mem.h"
20 
21 // SAD
get_sad_from_mm256_epi32(const __m256i * v)22 static inline unsigned int get_sad_from_mm256_epi32(const __m256i *v) {
23   // input 8 32-bit summation
24   __m128i lo128, hi128;
25   __m256i u = _mm256_srli_si256(*v, 8);
26   u = _mm256_add_epi32(u, *v);
27 
28   // 4 32-bit summation
29   hi128 = _mm256_extracti128_si256(u, 1);
30   lo128 = _mm256_castsi256_si128(u);
31   lo128 = _mm_add_epi32(hi128, lo128);
32 
33   // 2 32-bit summation
34   hi128 = _mm_srli_si128(lo128, 4);
35   lo128 = _mm_add_epi32(lo128, hi128);
36 
37   return (unsigned int)_mm_cvtsi128_si32(lo128);
38 }
39 
highbd_sad16x4_core_avx2(__m256i * s,__m256i * r,__m256i * sad_acc)40 static inline void highbd_sad16x4_core_avx2(__m256i *s, __m256i *r,
41                                             __m256i *sad_acc) {
42   const __m256i zero = _mm256_setzero_si256();
43   int i;
44   for (i = 0; i < 4; i++) {
45     s[i] = _mm256_sub_epi16(s[i], r[i]);
46     s[i] = _mm256_abs_epi16(s[i]);
47   }
48 
49   s[0] = _mm256_add_epi16(s[0], s[1]);
50   s[0] = _mm256_add_epi16(s[0], s[2]);
51   s[0] = _mm256_add_epi16(s[0], s[3]);
52 
53   r[0] = _mm256_unpacklo_epi16(s[0], zero);
54   r[1] = _mm256_unpackhi_epi16(s[0], zero);
55 
56   r[0] = _mm256_add_epi32(r[0], r[1]);
57   *sad_acc = _mm256_add_epi32(*sad_acc, r[0]);
58 }
59 
60 // If sec_ptr = 0, calculate regular SAD. Otherwise, calculate average SAD.
sad16x4(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)61 static inline void sad16x4(const uint16_t *src_ptr, int src_stride,
62                            const uint16_t *ref_ptr, int ref_stride,
63                            const uint16_t *sec_ptr, __m256i *sad_acc) {
64   __m256i s[4], r[4];
65   s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
66   s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
67   s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
68   s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
69 
70   r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
71   r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
72   r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
73   r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
74 
75   if (sec_ptr) {
76     r[0] = _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
77     r[1] = _mm256_avg_epu16(
78         r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
79     r[2] = _mm256_avg_epu16(
80         r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
81     r[3] = _mm256_avg_epu16(
82         r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
83   }
84   highbd_sad16x4_core_avx2(s, r, sad_acc);
85 }
86 
aom_highbd_sad16xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)87 static AOM_FORCE_INLINE unsigned int aom_highbd_sad16xN_avx2(int N,
88                                                              const uint8_t *src,
89                                                              int src_stride,
90                                                              const uint8_t *ref,
91                                                              int ref_stride) {
92   const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
93   const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
94   int i;
95   __m256i sad = _mm256_setzero_si256();
96   for (i = 0; i < N; i += 4) {
97     sad16x4(src_ptr, src_stride, ref_ptr, ref_stride, NULL, &sad);
98     src_ptr += src_stride << 2;
99     ref_ptr += ref_stride << 2;
100   }
101   return (unsigned int)get_sad_from_mm256_epi32(&sad);
102 }
103 
sad32x4(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)104 static void sad32x4(const uint16_t *src_ptr, int src_stride,
105                     const uint16_t *ref_ptr, int ref_stride,
106                     const uint16_t *sec_ptr, __m256i *sad_acc) {
107   __m256i s[4], r[4];
108   int row_sections = 0;
109 
110   while (row_sections < 2) {
111     s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
112     s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
113     s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
114     s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 16));
115 
116     r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
117     r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
118     r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
119     r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 16));
120 
121     if (sec_ptr) {
122       r[0] =
123           _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
124       r[1] = _mm256_avg_epu16(
125           r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
126       r[2] = _mm256_avg_epu16(
127           r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
128       r[3] = _mm256_avg_epu16(
129           r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
130       sec_ptr += 32 << 1;
131     }
132     highbd_sad16x4_core_avx2(s, r, sad_acc);
133 
134     row_sections += 1;
135     src_ptr += src_stride << 1;
136     ref_ptr += ref_stride << 1;
137   }
138 }
139 
aom_highbd_sad32xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)140 static AOM_FORCE_INLINE unsigned int aom_highbd_sad32xN_avx2(int N,
141                                                              const uint8_t *src,
142                                                              int src_stride,
143                                                              const uint8_t *ref,
144                                                              int ref_stride) {
145   __m256i sad = _mm256_setzero_si256();
146   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
147   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
148   const int left_shift = 2;
149   int i;
150 
151   for (i = 0; i < N; i += 4) {
152     sad32x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
153     srcp += src_stride << left_shift;
154     refp += ref_stride << left_shift;
155   }
156   return get_sad_from_mm256_epi32(&sad);
157 }
158 
sad64x2(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)159 static void sad64x2(const uint16_t *src_ptr, int src_stride,
160                     const uint16_t *ref_ptr, int ref_stride,
161                     const uint16_t *sec_ptr, __m256i *sad_acc) {
162   __m256i s[4], r[4];
163   int i;
164   for (i = 0; i < 2; i++) {
165     s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
166     s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
167     s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
168     s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
169 
170     r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
171     r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
172     r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
173     r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
174     if (sec_ptr) {
175       r[0] =
176           _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
177       r[1] = _mm256_avg_epu16(
178           r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
179       r[2] = _mm256_avg_epu16(
180           r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
181       r[3] = _mm256_avg_epu16(
182           r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
183       sec_ptr += 64;
184     }
185     highbd_sad16x4_core_avx2(s, r, sad_acc);
186     src_ptr += src_stride;
187     ref_ptr += ref_stride;
188   }
189 }
190 
aom_highbd_sad64xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)191 static AOM_FORCE_INLINE unsigned int aom_highbd_sad64xN_avx2(int N,
192                                                              const uint8_t *src,
193                                                              int src_stride,
194                                                              const uint8_t *ref,
195                                                              int ref_stride) {
196   __m256i sad = _mm256_setzero_si256();
197   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
198   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
199   const int left_shift = 1;
200   int i;
201   for (i = 0; i < N; i += 2) {
202     sad64x2(srcp, src_stride, refp, ref_stride, NULL, &sad);
203     srcp += src_stride << left_shift;
204     refp += ref_stride << left_shift;
205   }
206   return get_sad_from_mm256_epi32(&sad);
207 }
208 
sad128x1(const uint16_t * src_ptr,const uint16_t * ref_ptr,const uint16_t * sec_ptr,__m256i * sad_acc)209 static void sad128x1(const uint16_t *src_ptr, const uint16_t *ref_ptr,
210                      const uint16_t *sec_ptr, __m256i *sad_acc) {
211   __m256i s[4], r[4];
212   int i;
213   for (i = 0; i < 2; i++) {
214     s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
215     s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
216     s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
217     s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
218     r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
219     r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
220     r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
221     r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
222     if (sec_ptr) {
223       r[0] =
224           _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
225       r[1] = _mm256_avg_epu16(
226           r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
227       r[2] = _mm256_avg_epu16(
228           r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
229       r[3] = _mm256_avg_epu16(
230           r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
231       sec_ptr += 64;
232     }
233     highbd_sad16x4_core_avx2(s, r, sad_acc);
234     src_ptr += 64;
235     ref_ptr += 64;
236   }
237 }
238 
aom_highbd_sad128xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)239 static AOM_FORCE_INLINE unsigned int aom_highbd_sad128xN_avx2(
240     int N, const uint8_t *src, int src_stride, const uint8_t *ref,
241     int ref_stride) {
242   __m256i sad = _mm256_setzero_si256();
243   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
244   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
245   int row = 0;
246   while (row < N) {
247     sad128x1(srcp, refp, NULL, &sad);
248     srcp += src_stride;
249     refp += ref_stride;
250     row++;
251   }
252   return get_sad_from_mm256_epi32(&sad);
253 }
254 
255 #define HIGHBD_SADMXN_AVX2(m, n)                                            \
256   unsigned int aom_highbd_sad##m##x##n##_avx2(                              \
257       const uint8_t *src, int src_stride, const uint8_t *ref,               \
258       int ref_stride) {                                                     \
259     return aom_highbd_sad##m##xN_avx2(n, src, src_stride, ref, ref_stride); \
260   }
261 
262 #define HIGHBD_SAD_SKIP_MXN_AVX2(m, n)                                       \
263   unsigned int aom_highbd_sad_skip_##m##x##n##_avx2(                         \
264       const uint8_t *src, int src_stride, const uint8_t *ref,                \
265       int ref_stride) {                                                      \
266     return 2 * aom_highbd_sad##m##xN_avx2((n / 2), src, 2 * src_stride, ref, \
267                                           2 * ref_stride);                   \
268   }
269 
270 HIGHBD_SADMXN_AVX2(16, 8)
271 HIGHBD_SADMXN_AVX2(16, 16)
272 HIGHBD_SADMXN_AVX2(16, 32)
273 
274 HIGHBD_SADMXN_AVX2(32, 16)
275 HIGHBD_SADMXN_AVX2(32, 32)
276 HIGHBD_SADMXN_AVX2(32, 64)
277 
278 HIGHBD_SADMXN_AVX2(64, 32)
279 HIGHBD_SADMXN_AVX2(64, 64)
280 HIGHBD_SADMXN_AVX2(64, 128)
281 
282 HIGHBD_SADMXN_AVX2(128, 64)
283 HIGHBD_SADMXN_AVX2(128, 128)
284 
285 #if !CONFIG_REALTIME_ONLY
286 HIGHBD_SADMXN_AVX2(16, 4)
287 HIGHBD_SADMXN_AVX2(16, 64)
288 HIGHBD_SADMXN_AVX2(32, 8)
289 HIGHBD_SADMXN_AVX2(64, 16)
290 #endif  // !CONFIG_REALTIME_ONLY
291 
292 HIGHBD_SAD_SKIP_MXN_AVX2(16, 8)
293 HIGHBD_SAD_SKIP_MXN_AVX2(16, 16)
294 HIGHBD_SAD_SKIP_MXN_AVX2(16, 32)
295 
296 HIGHBD_SAD_SKIP_MXN_AVX2(32, 16)
297 HIGHBD_SAD_SKIP_MXN_AVX2(32, 32)
298 HIGHBD_SAD_SKIP_MXN_AVX2(32, 64)
299 
300 HIGHBD_SAD_SKIP_MXN_AVX2(64, 32)
301 HIGHBD_SAD_SKIP_MXN_AVX2(64, 64)
302 HIGHBD_SAD_SKIP_MXN_AVX2(64, 128)
303 
304 HIGHBD_SAD_SKIP_MXN_AVX2(128, 64)
305 HIGHBD_SAD_SKIP_MXN_AVX2(128, 128)
306 
307 #if !CONFIG_REALTIME_ONLY
308 HIGHBD_SAD_SKIP_MXN_AVX2(16, 64)
309 HIGHBD_SAD_SKIP_MXN_AVX2(32, 8)
310 HIGHBD_SAD_SKIP_MXN_AVX2(64, 16)
311 #endif  // !CONFIG_REALTIME_ONLY
312 
313 #if !CONFIG_REALTIME_ONLY
aom_highbd_sad16x4_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)314 unsigned int aom_highbd_sad16x4_avg_avx2(const uint8_t *src, int src_stride,
315                                          const uint8_t *ref, int ref_stride,
316                                          const uint8_t *second_pred) {
317   __m256i sad = _mm256_setzero_si256();
318   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
319   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
320   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
321   sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
322 
323   return get_sad_from_mm256_epi32(&sad);
324 }
325 #endif  // !CONFIG_REALTIME_ONLY
326 
aom_highbd_sad16x8_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)327 unsigned int aom_highbd_sad16x8_avg_avx2(const uint8_t *src, int src_stride,
328                                          const uint8_t *ref, int ref_stride,
329                                          const uint8_t *second_pred) {
330   __m256i sad = _mm256_setzero_si256();
331   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
332   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
333   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
334 
335   sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
336 
337   // Next 4 rows
338   srcp += src_stride << 2;
339   refp += ref_stride << 2;
340   secp += 64;
341   sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
342   return get_sad_from_mm256_epi32(&sad);
343 }
344 
aom_highbd_sad16x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)345 unsigned int aom_highbd_sad16x16_avg_avx2(const uint8_t *src, int src_stride,
346                                           const uint8_t *ref, int ref_stride,
347                                           const uint8_t *second_pred) {
348   const int left_shift = 3;
349   uint32_t sum = aom_highbd_sad16x8_avg_avx2(src, src_stride, ref, ref_stride,
350                                              second_pred);
351   src += src_stride << left_shift;
352   ref += ref_stride << left_shift;
353   second_pred += 16 << left_shift;
354   sum += aom_highbd_sad16x8_avg_avx2(src, src_stride, ref, ref_stride,
355                                      second_pred);
356   return sum;
357 }
358 
aom_highbd_sad16x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)359 unsigned int aom_highbd_sad16x32_avg_avx2(const uint8_t *src, int src_stride,
360                                           const uint8_t *ref, int ref_stride,
361                                           const uint8_t *second_pred) {
362   const int left_shift = 4;
363   uint32_t sum = aom_highbd_sad16x16_avg_avx2(src, src_stride, ref, ref_stride,
364                                               second_pred);
365   src += src_stride << left_shift;
366   ref += ref_stride << left_shift;
367   second_pred += 16 << left_shift;
368   sum += aom_highbd_sad16x16_avg_avx2(src, src_stride, ref, ref_stride,
369                                       second_pred);
370   return sum;
371 }
372 
373 #if !CONFIG_REALTIME_ONLY
aom_highbd_sad16x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)374 unsigned int aom_highbd_sad16x64_avg_avx2(const uint8_t *src, int src_stride,
375                                           const uint8_t *ref, int ref_stride,
376                                           const uint8_t *second_pred) {
377   const int left_shift = 5;
378   uint32_t sum = aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride,
379                                               second_pred);
380   src += src_stride << left_shift;
381   ref += ref_stride << left_shift;
382   second_pred += 16 << left_shift;
383   sum += aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride,
384                                       second_pred);
385   return sum;
386 }
387 
aom_highbd_sad32x8_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)388 unsigned int aom_highbd_sad32x8_avg_avx2(const uint8_t *src, int src_stride,
389                                          const uint8_t *ref, int ref_stride,
390                                          const uint8_t *second_pred) {
391   __m256i sad = _mm256_setzero_si256();
392   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
393   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
394   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
395   const int left_shift = 2;
396   int row_section = 0;
397 
398   while (row_section < 2) {
399     sad32x4(srcp, src_stride, refp, ref_stride, secp, &sad);
400     srcp += src_stride << left_shift;
401     refp += ref_stride << left_shift;
402     secp += 32 << left_shift;
403     row_section += 1;
404   }
405   return get_sad_from_mm256_epi32(&sad);
406 }
407 #endif  // !CONFIG_REALTIME_ONLY
408 
aom_highbd_sad32x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)409 unsigned int aom_highbd_sad32x16_avg_avx2(const uint8_t *src, int src_stride,
410                                           const uint8_t *ref, int ref_stride,
411                                           const uint8_t *second_pred) {
412   __m256i sad = _mm256_setzero_si256();
413   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
414   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
415   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
416   const int left_shift = 2;
417   int row_section = 0;
418 
419   while (row_section < 4) {
420     sad32x4(srcp, src_stride, refp, ref_stride, secp, &sad);
421     srcp += src_stride << left_shift;
422     refp += ref_stride << left_shift;
423     secp += 32 << left_shift;
424     row_section += 1;
425   }
426   return get_sad_from_mm256_epi32(&sad);
427 }
428 
aom_highbd_sad32x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)429 unsigned int aom_highbd_sad32x32_avg_avx2(const uint8_t *src, int src_stride,
430                                           const uint8_t *ref, int ref_stride,
431                                           const uint8_t *second_pred) {
432   const int left_shift = 4;
433   uint32_t sum = aom_highbd_sad32x16_avg_avx2(src, src_stride, ref, ref_stride,
434                                               second_pred);
435   src += src_stride << left_shift;
436   ref += ref_stride << left_shift;
437   second_pred += 32 << left_shift;
438   sum += aom_highbd_sad32x16_avg_avx2(src, src_stride, ref, ref_stride,
439                                       second_pred);
440   return sum;
441 }
442 
aom_highbd_sad32x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)443 unsigned int aom_highbd_sad32x64_avg_avx2(const uint8_t *src, int src_stride,
444                                           const uint8_t *ref, int ref_stride,
445                                           const uint8_t *second_pred) {
446   const int left_shift = 5;
447   uint32_t sum = aom_highbd_sad32x32_avg_avx2(src, src_stride, ref, ref_stride,
448                                               second_pred);
449   src += src_stride << left_shift;
450   ref += ref_stride << left_shift;
451   second_pred += 32 << left_shift;
452   sum += aom_highbd_sad32x32_avg_avx2(src, src_stride, ref, ref_stride,
453                                       second_pred);
454   return sum;
455 }
456 
457 #if !CONFIG_REALTIME_ONLY
aom_highbd_sad64x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)458 unsigned int aom_highbd_sad64x16_avg_avx2(const uint8_t *src, int src_stride,
459                                           const uint8_t *ref, int ref_stride,
460                                           const uint8_t *second_pred) {
461   __m256i sad = _mm256_setzero_si256();
462   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
463   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
464   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
465   const int left_shift = 1;
466   int row_section = 0;
467 
468   while (row_section < 8) {
469     sad64x2(srcp, src_stride, refp, ref_stride, secp, &sad);
470     srcp += src_stride << left_shift;
471     refp += ref_stride << left_shift;
472     secp += 64 << left_shift;
473     row_section += 1;
474   }
475   return get_sad_from_mm256_epi32(&sad);
476 }
477 #endif  // !CONFIG_REALTIME_ONLY
478 
aom_highbd_sad64x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)479 unsigned int aom_highbd_sad64x32_avg_avx2(const uint8_t *src, int src_stride,
480                                           const uint8_t *ref, int ref_stride,
481                                           const uint8_t *second_pred) {
482   __m256i sad = _mm256_setzero_si256();
483   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
484   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
485   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
486   const int left_shift = 1;
487   int row_section = 0;
488 
489   while (row_section < 16) {
490     sad64x2(srcp, src_stride, refp, ref_stride, secp, &sad);
491     srcp += src_stride << left_shift;
492     refp += ref_stride << left_shift;
493     secp += 64 << left_shift;
494     row_section += 1;
495   }
496   return get_sad_from_mm256_epi32(&sad);
497 }
498 
aom_highbd_sad64x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)499 unsigned int aom_highbd_sad64x64_avg_avx2(const uint8_t *src, int src_stride,
500                                           const uint8_t *ref, int ref_stride,
501                                           const uint8_t *second_pred) {
502   const int left_shift = 5;
503   uint32_t sum = aom_highbd_sad64x32_avg_avx2(src, src_stride, ref, ref_stride,
504                                               second_pred);
505   src += src_stride << left_shift;
506   ref += ref_stride << left_shift;
507   second_pred += 64 << left_shift;
508   sum += aom_highbd_sad64x32_avg_avx2(src, src_stride, ref, ref_stride,
509                                       second_pred);
510   return sum;
511 }
512 
aom_highbd_sad64x128_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)513 unsigned int aom_highbd_sad64x128_avg_avx2(const uint8_t *src, int src_stride,
514                                            const uint8_t *ref, int ref_stride,
515                                            const uint8_t *second_pred) {
516   const int left_shift = 6;
517   uint32_t sum = aom_highbd_sad64x64_avg_avx2(src, src_stride, ref, ref_stride,
518                                               second_pred);
519   src += src_stride << left_shift;
520   ref += ref_stride << left_shift;
521   second_pred += 64 << left_shift;
522   sum += aom_highbd_sad64x64_avg_avx2(src, src_stride, ref, ref_stride,
523                                       second_pred);
524   return sum;
525 }
526 
aom_highbd_sad128x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)527 unsigned int aom_highbd_sad128x64_avg_avx2(const uint8_t *src, int src_stride,
528                                            const uint8_t *ref, int ref_stride,
529                                            const uint8_t *second_pred) {
530   __m256i sad = _mm256_setzero_si256();
531   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
532   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
533   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
534   int row = 0;
535   while (row < 64) {
536     sad128x1(srcp, refp, secp, &sad);
537     srcp += src_stride;
538     refp += ref_stride;
539     secp += 16 << 3;
540     row += 1;
541   }
542   return get_sad_from_mm256_epi32(&sad);
543 }
544 
aom_highbd_sad128x128_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)545 unsigned int aom_highbd_sad128x128_avg_avx2(const uint8_t *src, int src_stride,
546                                             const uint8_t *ref, int ref_stride,
547                                             const uint8_t *second_pred) {
548   unsigned int sum;
549   const int left_shift = 6;
550 
551   sum = aom_highbd_sad128x64_avg_avx2(src, src_stride, ref, ref_stride,
552                                       second_pred);
553   src += src_stride << left_shift;
554   ref += ref_stride << left_shift;
555   second_pred += 128 << left_shift;
556   sum += aom_highbd_sad128x64_avg_avx2(src, src_stride, ref, ref_stride,
557                                        second_pred);
558   return sum;
559 }
560 
561 // SAD 4D
562 // Combine 4 __m256i input vectors  v to uint32_t result[4]
get_4d_sad_from_mm256_epi32(const __m256i * v,uint32_t * res)563 static inline void get_4d_sad_from_mm256_epi32(const __m256i *v,
564                                                uint32_t *res) {
565   __m256i u0, u1, u2, u3;
566   const __m256i mask = _mm256_set1_epi64x(~0u);
567   __m128i sad;
568 
569   // 8 32-bit summation
570   u0 = _mm256_srli_si256(v[0], 4);
571   u1 = _mm256_srli_si256(v[1], 4);
572   u2 = _mm256_srli_si256(v[2], 4);
573   u3 = _mm256_srli_si256(v[3], 4);
574 
575   u0 = _mm256_add_epi32(u0, v[0]);
576   u1 = _mm256_add_epi32(u1, v[1]);
577   u2 = _mm256_add_epi32(u2, v[2]);
578   u3 = _mm256_add_epi32(u3, v[3]);
579 
580   u0 = _mm256_and_si256(u0, mask);
581   u1 = _mm256_and_si256(u1, mask);
582   u2 = _mm256_and_si256(u2, mask);
583   u3 = _mm256_and_si256(u3, mask);
584   // 4 32-bit summation, evenly positioned
585 
586   u1 = _mm256_slli_si256(u1, 4);
587   u3 = _mm256_slli_si256(u3, 4);
588 
589   u0 = _mm256_or_si256(u0, u1);
590   u2 = _mm256_or_si256(u2, u3);
591   // 8 32-bit summation, interleaved
592 
593   u1 = _mm256_unpacklo_epi64(u0, u2);
594   u3 = _mm256_unpackhi_epi64(u0, u2);
595 
596   u0 = _mm256_add_epi32(u1, u3);
597   sad = _mm_add_epi32(_mm256_extractf128_si256(u0, 1),
598                       _mm256_castsi256_si128(u0));
599   _mm_storeu_si128((__m128i *)res, sad);
600 }
601 
convert_pointers(const uint8_t * const ref8[],const uint16_t * ref[])602 static void convert_pointers(const uint8_t *const ref8[],
603                              const uint16_t *ref[]) {
604   ref[0] = CONVERT_TO_SHORTPTR(ref8[0]);
605   ref[1] = CONVERT_TO_SHORTPTR(ref8[1]);
606   ref[2] = CONVERT_TO_SHORTPTR(ref8[2]);
607   ref[3] = CONVERT_TO_SHORTPTR(ref8[3]);
608 }
609 
init_sad(__m256i * s)610 static void init_sad(__m256i *s) {
611   s[0] = _mm256_setzero_si256();
612   s[1] = _mm256_setzero_si256();
613   s[2] = _mm256_setzero_si256();
614   s[3] = _mm256_setzero_si256();
615 }
616 
aom_highbd_sadMxNxD_avx2(int M,int N,int D,const uint8_t * src,int src_stride,const uint8_t * const ref_array[4],int ref_stride,uint32_t sad_array[4])617 static AOM_FORCE_INLINE void aom_highbd_sadMxNxD_avx2(
618     int M, int N, int D, const uint8_t *src, int src_stride,
619     const uint8_t *const ref_array[4], int ref_stride, uint32_t sad_array[4]) {
620   __m256i sad_vec[4];
621   const uint16_t *refp[4];
622   const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
623   const uint16_t *srcp;
624   const int shift_for_rows = (M < 128) + (M < 64);
625   const int row_units = 1 << shift_for_rows;
626   int i, r;
627 
628   init_sad(sad_vec);
629   convert_pointers(ref_array, refp);
630 
631   for (i = 0; i < D; ++i) {
632     srcp = keep;
633     for (r = 0; r < N; r += row_units) {
634       if (M == 128) {
635         sad128x1(srcp, refp[i], NULL, &sad_vec[i]);
636       } else if (M == 64) {
637         sad64x2(srcp, src_stride, refp[i], ref_stride, NULL, &sad_vec[i]);
638       } else if (M == 32) {
639         sad32x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
640       } else if (M == 16) {
641         sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
642       } else {
643         assert(0);
644       }
645       srcp += src_stride << shift_for_rows;
646       refp[i] += ref_stride << shift_for_rows;
647     }
648   }
649   get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
650 }
651 
652 #define HIGHBD_SAD_MXNX4D_AVX2(m, n)                                          \
653   void aom_highbd_sad##m##x##n##x4d_avx2(                                     \
654       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4],  \
655       int ref_stride, uint32_t sad_array[4]) {                                \
656     aom_highbd_sadMxNxD_avx2(m, n, 4, src, src_stride, ref_array, ref_stride, \
657                              sad_array);                                      \
658   }
659 #define HIGHBD_SAD_SKIP_MXNX4D_AVX2(m, n)                                    \
660   void aom_highbd_sad_skip_##m##x##n##x4d_avx2(                              \
661       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \
662       int ref_stride, uint32_t sad_array[4]) {                               \
663     aom_highbd_sadMxNxD_avx2(m, (n / 2), 4, src, 2 * src_stride, ref_array,  \
664                              2 * ref_stride, sad_array);                     \
665     sad_array[0] <<= 1;                                                      \
666     sad_array[1] <<= 1;                                                      \
667     sad_array[2] <<= 1;                                                      \
668     sad_array[3] <<= 1;                                                      \
669   }
670 #define HIGHBD_SAD_MXNX3D_AVX2(m, n)                                          \
671   void aom_highbd_sad##m##x##n##x3d_avx2(                                     \
672       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4],  \
673       int ref_stride, uint32_t sad_array[4]) {                                \
674     aom_highbd_sadMxNxD_avx2(m, n, 3, src, src_stride, ref_array, ref_stride, \
675                              sad_array);                                      \
676   }
677 
678 HIGHBD_SAD_MXNX4D_AVX2(16, 8)
679 HIGHBD_SAD_MXNX4D_AVX2(16, 16)
680 HIGHBD_SAD_MXNX4D_AVX2(16, 32)
681 
682 HIGHBD_SAD_MXNX4D_AVX2(32, 16)
683 HIGHBD_SAD_MXNX4D_AVX2(32, 32)
684 HIGHBD_SAD_MXNX4D_AVX2(32, 64)
685 
686 HIGHBD_SAD_MXNX4D_AVX2(64, 32)
687 HIGHBD_SAD_MXNX4D_AVX2(64, 64)
688 HIGHBD_SAD_MXNX4D_AVX2(64, 128)
689 
690 HIGHBD_SAD_MXNX4D_AVX2(128, 64)
691 HIGHBD_SAD_MXNX4D_AVX2(128, 128)
692 
693 #if !CONFIG_REALTIME_ONLY
694 HIGHBD_SAD_MXNX4D_AVX2(16, 4)
695 HIGHBD_SAD_MXNX4D_AVX2(16, 64)
696 HIGHBD_SAD_MXNX4D_AVX2(32, 8)
697 HIGHBD_SAD_MXNX4D_AVX2(64, 16)
698 #endif  // !CONFIG_REALTIME_ONLY
699 
700 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 8)
701 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 16)
702 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 32)
703 
704 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 16)
705 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 32)
706 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 64)
707 
708 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 32)
709 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 64)
710 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 128)
711 
712 HIGHBD_SAD_SKIP_MXNX4D_AVX2(128, 64)
713 HIGHBD_SAD_SKIP_MXNX4D_AVX2(128, 128)
714 
715 #if !CONFIG_REALTIME_ONLY
716 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 64)
717 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 8)
718 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 16)
719 #endif  // !CONFIG_REALTIME_ONLY
720 
721 HIGHBD_SAD_MXNX3D_AVX2(16, 8)
722 HIGHBD_SAD_MXNX3D_AVX2(16, 16)
723 HIGHBD_SAD_MXNX3D_AVX2(16, 32)
724 
725 HIGHBD_SAD_MXNX3D_AVX2(32, 16)
726 HIGHBD_SAD_MXNX3D_AVX2(32, 32)
727 HIGHBD_SAD_MXNX3D_AVX2(32, 64)
728 
729 HIGHBD_SAD_MXNX3D_AVX2(64, 32)
730 HIGHBD_SAD_MXNX3D_AVX2(64, 64)
731 HIGHBD_SAD_MXNX3D_AVX2(64, 128)
732 
733 HIGHBD_SAD_MXNX3D_AVX2(128, 64)
734 HIGHBD_SAD_MXNX3D_AVX2(128, 128)
735 
736 #if !CONFIG_REALTIME_ONLY
737 HIGHBD_SAD_MXNX3D_AVX2(16, 4)
738 HIGHBD_SAD_MXNX3D_AVX2(16, 64)
739 HIGHBD_SAD_MXNX3D_AVX2(32, 8)
740 HIGHBD_SAD_MXNX3D_AVX2(64, 16)
741 #endif  // !CONFIG_REALTIME_ONLY
742