1 /*
2 * Copyright (c) 2017, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <stdio.h>
13 #include <tmmintrin.h>
14
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17
18 #include "aom_dsp/blend.h"
19 #include "aom/aom_integer.h"
20 #include "aom_dsp/x86/synonyms.h"
21
22 #include "aom_dsp/x86/masked_sad_intrin_ssse3.h"
23
24 // For width a multiple of 16
25 static inline unsigned int masked_sad_ssse3(const uint8_t *src_ptr,
26 int src_stride,
27 const uint8_t *a_ptr, int a_stride,
28 const uint8_t *b_ptr, int b_stride,
29 const uint8_t *m_ptr, int m_stride,
30 int width, int height);
31
32 #define MASKSADMXN_SSSE3(m, n) \
33 unsigned int aom_masked_sad##m##x##n##_ssse3( \
34 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
35 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \
36 int invert_mask) { \
37 if (!invert_mask) \
38 return masked_sad_ssse3(src, src_stride, ref, ref_stride, second_pred, \
39 m, msk, msk_stride, m, n); \
40 else \
41 return masked_sad_ssse3(src, src_stride, second_pred, m, ref, \
42 ref_stride, msk, msk_stride, m, n); \
43 }
44
45 #define MASKSAD8XN_SSSE3(n) \
46 unsigned int aom_masked_sad8x##n##_ssse3( \
47 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
48 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \
49 int invert_mask) { \
50 if (!invert_mask) \
51 return aom_masked_sad8xh_ssse3(src, src_stride, ref, ref_stride, \
52 second_pred, 8, msk, msk_stride, n); \
53 else \
54 return aom_masked_sad8xh_ssse3(src, src_stride, second_pred, 8, ref, \
55 ref_stride, msk, msk_stride, n); \
56 }
57
58 #define MASKSAD4XN_SSSE3(n) \
59 unsigned int aom_masked_sad4x##n##_ssse3( \
60 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
61 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \
62 int invert_mask) { \
63 if (!invert_mask) \
64 return aom_masked_sad4xh_ssse3(src, src_stride, ref, ref_stride, \
65 second_pred, 4, msk, msk_stride, n); \
66 else \
67 return aom_masked_sad4xh_ssse3(src, src_stride, second_pred, 4, ref, \
68 ref_stride, msk, msk_stride, n); \
69 }
70
71 MASKSADMXN_SSSE3(128, 128)
72 MASKSADMXN_SSSE3(128, 64)
73 MASKSADMXN_SSSE3(64, 128)
74 MASKSADMXN_SSSE3(64, 64)
75 MASKSADMXN_SSSE3(64, 32)
76 MASKSADMXN_SSSE3(32, 64)
77 MASKSADMXN_SSSE3(32, 32)
78 MASKSADMXN_SSSE3(32, 16)
79 MASKSADMXN_SSSE3(16, 32)
80 MASKSADMXN_SSSE3(16, 16)
81 MASKSADMXN_SSSE3(16, 8)
82 MASKSAD8XN_SSSE3(16)
83 MASKSAD8XN_SSSE3(8)
84 MASKSAD8XN_SSSE3(4)
85 MASKSAD4XN_SSSE3(8)
86 MASKSAD4XN_SSSE3(4)
87
88 #if !CONFIG_REALTIME_ONLY
89 MASKSAD4XN_SSSE3(16)
90 MASKSADMXN_SSSE3(16, 4)
91 MASKSAD8XN_SSSE3(32)
92 MASKSADMXN_SSSE3(32, 8)
93 MASKSADMXN_SSSE3(16, 64)
94 MASKSADMXN_SSSE3(64, 16)
95 #endif // !CONFIG_REALTIME_ONLY
96
masked_sad_ssse3(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height)97 static inline unsigned int masked_sad_ssse3(const uint8_t *src_ptr,
98 int src_stride,
99 const uint8_t *a_ptr, int a_stride,
100 const uint8_t *b_ptr, int b_stride,
101 const uint8_t *m_ptr, int m_stride,
102 int width, int height) {
103 int x, y;
104 __m128i res = _mm_setzero_si128();
105 const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
106
107 for (y = 0; y < height; y++) {
108 for (x = 0; x < width; x += 16) {
109 const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
110 const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
111 const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
112 const __m128i m = _mm_loadu_si128((const __m128i *)&m_ptr[x]);
113 const __m128i m_inv = _mm_sub_epi8(mask_max, m);
114
115 // Calculate 16 predicted pixels.
116 // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
117 // is 64 * 255, so we have plenty of space to add rounding constants.
118 const __m128i data_l = _mm_unpacklo_epi8(a, b);
119 const __m128i mask_l = _mm_unpacklo_epi8(m, m_inv);
120 __m128i pred_l = _mm_maddubs_epi16(data_l, mask_l);
121 pred_l = xx_roundn_epu16(pred_l, AOM_BLEND_A64_ROUND_BITS);
122
123 const __m128i data_r = _mm_unpackhi_epi8(a, b);
124 const __m128i mask_r = _mm_unpackhi_epi8(m, m_inv);
125 __m128i pred_r = _mm_maddubs_epi16(data_r, mask_r);
126 pred_r = xx_roundn_epu16(pred_r, AOM_BLEND_A64_ROUND_BITS);
127
128 const __m128i pred = _mm_packus_epi16(pred_l, pred_r);
129 res = _mm_add_epi32(res, _mm_sad_epu8(pred, src));
130 }
131
132 src_ptr += src_stride;
133 a_ptr += a_stride;
134 b_ptr += b_stride;
135 m_ptr += m_stride;
136 }
137 // At this point, we have two 32-bit partial SADs in lanes 0 and 2 of 'res'.
138 unsigned int sad = (unsigned int)(_mm_cvtsi128_si32(res) +
139 _mm_cvtsi128_si32(_mm_srli_si128(res, 8)));
140 return sad;
141 }
142
aom_masked_sad8xh_ssse3(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int height)143 unsigned int aom_masked_sad8xh_ssse3(const uint8_t *src_ptr, int src_stride,
144 const uint8_t *a_ptr, int a_stride,
145 const uint8_t *b_ptr, int b_stride,
146 const uint8_t *m_ptr, int m_stride,
147 int height) {
148 int y;
149 __m128i res = _mm_setzero_si128();
150 const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
151
152 for (y = 0; y < height; y += 2) {
153 const __m128i src = _mm_unpacklo_epi64(
154 _mm_loadl_epi64((const __m128i *)src_ptr),
155 _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
156 const __m128i a0 = _mm_loadl_epi64((const __m128i *)a_ptr);
157 const __m128i a1 = _mm_loadl_epi64((const __m128i *)&a_ptr[a_stride]);
158 const __m128i b0 = _mm_loadl_epi64((const __m128i *)b_ptr);
159 const __m128i b1 = _mm_loadl_epi64((const __m128i *)&b_ptr[b_stride]);
160 const __m128i m =
161 _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)m_ptr),
162 _mm_loadl_epi64((const __m128i *)&m_ptr[m_stride]));
163 const __m128i m_inv = _mm_sub_epi8(mask_max, m);
164
165 const __m128i data_l = _mm_unpacklo_epi8(a0, b0);
166 const __m128i mask_l = _mm_unpacklo_epi8(m, m_inv);
167 __m128i pred_l = _mm_maddubs_epi16(data_l, mask_l);
168 pred_l = xx_roundn_epu16(pred_l, AOM_BLEND_A64_ROUND_BITS);
169
170 const __m128i data_r = _mm_unpacklo_epi8(a1, b1);
171 const __m128i mask_r = _mm_unpackhi_epi8(m, m_inv);
172 __m128i pred_r = _mm_maddubs_epi16(data_r, mask_r);
173 pred_r = xx_roundn_epu16(pred_r, AOM_BLEND_A64_ROUND_BITS);
174
175 const __m128i pred = _mm_packus_epi16(pred_l, pred_r);
176 res = _mm_add_epi32(res, _mm_sad_epu8(pred, src));
177
178 src_ptr += src_stride * 2;
179 a_ptr += a_stride * 2;
180 b_ptr += b_stride * 2;
181 m_ptr += m_stride * 2;
182 }
183 unsigned int sad = (unsigned int)(_mm_cvtsi128_si32(res) +
184 _mm_cvtsi128_si32(_mm_srli_si128(res, 8)));
185 return sad;
186 }
187
aom_masked_sad4xh_ssse3(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int height)188 unsigned int aom_masked_sad4xh_ssse3(const uint8_t *src_ptr, int src_stride,
189 const uint8_t *a_ptr, int a_stride,
190 const uint8_t *b_ptr, int b_stride,
191 const uint8_t *m_ptr, int m_stride,
192 int height) {
193 int y;
194 __m128i res = _mm_setzero_si128();
195 const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
196
197 for (y = 0; y < height; y += 2) {
198 // Load two rows at a time, this seems to be a bit faster
199 // than four rows at a time in this case.
200 const __m128i src =
201 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(int *)src_ptr),
202 _mm_cvtsi32_si128(*(int *)&src_ptr[src_stride]));
203 const __m128i a =
204 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(int *)a_ptr),
205 _mm_cvtsi32_si128(*(int *)&a_ptr[a_stride]));
206 const __m128i b =
207 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(int *)b_ptr),
208 _mm_cvtsi32_si128(*(int *)&b_ptr[b_stride]));
209 const __m128i m =
210 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(int *)m_ptr),
211 _mm_cvtsi32_si128(*(int *)&m_ptr[m_stride]));
212 const __m128i m_inv = _mm_sub_epi8(mask_max, m);
213
214 const __m128i data = _mm_unpacklo_epi8(a, b);
215 const __m128i mask = _mm_unpacklo_epi8(m, m_inv);
216 __m128i pred_16bit = _mm_maddubs_epi16(data, mask);
217 pred_16bit = xx_roundn_epu16(pred_16bit, AOM_BLEND_A64_ROUND_BITS);
218
219 const __m128i pred = _mm_packus_epi16(pred_16bit, _mm_setzero_si128());
220 res = _mm_add_epi32(res, _mm_sad_epu8(pred, src));
221
222 src_ptr += src_stride * 2;
223 a_ptr += a_stride * 2;
224 b_ptr += b_stride * 2;
225 m_ptr += m_stride * 2;
226 }
227 // At this point, the SAD is stored in lane 0 of 'res'
228 return (unsigned int)_mm_cvtsi128_si32(res);
229 }
230
231 #if CONFIG_AV1_HIGHBITDEPTH
232 // For width a multiple of 8
233 static inline unsigned int highbd_masked_sad_ssse3(
234 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
235 const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride,
236 int width, int height);
237
238 #define HIGHBD_MASKSADMXN_SSSE3(m, n) \
239 unsigned int aom_highbd_masked_sad##m##x##n##_ssse3( \
240 const uint8_t *src8, int src_stride, const uint8_t *ref8, \
241 int ref_stride, const uint8_t *second_pred8, const uint8_t *msk, \
242 int msk_stride, int invert_mask) { \
243 if (!invert_mask) \
244 return highbd_masked_sad_ssse3(src8, src_stride, ref8, ref_stride, \
245 second_pred8, m, msk, msk_stride, m, n); \
246 else \
247 return highbd_masked_sad_ssse3(src8, src_stride, second_pred8, m, ref8, \
248 ref_stride, msk, msk_stride, m, n); \
249 }
250
251 #define HIGHBD_MASKSAD4XN_SSSE3(n) \
252 unsigned int aom_highbd_masked_sad4x##n##_ssse3( \
253 const uint8_t *src8, int src_stride, const uint8_t *ref8, \
254 int ref_stride, const uint8_t *second_pred8, const uint8_t *msk, \
255 int msk_stride, int invert_mask) { \
256 if (!invert_mask) \
257 return aom_highbd_masked_sad4xh_ssse3(src8, src_stride, ref8, \
258 ref_stride, second_pred8, 4, msk, \
259 msk_stride, n); \
260 else \
261 return aom_highbd_masked_sad4xh_ssse3(src8, src_stride, second_pred8, 4, \
262 ref8, ref_stride, msk, msk_stride, \
263 n); \
264 }
265
266 HIGHBD_MASKSADMXN_SSSE3(128, 128)
267 HIGHBD_MASKSADMXN_SSSE3(128, 64)
268 HIGHBD_MASKSADMXN_SSSE3(64, 128)
269 HIGHBD_MASKSADMXN_SSSE3(64, 64)
270 HIGHBD_MASKSADMXN_SSSE3(64, 32)
271 HIGHBD_MASKSADMXN_SSSE3(32, 64)
272 HIGHBD_MASKSADMXN_SSSE3(32, 32)
273 HIGHBD_MASKSADMXN_SSSE3(32, 16)
274 HIGHBD_MASKSADMXN_SSSE3(16, 32)
275 HIGHBD_MASKSADMXN_SSSE3(16, 16)
276 HIGHBD_MASKSADMXN_SSSE3(16, 8)
277 HIGHBD_MASKSADMXN_SSSE3(8, 16)
278 HIGHBD_MASKSADMXN_SSSE3(8, 8)
279 HIGHBD_MASKSADMXN_SSSE3(8, 4)
280 HIGHBD_MASKSAD4XN_SSSE3(8)
281 HIGHBD_MASKSAD4XN_SSSE3(4)
282
283 #if !CONFIG_REALTIME_ONLY
284 HIGHBD_MASKSAD4XN_SSSE3(16)
285 HIGHBD_MASKSADMXN_SSSE3(16, 4)
286 HIGHBD_MASKSADMXN_SSSE3(8, 32)
287 HIGHBD_MASKSADMXN_SSSE3(32, 8)
288 HIGHBD_MASKSADMXN_SSSE3(16, 64)
289 HIGHBD_MASKSADMXN_SSSE3(64, 16)
290 #endif // !CONFIG_REALTIME_ONLY
291
highbd_masked_sad_ssse3(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height)292 static inline unsigned int highbd_masked_sad_ssse3(
293 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
294 const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride,
295 int width, int height) {
296 const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8);
297 const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8);
298 const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8);
299 int x, y;
300 __m128i res = _mm_setzero_si128();
301 const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
302 const __m128i round_const =
303 _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
304 const __m128i one = _mm_set1_epi16(1);
305
306 for (y = 0; y < height; y++) {
307 for (x = 0; x < width; x += 8) {
308 const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
309 const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
310 const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
311 // Zero-extend mask to 16 bits
312 const __m128i m = _mm_unpacklo_epi8(
313 _mm_loadl_epi64((const __m128i *)&m_ptr[x]), _mm_setzero_si128());
314 const __m128i m_inv = _mm_sub_epi16(mask_max, m);
315
316 const __m128i data_l = _mm_unpacklo_epi16(a, b);
317 const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
318 __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
319 pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
320 AOM_BLEND_A64_ROUND_BITS);
321
322 const __m128i data_r = _mm_unpackhi_epi16(a, b);
323 const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
324 __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
325 pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
326 AOM_BLEND_A64_ROUND_BITS);
327
328 // Note: the maximum value in pred_l/r is (2^bd)-1 < 2^15,
329 // so it is safe to do signed saturation here.
330 const __m128i pred = _mm_packs_epi32(pred_l, pred_r);
331 // There is no 16-bit SAD instruction, so we have to synthesize
332 // an 8-element SAD. We do this by storing 4 32-bit partial SADs,
333 // and accumulating them at the end
334 const __m128i diff = _mm_abs_epi16(_mm_sub_epi16(pred, src));
335 res = _mm_add_epi32(res, _mm_madd_epi16(diff, one));
336 }
337
338 src_ptr += src_stride;
339 a_ptr += a_stride;
340 b_ptr += b_stride;
341 m_ptr += m_stride;
342 }
343 // At this point, we have four 32-bit partial SADs stored in 'res'.
344 res = _mm_hadd_epi32(res, res);
345 res = _mm_hadd_epi32(res, res);
346 int sad = _mm_cvtsi128_si32(res);
347 return sad;
348 }
349
aom_highbd_masked_sad4xh_ssse3(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m_ptr,int m_stride,int height)350 unsigned int aom_highbd_masked_sad4xh_ssse3(const uint8_t *src8, int src_stride,
351 const uint8_t *a8, int a_stride,
352 const uint8_t *b8, int b_stride,
353 const uint8_t *m_ptr, int m_stride,
354 int height) {
355 const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8);
356 const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8);
357 const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8);
358 int y;
359 __m128i res = _mm_setzero_si128();
360 const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
361 const __m128i round_const =
362 _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
363 const __m128i one = _mm_set1_epi16(1);
364
365 for (y = 0; y < height; y += 2) {
366 const __m128i src = _mm_unpacklo_epi64(
367 _mm_loadl_epi64((const __m128i *)src_ptr),
368 _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
369 const __m128i a =
370 _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)a_ptr),
371 _mm_loadl_epi64((const __m128i *)&a_ptr[a_stride]));
372 const __m128i b =
373 _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)b_ptr),
374 _mm_loadl_epi64((const __m128i *)&b_ptr[b_stride]));
375 // Zero-extend mask to 16 bits
376 const __m128i m = _mm_unpacklo_epi8(
377 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(const int *)m_ptr),
378 _mm_cvtsi32_si128(*(const int *)&m_ptr[m_stride])),
379 _mm_setzero_si128());
380 const __m128i m_inv = _mm_sub_epi16(mask_max, m);
381
382 const __m128i data_l = _mm_unpacklo_epi16(a, b);
383 const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
384 __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
385 pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
386 AOM_BLEND_A64_ROUND_BITS);
387
388 const __m128i data_r = _mm_unpackhi_epi16(a, b);
389 const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
390 __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
391 pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
392 AOM_BLEND_A64_ROUND_BITS);
393
394 const __m128i pred = _mm_packs_epi32(pred_l, pred_r);
395 const __m128i diff = _mm_abs_epi16(_mm_sub_epi16(pred, src));
396 res = _mm_add_epi32(res, _mm_madd_epi16(diff, one));
397
398 src_ptr += src_stride * 2;
399 a_ptr += a_stride * 2;
400 b_ptr += b_stride * 2;
401 m_ptr += m_stride * 2;
402 }
403 res = _mm_hadd_epi32(res, res);
404 res = _mm_hadd_epi32(res, res);
405 int sad = _mm_cvtsi128_si32(res);
406 return sad;
407 }
408 #endif // CONFIG_AV1_HIGHBITDEPTH
409