xref: /aosp_15_r20/external/libvpx/vpx_dsp/x86/sad_avx2.c (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  *  Copyright (c) 2012 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #include <immintrin.h>
11 #include "./vpx_dsp_rtcd.h"
12 #include "vpx_ports/mem.h"
13 
sad64xh_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)14 static INLINE unsigned int sad64xh_avx2(const uint8_t *src_ptr, int src_stride,
15                                         const uint8_t *ref_ptr, int ref_stride,
16                                         int h) {
17   int i, res;
18   __m256i sad1_reg, sad2_reg, ref1_reg, ref2_reg;
19   __m256i sum_sad = _mm256_setzero_si256();
20   __m256i sum_sad_h;
21   __m128i sum_sad128;
22   for (i = 0; i < h; i++) {
23     ref1_reg = _mm256_loadu_si256((__m256i const *)ref_ptr);
24     ref2_reg = _mm256_loadu_si256((__m256i const *)(ref_ptr + 32));
25     sad1_reg =
26         _mm256_sad_epu8(ref1_reg, _mm256_loadu_si256((__m256i const *)src_ptr));
27     sad2_reg = _mm256_sad_epu8(
28         ref2_reg, _mm256_loadu_si256((__m256i const *)(src_ptr + 32)));
29     sum_sad = _mm256_add_epi32(sum_sad, _mm256_add_epi32(sad1_reg, sad2_reg));
30     ref_ptr += ref_stride;
31     src_ptr += src_stride;
32   }
33   sum_sad_h = _mm256_srli_si256(sum_sad, 8);
34   sum_sad = _mm256_add_epi32(sum_sad, sum_sad_h);
35   sum_sad128 = _mm256_extracti128_si256(sum_sad, 1);
36   sum_sad128 = _mm_add_epi32(_mm256_castsi256_si128(sum_sad), sum_sad128);
37   res = _mm_cvtsi128_si32(sum_sad128);
38   return res;
39 }
40 
sad32xh_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)41 static INLINE unsigned int sad32xh_avx2(const uint8_t *src_ptr, int src_stride,
42                                         const uint8_t *ref_ptr, int ref_stride,
43                                         int h) {
44   int i, res;
45   __m256i sad1_reg, sad2_reg, ref1_reg, ref2_reg;
46   __m256i sum_sad = _mm256_setzero_si256();
47   __m256i sum_sad_h;
48   __m128i sum_sad128;
49   const int ref2_stride = ref_stride << 1;
50   const int src2_stride = src_stride << 1;
51   const int max = h >> 1;
52   for (i = 0; i < max; i++) {
53     ref1_reg = _mm256_loadu_si256((__m256i const *)ref_ptr);
54     ref2_reg = _mm256_loadu_si256((__m256i const *)(ref_ptr + ref_stride));
55     sad1_reg =
56         _mm256_sad_epu8(ref1_reg, _mm256_loadu_si256((__m256i const *)src_ptr));
57     sad2_reg = _mm256_sad_epu8(
58         ref2_reg, _mm256_loadu_si256((__m256i const *)(src_ptr + src_stride)));
59     sum_sad = _mm256_add_epi32(sum_sad, _mm256_add_epi32(sad1_reg, sad2_reg));
60     ref_ptr += ref2_stride;
61     src_ptr += src2_stride;
62   }
63   sum_sad_h = _mm256_srli_si256(sum_sad, 8);
64   sum_sad = _mm256_add_epi32(sum_sad, sum_sad_h);
65   sum_sad128 = _mm256_extracti128_si256(sum_sad, 1);
66   sum_sad128 = _mm_add_epi32(_mm256_castsi256_si128(sum_sad), sum_sad128);
67   res = _mm_cvtsi128_si32(sum_sad128);
68   return res;
69 }
70 
71 #define FSAD64_H(h)                                                           \
72   unsigned int vpx_sad64x##h##_avx2(const uint8_t *src_ptr, int src_stride,   \
73                                     const uint8_t *ref_ptr, int ref_stride) { \
74     return sad64xh_avx2(src_ptr, src_stride, ref_ptr, ref_stride, h);         \
75   }
76 
77 #define FSADS64_H(h)                                                          \
78   unsigned int vpx_sad_skip_64x##h##_avx2(                                    \
79       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,         \
80       int ref_stride) {                                                       \
81     return 2 * sad64xh_avx2(src_ptr, src_stride * 2, ref_ptr, ref_stride * 2, \
82                             h / 2);                                           \
83   }
84 
85 #define FSAD32_H(h)                                                           \
86   unsigned int vpx_sad32x##h##_avx2(const uint8_t *src_ptr, int src_stride,   \
87                                     const uint8_t *ref_ptr, int ref_stride) { \
88     return sad32xh_avx2(src_ptr, src_stride, ref_ptr, ref_stride, h);         \
89   }
90 
91 #define FSADS32_H(h)                                                          \
92   unsigned int vpx_sad_skip_32x##h##_avx2(                                    \
93       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,         \
94       int ref_stride) {                                                       \
95     return 2 * sad32xh_avx2(src_ptr, src_stride * 2, ref_ptr, ref_stride * 2, \
96                             h / 2);                                           \
97   }
98 
99 #define FSAD64  \
100   FSAD64_H(64)  \
101   FSAD64_H(32)  \
102   FSADS64_H(64) \
103   FSADS64_H(32)
104 
105 #define FSAD32  \
106   FSAD32_H(64)  \
107   FSAD32_H(32)  \
108   FSAD32_H(16)  \
109   FSADS32_H(64) \
110   FSADS32_H(32) \
111   FSADS32_H(16)
112 
113 FSAD64
114 FSAD32
115 
116 #undef FSAD64
117 #undef FSAD32
118 #undef FSAD64_H
119 #undef FSAD32_H
120 #undef FSADS64_H
121 #undef FSADS32_H
122 
123 #define FSADAVG64_H(h)                                                        \
124   unsigned int vpx_sad64x##h##_avg_avx2(                                      \
125       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,         \
126       int ref_stride, const uint8_t *second_pred) {                           \
127     int i;                                                                    \
128     __m256i sad1_reg, sad2_reg, ref1_reg, ref2_reg;                           \
129     __m256i sum_sad = _mm256_setzero_si256();                                 \
130     __m256i sum_sad_h;                                                        \
131     __m128i sum_sad128;                                                       \
132     for (i = 0; i < h; i++) {                                                 \
133       ref1_reg = _mm256_loadu_si256((__m256i const *)ref_ptr);                \
134       ref2_reg = _mm256_loadu_si256((__m256i const *)(ref_ptr + 32));         \
135       ref1_reg = _mm256_avg_epu8(                                             \
136           ref1_reg, _mm256_loadu_si256((__m256i const *)second_pred));        \
137       ref2_reg = _mm256_avg_epu8(                                             \
138           ref2_reg, _mm256_loadu_si256((__m256i const *)(second_pred + 32))); \
139       sad1_reg = _mm256_sad_epu8(                                             \
140           ref1_reg, _mm256_loadu_si256((__m256i const *)src_ptr));            \
141       sad2_reg = _mm256_sad_epu8(                                             \
142           ref2_reg, _mm256_loadu_si256((__m256i const *)(src_ptr + 32)));     \
143       sum_sad =                                                               \
144           _mm256_add_epi32(sum_sad, _mm256_add_epi32(sad1_reg, sad2_reg));    \
145       ref_ptr += ref_stride;                                                  \
146       src_ptr += src_stride;                                                  \
147       second_pred += 64;                                                      \
148     }                                                                         \
149     sum_sad_h = _mm256_srli_si256(sum_sad, 8);                                \
150     sum_sad = _mm256_add_epi32(sum_sad, sum_sad_h);                           \
151     sum_sad128 = _mm256_extracti128_si256(sum_sad, 1);                        \
152     sum_sad128 = _mm_add_epi32(_mm256_castsi256_si128(sum_sad), sum_sad128);  \
153     return (unsigned int)_mm_cvtsi128_si32(sum_sad128);                       \
154   }
155 
156 #define FSADAVG32_H(h)                                                        \
157   unsigned int vpx_sad32x##h##_avg_avx2(                                      \
158       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,         \
159       int ref_stride, const uint8_t *second_pred) {                           \
160     int i;                                                                    \
161     __m256i sad1_reg, sad2_reg, ref1_reg, ref2_reg;                           \
162     __m256i sum_sad = _mm256_setzero_si256();                                 \
163     __m256i sum_sad_h;                                                        \
164     __m128i sum_sad128;                                                       \
165     int ref2_stride = ref_stride << 1;                                        \
166     int src2_stride = src_stride << 1;                                        \
167     int max = h >> 1;                                                         \
168     for (i = 0; i < max; i++) {                                               \
169       ref1_reg = _mm256_loadu_si256((__m256i const *)ref_ptr);                \
170       ref2_reg = _mm256_loadu_si256((__m256i const *)(ref_ptr + ref_stride)); \
171       ref1_reg = _mm256_avg_epu8(                                             \
172           ref1_reg, _mm256_loadu_si256((__m256i const *)second_pred));        \
173       ref2_reg = _mm256_avg_epu8(                                             \
174           ref2_reg, _mm256_loadu_si256((__m256i const *)(second_pred + 32))); \
175       sad1_reg = _mm256_sad_epu8(                                             \
176           ref1_reg, _mm256_loadu_si256((__m256i const *)src_ptr));            \
177       sad2_reg = _mm256_sad_epu8(                                             \
178           ref2_reg,                                                           \
179           _mm256_loadu_si256((__m256i const *)(src_ptr + src_stride)));       \
180       sum_sad =                                                               \
181           _mm256_add_epi32(sum_sad, _mm256_add_epi32(sad1_reg, sad2_reg));    \
182       ref_ptr += ref2_stride;                                                 \
183       src_ptr += src2_stride;                                                 \
184       second_pred += 64;                                                      \
185     }                                                                         \
186     sum_sad_h = _mm256_srli_si256(sum_sad, 8);                                \
187     sum_sad = _mm256_add_epi32(sum_sad, sum_sad_h);                           \
188     sum_sad128 = _mm256_extracti128_si256(sum_sad, 1);                        \
189     sum_sad128 = _mm_add_epi32(_mm256_castsi256_si128(sum_sad), sum_sad128);  \
190     return (unsigned int)_mm_cvtsi128_si32(sum_sad128);                       \
191   }
192 
193 #define FSADAVG64 \
194   FSADAVG64_H(64) \
195   FSADAVG64_H(32)
196 
197 #define FSADAVG32 \
198   FSADAVG32_H(64) \
199   FSADAVG32_H(32) \
200   FSADAVG32_H(16)
201 
202 FSADAVG64
203 FSADAVG32
204 
205 #undef FSADAVG64
206 #undef FSADAVG32
207 #undef FSADAVG64_H
208 #undef FSADAVG32_H
209