xref: /aosp_15_r20/external/libaom/aom_dsp/x86/obmc_variance_sse4.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <immintrin.h>
14 
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17 
18 #include "aom_ports/mem.h"
19 #include "aom/aom_integer.h"
20 
21 #include "aom_dsp/aom_dsp_common.h"
22 #include "aom_dsp/aom_filter.h"
23 #include "aom_dsp/x86/obmc_intrinsic_sse4.h"
24 #include "aom_dsp/x86/synonyms.h"
25 #include "aom_dsp/x86/variance_impl_ssse3.h"
26 
27 ////////////////////////////////////////////////////////////////////////////////
28 // 8 bit
29 ////////////////////////////////////////////////////////////////////////////////
30 
obmc_variance_w8n(const uint8_t * pre,const int pre_stride,const int32_t * wsrc,const int32_t * mask,unsigned int * const sse,int * const sum,const int w,const int h)31 static inline void obmc_variance_w8n(const uint8_t *pre, const int pre_stride,
32                                      const int32_t *wsrc, const int32_t *mask,
33                                      unsigned int *const sse, int *const sum,
34                                      const int w, const int h) {
35   const int pre_step = pre_stride - w;
36   int n = 0;
37   __m128i v_sum_d = _mm_setzero_si128();
38   __m128i v_sse_d = _mm_setzero_si128();
39 
40   assert(w >= 8);
41   assert(IS_POWER_OF_TWO(w));
42   assert(IS_POWER_OF_TWO(h));
43 
44   do {
45     const __m128i v_p1_b = xx_loadl_32(pre + n + 4);
46     const __m128i v_m1_d = xx_load_128(mask + n + 4);
47     const __m128i v_w1_d = xx_load_128(wsrc + n + 4);
48     const __m128i v_p0_b = xx_loadl_32(pre + n);
49     const __m128i v_m0_d = xx_load_128(mask + n);
50     const __m128i v_w0_d = xx_load_128(wsrc + n);
51 
52     const __m128i v_p0_d = _mm_cvtepu8_epi32(v_p0_b);
53     const __m128i v_p1_d = _mm_cvtepu8_epi32(v_p1_b);
54 
55     // Values in both pre and mask fit in 15 bits, and are packed at 32 bit
56     // boundaries. We use pmaddwd, as it has lower latency on Haswell
57     // than pmulld but produces the same result with these inputs.
58     const __m128i v_pm0_d = _mm_madd_epi16(v_p0_d, v_m0_d);
59     const __m128i v_pm1_d = _mm_madd_epi16(v_p1_d, v_m1_d);
60 
61     const __m128i v_diff0_d = _mm_sub_epi32(v_w0_d, v_pm0_d);
62     const __m128i v_diff1_d = _mm_sub_epi32(v_w1_d, v_pm1_d);
63 
64     const __m128i v_rdiff0_d = xx_roundn_epi32(v_diff0_d, 12);
65     const __m128i v_rdiff1_d = xx_roundn_epi32(v_diff1_d, 12);
66     const __m128i v_rdiff01_w = _mm_packs_epi32(v_rdiff0_d, v_rdiff1_d);
67     const __m128i v_sqrdiff_d = _mm_madd_epi16(v_rdiff01_w, v_rdiff01_w);
68 
69     v_sum_d = _mm_add_epi32(v_sum_d, v_rdiff0_d);
70     v_sum_d = _mm_add_epi32(v_sum_d, v_rdiff1_d);
71     v_sse_d = _mm_add_epi32(v_sse_d, v_sqrdiff_d);
72 
73     n += 8;
74 
75     if (n % w == 0) pre += pre_step;
76   } while (n < w * h);
77 
78   *sum = xx_hsum_epi32_si32(v_sum_d);
79   *sse = xx_hsum_epi32_si32(v_sse_d);
80 }
81 
82 #define OBMCVARWXH(W, H)                                               \
83   unsigned int aom_obmc_variance##W##x##H##_sse4_1(                    \
84       const uint8_t *pre, int pre_stride, const int32_t *wsrc,         \
85       const int32_t *mask, unsigned int *sse) {                        \
86     int sum;                                                           \
87     if (W == 4) {                                                      \
88       obmc_variance_w4(pre, pre_stride, wsrc, mask, sse, &sum, H);     \
89     } else {                                                           \
90       obmc_variance_w8n(pre, pre_stride, wsrc, mask, sse, &sum, W, H); \
91     }                                                                  \
92     return *sse - (unsigned int)(((int64_t)sum * sum) / (W * H));      \
93   }
94 
95 OBMCVARWXH(128, 128)
96 OBMCVARWXH(128, 64)
97 OBMCVARWXH(64, 128)
98 OBMCVARWXH(64, 64)
99 OBMCVARWXH(64, 32)
100 OBMCVARWXH(32, 64)
101 OBMCVARWXH(32, 32)
102 OBMCVARWXH(32, 16)
103 OBMCVARWXH(16, 32)
104 OBMCVARWXH(16, 16)
105 OBMCVARWXH(16, 8)
106 OBMCVARWXH(8, 16)
107 OBMCVARWXH(8, 8)
108 OBMCVARWXH(8, 4)
109 OBMCVARWXH(4, 8)
110 OBMCVARWXH(4, 4)
111 OBMCVARWXH(4, 16)
112 OBMCVARWXH(16, 4)
113 OBMCVARWXH(8, 32)
114 OBMCVARWXH(32, 8)
115 OBMCVARWXH(16, 64)
116 OBMCVARWXH(64, 16)
117 
118 #include "config/aom_dsp_rtcd.h"
119 
120 #define OBMC_SUBPIX_VAR(W, H)                                                \
121   uint32_t aom_obmc_sub_pixel_variance##W##x##H##_sse4_1(                    \
122       const uint8_t *pre, int pre_stride, int xoffset, int yoffset,          \
123       const int32_t *wsrc, const int32_t *mask, unsigned int *sse) {         \
124     uint16_t fdata3[(H + 1) * W];                                            \
125     uint8_t temp2[H * W];                                                    \
126                                                                              \
127     aom_var_filter_block2d_bil_first_pass_ssse3(                             \
128         pre, fdata3, pre_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]); \
129     aom_var_filter_block2d_bil_second_pass_ssse3(                            \
130         fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);            \
131                                                                              \
132     return aom_obmc_variance##W##x##H##_sse4_1(temp2, W, wsrc, mask, sse);   \
133   }
134 
135 OBMC_SUBPIX_VAR(128, 128)
136 OBMC_SUBPIX_VAR(128, 64)
137 OBMC_SUBPIX_VAR(64, 128)
138 OBMC_SUBPIX_VAR(64, 64)
139 OBMC_SUBPIX_VAR(64, 32)
140 OBMC_SUBPIX_VAR(32, 64)
141 OBMC_SUBPIX_VAR(32, 32)
142 OBMC_SUBPIX_VAR(32, 16)
143 OBMC_SUBPIX_VAR(16, 32)
144 OBMC_SUBPIX_VAR(16, 16)
145 OBMC_SUBPIX_VAR(16, 8)
146 OBMC_SUBPIX_VAR(8, 16)
147 OBMC_SUBPIX_VAR(8, 8)
148 OBMC_SUBPIX_VAR(8, 4)
149 OBMC_SUBPIX_VAR(4, 8)
150 OBMC_SUBPIX_VAR(4, 4)
151 OBMC_SUBPIX_VAR(4, 16)
152 OBMC_SUBPIX_VAR(16, 4)
153 OBMC_SUBPIX_VAR(8, 32)
154 OBMC_SUBPIX_VAR(32, 8)
155 OBMC_SUBPIX_VAR(16, 64)
156 OBMC_SUBPIX_VAR(64, 16)
157 
158 ////////////////////////////////////////////////////////////////////////////////
159 // High bit-depth
160 ////////////////////////////////////////////////////////////////////////////////
161 #if CONFIG_AV1_HIGHBITDEPTH
hbd_obmc_variance_w4(const uint8_t * pre8,const int pre_stride,const int32_t * wsrc,const int32_t * mask,uint64_t * const sse,int64_t * const sum,const int h)162 static inline void hbd_obmc_variance_w4(
163     const uint8_t *pre8, const int pre_stride, const int32_t *wsrc,
164     const int32_t *mask, uint64_t *const sse, int64_t *const sum, const int h) {
165   const uint16_t *pre = CONVERT_TO_SHORTPTR(pre8);
166   const int pre_step = pre_stride - 4;
167   int n = 0;
168   __m128i v_sum_d = _mm_setzero_si128();
169   __m128i v_sse_d = _mm_setzero_si128();
170 
171   assert(IS_POWER_OF_TWO(h));
172 
173   do {
174     const __m128i v_p_w = xx_loadl_64(pre + n);
175     const __m128i v_m_d = xx_load_128(mask + n);
176     const __m128i v_w_d = xx_load_128(wsrc + n);
177 
178     const __m128i v_p_d = _mm_cvtepu16_epi32(v_p_w);
179 
180     // Values in both pre and mask fit in 15 bits, and are packed at 32 bit
181     // boundaries. We use pmaddwd, as it has lower latency on Haswell
182     // than pmulld but produces the same result with these inputs.
183     const __m128i v_pm_d = _mm_madd_epi16(v_p_d, v_m_d);
184 
185     const __m128i v_diff_d = _mm_sub_epi32(v_w_d, v_pm_d);
186     const __m128i v_rdiff_d = xx_roundn_epi32(v_diff_d, 12);
187     const __m128i v_sqrdiff_d = _mm_mullo_epi32(v_rdiff_d, v_rdiff_d);
188 
189     v_sum_d = _mm_add_epi32(v_sum_d, v_rdiff_d);
190     v_sse_d = _mm_add_epi32(v_sse_d, v_sqrdiff_d);
191 
192     n += 4;
193 
194     if (n % 4 == 0) pre += pre_step;
195   } while (n < 4 * h);
196 
197   *sum = xx_hsum_epi32_si32(v_sum_d);
198   *sse = xx_hsum_epi32_si32(v_sse_d);
199 }
200 
hbd_obmc_variance_w8n(const uint8_t * pre8,const int pre_stride,const int32_t * wsrc,const int32_t * mask,uint64_t * const sse,int64_t * const sum,const int w,const int h)201 static inline void hbd_obmc_variance_w8n(
202     const uint8_t *pre8, const int pre_stride, const int32_t *wsrc,
203     const int32_t *mask, uint64_t *const sse, int64_t *const sum, const int w,
204     const int h) {
205   const uint16_t *pre = CONVERT_TO_SHORTPTR(pre8);
206   const int pre_step = pre_stride - w;
207   int n = 0;
208   __m128i v_sum_d = _mm_setzero_si128();
209   __m128i v_sse_d = _mm_setzero_si128();
210 
211   assert(w >= 8);
212   assert(IS_POWER_OF_TWO(w));
213   assert(IS_POWER_OF_TWO(h));
214 
215   do {
216     const __m128i v_p1_w = xx_loadl_64(pre + n + 4);
217     const __m128i v_m1_d = xx_load_128(mask + n + 4);
218     const __m128i v_w1_d = xx_load_128(wsrc + n + 4);
219     const __m128i v_p0_w = xx_loadl_64(pre + n);
220     const __m128i v_m0_d = xx_load_128(mask + n);
221     const __m128i v_w0_d = xx_load_128(wsrc + n);
222 
223     const __m128i v_p0_d = _mm_cvtepu16_epi32(v_p0_w);
224     const __m128i v_p1_d = _mm_cvtepu16_epi32(v_p1_w);
225 
226     // Values in both pre and mask fit in 15 bits, and are packed at 32 bit
227     // boundaries. We use pmaddwd, as it has lower latency on Haswell
228     // than pmulld but produces the same result with these inputs.
229     const __m128i v_pm0_d = _mm_madd_epi16(v_p0_d, v_m0_d);
230     const __m128i v_pm1_d = _mm_madd_epi16(v_p1_d, v_m1_d);
231 
232     const __m128i v_diff0_d = _mm_sub_epi32(v_w0_d, v_pm0_d);
233     const __m128i v_diff1_d = _mm_sub_epi32(v_w1_d, v_pm1_d);
234 
235     const __m128i v_rdiff0_d = xx_roundn_epi32(v_diff0_d, 12);
236     const __m128i v_rdiff1_d = xx_roundn_epi32(v_diff1_d, 12);
237     const __m128i v_rdiff01_w = _mm_packs_epi32(v_rdiff0_d, v_rdiff1_d);
238     const __m128i v_sqrdiff_d = _mm_madd_epi16(v_rdiff01_w, v_rdiff01_w);
239 
240     v_sum_d = _mm_add_epi32(v_sum_d, v_rdiff0_d);
241     v_sum_d = _mm_add_epi32(v_sum_d, v_rdiff1_d);
242     v_sse_d = _mm_add_epi32(v_sse_d, v_sqrdiff_d);
243 
244     n += 8;
245 
246     if (n % w == 0) pre += pre_step;
247   } while (n < w * h);
248 
249   *sum += xx_hsum_epi32_si64(v_sum_d);
250   *sse += xx_hsum_epi32_si64(v_sse_d);
251 }
252 
highbd_8_obmc_variance(const uint8_t * pre8,int pre_stride,const int32_t * wsrc,const int32_t * mask,int w,int h,unsigned int * sse,int * sum)253 static inline void highbd_8_obmc_variance(const uint8_t *pre8, int pre_stride,
254                                           const int32_t *wsrc,
255                                           const int32_t *mask, int w, int h,
256                                           unsigned int *sse, int *sum) {
257   int64_t sum64 = 0;
258   uint64_t sse64 = 0;
259   if (w == 4) {
260     hbd_obmc_variance_w4(pre8, pre_stride, wsrc, mask, &sse64, &sum64, h);
261   } else {
262     hbd_obmc_variance_w8n(pre8, pre_stride, wsrc, mask, &sse64, &sum64, w, h);
263   }
264   *sum = (int)sum64;
265   *sse = (unsigned int)sse64;
266 }
267 
highbd_10_obmc_variance(const uint8_t * pre8,int pre_stride,const int32_t * wsrc,const int32_t * mask,int w,int h,unsigned int * sse,int * sum)268 static inline void highbd_10_obmc_variance(const uint8_t *pre8, int pre_stride,
269                                            const int32_t *wsrc,
270                                            const int32_t *mask, int w, int h,
271                                            unsigned int *sse, int *sum) {
272   int64_t sum64 = 0;
273   uint64_t sse64 = 0;
274   if (w == 4) {
275     hbd_obmc_variance_w4(pre8, pre_stride, wsrc, mask, &sse64, &sum64, h);
276   } else if (w < 128 || h < 128) {
277     hbd_obmc_variance_w8n(pre8, pre_stride, wsrc, mask, &sse64, &sum64, w, h);
278   } else {
279     assert(w == 128 && h == 128);
280 
281     do {
282       hbd_obmc_variance_w8n(pre8, pre_stride, wsrc, mask, &sse64, &sum64, w,
283                             64);
284       pre8 += 64 * pre_stride;
285       wsrc += 64 * w;
286       mask += 64 * w;
287       h -= 64;
288     } while (h > 0);
289   }
290   *sum = (int)ROUND_POWER_OF_TWO(sum64, 2);
291   *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 4);
292 }
293 
highbd_12_obmc_variance(const uint8_t * pre8,int pre_stride,const int32_t * wsrc,const int32_t * mask,int w,int h,unsigned int * sse,int * sum)294 static inline void highbd_12_obmc_variance(const uint8_t *pre8, int pre_stride,
295                                            const int32_t *wsrc,
296                                            const int32_t *mask, int w, int h,
297                                            unsigned int *sse, int *sum) {
298   int64_t sum64 = 0;
299   uint64_t sse64 = 0;
300   int max_pel_allowed_per_ovf = 512;
301   if (w == 4) {
302     hbd_obmc_variance_w4(pre8, pre_stride, wsrc, mask, &sse64, &sum64, h);
303   } else if (w * h <= max_pel_allowed_per_ovf) {
304     hbd_obmc_variance_w8n(pre8, pre_stride, wsrc, mask, &sse64, &sum64, w, h);
305   } else {
306     int h_per_ovf = max_pel_allowed_per_ovf / w;
307 
308     assert(max_pel_allowed_per_ovf % w == 0);
309     do {
310       hbd_obmc_variance_w8n(pre8, pre_stride, wsrc, mask, &sse64, &sum64, w,
311                             h_per_ovf);
312       pre8 += h_per_ovf * pre_stride;
313       wsrc += h_per_ovf * w;
314       mask += h_per_ovf * w;
315       h -= h_per_ovf;
316     } while (h > 0);
317   }
318   *sum = (int)ROUND_POWER_OF_TWO(sum64, 4);
319   *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 8);
320 }
321 
322 #define HBD_OBMCVARWXH(W, H)                                               \
323   unsigned int aom_highbd_8_obmc_variance##W##x##H##_sse4_1(               \
324       const uint8_t *pre, int pre_stride, const int32_t *wsrc,             \
325       const int32_t *mask, unsigned int *sse) {                            \
326     int sum;                                                               \
327     highbd_8_obmc_variance(pre, pre_stride, wsrc, mask, W, H, sse, &sum);  \
328     return *sse - (unsigned int)(((int64_t)sum * sum) / (W * H));          \
329   }                                                                        \
330                                                                            \
331   unsigned int aom_highbd_10_obmc_variance##W##x##H##_sse4_1(              \
332       const uint8_t *pre, int pre_stride, const int32_t *wsrc,             \
333       const int32_t *mask, unsigned int *sse) {                            \
334     int sum;                                                               \
335     int64_t var;                                                           \
336     highbd_10_obmc_variance(pre, pre_stride, wsrc, mask, W, H, sse, &sum); \
337     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H));              \
338     return (var >= 0) ? (uint32_t)var : 0;                                 \
339   }                                                                        \
340                                                                            \
341   unsigned int aom_highbd_12_obmc_variance##W##x##H##_sse4_1(              \
342       const uint8_t *pre, int pre_stride, const int32_t *wsrc,             \
343       const int32_t *mask, unsigned int *sse) {                            \
344     int sum;                                                               \
345     int64_t var;                                                           \
346     highbd_12_obmc_variance(pre, pre_stride, wsrc, mask, W, H, sse, &sum); \
347     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H));              \
348     return (var >= 0) ? (uint32_t)var : 0;                                 \
349   }
350 
351 HBD_OBMCVARWXH(128, 128)
352 HBD_OBMCVARWXH(128, 64)
353 HBD_OBMCVARWXH(64, 128)
354 HBD_OBMCVARWXH(64, 64)
355 HBD_OBMCVARWXH(64, 32)
356 HBD_OBMCVARWXH(32, 64)
357 HBD_OBMCVARWXH(32, 32)
358 HBD_OBMCVARWXH(32, 16)
359 HBD_OBMCVARWXH(16, 32)
360 HBD_OBMCVARWXH(16, 16)
361 HBD_OBMCVARWXH(16, 8)
362 HBD_OBMCVARWXH(8, 16)
363 HBD_OBMCVARWXH(8, 8)
364 HBD_OBMCVARWXH(8, 4)
365 HBD_OBMCVARWXH(4, 8)
366 HBD_OBMCVARWXH(4, 4)
367 HBD_OBMCVARWXH(4, 16)
368 HBD_OBMCVARWXH(16, 4)
369 HBD_OBMCVARWXH(8, 32)
370 HBD_OBMCVARWXH(32, 8)
371 HBD_OBMCVARWXH(16, 64)
372 HBD_OBMCVARWXH(64, 16)
373 #endif  // CONFIG_AV1_HIGHBITDEPTH
374