1 /*
2 * Copyright (c) 2023 The WebM project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <immintrin.h>
12 #include <smmintrin.h>
13 #include <stdint.h>
14
15 #include "./vpx_config.h"
16 #include "./vpx_dsp_rtcd.h"
17
18 #include "vpx_ports/mem.h"
19 #include "vpx_dsp/x86/mem_sse2.h"
20
sse_w32_avx2(__m256i * sum,const uint8_t * a,const uint8_t * b)21 static INLINE void sse_w32_avx2(__m256i *sum, const uint8_t *a,
22 const uint8_t *b) {
23 const __m256i v_a0 = _mm256_loadu_si256((const __m256i *)a);
24 const __m256i v_b0 = _mm256_loadu_si256((const __m256i *)b);
25 const __m256i zero = _mm256_setzero_si256();
26 const __m256i v_a00_w = _mm256_unpacklo_epi8(v_a0, zero);
27 const __m256i v_a01_w = _mm256_unpackhi_epi8(v_a0, zero);
28 const __m256i v_b00_w = _mm256_unpacklo_epi8(v_b0, zero);
29 const __m256i v_b01_w = _mm256_unpackhi_epi8(v_b0, zero);
30 const __m256i v_d00_w = _mm256_sub_epi16(v_a00_w, v_b00_w);
31 const __m256i v_d01_w = _mm256_sub_epi16(v_a01_w, v_b01_w);
32 *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d00_w, v_d00_w));
33 *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d01_w, v_d01_w));
34 }
35
summary_all_avx2(const __m256i * sum_all)36 static INLINE int64_t summary_all_avx2(const __m256i *sum_all) {
37 int64_t sum;
38 __m256i zero = _mm256_setzero_si256();
39 const __m256i sum0_4x64 = _mm256_unpacklo_epi32(*sum_all, zero);
40 const __m256i sum1_4x64 = _mm256_unpackhi_epi32(*sum_all, zero);
41 const __m256i sum_4x64 = _mm256_add_epi64(sum0_4x64, sum1_4x64);
42 const __m128i sum_2x64 = _mm_add_epi64(_mm256_castsi256_si128(sum_4x64),
43 _mm256_extracti128_si256(sum_4x64, 1));
44 const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
45 _mm_storel_epi64((__m128i *)&sum, sum_1x64);
46 return sum;
47 }
48
49 #if CONFIG_VP9_HIGHBITDEPTH
summary_32_avx2(const __m256i * sum32,__m256i * sum)50 static INLINE void summary_32_avx2(const __m256i *sum32, __m256i *sum) {
51 const __m256i sum0_4x64 =
52 _mm256_cvtepu32_epi64(_mm256_castsi256_si128(*sum32));
53 const __m256i sum1_4x64 =
54 _mm256_cvtepu32_epi64(_mm256_extracti128_si256(*sum32, 1));
55 const __m256i sum_4x64 = _mm256_add_epi64(sum0_4x64, sum1_4x64);
56 *sum = _mm256_add_epi64(*sum, sum_4x64);
57 }
58
summary_4x64_avx2(const __m256i sum_4x64)59 static INLINE int64_t summary_4x64_avx2(const __m256i sum_4x64) {
60 int64_t sum;
61 const __m128i sum_2x64 = _mm_add_epi64(_mm256_castsi256_si128(sum_4x64),
62 _mm256_extracti128_si256(sum_4x64, 1));
63 const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
64
65 _mm_storel_epi64((__m128i *)&sum, sum_1x64);
66 return sum;
67 }
68 #endif
69
sse_w4x4_avx2(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,__m256i * sum)70 static INLINE void sse_w4x4_avx2(const uint8_t *a, int a_stride,
71 const uint8_t *b, int b_stride, __m256i *sum) {
72 const __m128i v_a0 = load_unaligned_u32(a);
73 const __m128i v_a1 = load_unaligned_u32(a + a_stride);
74 const __m128i v_a2 = load_unaligned_u32(a + a_stride * 2);
75 const __m128i v_a3 = load_unaligned_u32(a + a_stride * 3);
76 const __m128i v_b0 = load_unaligned_u32(b);
77 const __m128i v_b1 = load_unaligned_u32(b + b_stride);
78 const __m128i v_b2 = load_unaligned_u32(b + b_stride * 2);
79 const __m128i v_b3 = load_unaligned_u32(b + b_stride * 3);
80 const __m128i v_a0123 = _mm_unpacklo_epi64(_mm_unpacklo_epi32(v_a0, v_a1),
81 _mm_unpacklo_epi32(v_a2, v_a3));
82 const __m128i v_b0123 = _mm_unpacklo_epi64(_mm_unpacklo_epi32(v_b0, v_b1),
83 _mm_unpacklo_epi32(v_b2, v_b3));
84 const __m256i v_a_w = _mm256_cvtepu8_epi16(v_a0123);
85 const __m256i v_b_w = _mm256_cvtepu8_epi16(v_b0123);
86 const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
87 *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
88 }
89
sse_w8x2_avx2(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,__m256i * sum)90 static INLINE void sse_w8x2_avx2(const uint8_t *a, int a_stride,
91 const uint8_t *b, int b_stride, __m256i *sum) {
92 const __m128i v_a0 = _mm_loadl_epi64((const __m128i *)a);
93 const __m128i v_a1 = _mm_loadl_epi64((const __m128i *)(a + a_stride));
94 const __m128i v_b0 = _mm_loadl_epi64((const __m128i *)b);
95 const __m128i v_b1 = _mm_loadl_epi64((const __m128i *)(b + b_stride));
96 const __m256i v_a_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(v_a0, v_a1));
97 const __m256i v_b_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(v_b0, v_b1));
98 const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
99 *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
100 }
101
vpx_sse_avx2(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,int width,int height)102 int64_t vpx_sse_avx2(const uint8_t *a, int a_stride, const uint8_t *b,
103 int b_stride, int width, int height) {
104 int32_t y = 0;
105 int64_t sse = 0;
106 __m256i sum = _mm256_setzero_si256();
107 __m256i zero = _mm256_setzero_si256();
108 switch (width) {
109 case 4:
110 do {
111 sse_w4x4_avx2(a, a_stride, b, b_stride, &sum);
112 a += a_stride << 2;
113 b += b_stride << 2;
114 y += 4;
115 } while (y < height);
116 sse = summary_all_avx2(&sum);
117 break;
118 case 8:
119 do {
120 sse_w8x2_avx2(a, a_stride, b, b_stride, &sum);
121 a += a_stride << 1;
122 b += b_stride << 1;
123 y += 2;
124 } while (y < height);
125 sse = summary_all_avx2(&sum);
126 break;
127 case 16:
128 do {
129 const __m128i v_a0 = _mm_loadu_si128((const __m128i *)a);
130 const __m128i v_a1 = _mm_loadu_si128((const __m128i *)(a + a_stride));
131 const __m128i v_b0 = _mm_loadu_si128((const __m128i *)b);
132 const __m128i v_b1 = _mm_loadu_si128((const __m128i *)(b + b_stride));
133 const __m256i v_a =
134 _mm256_insertf128_si256(_mm256_castsi128_si256(v_a0), v_a1, 0x01);
135 const __m256i v_b =
136 _mm256_insertf128_si256(_mm256_castsi128_si256(v_b0), v_b1, 0x01);
137 const __m256i v_al = _mm256_unpacklo_epi8(v_a, zero);
138 const __m256i v_au = _mm256_unpackhi_epi8(v_a, zero);
139 const __m256i v_bl = _mm256_unpacklo_epi8(v_b, zero);
140 const __m256i v_bu = _mm256_unpackhi_epi8(v_b, zero);
141 const __m256i v_asub = _mm256_sub_epi16(v_al, v_bl);
142 const __m256i v_bsub = _mm256_sub_epi16(v_au, v_bu);
143 const __m256i temp =
144 _mm256_add_epi32(_mm256_madd_epi16(v_asub, v_asub),
145 _mm256_madd_epi16(v_bsub, v_bsub));
146 sum = _mm256_add_epi32(sum, temp);
147 a += a_stride << 1;
148 b += b_stride << 1;
149 y += 2;
150 } while (y < height);
151 sse = summary_all_avx2(&sum);
152 break;
153 case 32:
154 do {
155 sse_w32_avx2(&sum, a, b);
156 a += a_stride;
157 b += b_stride;
158 y += 1;
159 } while (y < height);
160 sse = summary_all_avx2(&sum);
161 break;
162 case 64:
163 do {
164 sse_w32_avx2(&sum, a, b);
165 sse_w32_avx2(&sum, a + 32, b + 32);
166 a += a_stride;
167 b += b_stride;
168 y += 1;
169 } while (y < height);
170 sse = summary_all_avx2(&sum);
171 break;
172 default:
173 if ((width & 0x07) == 0) {
174 do {
175 int i = 0;
176 do {
177 sse_w8x2_avx2(a + i, a_stride, b + i, b_stride, &sum);
178 i += 8;
179 } while (i < width);
180 a += a_stride << 1;
181 b += b_stride << 1;
182 y += 2;
183 } while (y < height);
184 } else {
185 do {
186 int i = 0;
187 do {
188 const uint8_t *a2;
189 const uint8_t *b2;
190 sse_w8x2_avx2(a + i, a_stride, b + i, b_stride, &sum);
191 a2 = a + i + (a_stride << 1);
192 b2 = b + i + (b_stride << 1);
193 sse_w8x2_avx2(a2, a_stride, b2, b_stride, &sum);
194 i += 8;
195 } while (i + 4 < width);
196 sse_w4x4_avx2(a + i, a_stride, b + i, b_stride, &sum);
197 a += a_stride << 2;
198 b += b_stride << 2;
199 y += 4;
200 } while (y < height);
201 }
202 sse = summary_all_avx2(&sum);
203 break;
204 }
205
206 return sse;
207 }
208
209 #if CONFIG_VP9_HIGHBITDEPTH
highbd_sse_w16_avx2(__m256i * sum,const uint16_t * a,const uint16_t * b)210 static INLINE void highbd_sse_w16_avx2(__m256i *sum, const uint16_t *a,
211 const uint16_t *b) {
212 const __m256i v_a_w = _mm256_loadu_si256((const __m256i *)a);
213 const __m256i v_b_w = _mm256_loadu_si256((const __m256i *)b);
214 const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
215 *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
216 }
217
highbd_sse_w4x4_avx2(__m256i * sum,const uint16_t * a,int a_stride,const uint16_t * b,int b_stride)218 static INLINE void highbd_sse_w4x4_avx2(__m256i *sum, const uint16_t *a,
219 int a_stride, const uint16_t *b,
220 int b_stride) {
221 const __m128i v_a0 = _mm_loadl_epi64((const __m128i *)a);
222 const __m128i v_a1 = _mm_loadl_epi64((const __m128i *)(a + a_stride));
223 const __m128i v_a2 = _mm_loadl_epi64((const __m128i *)(a + a_stride * 2));
224 const __m128i v_a3 = _mm_loadl_epi64((const __m128i *)(a + a_stride * 3));
225 const __m128i v_b0 = _mm_loadl_epi64((const __m128i *)b);
226 const __m128i v_b1 = _mm_loadl_epi64((const __m128i *)(b + b_stride));
227 const __m128i v_b2 = _mm_loadl_epi64((const __m128i *)(b + b_stride * 2));
228 const __m128i v_b3 = _mm_loadl_epi64((const __m128i *)(b + b_stride * 3));
229 const __m128i v_a_hi = _mm_unpacklo_epi64(v_a0, v_a1);
230 const __m128i v_a_lo = _mm_unpacklo_epi64(v_a2, v_a3);
231 const __m256i v_a_w =
232 _mm256_insertf128_si256(_mm256_castsi128_si256(v_a_lo), v_a_hi, 1);
233 const __m128i v_b_hi = _mm_unpacklo_epi64(v_b0, v_b1);
234 const __m128i v_b_lo = _mm_unpacklo_epi64(v_b2, v_b3);
235 const __m256i v_b_w =
236 _mm256_insertf128_si256(_mm256_castsi128_si256(v_b_lo), v_b_hi, 1);
237 const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
238 *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
239 }
240
highbd_sse_w8x2_avx2(__m256i * sum,const uint16_t * a,int a_stride,const uint16_t * b,int b_stride)241 static INLINE void highbd_sse_w8x2_avx2(__m256i *sum, const uint16_t *a,
242 int a_stride, const uint16_t *b,
243 int b_stride) {
244 const __m128i v_a_hi = _mm_loadu_si128((const __m128i *)(a + a_stride));
245 const __m128i v_a_lo = _mm_loadu_si128((const __m128i *)a);
246 const __m256i v_a_w =
247 _mm256_insertf128_si256(_mm256_castsi128_si256(v_a_lo), v_a_hi, 1);
248 const __m128i v_b_hi = _mm_loadu_si128((const __m128i *)(b + b_stride));
249 const __m128i v_b_lo = _mm_loadu_si128((const __m128i *)b);
250 const __m256i v_b_w =
251 _mm256_insertf128_si256(_mm256_castsi128_si256(v_b_lo), v_b_hi, 1);
252 const __m256i v_d_w = _mm256_sub_epi16(v_a_w, v_b_w);
253 *sum = _mm256_add_epi32(*sum, _mm256_madd_epi16(v_d_w, v_d_w));
254 }
255
vpx_highbd_sse_avx2(const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,int width,int height)256 int64_t vpx_highbd_sse_avx2(const uint8_t *a8, int a_stride, const uint8_t *b8,
257 int b_stride, int width, int height) {
258 int32_t y = 0;
259 int64_t sse = 0;
260 uint16_t *a = CONVERT_TO_SHORTPTR(a8);
261 uint16_t *b = CONVERT_TO_SHORTPTR(b8);
262 __m256i sum = _mm256_setzero_si256();
263 switch (width) {
264 case 4:
265 do {
266 highbd_sse_w4x4_avx2(&sum, a, a_stride, b, b_stride);
267 a += a_stride << 2;
268 b += b_stride << 2;
269 y += 4;
270 } while (y < height);
271 sse = summary_all_avx2(&sum);
272 break;
273 case 8:
274 do {
275 highbd_sse_w8x2_avx2(&sum, a, a_stride, b, b_stride);
276 a += a_stride << 1;
277 b += b_stride << 1;
278 y += 2;
279 } while (y < height);
280 sse = summary_all_avx2(&sum);
281 break;
282 case 16:
283 do {
284 highbd_sse_w16_avx2(&sum, a, b);
285 a += a_stride;
286 b += b_stride;
287 y += 1;
288 } while (y < height);
289 sse = summary_all_avx2(&sum);
290 break;
291 case 32:
292 do {
293 int l = 0;
294 __m256i sum32 = _mm256_setzero_si256();
295 do {
296 highbd_sse_w16_avx2(&sum32, a, b);
297 highbd_sse_w16_avx2(&sum32, a + 16, b + 16);
298 a += a_stride;
299 b += b_stride;
300 l += 1;
301 } while (l < 64 && l < (height - y));
302 summary_32_avx2(&sum32, &sum);
303 y += 64;
304 } while (y < height);
305 sse = summary_4x64_avx2(sum);
306 break;
307 case 64:
308 do {
309 int l = 0;
310 __m256i sum32 = _mm256_setzero_si256();
311 do {
312 highbd_sse_w16_avx2(&sum32, a, b);
313 highbd_sse_w16_avx2(&sum32, a + 16 * 1, b + 16 * 1);
314 highbd_sse_w16_avx2(&sum32, a + 16 * 2, b + 16 * 2);
315 highbd_sse_w16_avx2(&sum32, a + 16 * 3, b + 16 * 3);
316 a += a_stride;
317 b += b_stride;
318 l += 1;
319 } while (l < 32 && l < (height - y));
320 summary_32_avx2(&sum32, &sum);
321 y += 32;
322 } while (y < height);
323 sse = summary_4x64_avx2(sum);
324 break;
325 default:
326 if (width & 0x7) {
327 do {
328 int i = 0;
329 __m256i sum32 = _mm256_setzero_si256();
330 do {
331 const uint16_t *a2;
332 const uint16_t *b2;
333 highbd_sse_w8x2_avx2(&sum32, a + i, a_stride, b + i, b_stride);
334 a2 = a + i + (a_stride << 1);
335 b2 = b + i + (b_stride << 1);
336 highbd_sse_w8x2_avx2(&sum32, a2, a_stride, b2, b_stride);
337 i += 8;
338 } while (i + 4 < width);
339 highbd_sse_w4x4_avx2(&sum32, a + i, a_stride, b + i, b_stride);
340 summary_32_avx2(&sum32, &sum);
341 a += a_stride << 2;
342 b += b_stride << 2;
343 y += 4;
344 } while (y < height);
345 } else {
346 do {
347 int l = 0;
348 __m256i sum32 = _mm256_setzero_si256();
349 do {
350 int i = 0;
351 do {
352 highbd_sse_w8x2_avx2(&sum32, a + i, a_stride, b + i, b_stride);
353 i += 8;
354 } while (i < width);
355 a += a_stride << 1;
356 b += b_stride << 1;
357 l += 2;
358 } while (l < 8 && l < (height - y));
359 summary_32_avx2(&sum32, &sum);
360 y += 8;
361 } while (y < height);
362 }
363 sse = summary_4x64_avx2(sum);
364 break;
365 }
366 return sse;
367 }
368 #endif // CONFIG_VP9_HIGHBITDEPTH
369