xref: /aosp_15_r20/external/libaom/aom_dsp/x86/masked_sad_intrin_avx2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom_dsp/blend.h"
18 #include "aom/aom_integer.h"
19 #include "aom_dsp/x86/synonyms.h"
20 #include "aom_dsp/x86/synonyms_avx2.h"
21 #include "aom_dsp/x86/masked_sad_intrin_ssse3.h"
22 
masked_sad32xh_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height)23 static inline unsigned int masked_sad32xh_avx2(
24     const uint8_t *src_ptr, int src_stride, const uint8_t *a_ptr, int a_stride,
25     const uint8_t *b_ptr, int b_stride, const uint8_t *m_ptr, int m_stride,
26     int width, int height) {
27   int x, y;
28   __m256i res = _mm256_setzero_si256();
29   const __m256i mask_max = _mm256_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
30   const __m256i round_scale =
31       _mm256_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
32   for (y = 0; y < height; y++) {
33     for (x = 0; x < width; x += 32) {
34       const __m256i src = _mm256_lddqu_si256((const __m256i *)&src_ptr[x]);
35       const __m256i a = _mm256_lddqu_si256((const __m256i *)&a_ptr[x]);
36       const __m256i b = _mm256_lddqu_si256((const __m256i *)&b_ptr[x]);
37       const __m256i m = _mm256_lddqu_si256((const __m256i *)&m_ptr[x]);
38       const __m256i m_inv = _mm256_sub_epi8(mask_max, m);
39 
40       // Calculate 16 predicted pixels.
41       // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
42       // is 64 * 255, so we have plenty of space to add rounding constants.
43       const __m256i data_l = _mm256_unpacklo_epi8(a, b);
44       const __m256i mask_l = _mm256_unpacklo_epi8(m, m_inv);
45       __m256i pred_l = _mm256_maddubs_epi16(data_l, mask_l);
46       pred_l = _mm256_mulhrs_epi16(pred_l, round_scale);
47 
48       const __m256i data_r = _mm256_unpackhi_epi8(a, b);
49       const __m256i mask_r = _mm256_unpackhi_epi8(m, m_inv);
50       __m256i pred_r = _mm256_maddubs_epi16(data_r, mask_r);
51       pred_r = _mm256_mulhrs_epi16(pred_r, round_scale);
52 
53       const __m256i pred = _mm256_packus_epi16(pred_l, pred_r);
54       res = _mm256_add_epi32(res, _mm256_sad_epu8(pred, src));
55     }
56 
57     src_ptr += src_stride;
58     a_ptr += a_stride;
59     b_ptr += b_stride;
60     m_ptr += m_stride;
61   }
62   // At this point, we have two 32-bit partial SADs in lanes 0 and 2 of 'res'.
63   res = _mm256_shuffle_epi32(res, 0xd8);
64   res = _mm256_permute4x64_epi64(res, 0xd8);
65   res = _mm256_hadd_epi32(res, res);
66   res = _mm256_hadd_epi32(res, res);
67   int32_t sad = _mm256_extract_epi32(res, 0);
68   return sad;
69 }
70 
masked_sad16xh_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int height)71 static inline unsigned int masked_sad16xh_avx2(
72     const uint8_t *src_ptr, int src_stride, const uint8_t *a_ptr, int a_stride,
73     const uint8_t *b_ptr, int b_stride, const uint8_t *m_ptr, int m_stride,
74     int height) {
75   int y;
76   __m256i res = _mm256_setzero_si256();
77   const __m256i mask_max = _mm256_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
78   const __m256i round_scale =
79       _mm256_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
80   for (y = 0; y < height; y += 2) {
81     const __m256i src = yy_loadu2_128(src_ptr + src_stride, src_ptr);
82     const __m256i a = yy_loadu2_128(a_ptr + a_stride, a_ptr);
83     const __m256i b = yy_loadu2_128(b_ptr + b_stride, b_ptr);
84     const __m256i m = yy_loadu2_128(m_ptr + m_stride, m_ptr);
85     const __m256i m_inv = _mm256_sub_epi8(mask_max, m);
86 
87     // Calculate 16 predicted pixels.
88     // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
89     // is 64 * 255, so we have plenty of space to add rounding constants.
90     const __m256i data_l = _mm256_unpacklo_epi8(a, b);
91     const __m256i mask_l = _mm256_unpacklo_epi8(m, m_inv);
92     __m256i pred_l = _mm256_maddubs_epi16(data_l, mask_l);
93     pred_l = _mm256_mulhrs_epi16(pred_l, round_scale);
94 
95     const __m256i data_r = _mm256_unpackhi_epi8(a, b);
96     const __m256i mask_r = _mm256_unpackhi_epi8(m, m_inv);
97     __m256i pred_r = _mm256_maddubs_epi16(data_r, mask_r);
98     pred_r = _mm256_mulhrs_epi16(pred_r, round_scale);
99 
100     const __m256i pred = _mm256_packus_epi16(pred_l, pred_r);
101     res = _mm256_add_epi32(res, _mm256_sad_epu8(pred, src));
102 
103     src_ptr += src_stride << 1;
104     a_ptr += a_stride << 1;
105     b_ptr += b_stride << 1;
106     m_ptr += m_stride << 1;
107   }
108   // At this point, we have two 32-bit partial SADs in lanes 0 and 2 of 'res'.
109   res = _mm256_shuffle_epi32(res, 0xd8);
110   res = _mm256_permute4x64_epi64(res, 0xd8);
111   res = _mm256_hadd_epi32(res, res);
112   res = _mm256_hadd_epi32(res, res);
113   int32_t sad = _mm256_extract_epi32(res, 0);
114   return sad;
115 }
116 
aom_masked_sad_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred,const uint8_t * msk,int msk_stride,int invert_mask,int m,int n)117 static inline unsigned int aom_masked_sad_avx2(
118     const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,
119     const uint8_t *second_pred, const uint8_t *msk, int msk_stride,
120     int invert_mask, int m, int n) {
121   unsigned int sad;
122   if (!invert_mask) {
123     switch (m) {
124       case 4:
125         sad = aom_masked_sad4xh_ssse3(src, src_stride, ref, ref_stride,
126                                       second_pred, m, msk, msk_stride, n);
127         break;
128       case 8:
129         sad = aom_masked_sad8xh_ssse3(src, src_stride, ref, ref_stride,
130                                       second_pred, m, msk, msk_stride, n);
131         break;
132       case 16:
133         sad = masked_sad16xh_avx2(src, src_stride, ref, ref_stride, second_pred,
134                                   m, msk, msk_stride, n);
135         break;
136       default:
137         sad = masked_sad32xh_avx2(src, src_stride, ref, ref_stride, second_pred,
138                                   m, msk, msk_stride, m, n);
139         break;
140     }
141   } else {
142     switch (m) {
143       case 4:
144         sad = aom_masked_sad4xh_ssse3(src, src_stride, second_pred, m, ref,
145                                       ref_stride, msk, msk_stride, n);
146         break;
147       case 8:
148         sad = aom_masked_sad8xh_ssse3(src, src_stride, second_pred, m, ref,
149                                       ref_stride, msk, msk_stride, n);
150         break;
151       case 16:
152         sad = masked_sad16xh_avx2(src, src_stride, second_pred, m, ref,
153                                   ref_stride, msk, msk_stride, n);
154         break;
155       default:
156         sad = masked_sad32xh_avx2(src, src_stride, second_pred, m, ref,
157                                   ref_stride, msk, msk_stride, m, n);
158         break;
159     }
160   }
161   return sad;
162 }
163 
164 #define MASKSADMXN_AVX2(m, n)                                                 \
165   unsigned int aom_masked_sad##m##x##n##_avx2(                                \
166       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
167       const uint8_t *second_pred, const uint8_t *msk, int msk_stride,         \
168       int invert_mask) {                                                      \
169     return aom_masked_sad_avx2(src, src_stride, ref, ref_stride, second_pred, \
170                                msk, msk_stride, invert_mask, m, n);           \
171   }
172 
173 MASKSADMXN_AVX2(4, 4)
174 MASKSADMXN_AVX2(4, 8)
175 MASKSADMXN_AVX2(8, 4)
176 MASKSADMXN_AVX2(8, 8)
177 MASKSADMXN_AVX2(8, 16)
178 MASKSADMXN_AVX2(16, 8)
179 MASKSADMXN_AVX2(16, 16)
180 MASKSADMXN_AVX2(16, 32)
181 MASKSADMXN_AVX2(32, 16)
182 MASKSADMXN_AVX2(32, 32)
183 MASKSADMXN_AVX2(32, 64)
184 MASKSADMXN_AVX2(64, 32)
185 MASKSADMXN_AVX2(64, 64)
186 MASKSADMXN_AVX2(64, 128)
187 MASKSADMXN_AVX2(128, 64)
188 MASKSADMXN_AVX2(128, 128)
189 
190 #if !CONFIG_REALTIME_ONLY
191 MASKSADMXN_AVX2(4, 16)
192 MASKSADMXN_AVX2(16, 4)
193 MASKSADMXN_AVX2(8, 32)
194 MASKSADMXN_AVX2(32, 8)
195 MASKSADMXN_AVX2(16, 64)
196 MASKSADMXN_AVX2(64, 16)
197 #endif  // !CONFIG_REALTIME_ONLY
198 
199 #if CONFIG_AV1_HIGHBITDEPTH
highbd_masked_sad8xh_avx2(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m_ptr,int m_stride,int height)200 static inline unsigned int highbd_masked_sad8xh_avx2(
201     const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
202     const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride,
203     int height) {
204   const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8);
205   const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8);
206   const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8);
207   int y;
208   __m256i res = _mm256_setzero_si256();
209   const __m256i mask_max = _mm256_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
210   const __m256i round_const =
211       _mm256_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
212   const __m256i one = _mm256_set1_epi16(1);
213 
214   for (y = 0; y < height; y += 2) {
215     const __m256i src = yy_loadu2_128(src_ptr + src_stride, src_ptr);
216     const __m256i a = yy_loadu2_128(a_ptr + a_stride, a_ptr);
217     const __m256i b = yy_loadu2_128(b_ptr + b_stride, b_ptr);
218     // Zero-extend mask to 16 bits
219     const __m256i m = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(
220         _mm_loadl_epi64((const __m128i *)(m_ptr)),
221         _mm_loadl_epi64((const __m128i *)(m_ptr + m_stride))));
222     const __m256i m_inv = _mm256_sub_epi16(mask_max, m);
223 
224     const __m256i data_l = _mm256_unpacklo_epi16(a, b);
225     const __m256i mask_l = _mm256_unpacklo_epi16(m, m_inv);
226     __m256i pred_l = _mm256_madd_epi16(data_l, mask_l);
227     pred_l = _mm256_srai_epi32(_mm256_add_epi32(pred_l, round_const),
228                                AOM_BLEND_A64_ROUND_BITS);
229 
230     const __m256i data_r = _mm256_unpackhi_epi16(a, b);
231     const __m256i mask_r = _mm256_unpackhi_epi16(m, m_inv);
232     __m256i pred_r = _mm256_madd_epi16(data_r, mask_r);
233     pred_r = _mm256_srai_epi32(_mm256_add_epi32(pred_r, round_const),
234                                AOM_BLEND_A64_ROUND_BITS);
235 
236     // Note: the maximum value in pred_l/r is (2^bd)-1 < 2^15,
237     // so it is safe to do signed saturation here.
238     const __m256i pred = _mm256_packs_epi32(pred_l, pred_r);
239     // There is no 16-bit SAD instruction, so we have to synthesize
240     // an 8-element SAD. We do this by storing 4 32-bit partial SADs,
241     // and accumulating them at the end
242     const __m256i diff = _mm256_abs_epi16(_mm256_sub_epi16(pred, src));
243     res = _mm256_add_epi32(res, _mm256_madd_epi16(diff, one));
244 
245     src_ptr += src_stride << 1;
246     a_ptr += a_stride << 1;
247     b_ptr += b_stride << 1;
248     m_ptr += m_stride << 1;
249   }
250   // At this point, we have four 32-bit partial SADs stored in 'res'.
251   res = _mm256_hadd_epi32(res, res);
252   res = _mm256_hadd_epi32(res, res);
253   int sad = _mm256_extract_epi32(res, 0) + _mm256_extract_epi32(res, 4);
254   return sad;
255 }
256 
highbd_masked_sad16xh_avx2(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height)257 static inline unsigned int highbd_masked_sad16xh_avx2(
258     const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
259     const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride,
260     int width, int height) {
261   const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8);
262   const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8);
263   const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8);
264   int x, y;
265   __m256i res = _mm256_setzero_si256();
266   const __m256i mask_max = _mm256_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
267   const __m256i round_const =
268       _mm256_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
269   const __m256i one = _mm256_set1_epi16(1);
270 
271   for (y = 0; y < height; y++) {
272     for (x = 0; x < width; x += 16) {
273       const __m256i src = _mm256_lddqu_si256((const __m256i *)&src_ptr[x]);
274       const __m256i a = _mm256_lddqu_si256((const __m256i *)&a_ptr[x]);
275       const __m256i b = _mm256_lddqu_si256((const __m256i *)&b_ptr[x]);
276       // Zero-extend mask to 16 bits
277       const __m256i m =
278           _mm256_cvtepu8_epi16(_mm_lddqu_si128((const __m128i *)&m_ptr[x]));
279       const __m256i m_inv = _mm256_sub_epi16(mask_max, m);
280 
281       const __m256i data_l = _mm256_unpacklo_epi16(a, b);
282       const __m256i mask_l = _mm256_unpacklo_epi16(m, m_inv);
283       __m256i pred_l = _mm256_madd_epi16(data_l, mask_l);
284       pred_l = _mm256_srai_epi32(_mm256_add_epi32(pred_l, round_const),
285                                  AOM_BLEND_A64_ROUND_BITS);
286 
287       const __m256i data_r = _mm256_unpackhi_epi16(a, b);
288       const __m256i mask_r = _mm256_unpackhi_epi16(m, m_inv);
289       __m256i pred_r = _mm256_madd_epi16(data_r, mask_r);
290       pred_r = _mm256_srai_epi32(_mm256_add_epi32(pred_r, round_const),
291                                  AOM_BLEND_A64_ROUND_BITS);
292 
293       // Note: the maximum value in pred_l/r is (2^bd)-1 < 2^15,
294       // so it is safe to do signed saturation here.
295       const __m256i pred = _mm256_packs_epi32(pred_l, pred_r);
296       // There is no 16-bit SAD instruction, so we have to synthesize
297       // an 8-element SAD. We do this by storing 4 32-bit partial SADs,
298       // and accumulating them at the end
299       const __m256i diff = _mm256_abs_epi16(_mm256_sub_epi16(pred, src));
300       res = _mm256_add_epi32(res, _mm256_madd_epi16(diff, one));
301     }
302 
303     src_ptr += src_stride;
304     a_ptr += a_stride;
305     b_ptr += b_stride;
306     m_ptr += m_stride;
307   }
308   // At this point, we have four 32-bit partial SADs stored in 'res'.
309   res = _mm256_hadd_epi32(res, res);
310   res = _mm256_hadd_epi32(res, res);
311   int sad = _mm256_extract_epi32(res, 0) + _mm256_extract_epi32(res, 4);
312   return sad;
313 }
314 
aom_highbd_masked_sad_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred,const uint8_t * msk,int msk_stride,int invert_mask,int m,int n)315 static inline unsigned int aom_highbd_masked_sad_avx2(
316     const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,
317     const uint8_t *second_pred, const uint8_t *msk, int msk_stride,
318     int invert_mask, int m, int n) {
319   unsigned int sad;
320   if (!invert_mask) {
321     switch (m) {
322       case 4:
323         sad =
324             aom_highbd_masked_sad4xh_ssse3(src, src_stride, ref, ref_stride,
325                                            second_pred, m, msk, msk_stride, n);
326         break;
327       case 8:
328         sad = highbd_masked_sad8xh_avx2(src, src_stride, ref, ref_stride,
329                                         second_pred, m, msk, msk_stride, n);
330         break;
331       default:
332         sad = highbd_masked_sad16xh_avx2(src, src_stride, ref, ref_stride,
333                                          second_pred, m, msk, msk_stride, m, n);
334         break;
335     }
336   } else {
337     switch (m) {
338       case 4:
339         sad =
340             aom_highbd_masked_sad4xh_ssse3(src, src_stride, second_pred, m, ref,
341                                            ref_stride, msk, msk_stride, n);
342         break;
343       case 8:
344         sad = highbd_masked_sad8xh_avx2(src, src_stride, second_pred, m, ref,
345                                         ref_stride, msk, msk_stride, n);
346         break;
347       default:
348         sad = highbd_masked_sad16xh_avx2(src, src_stride, second_pred, m, ref,
349                                          ref_stride, msk, msk_stride, m, n);
350         break;
351     }
352   }
353   return sad;
354 }
355 
356 #define HIGHBD_MASKSADMXN_AVX2(m, n)                                      \
357   unsigned int aom_highbd_masked_sad##m##x##n##_avx2(                     \
358       const uint8_t *src8, int src_stride, const uint8_t *ref8,           \
359       int ref_stride, const uint8_t *second_pred8, const uint8_t *msk,    \
360       int msk_stride, int invert_mask) {                                  \
361     return aom_highbd_masked_sad_avx2(src8, src_stride, ref8, ref_stride, \
362                                       second_pred8, msk, msk_stride,      \
363                                       invert_mask, m, n);                 \
364   }
365 
366 HIGHBD_MASKSADMXN_AVX2(4, 4)
367 HIGHBD_MASKSADMXN_AVX2(4, 8)
368 HIGHBD_MASKSADMXN_AVX2(8, 4)
369 HIGHBD_MASKSADMXN_AVX2(8, 8)
370 HIGHBD_MASKSADMXN_AVX2(8, 16)
371 HIGHBD_MASKSADMXN_AVX2(16, 8)
372 HIGHBD_MASKSADMXN_AVX2(16, 16)
373 HIGHBD_MASKSADMXN_AVX2(16, 32)
374 HIGHBD_MASKSADMXN_AVX2(32, 16)
375 HIGHBD_MASKSADMXN_AVX2(32, 32)
376 HIGHBD_MASKSADMXN_AVX2(32, 64)
377 HIGHBD_MASKSADMXN_AVX2(64, 32)
378 HIGHBD_MASKSADMXN_AVX2(64, 64)
379 HIGHBD_MASKSADMXN_AVX2(64, 128)
380 HIGHBD_MASKSADMXN_AVX2(128, 64)
381 HIGHBD_MASKSADMXN_AVX2(128, 128)
382 
383 #if !CONFIG_REALTIME_ONLY
384 HIGHBD_MASKSADMXN_AVX2(4, 16)
385 HIGHBD_MASKSADMXN_AVX2(16, 4)
386 HIGHBD_MASKSADMXN_AVX2(8, 32)
387 HIGHBD_MASKSADMXN_AVX2(32, 8)
388 HIGHBD_MASKSADMXN_AVX2(16, 64)
389 HIGHBD_MASKSADMXN_AVX2(64, 16)
390 #endif  // !CONFIG_REALTIME_ONLY
391 #endif  // CONFIG_AV1_HIGHBITDEPTH
392