xref: /aosp_15_r20/external/libvpx/vpx_dsp/x86/sse_sse4.c (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  *  Copyright (c) 2023 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <assert.h>
12 #include <smmintrin.h>
13 
14 #include "./vpx_config.h"
15 #include "./vpx_dsp_rtcd.h"
16 
17 #include "vpx_ports/mem.h"
18 #include "vpx/vpx_integer.h"
19 #include "vpx_dsp/x86/mem_sse2.h"
20 
summary_all_sse4(const __m128i * sum_all)21 static INLINE int64_t summary_all_sse4(const __m128i *sum_all) {
22   int64_t sum;
23   const __m128i sum0 = _mm_cvtepu32_epi64(*sum_all);
24   const __m128i sum1 = _mm_cvtepu32_epi64(_mm_srli_si128(*sum_all, 8));
25   const __m128i sum_2x64 = _mm_add_epi64(sum0, sum1);
26   const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
27   _mm_storel_epi64((__m128i *)&sum, sum_1x64);
28   return sum;
29 }
30 
31 #if CONFIG_VP9_HIGHBITDEPTH
summary_32_sse4(const __m128i * sum32,__m128i * sum64)32 static INLINE void summary_32_sse4(const __m128i *sum32, __m128i *sum64) {
33   const __m128i sum0 = _mm_cvtepu32_epi64(*sum32);
34   const __m128i sum1 = _mm_cvtepu32_epi64(_mm_srli_si128(*sum32, 8));
35   *sum64 = _mm_add_epi64(sum0, *sum64);
36   *sum64 = _mm_add_epi64(sum1, *sum64);
37 }
38 #endif
39 
sse_w16_sse4_1(__m128i * sum,const uint8_t * a,const uint8_t * b)40 static INLINE void sse_w16_sse4_1(__m128i *sum, const uint8_t *a,
41                                   const uint8_t *b) {
42   const __m128i v_a0 = _mm_loadu_si128((const __m128i *)a);
43   const __m128i v_b0 = _mm_loadu_si128((const __m128i *)b);
44   const __m128i v_a00_w = _mm_cvtepu8_epi16(v_a0);
45   const __m128i v_a01_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_a0, 8));
46   const __m128i v_b00_w = _mm_cvtepu8_epi16(v_b0);
47   const __m128i v_b01_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_b0, 8));
48   const __m128i v_d00_w = _mm_sub_epi16(v_a00_w, v_b00_w);
49   const __m128i v_d01_w = _mm_sub_epi16(v_a01_w, v_b01_w);
50   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d00_w, v_d00_w));
51   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d01_w, v_d01_w));
52 }
53 
sse4x2_sse4_1(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,__m128i * sum)54 static INLINE void sse4x2_sse4_1(const uint8_t *a, int a_stride,
55                                  const uint8_t *b, int b_stride, __m128i *sum) {
56   const __m128i v_a0 = load_unaligned_u32(a);
57   const __m128i v_a1 = load_unaligned_u32(a + a_stride);
58   const __m128i v_b0 = load_unaligned_u32(b);
59   const __m128i v_b1 = load_unaligned_u32(b + b_stride);
60   const __m128i v_a_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_a0, v_a1));
61   const __m128i v_b_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_b0, v_b1));
62   const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
63   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
64 }
65 
sse8_sse4_1(const uint8_t * a,const uint8_t * b,__m128i * sum)66 static INLINE void sse8_sse4_1(const uint8_t *a, const uint8_t *b,
67                                __m128i *sum) {
68   const __m128i v_a0 = _mm_loadl_epi64((const __m128i *)a);
69   const __m128i v_b0 = _mm_loadl_epi64((const __m128i *)b);
70   const __m128i v_a_w = _mm_cvtepu8_epi16(v_a0);
71   const __m128i v_b_w = _mm_cvtepu8_epi16(v_b0);
72   const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
73   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
74 }
75 
vpx_sse_sse4_1(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,int width,int height)76 int64_t vpx_sse_sse4_1(const uint8_t *a, int a_stride, const uint8_t *b,
77                        int b_stride, int width, int height) {
78   int y = 0;
79   int64_t sse = 0;
80   __m128i sum = _mm_setzero_si128();
81   switch (width) {
82     case 4:
83       do {
84         sse4x2_sse4_1(a, a_stride, b, b_stride, &sum);
85         a += a_stride << 1;
86         b += b_stride << 1;
87         y += 2;
88       } while (y < height);
89       sse = summary_all_sse4(&sum);
90       break;
91     case 8:
92       do {
93         sse8_sse4_1(a, b, &sum);
94         a += a_stride;
95         b += b_stride;
96         y += 1;
97       } while (y < height);
98       sse = summary_all_sse4(&sum);
99       break;
100     case 16:
101       do {
102         sse_w16_sse4_1(&sum, a, b);
103         a += a_stride;
104         b += b_stride;
105         y += 1;
106       } while (y < height);
107       sse = summary_all_sse4(&sum);
108       break;
109     case 32:
110       do {
111         sse_w16_sse4_1(&sum, a, b);
112         sse_w16_sse4_1(&sum, a + 16, b + 16);
113         a += a_stride;
114         b += b_stride;
115         y += 1;
116       } while (y < height);
117       sse = summary_all_sse4(&sum);
118       break;
119     case 64:
120       do {
121         sse_w16_sse4_1(&sum, a, b);
122         sse_w16_sse4_1(&sum, a + 16 * 1, b + 16 * 1);
123         sse_w16_sse4_1(&sum, a + 16 * 2, b + 16 * 2);
124         sse_w16_sse4_1(&sum, a + 16 * 3, b + 16 * 3);
125         a += a_stride;
126         b += b_stride;
127         y += 1;
128       } while (y < height);
129       sse = summary_all_sse4(&sum);
130       break;
131     default:
132       if (width & 0x07) {
133         do {
134           int i = 0;
135           do {
136             sse8_sse4_1(a + i, b + i, &sum);
137             sse8_sse4_1(a + i + a_stride, b + i + b_stride, &sum);
138             i += 8;
139           } while (i + 4 < width);
140           sse4x2_sse4_1(a + i, a_stride, b + i, b_stride, &sum);
141           a += (a_stride << 1);
142           b += (b_stride << 1);
143           y += 2;
144         } while (y < height);
145       } else {
146         do {
147           int i = 0;
148           do {
149             sse8_sse4_1(a + i, b + i, &sum);
150             i += 8;
151           } while (i < width);
152           a += a_stride;
153           b += b_stride;
154           y += 1;
155         } while (y < height);
156       }
157       sse = summary_all_sse4(&sum);
158       break;
159   }
160 
161   return sse;
162 }
163 
164 #if CONFIG_VP9_HIGHBITDEPTH
highbd_sse_w4x2_sse4_1(__m128i * sum,const uint16_t * a,int a_stride,const uint16_t * b,int b_stride)165 static INLINE void highbd_sse_w4x2_sse4_1(__m128i *sum, const uint16_t *a,
166                                           int a_stride, const uint16_t *b,
167                                           int b_stride) {
168   const __m128i v_a0 = _mm_loadl_epi64((const __m128i *)a);
169   const __m128i v_a1 = _mm_loadl_epi64((const __m128i *)(a + a_stride));
170   const __m128i v_b0 = _mm_loadl_epi64((const __m128i *)b);
171   const __m128i v_b1 = _mm_loadl_epi64((const __m128i *)(b + b_stride));
172   const __m128i v_a_w = _mm_unpacklo_epi64(v_a0, v_a1);
173   const __m128i v_b_w = _mm_unpacklo_epi64(v_b0, v_b1);
174   const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
175   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
176 }
177 
highbd_sse_w8_sse4_1(__m128i * sum,const uint16_t * a,const uint16_t * b)178 static INLINE void highbd_sse_w8_sse4_1(__m128i *sum, const uint16_t *a,
179                                         const uint16_t *b) {
180   const __m128i v_a_w = _mm_loadu_si128((const __m128i *)a);
181   const __m128i v_b_w = _mm_loadu_si128((const __m128i *)b);
182   const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
183   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
184 }
185 
vpx_highbd_sse_sse4_1(const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,int width,int height)186 int64_t vpx_highbd_sse_sse4_1(const uint8_t *a8, int a_stride,
187                               const uint8_t *b8, int b_stride, int width,
188                               int height) {
189   int32_t y = 0;
190   int64_t sse = 0;
191   uint16_t *a = CONVERT_TO_SHORTPTR(a8);
192   uint16_t *b = CONVERT_TO_SHORTPTR(b8);
193   __m128i sum = _mm_setzero_si128();
194   switch (width) {
195     case 4:
196       do {
197         highbd_sse_w4x2_sse4_1(&sum, a, a_stride, b, b_stride);
198         a += a_stride << 1;
199         b += b_stride << 1;
200         y += 2;
201       } while (y < height);
202       sse = summary_all_sse4(&sum);
203       break;
204     case 8:
205       do {
206         highbd_sse_w8_sse4_1(&sum, a, b);
207         a += a_stride;
208         b += b_stride;
209         y += 1;
210       } while (y < height);
211       sse = summary_all_sse4(&sum);
212       break;
213     case 16:
214       do {
215         int l = 0;
216         __m128i sum32 = _mm_setzero_si128();
217         do {
218           highbd_sse_w8_sse4_1(&sum32, a, b);
219           highbd_sse_w8_sse4_1(&sum32, a + 8, b + 8);
220           a += a_stride;
221           b += b_stride;
222           l += 1;
223         } while (l < 64 && l < (height - y));
224         summary_32_sse4(&sum32, &sum);
225         y += 64;
226       } while (y < height);
227       _mm_storel_epi64((__m128i *)&sse,
228                        _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
229       break;
230     case 32:
231       do {
232         int l = 0;
233         __m128i sum32 = _mm_setzero_si128();
234         do {
235           highbd_sse_w8_sse4_1(&sum32, a, b);
236           highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1);
237           highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2);
238           highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3);
239           a += a_stride;
240           b += b_stride;
241           l += 1;
242         } while (l < 32 && l < (height - y));
243         summary_32_sse4(&sum32, &sum);
244         y += 32;
245       } while (y < height);
246       _mm_storel_epi64((__m128i *)&sse,
247                        _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
248       break;
249     case 64:
250       do {
251         int l = 0;
252         __m128i sum32 = _mm_setzero_si128();
253         do {
254           highbd_sse_w8_sse4_1(&sum32, a, b);
255           highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1);
256           highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2);
257           highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3);
258           highbd_sse_w8_sse4_1(&sum32, a + 8 * 4, b + 8 * 4);
259           highbd_sse_w8_sse4_1(&sum32, a + 8 * 5, b + 8 * 5);
260           highbd_sse_w8_sse4_1(&sum32, a + 8 * 6, b + 8 * 6);
261           highbd_sse_w8_sse4_1(&sum32, a + 8 * 7, b + 8 * 7);
262           a += a_stride;
263           b += b_stride;
264           l += 1;
265         } while (l < 16 && l < (height - y));
266         summary_32_sse4(&sum32, &sum);
267         y += 16;
268       } while (y < height);
269       _mm_storel_epi64((__m128i *)&sse,
270                        _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
271       break;
272     default:
273       if (width & 0x7) {
274         do {
275           __m128i sum32 = _mm_setzero_si128();
276           int i = 0;
277           do {
278             highbd_sse_w8_sse4_1(&sum32, a + i, b + i);
279             highbd_sse_w8_sse4_1(&sum32, a + i + a_stride, b + i + b_stride);
280             i += 8;
281           } while (i + 4 < width);
282           highbd_sse_w4x2_sse4_1(&sum32, a + i, a_stride, b + i, b_stride);
283           a += (a_stride << 1);
284           b += (b_stride << 1);
285           y += 2;
286           summary_32_sse4(&sum32, &sum);
287         } while (y < height);
288       } else {
289         do {
290           int l = 0;
291           __m128i sum32 = _mm_setzero_si128();
292           do {
293             int i = 0;
294             do {
295               highbd_sse_w8_sse4_1(&sum32, a + i, b + i);
296               i += 8;
297             } while (i < width);
298             a += a_stride;
299             b += b_stride;
300             l += 1;
301           } while (l < 8 && l < (height - y));
302           summary_32_sse4(&sum32, &sum);
303           y += 8;
304         } while (y < height);
305       }
306       _mm_storel_epi64((__m128i *)&sse,
307                        _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
308       break;
309   }
310   return sse;
311 }
312 #endif  // CONFIG_VP9_HIGHBITDEPTH
313