1 /*
2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <immintrin.h>
13
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/x86/synonyms_avx2.h"
19 #include "aom_ports/mem.h"
20
21 // SAD
get_sad_from_mm256_epi32(const __m256i * v)22 static inline unsigned int get_sad_from_mm256_epi32(const __m256i *v) {
23 // input 8 32-bit summation
24 __m128i lo128, hi128;
25 __m256i u = _mm256_srli_si256(*v, 8);
26 u = _mm256_add_epi32(u, *v);
27
28 // 4 32-bit summation
29 hi128 = _mm256_extracti128_si256(u, 1);
30 lo128 = _mm256_castsi256_si128(u);
31 lo128 = _mm_add_epi32(hi128, lo128);
32
33 // 2 32-bit summation
34 hi128 = _mm_srli_si128(lo128, 4);
35 lo128 = _mm_add_epi32(lo128, hi128);
36
37 return (unsigned int)_mm_cvtsi128_si32(lo128);
38 }
39
highbd_sad16x4_core_avx2(__m256i * s,__m256i * r,__m256i * sad_acc)40 static inline void highbd_sad16x4_core_avx2(__m256i *s, __m256i *r,
41 __m256i *sad_acc) {
42 const __m256i zero = _mm256_setzero_si256();
43 int i;
44 for (i = 0; i < 4; i++) {
45 s[i] = _mm256_sub_epi16(s[i], r[i]);
46 s[i] = _mm256_abs_epi16(s[i]);
47 }
48
49 s[0] = _mm256_add_epi16(s[0], s[1]);
50 s[0] = _mm256_add_epi16(s[0], s[2]);
51 s[0] = _mm256_add_epi16(s[0], s[3]);
52
53 r[0] = _mm256_unpacklo_epi16(s[0], zero);
54 r[1] = _mm256_unpackhi_epi16(s[0], zero);
55
56 r[0] = _mm256_add_epi32(r[0], r[1]);
57 *sad_acc = _mm256_add_epi32(*sad_acc, r[0]);
58 }
59
60 // If sec_ptr = 0, calculate regular SAD. Otherwise, calculate average SAD.
sad16x4(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)61 static inline void sad16x4(const uint16_t *src_ptr, int src_stride,
62 const uint16_t *ref_ptr, int ref_stride,
63 const uint16_t *sec_ptr, __m256i *sad_acc) {
64 __m256i s[4], r[4];
65 s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
66 s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
67 s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
68 s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
69
70 r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
71 r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
72 r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
73 r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
74
75 if (sec_ptr) {
76 r[0] = _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
77 r[1] = _mm256_avg_epu16(
78 r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
79 r[2] = _mm256_avg_epu16(
80 r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
81 r[3] = _mm256_avg_epu16(
82 r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
83 }
84 highbd_sad16x4_core_avx2(s, r, sad_acc);
85 }
86
aom_highbd_sad16xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)87 static AOM_FORCE_INLINE unsigned int aom_highbd_sad16xN_avx2(int N,
88 const uint8_t *src,
89 int src_stride,
90 const uint8_t *ref,
91 int ref_stride) {
92 const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
93 const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
94 int i;
95 __m256i sad = _mm256_setzero_si256();
96 for (i = 0; i < N; i += 4) {
97 sad16x4(src_ptr, src_stride, ref_ptr, ref_stride, NULL, &sad);
98 src_ptr += src_stride << 2;
99 ref_ptr += ref_stride << 2;
100 }
101 return (unsigned int)get_sad_from_mm256_epi32(&sad);
102 }
103
sad32x4(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)104 static void sad32x4(const uint16_t *src_ptr, int src_stride,
105 const uint16_t *ref_ptr, int ref_stride,
106 const uint16_t *sec_ptr, __m256i *sad_acc) {
107 __m256i s[4], r[4];
108 int row_sections = 0;
109
110 while (row_sections < 2) {
111 s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
112 s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
113 s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
114 s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 16));
115
116 r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
117 r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
118 r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
119 r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 16));
120
121 if (sec_ptr) {
122 r[0] =
123 _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
124 r[1] = _mm256_avg_epu16(
125 r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
126 r[2] = _mm256_avg_epu16(
127 r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
128 r[3] = _mm256_avg_epu16(
129 r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
130 sec_ptr += 32 << 1;
131 }
132 highbd_sad16x4_core_avx2(s, r, sad_acc);
133
134 row_sections += 1;
135 src_ptr += src_stride << 1;
136 ref_ptr += ref_stride << 1;
137 }
138 }
139
aom_highbd_sad32xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)140 static AOM_FORCE_INLINE unsigned int aom_highbd_sad32xN_avx2(int N,
141 const uint8_t *src,
142 int src_stride,
143 const uint8_t *ref,
144 int ref_stride) {
145 __m256i sad = _mm256_setzero_si256();
146 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
147 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
148 const int left_shift = 2;
149 int i;
150
151 for (i = 0; i < N; i += 4) {
152 sad32x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
153 srcp += src_stride << left_shift;
154 refp += ref_stride << left_shift;
155 }
156 return get_sad_from_mm256_epi32(&sad);
157 }
158
sad64x2(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)159 static void sad64x2(const uint16_t *src_ptr, int src_stride,
160 const uint16_t *ref_ptr, int ref_stride,
161 const uint16_t *sec_ptr, __m256i *sad_acc) {
162 __m256i s[4], r[4];
163 int i;
164 for (i = 0; i < 2; i++) {
165 s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
166 s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
167 s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
168 s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
169
170 r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
171 r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
172 r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
173 r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
174 if (sec_ptr) {
175 r[0] =
176 _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
177 r[1] = _mm256_avg_epu16(
178 r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
179 r[2] = _mm256_avg_epu16(
180 r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
181 r[3] = _mm256_avg_epu16(
182 r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
183 sec_ptr += 64;
184 }
185 highbd_sad16x4_core_avx2(s, r, sad_acc);
186 src_ptr += src_stride;
187 ref_ptr += ref_stride;
188 }
189 }
190
aom_highbd_sad64xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)191 static AOM_FORCE_INLINE unsigned int aom_highbd_sad64xN_avx2(int N,
192 const uint8_t *src,
193 int src_stride,
194 const uint8_t *ref,
195 int ref_stride) {
196 __m256i sad = _mm256_setzero_si256();
197 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
198 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
199 const int left_shift = 1;
200 int i;
201 for (i = 0; i < N; i += 2) {
202 sad64x2(srcp, src_stride, refp, ref_stride, NULL, &sad);
203 srcp += src_stride << left_shift;
204 refp += ref_stride << left_shift;
205 }
206 return get_sad_from_mm256_epi32(&sad);
207 }
208
sad128x1(const uint16_t * src_ptr,const uint16_t * ref_ptr,const uint16_t * sec_ptr,__m256i * sad_acc)209 static void sad128x1(const uint16_t *src_ptr, const uint16_t *ref_ptr,
210 const uint16_t *sec_ptr, __m256i *sad_acc) {
211 __m256i s[4], r[4];
212 int i;
213 for (i = 0; i < 2; i++) {
214 s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
215 s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
216 s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
217 s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
218 r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
219 r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
220 r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
221 r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
222 if (sec_ptr) {
223 r[0] =
224 _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
225 r[1] = _mm256_avg_epu16(
226 r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
227 r[2] = _mm256_avg_epu16(
228 r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
229 r[3] = _mm256_avg_epu16(
230 r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
231 sec_ptr += 64;
232 }
233 highbd_sad16x4_core_avx2(s, r, sad_acc);
234 src_ptr += 64;
235 ref_ptr += 64;
236 }
237 }
238
aom_highbd_sad128xN_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)239 static AOM_FORCE_INLINE unsigned int aom_highbd_sad128xN_avx2(
240 int N, const uint8_t *src, int src_stride, const uint8_t *ref,
241 int ref_stride) {
242 __m256i sad = _mm256_setzero_si256();
243 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
244 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
245 int row = 0;
246 while (row < N) {
247 sad128x1(srcp, refp, NULL, &sad);
248 srcp += src_stride;
249 refp += ref_stride;
250 row++;
251 }
252 return get_sad_from_mm256_epi32(&sad);
253 }
254
255 #define HIGHBD_SADMXN_AVX2(m, n) \
256 unsigned int aom_highbd_sad##m##x##n##_avx2( \
257 const uint8_t *src, int src_stride, const uint8_t *ref, \
258 int ref_stride) { \
259 return aom_highbd_sad##m##xN_avx2(n, src, src_stride, ref, ref_stride); \
260 }
261
262 #define HIGHBD_SAD_SKIP_MXN_AVX2(m, n) \
263 unsigned int aom_highbd_sad_skip_##m##x##n##_avx2( \
264 const uint8_t *src, int src_stride, const uint8_t *ref, \
265 int ref_stride) { \
266 return 2 * aom_highbd_sad##m##xN_avx2((n / 2), src, 2 * src_stride, ref, \
267 2 * ref_stride); \
268 }
269
270 HIGHBD_SADMXN_AVX2(16, 8)
271 HIGHBD_SADMXN_AVX2(16, 16)
272 HIGHBD_SADMXN_AVX2(16, 32)
273
274 HIGHBD_SADMXN_AVX2(32, 16)
275 HIGHBD_SADMXN_AVX2(32, 32)
276 HIGHBD_SADMXN_AVX2(32, 64)
277
278 HIGHBD_SADMXN_AVX2(64, 32)
279 HIGHBD_SADMXN_AVX2(64, 64)
280 HIGHBD_SADMXN_AVX2(64, 128)
281
282 HIGHBD_SADMXN_AVX2(128, 64)
283 HIGHBD_SADMXN_AVX2(128, 128)
284
285 #if !CONFIG_REALTIME_ONLY
286 HIGHBD_SADMXN_AVX2(16, 4)
287 HIGHBD_SADMXN_AVX2(16, 64)
288 HIGHBD_SADMXN_AVX2(32, 8)
289 HIGHBD_SADMXN_AVX2(64, 16)
290 #endif // !CONFIG_REALTIME_ONLY
291
292 HIGHBD_SAD_SKIP_MXN_AVX2(16, 8)
293 HIGHBD_SAD_SKIP_MXN_AVX2(16, 16)
294 HIGHBD_SAD_SKIP_MXN_AVX2(16, 32)
295
296 HIGHBD_SAD_SKIP_MXN_AVX2(32, 16)
297 HIGHBD_SAD_SKIP_MXN_AVX2(32, 32)
298 HIGHBD_SAD_SKIP_MXN_AVX2(32, 64)
299
300 HIGHBD_SAD_SKIP_MXN_AVX2(64, 32)
301 HIGHBD_SAD_SKIP_MXN_AVX2(64, 64)
302 HIGHBD_SAD_SKIP_MXN_AVX2(64, 128)
303
304 HIGHBD_SAD_SKIP_MXN_AVX2(128, 64)
305 HIGHBD_SAD_SKIP_MXN_AVX2(128, 128)
306
307 #if !CONFIG_REALTIME_ONLY
308 HIGHBD_SAD_SKIP_MXN_AVX2(16, 64)
309 HIGHBD_SAD_SKIP_MXN_AVX2(32, 8)
310 HIGHBD_SAD_SKIP_MXN_AVX2(64, 16)
311 #endif // !CONFIG_REALTIME_ONLY
312
313 #if !CONFIG_REALTIME_ONLY
aom_highbd_sad16x4_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)314 unsigned int aom_highbd_sad16x4_avg_avx2(const uint8_t *src, int src_stride,
315 const uint8_t *ref, int ref_stride,
316 const uint8_t *second_pred) {
317 __m256i sad = _mm256_setzero_si256();
318 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
319 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
320 uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
321 sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
322
323 return get_sad_from_mm256_epi32(&sad);
324 }
325 #endif // !CONFIG_REALTIME_ONLY
326
aom_highbd_sad16x8_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)327 unsigned int aom_highbd_sad16x8_avg_avx2(const uint8_t *src, int src_stride,
328 const uint8_t *ref, int ref_stride,
329 const uint8_t *second_pred) {
330 __m256i sad = _mm256_setzero_si256();
331 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
332 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
333 uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
334
335 sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
336
337 // Next 4 rows
338 srcp += src_stride << 2;
339 refp += ref_stride << 2;
340 secp += 64;
341 sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
342 return get_sad_from_mm256_epi32(&sad);
343 }
344
aom_highbd_sad16x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)345 unsigned int aom_highbd_sad16x16_avg_avx2(const uint8_t *src, int src_stride,
346 const uint8_t *ref, int ref_stride,
347 const uint8_t *second_pred) {
348 const int left_shift = 3;
349 uint32_t sum = aom_highbd_sad16x8_avg_avx2(src, src_stride, ref, ref_stride,
350 second_pred);
351 src += src_stride << left_shift;
352 ref += ref_stride << left_shift;
353 second_pred += 16 << left_shift;
354 sum += aom_highbd_sad16x8_avg_avx2(src, src_stride, ref, ref_stride,
355 second_pred);
356 return sum;
357 }
358
aom_highbd_sad16x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)359 unsigned int aom_highbd_sad16x32_avg_avx2(const uint8_t *src, int src_stride,
360 const uint8_t *ref, int ref_stride,
361 const uint8_t *second_pred) {
362 const int left_shift = 4;
363 uint32_t sum = aom_highbd_sad16x16_avg_avx2(src, src_stride, ref, ref_stride,
364 second_pred);
365 src += src_stride << left_shift;
366 ref += ref_stride << left_shift;
367 second_pred += 16 << left_shift;
368 sum += aom_highbd_sad16x16_avg_avx2(src, src_stride, ref, ref_stride,
369 second_pred);
370 return sum;
371 }
372
373 #if !CONFIG_REALTIME_ONLY
aom_highbd_sad16x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)374 unsigned int aom_highbd_sad16x64_avg_avx2(const uint8_t *src, int src_stride,
375 const uint8_t *ref, int ref_stride,
376 const uint8_t *second_pred) {
377 const int left_shift = 5;
378 uint32_t sum = aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride,
379 second_pred);
380 src += src_stride << left_shift;
381 ref += ref_stride << left_shift;
382 second_pred += 16 << left_shift;
383 sum += aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride,
384 second_pred);
385 return sum;
386 }
387
aom_highbd_sad32x8_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)388 unsigned int aom_highbd_sad32x8_avg_avx2(const uint8_t *src, int src_stride,
389 const uint8_t *ref, int ref_stride,
390 const uint8_t *second_pred) {
391 __m256i sad = _mm256_setzero_si256();
392 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
393 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
394 uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
395 const int left_shift = 2;
396 int row_section = 0;
397
398 while (row_section < 2) {
399 sad32x4(srcp, src_stride, refp, ref_stride, secp, &sad);
400 srcp += src_stride << left_shift;
401 refp += ref_stride << left_shift;
402 secp += 32 << left_shift;
403 row_section += 1;
404 }
405 return get_sad_from_mm256_epi32(&sad);
406 }
407 #endif // !CONFIG_REALTIME_ONLY
408
aom_highbd_sad32x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)409 unsigned int aom_highbd_sad32x16_avg_avx2(const uint8_t *src, int src_stride,
410 const uint8_t *ref, int ref_stride,
411 const uint8_t *second_pred) {
412 __m256i sad = _mm256_setzero_si256();
413 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
414 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
415 uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
416 const int left_shift = 2;
417 int row_section = 0;
418
419 while (row_section < 4) {
420 sad32x4(srcp, src_stride, refp, ref_stride, secp, &sad);
421 srcp += src_stride << left_shift;
422 refp += ref_stride << left_shift;
423 secp += 32 << left_shift;
424 row_section += 1;
425 }
426 return get_sad_from_mm256_epi32(&sad);
427 }
428
aom_highbd_sad32x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)429 unsigned int aom_highbd_sad32x32_avg_avx2(const uint8_t *src, int src_stride,
430 const uint8_t *ref, int ref_stride,
431 const uint8_t *second_pred) {
432 const int left_shift = 4;
433 uint32_t sum = aom_highbd_sad32x16_avg_avx2(src, src_stride, ref, ref_stride,
434 second_pred);
435 src += src_stride << left_shift;
436 ref += ref_stride << left_shift;
437 second_pred += 32 << left_shift;
438 sum += aom_highbd_sad32x16_avg_avx2(src, src_stride, ref, ref_stride,
439 second_pred);
440 return sum;
441 }
442
aom_highbd_sad32x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)443 unsigned int aom_highbd_sad32x64_avg_avx2(const uint8_t *src, int src_stride,
444 const uint8_t *ref, int ref_stride,
445 const uint8_t *second_pred) {
446 const int left_shift = 5;
447 uint32_t sum = aom_highbd_sad32x32_avg_avx2(src, src_stride, ref, ref_stride,
448 second_pred);
449 src += src_stride << left_shift;
450 ref += ref_stride << left_shift;
451 second_pred += 32 << left_shift;
452 sum += aom_highbd_sad32x32_avg_avx2(src, src_stride, ref, ref_stride,
453 second_pred);
454 return sum;
455 }
456
457 #if !CONFIG_REALTIME_ONLY
aom_highbd_sad64x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)458 unsigned int aom_highbd_sad64x16_avg_avx2(const uint8_t *src, int src_stride,
459 const uint8_t *ref, int ref_stride,
460 const uint8_t *second_pred) {
461 __m256i sad = _mm256_setzero_si256();
462 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
463 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
464 uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
465 const int left_shift = 1;
466 int row_section = 0;
467
468 while (row_section < 8) {
469 sad64x2(srcp, src_stride, refp, ref_stride, secp, &sad);
470 srcp += src_stride << left_shift;
471 refp += ref_stride << left_shift;
472 secp += 64 << left_shift;
473 row_section += 1;
474 }
475 return get_sad_from_mm256_epi32(&sad);
476 }
477 #endif // !CONFIG_REALTIME_ONLY
478
aom_highbd_sad64x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)479 unsigned int aom_highbd_sad64x32_avg_avx2(const uint8_t *src, int src_stride,
480 const uint8_t *ref, int ref_stride,
481 const uint8_t *second_pred) {
482 __m256i sad = _mm256_setzero_si256();
483 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
484 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
485 uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
486 const int left_shift = 1;
487 int row_section = 0;
488
489 while (row_section < 16) {
490 sad64x2(srcp, src_stride, refp, ref_stride, secp, &sad);
491 srcp += src_stride << left_shift;
492 refp += ref_stride << left_shift;
493 secp += 64 << left_shift;
494 row_section += 1;
495 }
496 return get_sad_from_mm256_epi32(&sad);
497 }
498
aom_highbd_sad64x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)499 unsigned int aom_highbd_sad64x64_avg_avx2(const uint8_t *src, int src_stride,
500 const uint8_t *ref, int ref_stride,
501 const uint8_t *second_pred) {
502 const int left_shift = 5;
503 uint32_t sum = aom_highbd_sad64x32_avg_avx2(src, src_stride, ref, ref_stride,
504 second_pred);
505 src += src_stride << left_shift;
506 ref += ref_stride << left_shift;
507 second_pred += 64 << left_shift;
508 sum += aom_highbd_sad64x32_avg_avx2(src, src_stride, ref, ref_stride,
509 second_pred);
510 return sum;
511 }
512
aom_highbd_sad64x128_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)513 unsigned int aom_highbd_sad64x128_avg_avx2(const uint8_t *src, int src_stride,
514 const uint8_t *ref, int ref_stride,
515 const uint8_t *second_pred) {
516 const int left_shift = 6;
517 uint32_t sum = aom_highbd_sad64x64_avg_avx2(src, src_stride, ref, ref_stride,
518 second_pred);
519 src += src_stride << left_shift;
520 ref += ref_stride << left_shift;
521 second_pred += 64 << left_shift;
522 sum += aom_highbd_sad64x64_avg_avx2(src, src_stride, ref, ref_stride,
523 second_pred);
524 return sum;
525 }
526
aom_highbd_sad128x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)527 unsigned int aom_highbd_sad128x64_avg_avx2(const uint8_t *src, int src_stride,
528 const uint8_t *ref, int ref_stride,
529 const uint8_t *second_pred) {
530 __m256i sad = _mm256_setzero_si256();
531 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
532 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
533 uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
534 int row = 0;
535 while (row < 64) {
536 sad128x1(srcp, refp, secp, &sad);
537 srcp += src_stride;
538 refp += ref_stride;
539 secp += 16 << 3;
540 row += 1;
541 }
542 return get_sad_from_mm256_epi32(&sad);
543 }
544
aom_highbd_sad128x128_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)545 unsigned int aom_highbd_sad128x128_avg_avx2(const uint8_t *src, int src_stride,
546 const uint8_t *ref, int ref_stride,
547 const uint8_t *second_pred) {
548 unsigned int sum;
549 const int left_shift = 6;
550
551 sum = aom_highbd_sad128x64_avg_avx2(src, src_stride, ref, ref_stride,
552 second_pred);
553 src += src_stride << left_shift;
554 ref += ref_stride << left_shift;
555 second_pred += 128 << left_shift;
556 sum += aom_highbd_sad128x64_avg_avx2(src, src_stride, ref, ref_stride,
557 second_pred);
558 return sum;
559 }
560
561 // SAD 4D
562 // Combine 4 __m256i input vectors v to uint32_t result[4]
get_4d_sad_from_mm256_epi32(const __m256i * v,uint32_t * res)563 static inline void get_4d_sad_from_mm256_epi32(const __m256i *v,
564 uint32_t *res) {
565 __m256i u0, u1, u2, u3;
566 const __m256i mask = _mm256_set1_epi64x(~0u);
567 __m128i sad;
568
569 // 8 32-bit summation
570 u0 = _mm256_srli_si256(v[0], 4);
571 u1 = _mm256_srli_si256(v[1], 4);
572 u2 = _mm256_srli_si256(v[2], 4);
573 u3 = _mm256_srli_si256(v[3], 4);
574
575 u0 = _mm256_add_epi32(u0, v[0]);
576 u1 = _mm256_add_epi32(u1, v[1]);
577 u2 = _mm256_add_epi32(u2, v[2]);
578 u3 = _mm256_add_epi32(u3, v[3]);
579
580 u0 = _mm256_and_si256(u0, mask);
581 u1 = _mm256_and_si256(u1, mask);
582 u2 = _mm256_and_si256(u2, mask);
583 u3 = _mm256_and_si256(u3, mask);
584 // 4 32-bit summation, evenly positioned
585
586 u1 = _mm256_slli_si256(u1, 4);
587 u3 = _mm256_slli_si256(u3, 4);
588
589 u0 = _mm256_or_si256(u0, u1);
590 u2 = _mm256_or_si256(u2, u3);
591 // 8 32-bit summation, interleaved
592
593 u1 = _mm256_unpacklo_epi64(u0, u2);
594 u3 = _mm256_unpackhi_epi64(u0, u2);
595
596 u0 = _mm256_add_epi32(u1, u3);
597 sad = _mm_add_epi32(_mm256_extractf128_si256(u0, 1),
598 _mm256_castsi256_si128(u0));
599 _mm_storeu_si128((__m128i *)res, sad);
600 }
601
convert_pointers(const uint8_t * const ref8[],const uint16_t * ref[])602 static void convert_pointers(const uint8_t *const ref8[],
603 const uint16_t *ref[]) {
604 ref[0] = CONVERT_TO_SHORTPTR(ref8[0]);
605 ref[1] = CONVERT_TO_SHORTPTR(ref8[1]);
606 ref[2] = CONVERT_TO_SHORTPTR(ref8[2]);
607 ref[3] = CONVERT_TO_SHORTPTR(ref8[3]);
608 }
609
init_sad(__m256i * s)610 static void init_sad(__m256i *s) {
611 s[0] = _mm256_setzero_si256();
612 s[1] = _mm256_setzero_si256();
613 s[2] = _mm256_setzero_si256();
614 s[3] = _mm256_setzero_si256();
615 }
616
aom_highbd_sadMxNxD_avx2(int M,int N,int D,const uint8_t * src,int src_stride,const uint8_t * const ref_array[4],int ref_stride,uint32_t sad_array[4])617 static AOM_FORCE_INLINE void aom_highbd_sadMxNxD_avx2(
618 int M, int N, int D, const uint8_t *src, int src_stride,
619 const uint8_t *const ref_array[4], int ref_stride, uint32_t sad_array[4]) {
620 __m256i sad_vec[4];
621 const uint16_t *refp[4];
622 const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
623 const uint16_t *srcp;
624 const int shift_for_rows = (M < 128) + (M < 64);
625 const int row_units = 1 << shift_for_rows;
626 int i, r;
627
628 init_sad(sad_vec);
629 convert_pointers(ref_array, refp);
630
631 for (i = 0; i < D; ++i) {
632 srcp = keep;
633 for (r = 0; r < N; r += row_units) {
634 if (M == 128) {
635 sad128x1(srcp, refp[i], NULL, &sad_vec[i]);
636 } else if (M == 64) {
637 sad64x2(srcp, src_stride, refp[i], ref_stride, NULL, &sad_vec[i]);
638 } else if (M == 32) {
639 sad32x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
640 } else if (M == 16) {
641 sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
642 } else {
643 assert(0);
644 }
645 srcp += src_stride << shift_for_rows;
646 refp[i] += ref_stride << shift_for_rows;
647 }
648 }
649 get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
650 }
651
652 #define HIGHBD_SAD_MXNX4D_AVX2(m, n) \
653 void aom_highbd_sad##m##x##n##x4d_avx2( \
654 const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \
655 int ref_stride, uint32_t sad_array[4]) { \
656 aom_highbd_sadMxNxD_avx2(m, n, 4, src, src_stride, ref_array, ref_stride, \
657 sad_array); \
658 }
659 #define HIGHBD_SAD_SKIP_MXNX4D_AVX2(m, n) \
660 void aom_highbd_sad_skip_##m##x##n##x4d_avx2( \
661 const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \
662 int ref_stride, uint32_t sad_array[4]) { \
663 aom_highbd_sadMxNxD_avx2(m, (n / 2), 4, src, 2 * src_stride, ref_array, \
664 2 * ref_stride, sad_array); \
665 sad_array[0] <<= 1; \
666 sad_array[1] <<= 1; \
667 sad_array[2] <<= 1; \
668 sad_array[3] <<= 1; \
669 }
670 #define HIGHBD_SAD_MXNX3D_AVX2(m, n) \
671 void aom_highbd_sad##m##x##n##x3d_avx2( \
672 const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \
673 int ref_stride, uint32_t sad_array[4]) { \
674 aom_highbd_sadMxNxD_avx2(m, n, 3, src, src_stride, ref_array, ref_stride, \
675 sad_array); \
676 }
677
678 HIGHBD_SAD_MXNX4D_AVX2(16, 8)
679 HIGHBD_SAD_MXNX4D_AVX2(16, 16)
680 HIGHBD_SAD_MXNX4D_AVX2(16, 32)
681
682 HIGHBD_SAD_MXNX4D_AVX2(32, 16)
683 HIGHBD_SAD_MXNX4D_AVX2(32, 32)
684 HIGHBD_SAD_MXNX4D_AVX2(32, 64)
685
686 HIGHBD_SAD_MXNX4D_AVX2(64, 32)
687 HIGHBD_SAD_MXNX4D_AVX2(64, 64)
688 HIGHBD_SAD_MXNX4D_AVX2(64, 128)
689
690 HIGHBD_SAD_MXNX4D_AVX2(128, 64)
691 HIGHBD_SAD_MXNX4D_AVX2(128, 128)
692
693 #if !CONFIG_REALTIME_ONLY
694 HIGHBD_SAD_MXNX4D_AVX2(16, 4)
695 HIGHBD_SAD_MXNX4D_AVX2(16, 64)
696 HIGHBD_SAD_MXNX4D_AVX2(32, 8)
697 HIGHBD_SAD_MXNX4D_AVX2(64, 16)
698 #endif // !CONFIG_REALTIME_ONLY
699
700 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 8)
701 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 16)
702 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 32)
703
704 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 16)
705 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 32)
706 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 64)
707
708 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 32)
709 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 64)
710 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 128)
711
712 HIGHBD_SAD_SKIP_MXNX4D_AVX2(128, 64)
713 HIGHBD_SAD_SKIP_MXNX4D_AVX2(128, 128)
714
715 #if !CONFIG_REALTIME_ONLY
716 HIGHBD_SAD_SKIP_MXNX4D_AVX2(16, 64)
717 HIGHBD_SAD_SKIP_MXNX4D_AVX2(32, 8)
718 HIGHBD_SAD_SKIP_MXNX4D_AVX2(64, 16)
719 #endif // !CONFIG_REALTIME_ONLY
720
721 HIGHBD_SAD_MXNX3D_AVX2(16, 8)
722 HIGHBD_SAD_MXNX3D_AVX2(16, 16)
723 HIGHBD_SAD_MXNX3D_AVX2(16, 32)
724
725 HIGHBD_SAD_MXNX3D_AVX2(32, 16)
726 HIGHBD_SAD_MXNX3D_AVX2(32, 32)
727 HIGHBD_SAD_MXNX3D_AVX2(32, 64)
728
729 HIGHBD_SAD_MXNX3D_AVX2(64, 32)
730 HIGHBD_SAD_MXNX3D_AVX2(64, 64)
731 HIGHBD_SAD_MXNX3D_AVX2(64, 128)
732
733 HIGHBD_SAD_MXNX3D_AVX2(128, 64)
734 HIGHBD_SAD_MXNX3D_AVX2(128, 128)
735
736 #if !CONFIG_REALTIME_ONLY
737 HIGHBD_SAD_MXNX3D_AVX2(16, 4)
738 HIGHBD_SAD_MXNX3D_AVX2(16, 64)
739 HIGHBD_SAD_MXNX3D_AVX2(32, 8)
740 HIGHBD_SAD_MXNX3D_AVX2(64, 16)
741 #endif // !CONFIG_REALTIME_ONLY
742