xref: /aosp_15_r20/external/libaom/aom_dsp/x86/masked_variance_intrin_ssse3.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2017, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <stdlib.h>
13 #include <string.h>
14 #include <tmmintrin.h>
15 
16 #include "config/aom_config.h"
17 #include "config/aom_dsp_rtcd.h"
18 
19 #include "aom/aom_integer.h"
20 #include "aom_dsp/aom_filter.h"
21 #include "aom_dsp/blend.h"
22 #include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
23 #include "aom_dsp/x86/synonyms.h"
24 #include "aom_ports/mem.h"
25 
26 // For width a multiple of 16
27 static void bilinear_filter(const uint8_t *src, int src_stride, int xoffset,
28                             int yoffset, uint8_t *dst, int w, int h);
29 
30 static void bilinear_filter8xh(const uint8_t *src, int src_stride, int xoffset,
31                                int yoffset, uint8_t *dst, int h);
32 
33 static void bilinear_filter4xh(const uint8_t *src, int src_stride, int xoffset,
34                                int yoffset, uint8_t *dst, int h);
35 
36 // For width a multiple of 16
37 static void masked_variance(const uint8_t *src_ptr, int src_stride,
38                             const uint8_t *a_ptr, int a_stride,
39                             const uint8_t *b_ptr, int b_stride,
40                             const uint8_t *m_ptr, int m_stride, int width,
41                             int height, unsigned int *sse, int *sum_);
42 
43 static void masked_variance8xh(const uint8_t *src_ptr, int src_stride,
44                                const uint8_t *a_ptr, const uint8_t *b_ptr,
45                                const uint8_t *m_ptr, int m_stride, int height,
46                                unsigned int *sse, int *sum_);
47 
48 static void masked_variance4xh(const uint8_t *src_ptr, int src_stride,
49                                const uint8_t *a_ptr, const uint8_t *b_ptr,
50                                const uint8_t *m_ptr, int m_stride, int height,
51                                unsigned int *sse, int *sum_);
52 
53 #define MASK_SUBPIX_VAR_SSSE3(W, H)                                   \
54   unsigned int aom_masked_sub_pixel_variance##W##x##H##_ssse3(        \
55       const uint8_t *src, int src_stride, int xoffset, int yoffset,   \
56       const uint8_t *ref, int ref_stride, const uint8_t *second_pred, \
57       const uint8_t *msk, int msk_stride, int invert_mask,            \
58       unsigned int *sse) {                                            \
59     int sum;                                                          \
60     uint8_t temp[(H + 1) * W];                                        \
61                                                                       \
62     bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H);   \
63                                                                       \
64     if (!invert_mask)                                                 \
65       masked_variance(ref, ref_stride, temp, W, second_pred, W, msk,  \
66                       msk_stride, W, H, sse, &sum);                   \
67     else                                                              \
68       masked_variance(ref, ref_stride, second_pred, W, temp, W, msk,  \
69                       msk_stride, W, H, sse, &sum);                   \
70     return *sse - (uint32_t)(((int64_t)sum * sum) / (W * H));         \
71   }
72 
73 #define MASK_SUBPIX_VAR8XH_SSSE3(H)                                           \
74   unsigned int aom_masked_sub_pixel_variance8x##H##_ssse3(                    \
75       const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
76       const uint8_t *ref, int ref_stride, const uint8_t *second_pred,         \
77       const uint8_t *msk, int msk_stride, int invert_mask,                    \
78       unsigned int *sse) {                                                    \
79     int sum;                                                                  \
80     uint8_t temp[(H + 1) * 8];                                                \
81                                                                               \
82     bilinear_filter8xh(src, src_stride, xoffset, yoffset, temp, H);           \
83                                                                               \
84     if (!invert_mask)                                                         \
85       masked_variance8xh(ref, ref_stride, temp, second_pred, msk, msk_stride, \
86                          H, sse, &sum);                                       \
87     else                                                                      \
88       masked_variance8xh(ref, ref_stride, second_pred, temp, msk, msk_stride, \
89                          H, sse, &sum);                                       \
90     return *sse - (uint32_t)(((int64_t)sum * sum) / (8 * H));                 \
91   }
92 
93 #define MASK_SUBPIX_VAR4XH_SSSE3(H)                                           \
94   unsigned int aom_masked_sub_pixel_variance4x##H##_ssse3(                    \
95       const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
96       const uint8_t *ref, int ref_stride, const uint8_t *second_pred,         \
97       const uint8_t *msk, int msk_stride, int invert_mask,                    \
98       unsigned int *sse) {                                                    \
99     int sum;                                                                  \
100     uint8_t temp[(H + 1) * 4];                                                \
101                                                                               \
102     bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H);           \
103                                                                               \
104     if (!invert_mask)                                                         \
105       masked_variance4xh(ref, ref_stride, temp, second_pred, msk, msk_stride, \
106                          H, sse, &sum);                                       \
107     else                                                                      \
108       masked_variance4xh(ref, ref_stride, second_pred, temp, msk, msk_stride, \
109                          H, sse, &sum);                                       \
110     return *sse - (uint32_t)(((int64_t)sum * sum) / (4 * H));                 \
111   }
112 
113 MASK_SUBPIX_VAR_SSSE3(128, 128)
114 MASK_SUBPIX_VAR_SSSE3(128, 64)
115 MASK_SUBPIX_VAR_SSSE3(64, 128)
116 MASK_SUBPIX_VAR_SSSE3(64, 64)
117 MASK_SUBPIX_VAR_SSSE3(64, 32)
118 MASK_SUBPIX_VAR_SSSE3(32, 64)
119 MASK_SUBPIX_VAR_SSSE3(32, 32)
120 MASK_SUBPIX_VAR_SSSE3(32, 16)
121 MASK_SUBPIX_VAR_SSSE3(16, 32)
122 MASK_SUBPIX_VAR_SSSE3(16, 16)
123 MASK_SUBPIX_VAR_SSSE3(16, 8)
124 MASK_SUBPIX_VAR8XH_SSSE3(16)
125 MASK_SUBPIX_VAR8XH_SSSE3(8)
126 MASK_SUBPIX_VAR8XH_SSSE3(4)
127 MASK_SUBPIX_VAR4XH_SSSE3(8)
128 MASK_SUBPIX_VAR4XH_SSSE3(4)
129 
130 #if !CONFIG_REALTIME_ONLY
131 MASK_SUBPIX_VAR4XH_SSSE3(16)
132 MASK_SUBPIX_VAR_SSSE3(16, 4)
133 MASK_SUBPIX_VAR8XH_SSSE3(32)
134 MASK_SUBPIX_VAR_SSSE3(32, 8)
135 MASK_SUBPIX_VAR_SSSE3(64, 16)
136 MASK_SUBPIX_VAR_SSSE3(16, 64)
137 #endif  // !CONFIG_REALTIME_ONLY
138 
filter_block(const __m128i a,const __m128i b,const __m128i filter)139 static inline __m128i filter_block(const __m128i a, const __m128i b,
140                                    const __m128i filter) {
141   __m128i v0 = _mm_unpacklo_epi8(a, b);
142   v0 = _mm_maddubs_epi16(v0, filter);
143   v0 = xx_roundn_epu16(v0, FILTER_BITS);
144 
145   __m128i v1 = _mm_unpackhi_epi8(a, b);
146   v1 = _mm_maddubs_epi16(v1, filter);
147   v1 = xx_roundn_epu16(v1, FILTER_BITS);
148 
149   return _mm_packus_epi16(v0, v1);
150 }
151 
bilinear_filter(const uint8_t * src,int src_stride,int xoffset,int yoffset,uint8_t * dst,int w,int h)152 static void bilinear_filter(const uint8_t *src, int src_stride, int xoffset,
153                             int yoffset, uint8_t *dst, int w, int h) {
154   int i, j;
155   // Horizontal filter
156   if (xoffset == 0) {
157     uint8_t *b = dst;
158     for (i = 0; i < h + 1; ++i) {
159       for (j = 0; j < w; j += 16) {
160         __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
161         _mm_storeu_si128((__m128i *)&b[j], x);
162       }
163       src += src_stride;
164       b += w;
165     }
166   } else if (xoffset == 4) {
167     uint8_t *b = dst;
168     for (i = 0; i < h + 1; ++i) {
169       for (j = 0; j < w; j += 16) {
170         __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
171         __m128i y = _mm_loadu_si128((__m128i *)&src[j + 16]);
172         __m128i z = _mm_alignr_epi8(y, x, 1);
173         _mm_storeu_si128((__m128i *)&b[j], _mm_avg_epu8(x, z));
174       }
175       src += src_stride;
176       b += w;
177     }
178   } else {
179     uint8_t *b = dst;
180     const uint8_t *hfilter = bilinear_filters_2t[xoffset];
181     const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
182     for (i = 0; i < h + 1; ++i) {
183       for (j = 0; j < w; j += 16) {
184         const __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
185         const __m128i y = _mm_loadu_si128((__m128i *)&src[j + 16]);
186         const __m128i z = _mm_alignr_epi8(y, x, 1);
187         const __m128i res = filter_block(x, z, hfilter_vec);
188         _mm_storeu_si128((__m128i *)&b[j], res);
189       }
190 
191       src += src_stride;
192       b += w;
193     }
194   }
195 
196   // Vertical filter
197   if (yoffset == 0) {
198     // The data is already in 'dst', so no need to filter
199   } else if (yoffset == 4) {
200     for (i = 0; i < h; ++i) {
201       for (j = 0; j < w; j += 16) {
202         __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
203         __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
204         _mm_storeu_si128((__m128i *)&dst[j], _mm_avg_epu8(x, y));
205       }
206       dst += w;
207     }
208   } else {
209     const uint8_t *vfilter = bilinear_filters_2t[yoffset];
210     const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
211     for (i = 0; i < h; ++i) {
212       for (j = 0; j < w; j += 16) {
213         const __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
214         const __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
215         const __m128i res = filter_block(x, y, vfilter_vec);
216         _mm_storeu_si128((__m128i *)&dst[j], res);
217       }
218 
219       dst += w;
220     }
221   }
222 }
223 
filter_block_2rows(const __m128i * a0,const __m128i * b0,const __m128i * a1,const __m128i * b1,const __m128i * filter)224 static inline __m128i filter_block_2rows(const __m128i *a0, const __m128i *b0,
225                                          const __m128i *a1, const __m128i *b1,
226                                          const __m128i *filter) {
227   __m128i v0 = _mm_unpacklo_epi8(*a0, *b0);
228   v0 = _mm_maddubs_epi16(v0, *filter);
229   v0 = xx_roundn_epu16(v0, FILTER_BITS);
230 
231   __m128i v1 = _mm_unpacklo_epi8(*a1, *b1);
232   v1 = _mm_maddubs_epi16(v1, *filter);
233   v1 = xx_roundn_epu16(v1, FILTER_BITS);
234 
235   return _mm_packus_epi16(v0, v1);
236 }
237 
bilinear_filter8xh(const uint8_t * src,int src_stride,int xoffset,int yoffset,uint8_t * dst,int h)238 static void bilinear_filter8xh(const uint8_t *src, int src_stride, int xoffset,
239                                int yoffset, uint8_t *dst, int h) {
240   int i;
241   // Horizontal filter
242   if (xoffset == 0) {
243     uint8_t *b = dst;
244     for (i = 0; i < h + 1; ++i) {
245       __m128i x = _mm_loadl_epi64((__m128i *)src);
246       _mm_storel_epi64((__m128i *)b, x);
247       src += src_stride;
248       b += 8;
249     }
250   } else if (xoffset == 4) {
251     uint8_t *b = dst;
252     for (i = 0; i < h + 1; ++i) {
253       __m128i x = _mm_loadu_si128((__m128i *)src);
254       __m128i z = _mm_srli_si128(x, 1);
255       _mm_storel_epi64((__m128i *)b, _mm_avg_epu8(x, z));
256       src += src_stride;
257       b += 8;
258     }
259   } else {
260     uint8_t *b = dst;
261     const uint8_t *hfilter = bilinear_filters_2t[xoffset];
262     const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
263     for (i = 0; i < h; i += 2) {
264       const __m128i x0 = _mm_loadu_si128((__m128i *)src);
265       const __m128i z0 = _mm_srli_si128(x0, 1);
266       const __m128i x1 = _mm_loadu_si128((__m128i *)&src[src_stride]);
267       const __m128i z1 = _mm_srli_si128(x1, 1);
268       const __m128i res = filter_block_2rows(&x0, &z0, &x1, &z1, &hfilter_vec);
269       _mm_storeu_si128((__m128i *)b, res);
270 
271       src += src_stride * 2;
272       b += 16;
273     }
274     // Handle i = h separately
275     const __m128i x0 = _mm_loadu_si128((__m128i *)src);
276     const __m128i z0 = _mm_srli_si128(x0, 1);
277 
278     __m128i v0 = _mm_unpacklo_epi8(x0, z0);
279     v0 = _mm_maddubs_epi16(v0, hfilter_vec);
280     v0 = xx_roundn_epu16(v0, FILTER_BITS);
281 
282     _mm_storel_epi64((__m128i *)b, _mm_packus_epi16(v0, v0));
283   }
284 
285   // Vertical filter
286   if (yoffset == 0) {
287     // The data is already in 'dst', so no need to filter
288   } else if (yoffset == 4) {
289     for (i = 0; i < h; ++i) {
290       __m128i x = _mm_loadl_epi64((__m128i *)dst);
291       __m128i y = _mm_loadl_epi64((__m128i *)&dst[8]);
292       _mm_storel_epi64((__m128i *)dst, _mm_avg_epu8(x, y));
293       dst += 8;
294     }
295   } else {
296     const uint8_t *vfilter = bilinear_filters_2t[yoffset];
297     const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
298     for (i = 0; i < h; i += 2) {
299       const __m128i x = _mm_loadl_epi64((__m128i *)dst);
300       const __m128i y = _mm_loadl_epi64((__m128i *)&dst[8]);
301       const __m128i z = _mm_loadl_epi64((__m128i *)&dst[16]);
302       const __m128i res = filter_block_2rows(&x, &y, &y, &z, &vfilter_vec);
303       _mm_storeu_si128((__m128i *)dst, res);
304 
305       dst += 16;
306     }
307   }
308 }
309 
bilinear_filter4xh(const uint8_t * src,int src_stride,int xoffset,int yoffset,uint8_t * dst,int h)310 static void bilinear_filter4xh(const uint8_t *src, int src_stride, int xoffset,
311                                int yoffset, uint8_t *dst, int h) {
312   int i;
313   // Horizontal filter
314   if (xoffset == 0) {
315     uint8_t *b = dst;
316     for (i = 0; i < h + 1; ++i) {
317       __m128i x = xx_loadl_32((__m128i *)src);
318       xx_storel_32(b, x);
319       src += src_stride;
320       b += 4;
321     }
322   } else if (xoffset == 4) {
323     uint8_t *b = dst;
324     for (i = 0; i < h + 1; ++i) {
325       __m128i x = _mm_loadl_epi64((__m128i *)src);
326       __m128i z = _mm_srli_si128(x, 1);
327       xx_storel_32(b, _mm_avg_epu8(x, z));
328       src += src_stride;
329       b += 4;
330     }
331   } else {
332     uint8_t *b = dst;
333     const uint8_t *hfilter = bilinear_filters_2t[xoffset];
334     const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
335     for (i = 0; i < h; i += 4) {
336       const __m128i x0 = _mm_loadl_epi64((__m128i *)src);
337       const __m128i z0 = _mm_srli_si128(x0, 1);
338       const __m128i x1 = _mm_loadl_epi64((__m128i *)&src[src_stride]);
339       const __m128i z1 = _mm_srli_si128(x1, 1);
340       const __m128i x2 = _mm_loadl_epi64((__m128i *)&src[src_stride * 2]);
341       const __m128i z2 = _mm_srli_si128(x2, 1);
342       const __m128i x3 = _mm_loadl_epi64((__m128i *)&src[src_stride * 3]);
343       const __m128i z3 = _mm_srli_si128(x3, 1);
344 
345       const __m128i a0 = _mm_unpacklo_epi32(x0, x1);
346       const __m128i b0 = _mm_unpacklo_epi32(z0, z1);
347       const __m128i a1 = _mm_unpacklo_epi32(x2, x3);
348       const __m128i b1 = _mm_unpacklo_epi32(z2, z3);
349       const __m128i res = filter_block_2rows(&a0, &b0, &a1, &b1, &hfilter_vec);
350       _mm_storeu_si128((__m128i *)b, res);
351 
352       src += src_stride * 4;
353       b += 16;
354     }
355     // Handle i = h separately
356     const __m128i x = _mm_loadl_epi64((__m128i *)src);
357     const __m128i z = _mm_srli_si128(x, 1);
358 
359     __m128i v0 = _mm_unpacklo_epi8(x, z);
360     v0 = _mm_maddubs_epi16(v0, hfilter_vec);
361     v0 = xx_roundn_epu16(v0, FILTER_BITS);
362 
363     xx_storel_32(b, _mm_packus_epi16(v0, v0));
364   }
365 
366   // Vertical filter
367   if (yoffset == 0) {
368     // The data is already in 'dst', so no need to filter
369   } else if (yoffset == 4) {
370     for (i = 0; i < h; ++i) {
371       __m128i x = xx_loadl_32((__m128i *)dst);
372       __m128i y = xx_loadl_32((__m128i *)&dst[4]);
373       xx_storel_32(dst, _mm_avg_epu8(x, y));
374       dst += 4;
375     }
376   } else {
377     const uint8_t *vfilter = bilinear_filters_2t[yoffset];
378     const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
379     for (i = 0; i < h; i += 4) {
380       const __m128i a = xx_loadl_32((__m128i *)dst);
381       const __m128i b = xx_loadl_32((__m128i *)&dst[4]);
382       const __m128i c = xx_loadl_32((__m128i *)&dst[8]);
383       const __m128i d = xx_loadl_32((__m128i *)&dst[12]);
384       const __m128i e = xx_loadl_32((__m128i *)&dst[16]);
385 
386       const __m128i a0 = _mm_unpacklo_epi32(a, b);
387       const __m128i b0 = _mm_unpacklo_epi32(b, c);
388       const __m128i a1 = _mm_unpacklo_epi32(c, d);
389       const __m128i b1 = _mm_unpacklo_epi32(d, e);
390       const __m128i res = filter_block_2rows(&a0, &b0, &a1, &b1, &vfilter_vec);
391       _mm_storeu_si128((__m128i *)dst, res);
392 
393       dst += 16;
394     }
395   }
396 }
397 
accumulate_block(const __m128i * src,const __m128i * a,const __m128i * b,const __m128i * m,__m128i * sum,__m128i * sum_sq)398 static inline void accumulate_block(const __m128i *src, const __m128i *a,
399                                     const __m128i *b, const __m128i *m,
400                                     __m128i *sum, __m128i *sum_sq) {
401   const __m128i zero = _mm_setzero_si128();
402   const __m128i one = _mm_set1_epi16(1);
403   const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
404   const __m128i m_inv = _mm_sub_epi8(mask_max, *m);
405 
406   // Calculate 16 predicted pixels.
407   // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
408   // is 64 * 255, so we have plenty of space to add rounding constants.
409   const __m128i data_l = _mm_unpacklo_epi8(*a, *b);
410   const __m128i mask_l = _mm_unpacklo_epi8(*m, m_inv);
411   __m128i pred_l = _mm_maddubs_epi16(data_l, mask_l);
412   pred_l = xx_roundn_epu16(pred_l, AOM_BLEND_A64_ROUND_BITS);
413 
414   const __m128i data_r = _mm_unpackhi_epi8(*a, *b);
415   const __m128i mask_r = _mm_unpackhi_epi8(*m, m_inv);
416   __m128i pred_r = _mm_maddubs_epi16(data_r, mask_r);
417   pred_r = xx_roundn_epu16(pred_r, AOM_BLEND_A64_ROUND_BITS);
418 
419   const __m128i src_l = _mm_unpacklo_epi8(*src, zero);
420   const __m128i src_r = _mm_unpackhi_epi8(*src, zero);
421   const __m128i diff_l = _mm_sub_epi16(pred_l, src_l);
422   const __m128i diff_r = _mm_sub_epi16(pred_r, src_r);
423 
424   // Update partial sums and partial sums of squares
425   *sum =
426       _mm_add_epi32(*sum, _mm_madd_epi16(_mm_add_epi16(diff_l, diff_r), one));
427   *sum_sq =
428       _mm_add_epi32(*sum_sq, _mm_add_epi32(_mm_madd_epi16(diff_l, diff_l),
429                                            _mm_madd_epi16(diff_r, diff_r)));
430 }
431 
masked_variance(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height,unsigned int * sse,int * sum_)432 static void masked_variance(const uint8_t *src_ptr, int src_stride,
433                             const uint8_t *a_ptr, int a_stride,
434                             const uint8_t *b_ptr, int b_stride,
435                             const uint8_t *m_ptr, int m_stride, int width,
436                             int height, unsigned int *sse, int *sum_) {
437   int x, y;
438   __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
439 
440   for (y = 0; y < height; y++) {
441     for (x = 0; x < width; x += 16) {
442       const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
443       const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
444       const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
445       const __m128i m = _mm_loadu_si128((const __m128i *)&m_ptr[x]);
446       accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
447     }
448 
449     src_ptr += src_stride;
450     a_ptr += a_stride;
451     b_ptr += b_stride;
452     m_ptr += m_stride;
453   }
454   // Reduce down to a single sum and sum of squares
455   sum = _mm_hadd_epi32(sum, sum_sq);
456   sum = _mm_hadd_epi32(sum, sum);
457   *sum_ = _mm_cvtsi128_si32(sum);
458   *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
459 }
460 
masked_variance8xh(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,const uint8_t * b_ptr,const uint8_t * m_ptr,int m_stride,int height,unsigned int * sse,int * sum_)461 static void masked_variance8xh(const uint8_t *src_ptr, int src_stride,
462                                const uint8_t *a_ptr, const uint8_t *b_ptr,
463                                const uint8_t *m_ptr, int m_stride, int height,
464                                unsigned int *sse, int *sum_) {
465   int y;
466   __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
467 
468   for (y = 0; y < height; y += 2) {
469     __m128i src = _mm_unpacklo_epi64(
470         _mm_loadl_epi64((const __m128i *)src_ptr),
471         _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
472     const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
473     const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
474     const __m128i m =
475         _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)m_ptr),
476                            _mm_loadl_epi64((const __m128i *)&m_ptr[m_stride]));
477     accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
478 
479     src_ptr += src_stride * 2;
480     a_ptr += 16;
481     b_ptr += 16;
482     m_ptr += m_stride * 2;
483   }
484   // Reduce down to a single sum and sum of squares
485   sum = _mm_hadd_epi32(sum, sum_sq);
486   sum = _mm_hadd_epi32(sum, sum);
487   *sum_ = _mm_cvtsi128_si32(sum);
488   *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
489 }
490 
masked_variance4xh(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,const uint8_t * b_ptr,const uint8_t * m_ptr,int m_stride,int height,unsigned int * sse,int * sum_)491 static void masked_variance4xh(const uint8_t *src_ptr, int src_stride,
492                                const uint8_t *a_ptr, const uint8_t *b_ptr,
493                                const uint8_t *m_ptr, int m_stride, int height,
494                                unsigned int *sse, int *sum_) {
495   int y;
496   __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
497 
498   for (y = 0; y < height; y += 4) {
499     // Load four rows at a time
500     __m128i src = _mm_setr_epi32(*(int *)src_ptr, *(int *)&src_ptr[src_stride],
501                                  *(int *)&src_ptr[src_stride * 2],
502                                  *(int *)&src_ptr[src_stride * 3]);
503     const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
504     const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
505     const __m128i m = _mm_setr_epi32(*(int *)m_ptr, *(int *)&m_ptr[m_stride],
506                                      *(int *)&m_ptr[m_stride * 2],
507                                      *(int *)&m_ptr[m_stride * 3]);
508     accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
509 
510     src_ptr += src_stride * 4;
511     a_ptr += 16;
512     b_ptr += 16;
513     m_ptr += m_stride * 4;
514   }
515   // Reduce down to a single sum and sum of squares
516   sum = _mm_hadd_epi32(sum, sum_sq);
517   sum = _mm_hadd_epi32(sum, sum);
518   *sum_ = _mm_cvtsi128_si32(sum);
519   *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
520 }
521 
522 #if CONFIG_AV1_HIGHBITDEPTH
523 // For width a multiple of 8
524 static void highbd_bilinear_filter(const uint16_t *src, int src_stride,
525                                    int xoffset, int yoffset, uint16_t *dst,
526                                    int w, int h);
527 
528 static void highbd_bilinear_filter4xh(const uint16_t *src, int src_stride,
529                                       int xoffset, int yoffset, uint16_t *dst,
530                                       int h);
531 
532 // For width a multiple of 8
533 static void highbd_masked_variance(const uint16_t *src_ptr, int src_stride,
534                                    const uint16_t *a_ptr, int a_stride,
535                                    const uint16_t *b_ptr, int b_stride,
536                                    const uint8_t *m_ptr, int m_stride,
537                                    int width, int height, uint64_t *sse,
538                                    int *sum_);
539 
540 static void highbd_masked_variance4xh(const uint16_t *src_ptr, int src_stride,
541                                       const uint16_t *a_ptr,
542                                       const uint16_t *b_ptr,
543                                       const uint8_t *m_ptr, int m_stride,
544                                       int height, int *sse, int *sum_);
545 
546 #define HIGHBD_MASK_SUBPIX_VAR_SSSE3(W, H)                                  \
547   unsigned int aom_highbd_8_masked_sub_pixel_variance##W##x##H##_ssse3(     \
548       const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
549       const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
550       const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
551     uint64_t sse64;                                                         \
552     int sum;                                                                \
553     uint16_t temp[(H + 1) * W];                                             \
554     const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
555     const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
556     const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
557                                                                             \
558     highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H);  \
559                                                                             \
560     if (!invert_mask)                                                       \
561       highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
562                              msk_stride, W, H, &sse64, &sum);               \
563     else                                                                    \
564       highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
565                              msk_stride, W, H, &sse64, &sum);               \
566     *sse = (uint32_t)sse64;                                                 \
567     return *sse - (uint32_t)(((int64_t)sum * sum) / (W * H));               \
568   }                                                                         \
569   unsigned int aom_highbd_10_masked_sub_pixel_variance##W##x##H##_ssse3(    \
570       const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
571       const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
572       const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
573     uint64_t sse64;                                                         \
574     int sum;                                                                \
575     int64_t var;                                                            \
576     uint16_t temp[(H + 1) * W];                                             \
577     const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
578     const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
579     const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
580                                                                             \
581     highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H);  \
582                                                                             \
583     if (!invert_mask)                                                       \
584       highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
585                              msk_stride, W, H, &sse64, &sum);               \
586     else                                                                    \
587       highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
588                              msk_stride, W, H, &sse64, &sum);               \
589     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse64, 4);                          \
590     sum = ROUND_POWER_OF_TWO(sum, 2);                                       \
591     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H));               \
592     return (var >= 0) ? (uint32_t)var : 0;                                  \
593   }                                                                         \
594   unsigned int aom_highbd_12_masked_sub_pixel_variance##W##x##H##_ssse3(    \
595       const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
596       const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
597       const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
598     uint64_t sse64;                                                         \
599     int sum;                                                                \
600     int64_t var;                                                            \
601     uint16_t temp[(H + 1) * W];                                             \
602     const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
603     const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
604     const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
605                                                                             \
606     highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H);  \
607                                                                             \
608     if (!invert_mask)                                                       \
609       highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
610                              msk_stride, W, H, &sse64, &sum);               \
611     else                                                                    \
612       highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
613                              msk_stride, W, H, &sse64, &sum);               \
614     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse64, 8);                          \
615     sum = ROUND_POWER_OF_TWO(sum, 4);                                       \
616     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H));               \
617     return (var >= 0) ? (uint32_t)var : 0;                                  \
618   }
619 
620 #define HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(H)                                  \
621   unsigned int aom_highbd_8_masked_sub_pixel_variance4x##H##_ssse3(         \
622       const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
623       const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
624       const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
625     int sse_;                                                               \
626     int sum;                                                                \
627     uint16_t temp[(H + 1) * 4];                                             \
628     const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
629     const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
630     const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
631                                                                             \
632     highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H);  \
633                                                                             \
634     if (!invert_mask)                                                       \
635       highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk,    \
636                                 msk_stride, H, &sse_, &sum);                \
637     else                                                                    \
638       highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk,    \
639                                 msk_stride, H, &sse_, &sum);                \
640     *sse = (uint32_t)sse_;                                                  \
641     return *sse - (uint32_t)(((int64_t)sum * sum) / (4 * H));               \
642   }                                                                         \
643   unsigned int aom_highbd_10_masked_sub_pixel_variance4x##H##_ssse3(        \
644       const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
645       const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
646       const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
647     int sse_;                                                               \
648     int sum;                                                                \
649     int64_t var;                                                            \
650     uint16_t temp[(H + 1) * 4];                                             \
651     const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
652     const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
653     const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
654                                                                             \
655     highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H);  \
656                                                                             \
657     if (!invert_mask)                                                       \
658       highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk,    \
659                                 msk_stride, H, &sse_, &sum);                \
660     else                                                                    \
661       highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk,    \
662                                 msk_stride, H, &sse_, &sum);                \
663     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_, 4);                           \
664     sum = ROUND_POWER_OF_TWO(sum, 2);                                       \
665     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (4 * H));               \
666     return (var >= 0) ? (uint32_t)var : 0;                                  \
667   }                                                                         \
668   unsigned int aom_highbd_12_masked_sub_pixel_variance4x##H##_ssse3(        \
669       const uint8_t *src8, int src_stride, int xoffset, int yoffset,        \
670       const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8,     \
671       const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
672     int sse_;                                                               \
673     int sum;                                                                \
674     int64_t var;                                                            \
675     uint16_t temp[(H + 1) * 4];                                             \
676     const uint16_t *src = CONVERT_TO_SHORTPTR(src8);                        \
677     const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                        \
678     const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8);        \
679                                                                             \
680     highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H);  \
681                                                                             \
682     if (!invert_mask)                                                       \
683       highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk,    \
684                                 msk_stride, H, &sse_, &sum);                \
685     else                                                                    \
686       highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk,    \
687                                 msk_stride, H, &sse_, &sum);                \
688     *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_, 8);                           \
689     sum = ROUND_POWER_OF_TWO(sum, 4);                                       \
690     var = (int64_t)(*sse) - (((int64_t)sum * sum) / (4 * H));               \
691     return (var >= 0) ? (uint32_t)var : 0;                                  \
692   }
693 
694 HIGHBD_MASK_SUBPIX_VAR_SSSE3(128, 128)
695 HIGHBD_MASK_SUBPIX_VAR_SSSE3(128, 64)
696 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 128)
697 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 64)
698 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 32)
699 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 64)
700 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 32)
701 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 16)
702 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 32)
703 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 16)
704 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 8)
705 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 16)
706 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 8)
707 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 4)
708 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(8)
709 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(4)
710 
711 #if !CONFIG_REALTIME_ONLY
712 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(16)
713 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 4)
714 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 32)
715 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 8)
716 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 64)
717 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 16)
718 #endif  // !CONFIG_REALTIME_ONLY
719 
highbd_filter_block(const __m128i a,const __m128i b,const __m128i filter)720 static inline __m128i highbd_filter_block(const __m128i a, const __m128i b,
721                                           const __m128i filter) {
722   __m128i v0 = _mm_unpacklo_epi16(a, b);
723   v0 = _mm_madd_epi16(v0, filter);
724   v0 = xx_roundn_epu32(v0, FILTER_BITS);
725 
726   __m128i v1 = _mm_unpackhi_epi16(a, b);
727   v1 = _mm_madd_epi16(v1, filter);
728   v1 = xx_roundn_epu32(v1, FILTER_BITS);
729 
730   return _mm_packs_epi32(v0, v1);
731 }
732 
highbd_bilinear_filter(const uint16_t * src,int src_stride,int xoffset,int yoffset,uint16_t * dst,int w,int h)733 static void highbd_bilinear_filter(const uint16_t *src, int src_stride,
734                                    int xoffset, int yoffset, uint16_t *dst,
735                                    int w, int h) {
736   int i, j;
737   // Horizontal filter
738   if (xoffset == 0) {
739     uint16_t *b = dst;
740     for (i = 0; i < h + 1; ++i) {
741       for (j = 0; j < w; j += 8) {
742         __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
743         _mm_storeu_si128((__m128i *)&b[j], x);
744       }
745       src += src_stride;
746       b += w;
747     }
748   } else if (xoffset == 4) {
749     uint16_t *b = dst;
750     for (i = 0; i < h + 1; ++i) {
751       for (j = 0; j < w; j += 8) {
752         __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
753         __m128i y = _mm_loadu_si128((__m128i *)&src[j + 8]);
754         __m128i z = _mm_alignr_epi8(y, x, 2);
755         _mm_storeu_si128((__m128i *)&b[j], _mm_avg_epu16(x, z));
756       }
757       src += src_stride;
758       b += w;
759     }
760   } else {
761     uint16_t *b = dst;
762     const uint8_t *hfilter = bilinear_filters_2t[xoffset];
763     const __m128i hfilter_vec = _mm_set1_epi32(hfilter[0] | (hfilter[1] << 16));
764     for (i = 0; i < h + 1; ++i) {
765       for (j = 0; j < w; j += 8) {
766         const __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
767         const __m128i y = _mm_loadu_si128((__m128i *)&src[j + 8]);
768         const __m128i z = _mm_alignr_epi8(y, x, 2);
769         const __m128i res = highbd_filter_block(x, z, hfilter_vec);
770         _mm_storeu_si128((__m128i *)&b[j], res);
771       }
772 
773       src += src_stride;
774       b += w;
775     }
776   }
777 
778   // Vertical filter
779   if (yoffset == 0) {
780     // The data is already in 'dst', so no need to filter
781   } else if (yoffset == 4) {
782     for (i = 0; i < h; ++i) {
783       for (j = 0; j < w; j += 8) {
784         __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
785         __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
786         _mm_storeu_si128((__m128i *)&dst[j], _mm_avg_epu16(x, y));
787       }
788       dst += w;
789     }
790   } else {
791     const uint8_t *vfilter = bilinear_filters_2t[yoffset];
792     const __m128i vfilter_vec = _mm_set1_epi32(vfilter[0] | (vfilter[1] << 16));
793     for (i = 0; i < h; ++i) {
794       for (j = 0; j < w; j += 8) {
795         const __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
796         const __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
797         const __m128i res = highbd_filter_block(x, y, vfilter_vec);
798         _mm_storeu_si128((__m128i *)&dst[j], res);
799       }
800 
801       dst += w;
802     }
803   }
804 }
805 
highbd_filter_block_2rows(const __m128i * a0,const __m128i * b0,const __m128i * a1,const __m128i * b1,const __m128i * filter)806 static inline __m128i highbd_filter_block_2rows(const __m128i *a0,
807                                                 const __m128i *b0,
808                                                 const __m128i *a1,
809                                                 const __m128i *b1,
810                                                 const __m128i *filter) {
811   __m128i v0 = _mm_unpacklo_epi16(*a0, *b0);
812   v0 = _mm_madd_epi16(v0, *filter);
813   v0 = xx_roundn_epu32(v0, FILTER_BITS);
814 
815   __m128i v1 = _mm_unpacklo_epi16(*a1, *b1);
816   v1 = _mm_madd_epi16(v1, *filter);
817   v1 = xx_roundn_epu32(v1, FILTER_BITS);
818 
819   return _mm_packs_epi32(v0, v1);
820 }
821 
highbd_bilinear_filter4xh(const uint16_t * src,int src_stride,int xoffset,int yoffset,uint16_t * dst,int h)822 static void highbd_bilinear_filter4xh(const uint16_t *src, int src_stride,
823                                       int xoffset, int yoffset, uint16_t *dst,
824                                       int h) {
825   int i;
826   // Horizontal filter
827   if (xoffset == 0) {
828     uint16_t *b = dst;
829     for (i = 0; i < h + 1; ++i) {
830       __m128i x = _mm_loadl_epi64((__m128i *)src);
831       _mm_storel_epi64((__m128i *)b, x);
832       src += src_stride;
833       b += 4;
834     }
835   } else if (xoffset == 4) {
836     uint16_t *b = dst;
837     for (i = 0; i < h + 1; ++i) {
838       __m128i x = _mm_loadu_si128((__m128i *)src);
839       __m128i z = _mm_srli_si128(x, 2);
840       _mm_storel_epi64((__m128i *)b, _mm_avg_epu16(x, z));
841       src += src_stride;
842       b += 4;
843     }
844   } else {
845     uint16_t *b = dst;
846     const uint8_t *hfilter = bilinear_filters_2t[xoffset];
847     const __m128i hfilter_vec = _mm_set1_epi32(hfilter[0] | (hfilter[1] << 16));
848     for (i = 0; i < h; i += 2) {
849       const __m128i x0 = _mm_loadu_si128((__m128i *)src);
850       const __m128i z0 = _mm_srli_si128(x0, 2);
851       const __m128i x1 = _mm_loadu_si128((__m128i *)&src[src_stride]);
852       const __m128i z1 = _mm_srli_si128(x1, 2);
853       const __m128i res =
854           highbd_filter_block_2rows(&x0, &z0, &x1, &z1, &hfilter_vec);
855       _mm_storeu_si128((__m128i *)b, res);
856 
857       src += src_stride * 2;
858       b += 8;
859     }
860     // Process i = h separately
861     __m128i x = _mm_loadu_si128((__m128i *)src);
862     __m128i z = _mm_srli_si128(x, 2);
863 
864     __m128i v0 = _mm_unpacklo_epi16(x, z);
865     v0 = _mm_madd_epi16(v0, hfilter_vec);
866     v0 = xx_roundn_epu32(v0, FILTER_BITS);
867 
868     _mm_storel_epi64((__m128i *)b, _mm_packs_epi32(v0, v0));
869   }
870 
871   // Vertical filter
872   if (yoffset == 0) {
873     // The data is already in 'dst', so no need to filter
874   } else if (yoffset == 4) {
875     for (i = 0; i < h; ++i) {
876       __m128i x = _mm_loadl_epi64((__m128i *)dst);
877       __m128i y = _mm_loadl_epi64((__m128i *)&dst[4]);
878       _mm_storel_epi64((__m128i *)dst, _mm_avg_epu16(x, y));
879       dst += 4;
880     }
881   } else {
882     const uint8_t *vfilter = bilinear_filters_2t[yoffset];
883     const __m128i vfilter_vec = _mm_set1_epi32(vfilter[0] | (vfilter[1] << 16));
884     for (i = 0; i < h; i += 2) {
885       const __m128i x = _mm_loadl_epi64((__m128i *)dst);
886       const __m128i y = _mm_loadl_epi64((__m128i *)&dst[4]);
887       const __m128i z = _mm_loadl_epi64((__m128i *)&dst[8]);
888       const __m128i res =
889           highbd_filter_block_2rows(&x, &y, &y, &z, &vfilter_vec);
890       _mm_storeu_si128((__m128i *)dst, res);
891 
892       dst += 8;
893     }
894   }
895 }
896 
highbd_masked_variance(const uint16_t * src_ptr,int src_stride,const uint16_t * a_ptr,int a_stride,const uint16_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height,uint64_t * sse,int * sum_)897 static void highbd_masked_variance(const uint16_t *src_ptr, int src_stride,
898                                    const uint16_t *a_ptr, int a_stride,
899                                    const uint16_t *b_ptr, int b_stride,
900                                    const uint8_t *m_ptr, int m_stride,
901                                    int width, int height, uint64_t *sse,
902                                    int *sum_) {
903   int x, y;
904   // Note on bit widths:
905   // The maximum value of 'sum' is (2^12 - 1) * 128 * 128 =~ 2^26,
906   // so this can be kept as four 32-bit values.
907   // But the maximum value of 'sum_sq' is (2^12 - 1)^2 * 128 * 128 =~ 2^38,
908   // so this must be stored as two 64-bit values.
909   __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
910   const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
911   const __m128i round_const =
912       _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
913   const __m128i zero = _mm_setzero_si128();
914 
915   for (y = 0; y < height; y++) {
916     for (x = 0; x < width; x += 8) {
917       const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
918       const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
919       const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
920       const __m128i m =
921           _mm_unpacklo_epi8(_mm_loadl_epi64((const __m128i *)&m_ptr[x]), zero);
922       const __m128i m_inv = _mm_sub_epi16(mask_max, m);
923 
924       // Calculate 8 predicted pixels.
925       const __m128i data_l = _mm_unpacklo_epi16(a, b);
926       const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
927       __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
928       pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
929                               AOM_BLEND_A64_ROUND_BITS);
930 
931       const __m128i data_r = _mm_unpackhi_epi16(a, b);
932       const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
933       __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
934       pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
935                               AOM_BLEND_A64_ROUND_BITS);
936 
937       const __m128i src_l = _mm_unpacklo_epi16(src, zero);
938       const __m128i src_r = _mm_unpackhi_epi16(src, zero);
939       __m128i diff_l = _mm_sub_epi32(pred_l, src_l);
940       __m128i diff_r = _mm_sub_epi32(pred_r, src_r);
941 
942       // Update partial sums and partial sums of squares
943       sum = _mm_add_epi32(sum, _mm_add_epi32(diff_l, diff_r));
944       // A trick: Now each entry of diff_l and diff_r is stored in a 32-bit
945       // field, but the range of values is only [-(2^12 - 1), 2^12 - 1].
946       // So we can re-pack into 16-bit fields and use _mm_madd_epi16
947       // to calculate the squares and partially sum them.
948       const __m128i tmp = _mm_packs_epi32(diff_l, diff_r);
949       const __m128i prod = _mm_madd_epi16(tmp, tmp);
950       // Then we want to sign-extend to 64 bits and accumulate
951       const __m128i sign = _mm_srai_epi32(prod, 31);
952       const __m128i tmp_0 = _mm_unpacklo_epi32(prod, sign);
953       const __m128i tmp_1 = _mm_unpackhi_epi32(prod, sign);
954       sum_sq = _mm_add_epi64(sum_sq, _mm_add_epi64(tmp_0, tmp_1));
955     }
956 
957     src_ptr += src_stride;
958     a_ptr += a_stride;
959     b_ptr += b_stride;
960     m_ptr += m_stride;
961   }
962   // Reduce down to a single sum and sum of squares
963   sum = _mm_hadd_epi32(sum, zero);
964   sum = _mm_hadd_epi32(sum, zero);
965   *sum_ = _mm_cvtsi128_si32(sum);
966   sum_sq = _mm_add_epi64(sum_sq, _mm_srli_si128(sum_sq, 8));
967   _mm_storel_epi64((__m128i *)sse, sum_sq);
968 }
969 
highbd_masked_variance4xh(const uint16_t * src_ptr,int src_stride,const uint16_t * a_ptr,const uint16_t * b_ptr,const uint8_t * m_ptr,int m_stride,int height,int * sse,int * sum_)970 static void highbd_masked_variance4xh(const uint16_t *src_ptr, int src_stride,
971                                       const uint16_t *a_ptr,
972                                       const uint16_t *b_ptr,
973                                       const uint8_t *m_ptr, int m_stride,
974                                       int height, int *sse, int *sum_) {
975   int y;
976   // Note: For this function, h <= 8 (or maybe 16 if we add 4:1 partitions).
977   // So the maximum value of sum is (2^12 - 1) * 4 * 16 =~ 2^18
978   // and the maximum value of sum_sq is (2^12 - 1)^2 * 4 * 16 =~ 2^30.
979   // So we can safely pack sum_sq into 32-bit fields, which is slightly more
980   // convenient.
981   __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
982   const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
983   const __m128i round_const =
984       _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
985   const __m128i zero = _mm_setzero_si128();
986 
987   for (y = 0; y < height; y += 2) {
988     __m128i src = _mm_unpacklo_epi64(
989         _mm_loadl_epi64((const __m128i *)src_ptr),
990         _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
991     const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
992     const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
993     const __m128i m = _mm_unpacklo_epi8(
994         _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(const int *)m_ptr),
995                            _mm_cvtsi32_si128(*(const int *)&m_ptr[m_stride])),
996         zero);
997     const __m128i m_inv = _mm_sub_epi16(mask_max, m);
998 
999     const __m128i data_l = _mm_unpacklo_epi16(a, b);
1000     const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
1001     __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
1002     pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
1003                             AOM_BLEND_A64_ROUND_BITS);
1004 
1005     const __m128i data_r = _mm_unpackhi_epi16(a, b);
1006     const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
1007     __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
1008     pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
1009                             AOM_BLEND_A64_ROUND_BITS);
1010 
1011     const __m128i src_l = _mm_unpacklo_epi16(src, zero);
1012     const __m128i src_r = _mm_unpackhi_epi16(src, zero);
1013     __m128i diff_l = _mm_sub_epi32(pred_l, src_l);
1014     __m128i diff_r = _mm_sub_epi32(pred_r, src_r);
1015 
1016     // Update partial sums and partial sums of squares
1017     sum = _mm_add_epi32(sum, _mm_add_epi32(diff_l, diff_r));
1018     const __m128i tmp = _mm_packs_epi32(diff_l, diff_r);
1019     const __m128i prod = _mm_madd_epi16(tmp, tmp);
1020     sum_sq = _mm_add_epi32(sum_sq, prod);
1021 
1022     src_ptr += src_stride * 2;
1023     a_ptr += 8;
1024     b_ptr += 8;
1025     m_ptr += m_stride * 2;
1026   }
1027   // Reduce down to a single sum and sum of squares
1028   sum = _mm_hadd_epi32(sum, sum_sq);
1029   sum = _mm_hadd_epi32(sum, zero);
1030   *sum_ = _mm_cvtsi128_si32(sum);
1031   *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
1032 }
1033 #endif  // CONFIG_AV1_HIGHBITDEPTH
1034 
aom_comp_mask_pred_ssse3(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)1035 void aom_comp_mask_pred_ssse3(uint8_t *comp_pred, const uint8_t *pred,
1036                               int width, int height, const uint8_t *ref,
1037                               int ref_stride, const uint8_t *mask,
1038                               int mask_stride, int invert_mask) {
1039   const uint8_t *src0 = invert_mask ? pred : ref;
1040   const uint8_t *src1 = invert_mask ? ref : pred;
1041   const int stride0 = invert_mask ? width : ref_stride;
1042   const int stride1 = invert_mask ? ref_stride : width;
1043   assert(height % 2 == 0);
1044   int i = 0;
1045   if (width == 8) {
1046     comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
1047                            mask, mask_stride);
1048   } else if (width == 16) {
1049     do {
1050       comp_mask_pred_16_ssse3(src0, src1, mask, comp_pred);
1051       comp_mask_pred_16_ssse3(src0 + stride0, src1 + stride1,
1052                               mask + mask_stride, comp_pred + width);
1053       comp_pred += (width << 1);
1054       src0 += (stride0 << 1);
1055       src1 += (stride1 << 1);
1056       mask += (mask_stride << 1);
1057       i += 2;
1058     } while (i < height);
1059   } else {
1060     do {
1061       for (int x = 0; x < width; x += 32) {
1062         comp_mask_pred_16_ssse3(src0 + x, src1 + x, mask + x, comp_pred);
1063         comp_mask_pred_16_ssse3(src0 + x + 16, src1 + x + 16, mask + x + 16,
1064                                 comp_pred + 16);
1065         comp_pred += 32;
1066       }
1067       src0 += (stride0);
1068       src1 += (stride1);
1069       mask += (mask_stride);
1070       i += 1;
1071     } while (i < height);
1072   }
1073 }
1074