1 /*
2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11 #include <immintrin.h> // AVX2
12
13 #include "config/aom_dsp_rtcd.h"
14
15 #include "aom/aom_integer.h"
16 #include "aom_dsp/x86/synonyms_avx2.h"
17
aggregate_and_store_sum(uint32_t res[4],const __m256i * sum_ref0,const __m256i * sum_ref1,const __m256i * sum_ref2,const __m256i * sum_ref3)18 static AOM_FORCE_INLINE void aggregate_and_store_sum(uint32_t res[4],
19 const __m256i *sum_ref0,
20 const __m256i *sum_ref1,
21 const __m256i *sum_ref2,
22 const __m256i *sum_ref3) {
23 // In sum_ref-i the result is saved in the first 4 bytes and the other 4
24 // bytes are zeroed.
25 // merge sum_ref0 and sum_ref1 also sum_ref2 and sum_ref3
26 // 0, 0, 1, 1
27 __m256i sum_ref01 = _mm256_castps_si256(_mm256_shuffle_ps(
28 _mm256_castsi256_ps(*sum_ref0), _mm256_castsi256_ps(*sum_ref1),
29 _MM_SHUFFLE(2, 0, 2, 0)));
30 // 2, 2, 3, 3
31 __m256i sum_ref23 = _mm256_castps_si256(_mm256_shuffle_ps(
32 _mm256_castsi256_ps(*sum_ref2), _mm256_castsi256_ps(*sum_ref3),
33 _MM_SHUFFLE(2, 0, 2, 0)));
34
35 // sum adjacent 32 bit integers
36 __m256i sum_ref0123 = _mm256_hadd_epi32(sum_ref01, sum_ref23);
37
38 // add the low 128 bit to the high 128 bit
39 __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(sum_ref0123),
40 _mm256_extractf128_si256(sum_ref0123, 1));
41
42 _mm_storeu_si128((__m128i *)(res), sum);
43 }
44
aom_sadMxNx4d_avx2(int M,int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])45 static AOM_FORCE_INLINE void aom_sadMxNx4d_avx2(
46 int M, int N, const uint8_t *src, int src_stride,
47 const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) {
48 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg;
49 __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3;
50 int i, j;
51 const uint8_t *ref0, *ref1, *ref2, *ref3;
52
53 ref0 = ref[0];
54 ref1 = ref[1];
55 ref2 = ref[2];
56 ref3 = ref[3];
57 sum_ref0 = _mm256_setzero_si256();
58 sum_ref2 = _mm256_setzero_si256();
59 sum_ref1 = _mm256_setzero_si256();
60 sum_ref3 = _mm256_setzero_si256();
61
62 for (i = 0; i < N; i++) {
63 for (j = 0; j < M; j += 32) {
64 // load src and all refs
65 src_reg = _mm256_loadu_si256((const __m256i *)(src + j));
66 ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j));
67 ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j));
68 ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j));
69 ref3_reg = _mm256_loadu_si256((const __m256i *)(ref3 + j));
70
71 // sum of the absolute differences between every ref-i to src
72 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
73 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
74 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
75 ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg);
76 // sum every ref-i
77 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
78 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
79 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
80 sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg);
81 }
82 src += src_stride;
83 ref0 += ref_stride;
84 ref1 += ref_stride;
85 ref2 += ref_stride;
86 ref3 += ref_stride;
87 }
88
89 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3);
90 }
91
aom_sadMxNx3d_avx2(int M,int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])92 static AOM_FORCE_INLINE void aom_sadMxNx3d_avx2(
93 int M, int N, const uint8_t *src, int src_stride,
94 const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) {
95 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg;
96 __m256i sum_ref0, sum_ref1, sum_ref2;
97 int i, j;
98 const uint8_t *ref0, *ref1, *ref2;
99 const __m256i zero = _mm256_setzero_si256();
100
101 ref0 = ref[0];
102 ref1 = ref[1];
103 ref2 = ref[2];
104 sum_ref0 = _mm256_setzero_si256();
105 sum_ref2 = _mm256_setzero_si256();
106 sum_ref1 = _mm256_setzero_si256();
107
108 for (i = 0; i < N; i++) {
109 for (j = 0; j < M; j += 32) {
110 // load src and all refs
111 src_reg = _mm256_loadu_si256((const __m256i *)(src + j));
112 ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j));
113 ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j));
114 ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j));
115
116 // sum of the absolute differences between every ref-i to src
117 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
118 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
119 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
120 // sum every ref-i
121 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
122 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
123 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
124 }
125 src += src_stride;
126 ref0 += ref_stride;
127 ref1 += ref_stride;
128 ref2 += ref_stride;
129 }
130 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero);
131 }
132
133 #define SADMXN_AVX2(m, n) \
134 void aom_sad##m##x##n##x4d_avx2(const uint8_t *src, int src_stride, \
135 const uint8_t *const ref[4], int ref_stride, \
136 uint32_t res[4]) { \
137 aom_sadMxNx4d_avx2(m, n, src, src_stride, ref, ref_stride, res); \
138 } \
139 void aom_sad##m##x##n##x3d_avx2(const uint8_t *src, int src_stride, \
140 const uint8_t *const ref[4], int ref_stride, \
141 uint32_t res[4]) { \
142 aom_sadMxNx3d_avx2(m, n, src, src_stride, ref, ref_stride, res); \
143 }
144
145 SADMXN_AVX2(32, 16)
146 SADMXN_AVX2(32, 32)
147 SADMXN_AVX2(32, 64)
148
149 SADMXN_AVX2(64, 32)
150 SADMXN_AVX2(64, 64)
151 SADMXN_AVX2(64, 128)
152
153 SADMXN_AVX2(128, 64)
154 SADMXN_AVX2(128, 128)
155
156 #if !CONFIG_REALTIME_ONLY
157 SADMXN_AVX2(32, 8)
158 SADMXN_AVX2(64, 16)
159 #endif // !CONFIG_REALTIME_ONLY
160
161 #define SAD_SKIP_MXN_AVX2(m, n) \
162 void aom_sad_skip_##m##x##n##x4d_avx2(const uint8_t *src, int src_stride, \
163 const uint8_t *const ref[4], \
164 int ref_stride, uint32_t res[4]) { \
165 aom_sadMxNx4d_avx2(m, ((n) >> 1), src, 2 * src_stride, ref, \
166 2 * ref_stride, res); \
167 res[0] <<= 1; \
168 res[1] <<= 1; \
169 res[2] <<= 1; \
170 res[3] <<= 1; \
171 }
172
173 SAD_SKIP_MXN_AVX2(32, 16)
174 SAD_SKIP_MXN_AVX2(32, 32)
175 SAD_SKIP_MXN_AVX2(32, 64)
176
177 SAD_SKIP_MXN_AVX2(64, 32)
178 SAD_SKIP_MXN_AVX2(64, 64)
179 SAD_SKIP_MXN_AVX2(64, 128)
180
181 SAD_SKIP_MXN_AVX2(128, 64)
182 SAD_SKIP_MXN_AVX2(128, 128)
183
184 #if !CONFIG_REALTIME_ONLY
185 SAD_SKIP_MXN_AVX2(32, 8)
186 SAD_SKIP_MXN_AVX2(64, 16)
187 #endif // !CONFIG_REALTIME_ONLY
188
aom_sad16xNx3d_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])189 static AOM_FORCE_INLINE void aom_sad16xNx3d_avx2(int N, const uint8_t *src,
190 int src_stride,
191 const uint8_t *const ref[4],
192 int ref_stride,
193 uint32_t res[4]) {
194 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg;
195 __m256i sum_ref0, sum_ref1, sum_ref2;
196 const uint8_t *ref0, *ref1, *ref2;
197 const __m256i zero = _mm256_setzero_si256();
198 assert(N % 2 == 0);
199
200 ref0 = ref[0];
201 ref1 = ref[1];
202 ref2 = ref[2];
203 sum_ref0 = _mm256_setzero_si256();
204 sum_ref2 = _mm256_setzero_si256();
205 sum_ref1 = _mm256_setzero_si256();
206
207 for (int i = 0; i < N; i += 2) {
208 // load src and all refs
209 src_reg = yy_loadu2_128(src + src_stride, src);
210 ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0);
211 ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1);
212 ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2);
213
214 // sum of the absolute differences between every ref-i to src
215 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
216 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
217 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
218
219 // sum every ref-i
220 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
221 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
222 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
223
224 src += 2 * src_stride;
225 ref0 += 2 * ref_stride;
226 ref1 += 2 * ref_stride;
227 ref2 += 2 * ref_stride;
228 }
229
230 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero);
231 }
232
aom_sad16xNx4d_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])233 static AOM_FORCE_INLINE void aom_sad16xNx4d_avx2(int N, const uint8_t *src,
234 int src_stride,
235 const uint8_t *const ref[4],
236 int ref_stride,
237 uint32_t res[4]) {
238 __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg;
239 __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3;
240 const uint8_t *ref0, *ref1, *ref2, *ref3;
241 assert(N % 2 == 0);
242
243 ref0 = ref[0];
244 ref1 = ref[1];
245 ref2 = ref[2];
246 ref3 = ref[3];
247
248 sum_ref0 = _mm256_setzero_si256();
249 sum_ref2 = _mm256_setzero_si256();
250 sum_ref1 = _mm256_setzero_si256();
251 sum_ref3 = _mm256_setzero_si256();
252
253 for (int i = 0; i < N; i += 2) {
254 // load src and all refs
255 src_reg = yy_loadu2_128(src + src_stride, src);
256 ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0);
257 ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1);
258 ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2);
259 ref3_reg = yy_loadu2_128(ref3 + ref_stride, ref3);
260
261 // sum of the absolute differences between every ref-i to src
262 ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
263 ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
264 ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
265 ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg);
266
267 // sum every ref-i
268 sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
269 sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
270 sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
271 sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg);
272
273 src += 2 * src_stride;
274 ref0 += 2 * ref_stride;
275 ref1 += 2 * ref_stride;
276 ref2 += 2 * ref_stride;
277 ref3 += 2 * ref_stride;
278 }
279
280 aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3);
281 }
282
283 #define SAD16XNX3_AVX2(n) \
284 void aom_sad16x##n##x3d_avx2(const uint8_t *src, int src_stride, \
285 const uint8_t *const ref[4], int ref_stride, \
286 uint32_t res[4]) { \
287 aom_sad16xNx3d_avx2(n, src, src_stride, ref, ref_stride, res); \
288 }
289 #define SAD16XNX4_AVX2(n) \
290 void aom_sad16x##n##x4d_avx2(const uint8_t *src, int src_stride, \
291 const uint8_t *const ref[4], int ref_stride, \
292 uint32_t res[4]) { \
293 aom_sad16xNx4d_avx2(n, src, src_stride, ref, ref_stride, res); \
294 }
295
296 SAD16XNX4_AVX2(32)
297 SAD16XNX4_AVX2(16)
298 SAD16XNX4_AVX2(8)
299
300 SAD16XNX3_AVX2(32)
301 SAD16XNX3_AVX2(16)
302 SAD16XNX3_AVX2(8)
303
304 #if !CONFIG_REALTIME_ONLY
305 SAD16XNX3_AVX2(64)
306 SAD16XNX3_AVX2(4)
307
308 SAD16XNX4_AVX2(64)
309 SAD16XNX4_AVX2(4)
310
311 #endif // !CONFIG_REALTIME_ONLY
312
313 #define SAD_SKIP_16XN_AVX2(n) \
314 void aom_sad_skip_16x##n##x4d_avx2(const uint8_t *src, int src_stride, \
315 const uint8_t *const ref[4], \
316 int ref_stride, uint32_t res[4]) { \
317 aom_sad16xNx4d_avx2(((n) >> 1), src, 2 * src_stride, ref, 2 * ref_stride, \
318 res); \
319 res[0] <<= 1; \
320 res[1] <<= 1; \
321 res[2] <<= 1; \
322 res[3] <<= 1; \
323 }
324
325 SAD_SKIP_16XN_AVX2(32)
326 SAD_SKIP_16XN_AVX2(16)
327 SAD_SKIP_16XN_AVX2(8)
328
329 #if !CONFIG_REALTIME_ONLY
330 SAD_SKIP_16XN_AVX2(64)
331 SAD_SKIP_16XN_AVX2(4)
332 #endif // !CONFIG_REALTIME_ONLY
333