xref: /aosp_15_r20/external/libaom/aom_dsp/x86/sad4d_avx2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 #include <immintrin.h>  // AVX2
12 
13 #include "config/aom_dsp_rtcd.h"
14 
15 #include "aom/aom_integer.h"
16 #include "aom_dsp/x86/synonyms_avx2.h"
17 
aggregate_and_store_sum(uint32_t res[4],const __m256i * sum_ref0,const __m256i * sum_ref1,const __m256i * sum_ref2,const __m256i * sum_ref3)18 static AOM_FORCE_INLINE void aggregate_and_store_sum(uint32_t res[4],
19                                                      const __m256i *sum_ref0,
20                                                      const __m256i *sum_ref1,
21                                                      const __m256i *sum_ref2,
22                                                      const __m256i *sum_ref3) {
23   // In sum_ref-i the result is saved in the first 4 bytes and the other 4
24   // bytes are zeroed.
25   // merge sum_ref0 and sum_ref1 also sum_ref2 and sum_ref3
26   // 0, 0, 1, 1
27   __m256i sum_ref01 = _mm256_castps_si256(_mm256_shuffle_ps(
28       _mm256_castsi256_ps(*sum_ref0), _mm256_castsi256_ps(*sum_ref1),
29       _MM_SHUFFLE(2, 0, 2, 0)));
30   // 2, 2, 3, 3
31   __m256i sum_ref23 = _mm256_castps_si256(_mm256_shuffle_ps(
32       _mm256_castsi256_ps(*sum_ref2), _mm256_castsi256_ps(*sum_ref3),
33       _MM_SHUFFLE(2, 0, 2, 0)));
34 
35   // sum adjacent 32 bit integers
36   __m256i sum_ref0123 = _mm256_hadd_epi32(sum_ref01, sum_ref23);
37 
38   // add the low 128 bit to the high 128 bit
39   __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(sum_ref0123),
40                               _mm256_extractf128_si256(sum_ref0123, 1));
41 
42   _mm_storeu_si128((__m128i *)(res), sum);
43 }
44 
aom_sadMxNx4d_avx2(int M,int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])45 static AOM_FORCE_INLINE void aom_sadMxNx4d_avx2(
46     int M, int N, const uint8_t *src, int src_stride,
47     const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) {
48   __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg;
49   __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3;
50   int i, j;
51   const uint8_t *ref0, *ref1, *ref2, *ref3;
52 
53   ref0 = ref[0];
54   ref1 = ref[1];
55   ref2 = ref[2];
56   ref3 = ref[3];
57   sum_ref0 = _mm256_setzero_si256();
58   sum_ref2 = _mm256_setzero_si256();
59   sum_ref1 = _mm256_setzero_si256();
60   sum_ref3 = _mm256_setzero_si256();
61 
62   for (i = 0; i < N; i++) {
63     for (j = 0; j < M; j += 32) {
64       // load src and all refs
65       src_reg = _mm256_loadu_si256((const __m256i *)(src + j));
66       ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j));
67       ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j));
68       ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j));
69       ref3_reg = _mm256_loadu_si256((const __m256i *)(ref3 + j));
70 
71       // sum of the absolute differences between every ref-i to src
72       ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
73       ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
74       ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
75       ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg);
76       // sum every ref-i
77       sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
78       sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
79       sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
80       sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg);
81     }
82     src += src_stride;
83     ref0 += ref_stride;
84     ref1 += ref_stride;
85     ref2 += ref_stride;
86     ref3 += ref_stride;
87   }
88 
89   aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3);
90 }
91 
aom_sadMxNx3d_avx2(int M,int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])92 static AOM_FORCE_INLINE void aom_sadMxNx3d_avx2(
93     int M, int N, const uint8_t *src, int src_stride,
94     const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) {
95   __m256i src_reg, ref0_reg, ref1_reg, ref2_reg;
96   __m256i sum_ref0, sum_ref1, sum_ref2;
97   int i, j;
98   const uint8_t *ref0, *ref1, *ref2;
99   const __m256i zero = _mm256_setzero_si256();
100 
101   ref0 = ref[0];
102   ref1 = ref[1];
103   ref2 = ref[2];
104   sum_ref0 = _mm256_setzero_si256();
105   sum_ref2 = _mm256_setzero_si256();
106   sum_ref1 = _mm256_setzero_si256();
107 
108   for (i = 0; i < N; i++) {
109     for (j = 0; j < M; j += 32) {
110       // load src and all refs
111       src_reg = _mm256_loadu_si256((const __m256i *)(src + j));
112       ref0_reg = _mm256_loadu_si256((const __m256i *)(ref0 + j));
113       ref1_reg = _mm256_loadu_si256((const __m256i *)(ref1 + j));
114       ref2_reg = _mm256_loadu_si256((const __m256i *)(ref2 + j));
115 
116       // sum of the absolute differences between every ref-i to src
117       ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
118       ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
119       ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
120       // sum every ref-i
121       sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
122       sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
123       sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
124     }
125     src += src_stride;
126     ref0 += ref_stride;
127     ref1 += ref_stride;
128     ref2 += ref_stride;
129   }
130   aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero);
131 }
132 
133 #define SADMXN_AVX2(m, n)                                                      \
134   void aom_sad##m##x##n##x4d_avx2(const uint8_t *src, int src_stride,          \
135                                   const uint8_t *const ref[4], int ref_stride, \
136                                   uint32_t res[4]) {                           \
137     aom_sadMxNx4d_avx2(m, n, src, src_stride, ref, ref_stride, res);           \
138   }                                                                            \
139   void aom_sad##m##x##n##x3d_avx2(const uint8_t *src, int src_stride,          \
140                                   const uint8_t *const ref[4], int ref_stride, \
141                                   uint32_t res[4]) {                           \
142     aom_sadMxNx3d_avx2(m, n, src, src_stride, ref, ref_stride, res);           \
143   }
144 
145 SADMXN_AVX2(32, 16)
146 SADMXN_AVX2(32, 32)
147 SADMXN_AVX2(32, 64)
148 
149 SADMXN_AVX2(64, 32)
150 SADMXN_AVX2(64, 64)
151 SADMXN_AVX2(64, 128)
152 
153 SADMXN_AVX2(128, 64)
154 SADMXN_AVX2(128, 128)
155 
156 #if !CONFIG_REALTIME_ONLY
157 SADMXN_AVX2(32, 8)
158 SADMXN_AVX2(64, 16)
159 #endif  // !CONFIG_REALTIME_ONLY
160 
161 #define SAD_SKIP_MXN_AVX2(m, n)                                             \
162   void aom_sad_skip_##m##x##n##x4d_avx2(const uint8_t *src, int src_stride, \
163                                         const uint8_t *const ref[4],        \
164                                         int ref_stride, uint32_t res[4]) {  \
165     aom_sadMxNx4d_avx2(m, ((n) >> 1), src, 2 * src_stride, ref,             \
166                        2 * ref_stride, res);                                \
167     res[0] <<= 1;                                                           \
168     res[1] <<= 1;                                                           \
169     res[2] <<= 1;                                                           \
170     res[3] <<= 1;                                                           \
171   }
172 
173 SAD_SKIP_MXN_AVX2(32, 16)
174 SAD_SKIP_MXN_AVX2(32, 32)
175 SAD_SKIP_MXN_AVX2(32, 64)
176 
177 SAD_SKIP_MXN_AVX2(64, 32)
178 SAD_SKIP_MXN_AVX2(64, 64)
179 SAD_SKIP_MXN_AVX2(64, 128)
180 
181 SAD_SKIP_MXN_AVX2(128, 64)
182 SAD_SKIP_MXN_AVX2(128, 128)
183 
184 #if !CONFIG_REALTIME_ONLY
185 SAD_SKIP_MXN_AVX2(32, 8)
186 SAD_SKIP_MXN_AVX2(64, 16)
187 #endif  // !CONFIG_REALTIME_ONLY
188 
aom_sad16xNx3d_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])189 static AOM_FORCE_INLINE void aom_sad16xNx3d_avx2(int N, const uint8_t *src,
190                                                  int src_stride,
191                                                  const uint8_t *const ref[4],
192                                                  int ref_stride,
193                                                  uint32_t res[4]) {
194   __m256i src_reg, ref0_reg, ref1_reg, ref2_reg;
195   __m256i sum_ref0, sum_ref1, sum_ref2;
196   const uint8_t *ref0, *ref1, *ref2;
197   const __m256i zero = _mm256_setzero_si256();
198   assert(N % 2 == 0);
199 
200   ref0 = ref[0];
201   ref1 = ref[1];
202   ref2 = ref[2];
203   sum_ref0 = _mm256_setzero_si256();
204   sum_ref2 = _mm256_setzero_si256();
205   sum_ref1 = _mm256_setzero_si256();
206 
207   for (int i = 0; i < N; i += 2) {
208     // load src and all refs
209     src_reg = yy_loadu2_128(src + src_stride, src);
210     ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0);
211     ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1);
212     ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2);
213 
214     // sum of the absolute differences between every ref-i to src
215     ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
216     ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
217     ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
218 
219     // sum every ref-i
220     sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
221     sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
222     sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
223 
224     src += 2 * src_stride;
225     ref0 += 2 * ref_stride;
226     ref1 += 2 * ref_stride;
227     ref2 += 2 * ref_stride;
228   }
229 
230   aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &zero);
231 }
232 
aom_sad16xNx4d_avx2(int N,const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4])233 static AOM_FORCE_INLINE void aom_sad16xNx4d_avx2(int N, const uint8_t *src,
234                                                  int src_stride,
235                                                  const uint8_t *const ref[4],
236                                                  int ref_stride,
237                                                  uint32_t res[4]) {
238   __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg;
239   __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3;
240   const uint8_t *ref0, *ref1, *ref2, *ref3;
241   assert(N % 2 == 0);
242 
243   ref0 = ref[0];
244   ref1 = ref[1];
245   ref2 = ref[2];
246   ref3 = ref[3];
247 
248   sum_ref0 = _mm256_setzero_si256();
249   sum_ref2 = _mm256_setzero_si256();
250   sum_ref1 = _mm256_setzero_si256();
251   sum_ref3 = _mm256_setzero_si256();
252 
253   for (int i = 0; i < N; i += 2) {
254     // load src and all refs
255     src_reg = yy_loadu2_128(src + src_stride, src);
256     ref0_reg = yy_loadu2_128(ref0 + ref_stride, ref0);
257     ref1_reg = yy_loadu2_128(ref1 + ref_stride, ref1);
258     ref2_reg = yy_loadu2_128(ref2 + ref_stride, ref2);
259     ref3_reg = yy_loadu2_128(ref3 + ref_stride, ref3);
260 
261     // sum of the absolute differences between every ref-i to src
262     ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg);
263     ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg);
264     ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg);
265     ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg);
266 
267     // sum every ref-i
268     sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg);
269     sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg);
270     sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg);
271     sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg);
272 
273     src += 2 * src_stride;
274     ref0 += 2 * ref_stride;
275     ref1 += 2 * ref_stride;
276     ref2 += 2 * ref_stride;
277     ref3 += 2 * ref_stride;
278   }
279 
280   aggregate_and_store_sum(res, &sum_ref0, &sum_ref1, &sum_ref2, &sum_ref3);
281 }
282 
283 #define SAD16XNX3_AVX2(n)                                                   \
284   void aom_sad16x##n##x3d_avx2(const uint8_t *src, int src_stride,          \
285                                const uint8_t *const ref[4], int ref_stride, \
286                                uint32_t res[4]) {                           \
287     aom_sad16xNx3d_avx2(n, src, src_stride, ref, ref_stride, res);          \
288   }
289 #define SAD16XNX4_AVX2(n)                                                   \
290   void aom_sad16x##n##x4d_avx2(const uint8_t *src, int src_stride,          \
291                                const uint8_t *const ref[4], int ref_stride, \
292                                uint32_t res[4]) {                           \
293     aom_sad16xNx4d_avx2(n, src, src_stride, ref, ref_stride, res);          \
294   }
295 
296 SAD16XNX4_AVX2(32)
297 SAD16XNX4_AVX2(16)
298 SAD16XNX4_AVX2(8)
299 
300 SAD16XNX3_AVX2(32)
301 SAD16XNX3_AVX2(16)
302 SAD16XNX3_AVX2(8)
303 
304 #if !CONFIG_REALTIME_ONLY
305 SAD16XNX3_AVX2(64)
306 SAD16XNX3_AVX2(4)
307 
308 SAD16XNX4_AVX2(64)
309 SAD16XNX4_AVX2(4)
310 
311 #endif  // !CONFIG_REALTIME_ONLY
312 
313 #define SAD_SKIP_16XN_AVX2(n)                                                 \
314   void aom_sad_skip_16x##n##x4d_avx2(const uint8_t *src, int src_stride,      \
315                                      const uint8_t *const ref[4],             \
316                                      int ref_stride, uint32_t res[4]) {       \
317     aom_sad16xNx4d_avx2(((n) >> 1), src, 2 * src_stride, ref, 2 * ref_stride, \
318                         res);                                                 \
319     res[0] <<= 1;                                                             \
320     res[1] <<= 1;                                                             \
321     res[2] <<= 1;                                                             \
322     res[3] <<= 1;                                                             \
323   }
324 
325 SAD_SKIP_16XN_AVX2(32)
326 SAD_SKIP_16XN_AVX2(16)
327 SAD_SKIP_16XN_AVX2(8)
328 
329 #if !CONFIG_REALTIME_ONLY
330 SAD_SKIP_16XN_AVX2(64)
331 SAD_SKIP_16XN_AVX2(4)
332 #endif  // !CONFIG_REALTIME_ONLY
333