1 /*
2 * Copyright (c) 2017, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <stdlib.h>
13 #include <string.h>
14 #include <tmmintrin.h>
15
16 #include "config/aom_config.h"
17 #include "config/aom_dsp_rtcd.h"
18
19 #include "aom/aom_integer.h"
20 #include "aom_dsp/aom_filter.h"
21 #include "aom_dsp/blend.h"
22 #include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
23 #include "aom_dsp/x86/synonyms.h"
24 #include "aom_ports/mem.h"
25
26 // For width a multiple of 16
27 static void bilinear_filter(const uint8_t *src, int src_stride, int xoffset,
28 int yoffset, uint8_t *dst, int w, int h);
29
30 static void bilinear_filter8xh(const uint8_t *src, int src_stride, int xoffset,
31 int yoffset, uint8_t *dst, int h);
32
33 static void bilinear_filter4xh(const uint8_t *src, int src_stride, int xoffset,
34 int yoffset, uint8_t *dst, int h);
35
36 // For width a multiple of 16
37 static void masked_variance(const uint8_t *src_ptr, int src_stride,
38 const uint8_t *a_ptr, int a_stride,
39 const uint8_t *b_ptr, int b_stride,
40 const uint8_t *m_ptr, int m_stride, int width,
41 int height, unsigned int *sse, int *sum_);
42
43 static void masked_variance8xh(const uint8_t *src_ptr, int src_stride,
44 const uint8_t *a_ptr, const uint8_t *b_ptr,
45 const uint8_t *m_ptr, int m_stride, int height,
46 unsigned int *sse, int *sum_);
47
48 static void masked_variance4xh(const uint8_t *src_ptr, int src_stride,
49 const uint8_t *a_ptr, const uint8_t *b_ptr,
50 const uint8_t *m_ptr, int m_stride, int height,
51 unsigned int *sse, int *sum_);
52
53 #define MASK_SUBPIX_VAR_SSSE3(W, H) \
54 unsigned int aom_masked_sub_pixel_variance##W##x##H##_ssse3( \
55 const uint8_t *src, int src_stride, int xoffset, int yoffset, \
56 const uint8_t *ref, int ref_stride, const uint8_t *second_pred, \
57 const uint8_t *msk, int msk_stride, int invert_mask, \
58 unsigned int *sse) { \
59 int sum; \
60 uint8_t temp[(H + 1) * W]; \
61 \
62 bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H); \
63 \
64 if (!invert_mask) \
65 masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
66 msk_stride, W, H, sse, &sum); \
67 else \
68 masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
69 msk_stride, W, H, sse, &sum); \
70 return *sse - (uint32_t)(((int64_t)sum * sum) / (W * H)); \
71 }
72
73 #define MASK_SUBPIX_VAR8XH_SSSE3(H) \
74 unsigned int aom_masked_sub_pixel_variance8x##H##_ssse3( \
75 const uint8_t *src, int src_stride, int xoffset, int yoffset, \
76 const uint8_t *ref, int ref_stride, const uint8_t *second_pred, \
77 const uint8_t *msk, int msk_stride, int invert_mask, \
78 unsigned int *sse) { \
79 int sum; \
80 uint8_t temp[(H + 1) * 8]; \
81 \
82 bilinear_filter8xh(src, src_stride, xoffset, yoffset, temp, H); \
83 \
84 if (!invert_mask) \
85 masked_variance8xh(ref, ref_stride, temp, second_pred, msk, msk_stride, \
86 H, sse, &sum); \
87 else \
88 masked_variance8xh(ref, ref_stride, second_pred, temp, msk, msk_stride, \
89 H, sse, &sum); \
90 return *sse - (uint32_t)(((int64_t)sum * sum) / (8 * H)); \
91 }
92
93 #define MASK_SUBPIX_VAR4XH_SSSE3(H) \
94 unsigned int aom_masked_sub_pixel_variance4x##H##_ssse3( \
95 const uint8_t *src, int src_stride, int xoffset, int yoffset, \
96 const uint8_t *ref, int ref_stride, const uint8_t *second_pred, \
97 const uint8_t *msk, int msk_stride, int invert_mask, \
98 unsigned int *sse) { \
99 int sum; \
100 uint8_t temp[(H + 1) * 4]; \
101 \
102 bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H); \
103 \
104 if (!invert_mask) \
105 masked_variance4xh(ref, ref_stride, temp, second_pred, msk, msk_stride, \
106 H, sse, &sum); \
107 else \
108 masked_variance4xh(ref, ref_stride, second_pred, temp, msk, msk_stride, \
109 H, sse, &sum); \
110 return *sse - (uint32_t)(((int64_t)sum * sum) / (4 * H)); \
111 }
112
113 MASK_SUBPIX_VAR_SSSE3(128, 128)
114 MASK_SUBPIX_VAR_SSSE3(128, 64)
115 MASK_SUBPIX_VAR_SSSE3(64, 128)
116 MASK_SUBPIX_VAR_SSSE3(64, 64)
117 MASK_SUBPIX_VAR_SSSE3(64, 32)
118 MASK_SUBPIX_VAR_SSSE3(32, 64)
119 MASK_SUBPIX_VAR_SSSE3(32, 32)
120 MASK_SUBPIX_VAR_SSSE3(32, 16)
121 MASK_SUBPIX_VAR_SSSE3(16, 32)
122 MASK_SUBPIX_VAR_SSSE3(16, 16)
123 MASK_SUBPIX_VAR_SSSE3(16, 8)
124 MASK_SUBPIX_VAR8XH_SSSE3(16)
125 MASK_SUBPIX_VAR8XH_SSSE3(8)
126 MASK_SUBPIX_VAR8XH_SSSE3(4)
127 MASK_SUBPIX_VAR4XH_SSSE3(8)
128 MASK_SUBPIX_VAR4XH_SSSE3(4)
129
130 #if !CONFIG_REALTIME_ONLY
131 MASK_SUBPIX_VAR4XH_SSSE3(16)
132 MASK_SUBPIX_VAR_SSSE3(16, 4)
133 MASK_SUBPIX_VAR8XH_SSSE3(32)
134 MASK_SUBPIX_VAR_SSSE3(32, 8)
135 MASK_SUBPIX_VAR_SSSE3(64, 16)
136 MASK_SUBPIX_VAR_SSSE3(16, 64)
137 #endif // !CONFIG_REALTIME_ONLY
138
filter_block(const __m128i a,const __m128i b,const __m128i filter)139 static inline __m128i filter_block(const __m128i a, const __m128i b,
140 const __m128i filter) {
141 __m128i v0 = _mm_unpacklo_epi8(a, b);
142 v0 = _mm_maddubs_epi16(v0, filter);
143 v0 = xx_roundn_epu16(v0, FILTER_BITS);
144
145 __m128i v1 = _mm_unpackhi_epi8(a, b);
146 v1 = _mm_maddubs_epi16(v1, filter);
147 v1 = xx_roundn_epu16(v1, FILTER_BITS);
148
149 return _mm_packus_epi16(v0, v1);
150 }
151
bilinear_filter(const uint8_t * src,int src_stride,int xoffset,int yoffset,uint8_t * dst,int w,int h)152 static void bilinear_filter(const uint8_t *src, int src_stride, int xoffset,
153 int yoffset, uint8_t *dst, int w, int h) {
154 int i, j;
155 // Horizontal filter
156 if (xoffset == 0) {
157 uint8_t *b = dst;
158 for (i = 0; i < h + 1; ++i) {
159 for (j = 0; j < w; j += 16) {
160 __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
161 _mm_storeu_si128((__m128i *)&b[j], x);
162 }
163 src += src_stride;
164 b += w;
165 }
166 } else if (xoffset == 4) {
167 uint8_t *b = dst;
168 for (i = 0; i < h + 1; ++i) {
169 for (j = 0; j < w; j += 16) {
170 __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
171 __m128i y = _mm_loadu_si128((__m128i *)&src[j + 16]);
172 __m128i z = _mm_alignr_epi8(y, x, 1);
173 _mm_storeu_si128((__m128i *)&b[j], _mm_avg_epu8(x, z));
174 }
175 src += src_stride;
176 b += w;
177 }
178 } else {
179 uint8_t *b = dst;
180 const uint8_t *hfilter = bilinear_filters_2t[xoffset];
181 const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
182 for (i = 0; i < h + 1; ++i) {
183 for (j = 0; j < w; j += 16) {
184 const __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
185 const __m128i y = _mm_loadu_si128((__m128i *)&src[j + 16]);
186 const __m128i z = _mm_alignr_epi8(y, x, 1);
187 const __m128i res = filter_block(x, z, hfilter_vec);
188 _mm_storeu_si128((__m128i *)&b[j], res);
189 }
190
191 src += src_stride;
192 b += w;
193 }
194 }
195
196 // Vertical filter
197 if (yoffset == 0) {
198 // The data is already in 'dst', so no need to filter
199 } else if (yoffset == 4) {
200 for (i = 0; i < h; ++i) {
201 for (j = 0; j < w; j += 16) {
202 __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
203 __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
204 _mm_storeu_si128((__m128i *)&dst[j], _mm_avg_epu8(x, y));
205 }
206 dst += w;
207 }
208 } else {
209 const uint8_t *vfilter = bilinear_filters_2t[yoffset];
210 const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
211 for (i = 0; i < h; ++i) {
212 for (j = 0; j < w; j += 16) {
213 const __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
214 const __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
215 const __m128i res = filter_block(x, y, vfilter_vec);
216 _mm_storeu_si128((__m128i *)&dst[j], res);
217 }
218
219 dst += w;
220 }
221 }
222 }
223
filter_block_2rows(const __m128i * a0,const __m128i * b0,const __m128i * a1,const __m128i * b1,const __m128i * filter)224 static inline __m128i filter_block_2rows(const __m128i *a0, const __m128i *b0,
225 const __m128i *a1, const __m128i *b1,
226 const __m128i *filter) {
227 __m128i v0 = _mm_unpacklo_epi8(*a0, *b0);
228 v0 = _mm_maddubs_epi16(v0, *filter);
229 v0 = xx_roundn_epu16(v0, FILTER_BITS);
230
231 __m128i v1 = _mm_unpacklo_epi8(*a1, *b1);
232 v1 = _mm_maddubs_epi16(v1, *filter);
233 v1 = xx_roundn_epu16(v1, FILTER_BITS);
234
235 return _mm_packus_epi16(v0, v1);
236 }
237
bilinear_filter8xh(const uint8_t * src,int src_stride,int xoffset,int yoffset,uint8_t * dst,int h)238 static void bilinear_filter8xh(const uint8_t *src, int src_stride, int xoffset,
239 int yoffset, uint8_t *dst, int h) {
240 int i;
241 // Horizontal filter
242 if (xoffset == 0) {
243 uint8_t *b = dst;
244 for (i = 0; i < h + 1; ++i) {
245 __m128i x = _mm_loadl_epi64((__m128i *)src);
246 _mm_storel_epi64((__m128i *)b, x);
247 src += src_stride;
248 b += 8;
249 }
250 } else if (xoffset == 4) {
251 uint8_t *b = dst;
252 for (i = 0; i < h + 1; ++i) {
253 __m128i x = _mm_loadu_si128((__m128i *)src);
254 __m128i z = _mm_srli_si128(x, 1);
255 _mm_storel_epi64((__m128i *)b, _mm_avg_epu8(x, z));
256 src += src_stride;
257 b += 8;
258 }
259 } else {
260 uint8_t *b = dst;
261 const uint8_t *hfilter = bilinear_filters_2t[xoffset];
262 const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
263 for (i = 0; i < h; i += 2) {
264 const __m128i x0 = _mm_loadu_si128((__m128i *)src);
265 const __m128i z0 = _mm_srli_si128(x0, 1);
266 const __m128i x1 = _mm_loadu_si128((__m128i *)&src[src_stride]);
267 const __m128i z1 = _mm_srli_si128(x1, 1);
268 const __m128i res = filter_block_2rows(&x0, &z0, &x1, &z1, &hfilter_vec);
269 _mm_storeu_si128((__m128i *)b, res);
270
271 src += src_stride * 2;
272 b += 16;
273 }
274 // Handle i = h separately
275 const __m128i x0 = _mm_loadu_si128((__m128i *)src);
276 const __m128i z0 = _mm_srli_si128(x0, 1);
277
278 __m128i v0 = _mm_unpacklo_epi8(x0, z0);
279 v0 = _mm_maddubs_epi16(v0, hfilter_vec);
280 v0 = xx_roundn_epu16(v0, FILTER_BITS);
281
282 _mm_storel_epi64((__m128i *)b, _mm_packus_epi16(v0, v0));
283 }
284
285 // Vertical filter
286 if (yoffset == 0) {
287 // The data is already in 'dst', so no need to filter
288 } else if (yoffset == 4) {
289 for (i = 0; i < h; ++i) {
290 __m128i x = _mm_loadl_epi64((__m128i *)dst);
291 __m128i y = _mm_loadl_epi64((__m128i *)&dst[8]);
292 _mm_storel_epi64((__m128i *)dst, _mm_avg_epu8(x, y));
293 dst += 8;
294 }
295 } else {
296 const uint8_t *vfilter = bilinear_filters_2t[yoffset];
297 const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
298 for (i = 0; i < h; i += 2) {
299 const __m128i x = _mm_loadl_epi64((__m128i *)dst);
300 const __m128i y = _mm_loadl_epi64((__m128i *)&dst[8]);
301 const __m128i z = _mm_loadl_epi64((__m128i *)&dst[16]);
302 const __m128i res = filter_block_2rows(&x, &y, &y, &z, &vfilter_vec);
303 _mm_storeu_si128((__m128i *)dst, res);
304
305 dst += 16;
306 }
307 }
308 }
309
bilinear_filter4xh(const uint8_t * src,int src_stride,int xoffset,int yoffset,uint8_t * dst,int h)310 static void bilinear_filter4xh(const uint8_t *src, int src_stride, int xoffset,
311 int yoffset, uint8_t *dst, int h) {
312 int i;
313 // Horizontal filter
314 if (xoffset == 0) {
315 uint8_t *b = dst;
316 for (i = 0; i < h + 1; ++i) {
317 __m128i x = xx_loadl_32((__m128i *)src);
318 xx_storel_32(b, x);
319 src += src_stride;
320 b += 4;
321 }
322 } else if (xoffset == 4) {
323 uint8_t *b = dst;
324 for (i = 0; i < h + 1; ++i) {
325 __m128i x = _mm_loadl_epi64((__m128i *)src);
326 __m128i z = _mm_srli_si128(x, 1);
327 xx_storel_32(b, _mm_avg_epu8(x, z));
328 src += src_stride;
329 b += 4;
330 }
331 } else {
332 uint8_t *b = dst;
333 const uint8_t *hfilter = bilinear_filters_2t[xoffset];
334 const __m128i hfilter_vec = _mm_set1_epi16(hfilter[0] | (hfilter[1] << 8));
335 for (i = 0; i < h; i += 4) {
336 const __m128i x0 = _mm_loadl_epi64((__m128i *)src);
337 const __m128i z0 = _mm_srli_si128(x0, 1);
338 const __m128i x1 = _mm_loadl_epi64((__m128i *)&src[src_stride]);
339 const __m128i z1 = _mm_srli_si128(x1, 1);
340 const __m128i x2 = _mm_loadl_epi64((__m128i *)&src[src_stride * 2]);
341 const __m128i z2 = _mm_srli_si128(x2, 1);
342 const __m128i x3 = _mm_loadl_epi64((__m128i *)&src[src_stride * 3]);
343 const __m128i z3 = _mm_srli_si128(x3, 1);
344
345 const __m128i a0 = _mm_unpacklo_epi32(x0, x1);
346 const __m128i b0 = _mm_unpacklo_epi32(z0, z1);
347 const __m128i a1 = _mm_unpacklo_epi32(x2, x3);
348 const __m128i b1 = _mm_unpacklo_epi32(z2, z3);
349 const __m128i res = filter_block_2rows(&a0, &b0, &a1, &b1, &hfilter_vec);
350 _mm_storeu_si128((__m128i *)b, res);
351
352 src += src_stride * 4;
353 b += 16;
354 }
355 // Handle i = h separately
356 const __m128i x = _mm_loadl_epi64((__m128i *)src);
357 const __m128i z = _mm_srli_si128(x, 1);
358
359 __m128i v0 = _mm_unpacklo_epi8(x, z);
360 v0 = _mm_maddubs_epi16(v0, hfilter_vec);
361 v0 = xx_roundn_epu16(v0, FILTER_BITS);
362
363 xx_storel_32(b, _mm_packus_epi16(v0, v0));
364 }
365
366 // Vertical filter
367 if (yoffset == 0) {
368 // The data is already in 'dst', so no need to filter
369 } else if (yoffset == 4) {
370 for (i = 0; i < h; ++i) {
371 __m128i x = xx_loadl_32((__m128i *)dst);
372 __m128i y = xx_loadl_32((__m128i *)&dst[4]);
373 xx_storel_32(dst, _mm_avg_epu8(x, y));
374 dst += 4;
375 }
376 } else {
377 const uint8_t *vfilter = bilinear_filters_2t[yoffset];
378 const __m128i vfilter_vec = _mm_set1_epi16(vfilter[0] | (vfilter[1] << 8));
379 for (i = 0; i < h; i += 4) {
380 const __m128i a = xx_loadl_32((__m128i *)dst);
381 const __m128i b = xx_loadl_32((__m128i *)&dst[4]);
382 const __m128i c = xx_loadl_32((__m128i *)&dst[8]);
383 const __m128i d = xx_loadl_32((__m128i *)&dst[12]);
384 const __m128i e = xx_loadl_32((__m128i *)&dst[16]);
385
386 const __m128i a0 = _mm_unpacklo_epi32(a, b);
387 const __m128i b0 = _mm_unpacklo_epi32(b, c);
388 const __m128i a1 = _mm_unpacklo_epi32(c, d);
389 const __m128i b1 = _mm_unpacklo_epi32(d, e);
390 const __m128i res = filter_block_2rows(&a0, &b0, &a1, &b1, &vfilter_vec);
391 _mm_storeu_si128((__m128i *)dst, res);
392
393 dst += 16;
394 }
395 }
396 }
397
accumulate_block(const __m128i * src,const __m128i * a,const __m128i * b,const __m128i * m,__m128i * sum,__m128i * sum_sq)398 static inline void accumulate_block(const __m128i *src, const __m128i *a,
399 const __m128i *b, const __m128i *m,
400 __m128i *sum, __m128i *sum_sq) {
401 const __m128i zero = _mm_setzero_si128();
402 const __m128i one = _mm_set1_epi16(1);
403 const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
404 const __m128i m_inv = _mm_sub_epi8(mask_max, *m);
405
406 // Calculate 16 predicted pixels.
407 // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
408 // is 64 * 255, so we have plenty of space to add rounding constants.
409 const __m128i data_l = _mm_unpacklo_epi8(*a, *b);
410 const __m128i mask_l = _mm_unpacklo_epi8(*m, m_inv);
411 __m128i pred_l = _mm_maddubs_epi16(data_l, mask_l);
412 pred_l = xx_roundn_epu16(pred_l, AOM_BLEND_A64_ROUND_BITS);
413
414 const __m128i data_r = _mm_unpackhi_epi8(*a, *b);
415 const __m128i mask_r = _mm_unpackhi_epi8(*m, m_inv);
416 __m128i pred_r = _mm_maddubs_epi16(data_r, mask_r);
417 pred_r = xx_roundn_epu16(pred_r, AOM_BLEND_A64_ROUND_BITS);
418
419 const __m128i src_l = _mm_unpacklo_epi8(*src, zero);
420 const __m128i src_r = _mm_unpackhi_epi8(*src, zero);
421 const __m128i diff_l = _mm_sub_epi16(pred_l, src_l);
422 const __m128i diff_r = _mm_sub_epi16(pred_r, src_r);
423
424 // Update partial sums and partial sums of squares
425 *sum =
426 _mm_add_epi32(*sum, _mm_madd_epi16(_mm_add_epi16(diff_l, diff_r), one));
427 *sum_sq =
428 _mm_add_epi32(*sum_sq, _mm_add_epi32(_mm_madd_epi16(diff_l, diff_l),
429 _mm_madd_epi16(diff_r, diff_r)));
430 }
431
masked_variance(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height,unsigned int * sse,int * sum_)432 static void masked_variance(const uint8_t *src_ptr, int src_stride,
433 const uint8_t *a_ptr, int a_stride,
434 const uint8_t *b_ptr, int b_stride,
435 const uint8_t *m_ptr, int m_stride, int width,
436 int height, unsigned int *sse, int *sum_) {
437 int x, y;
438 __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
439
440 for (y = 0; y < height; y++) {
441 for (x = 0; x < width; x += 16) {
442 const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
443 const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
444 const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
445 const __m128i m = _mm_loadu_si128((const __m128i *)&m_ptr[x]);
446 accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
447 }
448
449 src_ptr += src_stride;
450 a_ptr += a_stride;
451 b_ptr += b_stride;
452 m_ptr += m_stride;
453 }
454 // Reduce down to a single sum and sum of squares
455 sum = _mm_hadd_epi32(sum, sum_sq);
456 sum = _mm_hadd_epi32(sum, sum);
457 *sum_ = _mm_cvtsi128_si32(sum);
458 *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
459 }
460
masked_variance8xh(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,const uint8_t * b_ptr,const uint8_t * m_ptr,int m_stride,int height,unsigned int * sse,int * sum_)461 static void masked_variance8xh(const uint8_t *src_ptr, int src_stride,
462 const uint8_t *a_ptr, const uint8_t *b_ptr,
463 const uint8_t *m_ptr, int m_stride, int height,
464 unsigned int *sse, int *sum_) {
465 int y;
466 __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
467
468 for (y = 0; y < height; y += 2) {
469 __m128i src = _mm_unpacklo_epi64(
470 _mm_loadl_epi64((const __m128i *)src_ptr),
471 _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
472 const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
473 const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
474 const __m128i m =
475 _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)m_ptr),
476 _mm_loadl_epi64((const __m128i *)&m_ptr[m_stride]));
477 accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
478
479 src_ptr += src_stride * 2;
480 a_ptr += 16;
481 b_ptr += 16;
482 m_ptr += m_stride * 2;
483 }
484 // Reduce down to a single sum and sum of squares
485 sum = _mm_hadd_epi32(sum, sum_sq);
486 sum = _mm_hadd_epi32(sum, sum);
487 *sum_ = _mm_cvtsi128_si32(sum);
488 *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
489 }
490
masked_variance4xh(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,const uint8_t * b_ptr,const uint8_t * m_ptr,int m_stride,int height,unsigned int * sse,int * sum_)491 static void masked_variance4xh(const uint8_t *src_ptr, int src_stride,
492 const uint8_t *a_ptr, const uint8_t *b_ptr,
493 const uint8_t *m_ptr, int m_stride, int height,
494 unsigned int *sse, int *sum_) {
495 int y;
496 __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
497
498 for (y = 0; y < height; y += 4) {
499 // Load four rows at a time
500 __m128i src = _mm_setr_epi32(*(int *)src_ptr, *(int *)&src_ptr[src_stride],
501 *(int *)&src_ptr[src_stride * 2],
502 *(int *)&src_ptr[src_stride * 3]);
503 const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
504 const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
505 const __m128i m = _mm_setr_epi32(*(int *)m_ptr, *(int *)&m_ptr[m_stride],
506 *(int *)&m_ptr[m_stride * 2],
507 *(int *)&m_ptr[m_stride * 3]);
508 accumulate_block(&src, &a, &b, &m, &sum, &sum_sq);
509
510 src_ptr += src_stride * 4;
511 a_ptr += 16;
512 b_ptr += 16;
513 m_ptr += m_stride * 4;
514 }
515 // Reduce down to a single sum and sum of squares
516 sum = _mm_hadd_epi32(sum, sum_sq);
517 sum = _mm_hadd_epi32(sum, sum);
518 *sum_ = _mm_cvtsi128_si32(sum);
519 *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
520 }
521
522 #if CONFIG_AV1_HIGHBITDEPTH
523 // For width a multiple of 8
524 static void highbd_bilinear_filter(const uint16_t *src, int src_stride,
525 int xoffset, int yoffset, uint16_t *dst,
526 int w, int h);
527
528 static void highbd_bilinear_filter4xh(const uint16_t *src, int src_stride,
529 int xoffset, int yoffset, uint16_t *dst,
530 int h);
531
532 // For width a multiple of 8
533 static void highbd_masked_variance(const uint16_t *src_ptr, int src_stride,
534 const uint16_t *a_ptr, int a_stride,
535 const uint16_t *b_ptr, int b_stride,
536 const uint8_t *m_ptr, int m_stride,
537 int width, int height, uint64_t *sse,
538 int *sum_);
539
540 static void highbd_masked_variance4xh(const uint16_t *src_ptr, int src_stride,
541 const uint16_t *a_ptr,
542 const uint16_t *b_ptr,
543 const uint8_t *m_ptr, int m_stride,
544 int height, int *sse, int *sum_);
545
546 #define HIGHBD_MASK_SUBPIX_VAR_SSSE3(W, H) \
547 unsigned int aom_highbd_8_masked_sub_pixel_variance##W##x##H##_ssse3( \
548 const uint8_t *src8, int src_stride, int xoffset, int yoffset, \
549 const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8, \
550 const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
551 uint64_t sse64; \
552 int sum; \
553 uint16_t temp[(H + 1) * W]; \
554 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); \
555 const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); \
556 const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8); \
557 \
558 highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H); \
559 \
560 if (!invert_mask) \
561 highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
562 msk_stride, W, H, &sse64, &sum); \
563 else \
564 highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
565 msk_stride, W, H, &sse64, &sum); \
566 *sse = (uint32_t)sse64; \
567 return *sse - (uint32_t)(((int64_t)sum * sum) / (W * H)); \
568 } \
569 unsigned int aom_highbd_10_masked_sub_pixel_variance##W##x##H##_ssse3( \
570 const uint8_t *src8, int src_stride, int xoffset, int yoffset, \
571 const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8, \
572 const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
573 uint64_t sse64; \
574 int sum; \
575 int64_t var; \
576 uint16_t temp[(H + 1) * W]; \
577 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); \
578 const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); \
579 const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8); \
580 \
581 highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H); \
582 \
583 if (!invert_mask) \
584 highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
585 msk_stride, W, H, &sse64, &sum); \
586 else \
587 highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
588 msk_stride, W, H, &sse64, &sum); \
589 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse64, 4); \
590 sum = ROUND_POWER_OF_TWO(sum, 2); \
591 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H)); \
592 return (var >= 0) ? (uint32_t)var : 0; \
593 } \
594 unsigned int aom_highbd_12_masked_sub_pixel_variance##W##x##H##_ssse3( \
595 const uint8_t *src8, int src_stride, int xoffset, int yoffset, \
596 const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8, \
597 const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
598 uint64_t sse64; \
599 int sum; \
600 int64_t var; \
601 uint16_t temp[(H + 1) * W]; \
602 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); \
603 const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); \
604 const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8); \
605 \
606 highbd_bilinear_filter(src, src_stride, xoffset, yoffset, temp, W, H); \
607 \
608 if (!invert_mask) \
609 highbd_masked_variance(ref, ref_stride, temp, W, second_pred, W, msk, \
610 msk_stride, W, H, &sse64, &sum); \
611 else \
612 highbd_masked_variance(ref, ref_stride, second_pred, W, temp, W, msk, \
613 msk_stride, W, H, &sse64, &sum); \
614 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse64, 8); \
615 sum = ROUND_POWER_OF_TWO(sum, 4); \
616 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (W * H)); \
617 return (var >= 0) ? (uint32_t)var : 0; \
618 }
619
620 #define HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(H) \
621 unsigned int aom_highbd_8_masked_sub_pixel_variance4x##H##_ssse3( \
622 const uint8_t *src8, int src_stride, int xoffset, int yoffset, \
623 const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8, \
624 const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
625 int sse_; \
626 int sum; \
627 uint16_t temp[(H + 1) * 4]; \
628 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); \
629 const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); \
630 const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8); \
631 \
632 highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H); \
633 \
634 if (!invert_mask) \
635 highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk, \
636 msk_stride, H, &sse_, &sum); \
637 else \
638 highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk, \
639 msk_stride, H, &sse_, &sum); \
640 *sse = (uint32_t)sse_; \
641 return *sse - (uint32_t)(((int64_t)sum * sum) / (4 * H)); \
642 } \
643 unsigned int aom_highbd_10_masked_sub_pixel_variance4x##H##_ssse3( \
644 const uint8_t *src8, int src_stride, int xoffset, int yoffset, \
645 const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8, \
646 const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
647 int sse_; \
648 int sum; \
649 int64_t var; \
650 uint16_t temp[(H + 1) * 4]; \
651 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); \
652 const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); \
653 const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8); \
654 \
655 highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H); \
656 \
657 if (!invert_mask) \
658 highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk, \
659 msk_stride, H, &sse_, &sum); \
660 else \
661 highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk, \
662 msk_stride, H, &sse_, &sum); \
663 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_, 4); \
664 sum = ROUND_POWER_OF_TWO(sum, 2); \
665 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (4 * H)); \
666 return (var >= 0) ? (uint32_t)var : 0; \
667 } \
668 unsigned int aom_highbd_12_masked_sub_pixel_variance4x##H##_ssse3( \
669 const uint8_t *src8, int src_stride, int xoffset, int yoffset, \
670 const uint8_t *ref8, int ref_stride, const uint8_t *second_pred8, \
671 const uint8_t *msk, int msk_stride, int invert_mask, uint32_t *sse) { \
672 int sse_; \
673 int sum; \
674 int64_t var; \
675 uint16_t temp[(H + 1) * 4]; \
676 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); \
677 const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); \
678 const uint16_t *second_pred = CONVERT_TO_SHORTPTR(second_pred8); \
679 \
680 highbd_bilinear_filter4xh(src, src_stride, xoffset, yoffset, temp, H); \
681 \
682 if (!invert_mask) \
683 highbd_masked_variance4xh(ref, ref_stride, temp, second_pred, msk, \
684 msk_stride, H, &sse_, &sum); \
685 else \
686 highbd_masked_variance4xh(ref, ref_stride, second_pred, temp, msk, \
687 msk_stride, H, &sse_, &sum); \
688 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_, 8); \
689 sum = ROUND_POWER_OF_TWO(sum, 4); \
690 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (4 * H)); \
691 return (var >= 0) ? (uint32_t)var : 0; \
692 }
693
694 HIGHBD_MASK_SUBPIX_VAR_SSSE3(128, 128)
695 HIGHBD_MASK_SUBPIX_VAR_SSSE3(128, 64)
696 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 128)
697 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 64)
698 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 32)
699 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 64)
700 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 32)
701 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 16)
702 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 32)
703 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 16)
704 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 8)
705 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 16)
706 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 8)
707 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 4)
708 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(8)
709 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(4)
710
711 #if !CONFIG_REALTIME_ONLY
712 HIGHBD_MASK_SUBPIX_VAR4XH_SSSE3(16)
713 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 4)
714 HIGHBD_MASK_SUBPIX_VAR_SSSE3(8, 32)
715 HIGHBD_MASK_SUBPIX_VAR_SSSE3(32, 8)
716 HIGHBD_MASK_SUBPIX_VAR_SSSE3(16, 64)
717 HIGHBD_MASK_SUBPIX_VAR_SSSE3(64, 16)
718 #endif // !CONFIG_REALTIME_ONLY
719
highbd_filter_block(const __m128i a,const __m128i b,const __m128i filter)720 static inline __m128i highbd_filter_block(const __m128i a, const __m128i b,
721 const __m128i filter) {
722 __m128i v0 = _mm_unpacklo_epi16(a, b);
723 v0 = _mm_madd_epi16(v0, filter);
724 v0 = xx_roundn_epu32(v0, FILTER_BITS);
725
726 __m128i v1 = _mm_unpackhi_epi16(a, b);
727 v1 = _mm_madd_epi16(v1, filter);
728 v1 = xx_roundn_epu32(v1, FILTER_BITS);
729
730 return _mm_packs_epi32(v0, v1);
731 }
732
highbd_bilinear_filter(const uint16_t * src,int src_stride,int xoffset,int yoffset,uint16_t * dst,int w,int h)733 static void highbd_bilinear_filter(const uint16_t *src, int src_stride,
734 int xoffset, int yoffset, uint16_t *dst,
735 int w, int h) {
736 int i, j;
737 // Horizontal filter
738 if (xoffset == 0) {
739 uint16_t *b = dst;
740 for (i = 0; i < h + 1; ++i) {
741 for (j = 0; j < w; j += 8) {
742 __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
743 _mm_storeu_si128((__m128i *)&b[j], x);
744 }
745 src += src_stride;
746 b += w;
747 }
748 } else if (xoffset == 4) {
749 uint16_t *b = dst;
750 for (i = 0; i < h + 1; ++i) {
751 for (j = 0; j < w; j += 8) {
752 __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
753 __m128i y = _mm_loadu_si128((__m128i *)&src[j + 8]);
754 __m128i z = _mm_alignr_epi8(y, x, 2);
755 _mm_storeu_si128((__m128i *)&b[j], _mm_avg_epu16(x, z));
756 }
757 src += src_stride;
758 b += w;
759 }
760 } else {
761 uint16_t *b = dst;
762 const uint8_t *hfilter = bilinear_filters_2t[xoffset];
763 const __m128i hfilter_vec = _mm_set1_epi32(hfilter[0] | (hfilter[1] << 16));
764 for (i = 0; i < h + 1; ++i) {
765 for (j = 0; j < w; j += 8) {
766 const __m128i x = _mm_loadu_si128((__m128i *)&src[j]);
767 const __m128i y = _mm_loadu_si128((__m128i *)&src[j + 8]);
768 const __m128i z = _mm_alignr_epi8(y, x, 2);
769 const __m128i res = highbd_filter_block(x, z, hfilter_vec);
770 _mm_storeu_si128((__m128i *)&b[j], res);
771 }
772
773 src += src_stride;
774 b += w;
775 }
776 }
777
778 // Vertical filter
779 if (yoffset == 0) {
780 // The data is already in 'dst', so no need to filter
781 } else if (yoffset == 4) {
782 for (i = 0; i < h; ++i) {
783 for (j = 0; j < w; j += 8) {
784 __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
785 __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
786 _mm_storeu_si128((__m128i *)&dst[j], _mm_avg_epu16(x, y));
787 }
788 dst += w;
789 }
790 } else {
791 const uint8_t *vfilter = bilinear_filters_2t[yoffset];
792 const __m128i vfilter_vec = _mm_set1_epi32(vfilter[0] | (vfilter[1] << 16));
793 for (i = 0; i < h; ++i) {
794 for (j = 0; j < w; j += 8) {
795 const __m128i x = _mm_loadu_si128((__m128i *)&dst[j]);
796 const __m128i y = _mm_loadu_si128((__m128i *)&dst[j + w]);
797 const __m128i res = highbd_filter_block(x, y, vfilter_vec);
798 _mm_storeu_si128((__m128i *)&dst[j], res);
799 }
800
801 dst += w;
802 }
803 }
804 }
805
highbd_filter_block_2rows(const __m128i * a0,const __m128i * b0,const __m128i * a1,const __m128i * b1,const __m128i * filter)806 static inline __m128i highbd_filter_block_2rows(const __m128i *a0,
807 const __m128i *b0,
808 const __m128i *a1,
809 const __m128i *b1,
810 const __m128i *filter) {
811 __m128i v0 = _mm_unpacklo_epi16(*a0, *b0);
812 v0 = _mm_madd_epi16(v0, *filter);
813 v0 = xx_roundn_epu32(v0, FILTER_BITS);
814
815 __m128i v1 = _mm_unpacklo_epi16(*a1, *b1);
816 v1 = _mm_madd_epi16(v1, *filter);
817 v1 = xx_roundn_epu32(v1, FILTER_BITS);
818
819 return _mm_packs_epi32(v0, v1);
820 }
821
highbd_bilinear_filter4xh(const uint16_t * src,int src_stride,int xoffset,int yoffset,uint16_t * dst,int h)822 static void highbd_bilinear_filter4xh(const uint16_t *src, int src_stride,
823 int xoffset, int yoffset, uint16_t *dst,
824 int h) {
825 int i;
826 // Horizontal filter
827 if (xoffset == 0) {
828 uint16_t *b = dst;
829 for (i = 0; i < h + 1; ++i) {
830 __m128i x = _mm_loadl_epi64((__m128i *)src);
831 _mm_storel_epi64((__m128i *)b, x);
832 src += src_stride;
833 b += 4;
834 }
835 } else if (xoffset == 4) {
836 uint16_t *b = dst;
837 for (i = 0; i < h + 1; ++i) {
838 __m128i x = _mm_loadu_si128((__m128i *)src);
839 __m128i z = _mm_srli_si128(x, 2);
840 _mm_storel_epi64((__m128i *)b, _mm_avg_epu16(x, z));
841 src += src_stride;
842 b += 4;
843 }
844 } else {
845 uint16_t *b = dst;
846 const uint8_t *hfilter = bilinear_filters_2t[xoffset];
847 const __m128i hfilter_vec = _mm_set1_epi32(hfilter[0] | (hfilter[1] << 16));
848 for (i = 0; i < h; i += 2) {
849 const __m128i x0 = _mm_loadu_si128((__m128i *)src);
850 const __m128i z0 = _mm_srli_si128(x0, 2);
851 const __m128i x1 = _mm_loadu_si128((__m128i *)&src[src_stride]);
852 const __m128i z1 = _mm_srli_si128(x1, 2);
853 const __m128i res =
854 highbd_filter_block_2rows(&x0, &z0, &x1, &z1, &hfilter_vec);
855 _mm_storeu_si128((__m128i *)b, res);
856
857 src += src_stride * 2;
858 b += 8;
859 }
860 // Process i = h separately
861 __m128i x = _mm_loadu_si128((__m128i *)src);
862 __m128i z = _mm_srli_si128(x, 2);
863
864 __m128i v0 = _mm_unpacklo_epi16(x, z);
865 v0 = _mm_madd_epi16(v0, hfilter_vec);
866 v0 = xx_roundn_epu32(v0, FILTER_BITS);
867
868 _mm_storel_epi64((__m128i *)b, _mm_packs_epi32(v0, v0));
869 }
870
871 // Vertical filter
872 if (yoffset == 0) {
873 // The data is already in 'dst', so no need to filter
874 } else if (yoffset == 4) {
875 for (i = 0; i < h; ++i) {
876 __m128i x = _mm_loadl_epi64((__m128i *)dst);
877 __m128i y = _mm_loadl_epi64((__m128i *)&dst[4]);
878 _mm_storel_epi64((__m128i *)dst, _mm_avg_epu16(x, y));
879 dst += 4;
880 }
881 } else {
882 const uint8_t *vfilter = bilinear_filters_2t[yoffset];
883 const __m128i vfilter_vec = _mm_set1_epi32(vfilter[0] | (vfilter[1] << 16));
884 for (i = 0; i < h; i += 2) {
885 const __m128i x = _mm_loadl_epi64((__m128i *)dst);
886 const __m128i y = _mm_loadl_epi64((__m128i *)&dst[4]);
887 const __m128i z = _mm_loadl_epi64((__m128i *)&dst[8]);
888 const __m128i res =
889 highbd_filter_block_2rows(&x, &y, &y, &z, &vfilter_vec);
890 _mm_storeu_si128((__m128i *)dst, res);
891
892 dst += 8;
893 }
894 }
895 }
896
highbd_masked_variance(const uint16_t * src_ptr,int src_stride,const uint16_t * a_ptr,int a_stride,const uint16_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height,uint64_t * sse,int * sum_)897 static void highbd_masked_variance(const uint16_t *src_ptr, int src_stride,
898 const uint16_t *a_ptr, int a_stride,
899 const uint16_t *b_ptr, int b_stride,
900 const uint8_t *m_ptr, int m_stride,
901 int width, int height, uint64_t *sse,
902 int *sum_) {
903 int x, y;
904 // Note on bit widths:
905 // The maximum value of 'sum' is (2^12 - 1) * 128 * 128 =~ 2^26,
906 // so this can be kept as four 32-bit values.
907 // But the maximum value of 'sum_sq' is (2^12 - 1)^2 * 128 * 128 =~ 2^38,
908 // so this must be stored as two 64-bit values.
909 __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
910 const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
911 const __m128i round_const =
912 _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
913 const __m128i zero = _mm_setzero_si128();
914
915 for (y = 0; y < height; y++) {
916 for (x = 0; x < width; x += 8) {
917 const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
918 const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
919 const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
920 const __m128i m =
921 _mm_unpacklo_epi8(_mm_loadl_epi64((const __m128i *)&m_ptr[x]), zero);
922 const __m128i m_inv = _mm_sub_epi16(mask_max, m);
923
924 // Calculate 8 predicted pixels.
925 const __m128i data_l = _mm_unpacklo_epi16(a, b);
926 const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
927 __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
928 pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
929 AOM_BLEND_A64_ROUND_BITS);
930
931 const __m128i data_r = _mm_unpackhi_epi16(a, b);
932 const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
933 __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
934 pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
935 AOM_BLEND_A64_ROUND_BITS);
936
937 const __m128i src_l = _mm_unpacklo_epi16(src, zero);
938 const __m128i src_r = _mm_unpackhi_epi16(src, zero);
939 __m128i diff_l = _mm_sub_epi32(pred_l, src_l);
940 __m128i diff_r = _mm_sub_epi32(pred_r, src_r);
941
942 // Update partial sums and partial sums of squares
943 sum = _mm_add_epi32(sum, _mm_add_epi32(diff_l, diff_r));
944 // A trick: Now each entry of diff_l and diff_r is stored in a 32-bit
945 // field, but the range of values is only [-(2^12 - 1), 2^12 - 1].
946 // So we can re-pack into 16-bit fields and use _mm_madd_epi16
947 // to calculate the squares and partially sum them.
948 const __m128i tmp = _mm_packs_epi32(diff_l, diff_r);
949 const __m128i prod = _mm_madd_epi16(tmp, tmp);
950 // Then we want to sign-extend to 64 bits and accumulate
951 const __m128i sign = _mm_srai_epi32(prod, 31);
952 const __m128i tmp_0 = _mm_unpacklo_epi32(prod, sign);
953 const __m128i tmp_1 = _mm_unpackhi_epi32(prod, sign);
954 sum_sq = _mm_add_epi64(sum_sq, _mm_add_epi64(tmp_0, tmp_1));
955 }
956
957 src_ptr += src_stride;
958 a_ptr += a_stride;
959 b_ptr += b_stride;
960 m_ptr += m_stride;
961 }
962 // Reduce down to a single sum and sum of squares
963 sum = _mm_hadd_epi32(sum, zero);
964 sum = _mm_hadd_epi32(sum, zero);
965 *sum_ = _mm_cvtsi128_si32(sum);
966 sum_sq = _mm_add_epi64(sum_sq, _mm_srli_si128(sum_sq, 8));
967 _mm_storel_epi64((__m128i *)sse, sum_sq);
968 }
969
highbd_masked_variance4xh(const uint16_t * src_ptr,int src_stride,const uint16_t * a_ptr,const uint16_t * b_ptr,const uint8_t * m_ptr,int m_stride,int height,int * sse,int * sum_)970 static void highbd_masked_variance4xh(const uint16_t *src_ptr, int src_stride,
971 const uint16_t *a_ptr,
972 const uint16_t *b_ptr,
973 const uint8_t *m_ptr, int m_stride,
974 int height, int *sse, int *sum_) {
975 int y;
976 // Note: For this function, h <= 8 (or maybe 16 if we add 4:1 partitions).
977 // So the maximum value of sum is (2^12 - 1) * 4 * 16 =~ 2^18
978 // and the maximum value of sum_sq is (2^12 - 1)^2 * 4 * 16 =~ 2^30.
979 // So we can safely pack sum_sq into 32-bit fields, which is slightly more
980 // convenient.
981 __m128i sum = _mm_setzero_si128(), sum_sq = _mm_setzero_si128();
982 const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
983 const __m128i round_const =
984 _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
985 const __m128i zero = _mm_setzero_si128();
986
987 for (y = 0; y < height; y += 2) {
988 __m128i src = _mm_unpacklo_epi64(
989 _mm_loadl_epi64((const __m128i *)src_ptr),
990 _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
991 const __m128i a = _mm_loadu_si128((const __m128i *)a_ptr);
992 const __m128i b = _mm_loadu_si128((const __m128i *)b_ptr);
993 const __m128i m = _mm_unpacklo_epi8(
994 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(const int *)m_ptr),
995 _mm_cvtsi32_si128(*(const int *)&m_ptr[m_stride])),
996 zero);
997 const __m128i m_inv = _mm_sub_epi16(mask_max, m);
998
999 const __m128i data_l = _mm_unpacklo_epi16(a, b);
1000 const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
1001 __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
1002 pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
1003 AOM_BLEND_A64_ROUND_BITS);
1004
1005 const __m128i data_r = _mm_unpackhi_epi16(a, b);
1006 const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
1007 __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
1008 pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
1009 AOM_BLEND_A64_ROUND_BITS);
1010
1011 const __m128i src_l = _mm_unpacklo_epi16(src, zero);
1012 const __m128i src_r = _mm_unpackhi_epi16(src, zero);
1013 __m128i diff_l = _mm_sub_epi32(pred_l, src_l);
1014 __m128i diff_r = _mm_sub_epi32(pred_r, src_r);
1015
1016 // Update partial sums and partial sums of squares
1017 sum = _mm_add_epi32(sum, _mm_add_epi32(diff_l, diff_r));
1018 const __m128i tmp = _mm_packs_epi32(diff_l, diff_r);
1019 const __m128i prod = _mm_madd_epi16(tmp, tmp);
1020 sum_sq = _mm_add_epi32(sum_sq, prod);
1021
1022 src_ptr += src_stride * 2;
1023 a_ptr += 8;
1024 b_ptr += 8;
1025 m_ptr += m_stride * 2;
1026 }
1027 // Reduce down to a single sum and sum of squares
1028 sum = _mm_hadd_epi32(sum, sum_sq);
1029 sum = _mm_hadd_epi32(sum, zero);
1030 *sum_ = _mm_cvtsi128_si32(sum);
1031 *sse = (unsigned int)_mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
1032 }
1033 #endif // CONFIG_AV1_HIGHBITDEPTH
1034
aom_comp_mask_pred_ssse3(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)1035 void aom_comp_mask_pred_ssse3(uint8_t *comp_pred, const uint8_t *pred,
1036 int width, int height, const uint8_t *ref,
1037 int ref_stride, const uint8_t *mask,
1038 int mask_stride, int invert_mask) {
1039 const uint8_t *src0 = invert_mask ? pred : ref;
1040 const uint8_t *src1 = invert_mask ? ref : pred;
1041 const int stride0 = invert_mask ? width : ref_stride;
1042 const int stride1 = invert_mask ? ref_stride : width;
1043 assert(height % 2 == 0);
1044 int i = 0;
1045 if (width == 8) {
1046 comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
1047 mask, mask_stride);
1048 } else if (width == 16) {
1049 do {
1050 comp_mask_pred_16_ssse3(src0, src1, mask, comp_pred);
1051 comp_mask_pred_16_ssse3(src0 + stride0, src1 + stride1,
1052 mask + mask_stride, comp_pred + width);
1053 comp_pred += (width << 1);
1054 src0 += (stride0 << 1);
1055 src1 += (stride1 << 1);
1056 mask += (mask_stride << 1);
1057 i += 2;
1058 } while (i < height);
1059 } else {
1060 do {
1061 for (int x = 0; x < width; x += 32) {
1062 comp_mask_pred_16_ssse3(src0 + x, src1 + x, mask + x, comp_pred);
1063 comp_mask_pred_16_ssse3(src0 + x + 16, src1 + x + 16, mask + x + 16,
1064 comp_pred + 16);
1065 comp_pred += 32;
1066 }
1067 src0 += (stride0);
1068 src1 += (stride1);
1069 mask += (mask_stride);
1070 i += 1;
1071 } while (i < height);
1072 }
1073 }
1074